diff --git a/application/agents/llm_handler.py b/application/agents/llm_handler.py index 9267dc53..a70357f8 100644 --- a/application/agents/llm_handler.py +++ b/application/agents/llm_handler.py @@ -72,9 +72,9 @@ class OpenAILLMHandler(LLMHandler): while True: tool_calls = {} for chunk in resp: - if isinstance(chunk, str): + if isinstance(chunk, str) and len(chunk) > 0: return - else: + elif hasattr(chunk, "delta"): chunk_delta = chunk.delta if ( @@ -113,6 +113,8 @@ class OpenAILLMHandler(LLMHandler): tool_response, call_id = agent._execute_tool_action( tools_dict, call ) + if isinstance(call["function"]["arguments"], str): + call["function"]["arguments"] = json.loads(call["function"]["arguments"]) function_call_dict = { "function_call": { @@ -156,6 +158,8 @@ class OpenAILLMHandler(LLMHandler): and chunk.finish_reason == "stop" ): return + elif isinstance(chunk, str) and len(chunk) == 0: + continue resp = agent.llm.gen_stream( model=agent.gpt_model, messages=messages, tools=agent.tools diff --git a/application/api/answer/routes.py b/application/api/answer/routes.py index 5a221e8d..7f88ba0f 100644 --- a/application/api/answer/routes.py +++ b/application/api/answer/routes.py @@ -42,6 +42,8 @@ elif settings.LLM_NAME == "anthropic": gpt_model = "claude-2" elif settings.LLM_NAME == "groq": gpt_model = "llama3-8b-8192" +elif settings.LLM_NAME == "novita": + gpt_model = "deepseek/deepseek-r1" if settings.MODEL_NAME: # in case there is particular model name configured gpt_model = settings.MODEL_NAME @@ -706,7 +708,6 @@ class Search(Resource): retriever = RetrieverCreator.create_retriever( retriever_name, - question=question, source=source, chat_history=[], prompt="default", @@ -716,7 +717,7 @@ class Search(Resource): user_api_key=user_api_key, ) - docs = retriever.search() + docs = retriever.search(question) retriever_params = retriever.get_params() user_logs_collection.insert_one( diff --git a/application/cache.py b/application/cache.py index 80dee4f4..117b444a 100644 --- a/application/cache.py +++ b/application/cache.py @@ -11,21 +11,25 @@ from application.utils import get_hash logger = logging.getLogger(__name__) _redis_instance = None +_redis_creation_failed = False _instance_lock = Lock() - def get_redis_instance(): - global _redis_instance - if _redis_instance is None: + global _redis_instance, _redis_creation_failed + if _redis_instance is None and not _redis_creation_failed: with _instance_lock: - if _redis_instance is None: + if _redis_instance is None and not _redis_creation_failed: try: _redis_instance = redis.Redis.from_url( settings.CACHE_REDIS_URL, socket_connect_timeout=2 ) + except ValueError as e: + logger.error(f"Invalid Redis URL: {e}") + _redis_creation_failed = True # Stop future attempts + _redis_instance = None except redis.ConnectionError as e: logger.error(f"Redis connection error: {e}") - _redis_instance = None + _redis_instance = None # Keep trying for connection errors return _redis_instance @@ -41,36 +45,48 @@ def gen_cache_key(messages, model="docgpt", tools=None): def gen_cache(func): def wrapper(self, model, messages, stream, tools=None, *args, **kwargs): + if tools is not None: + return func(self, model, messages, stream, tools, *args, **kwargs) + try: cache_key = gen_cache_key(messages, model, tools) - redis_client = get_redis_instance() - if redis_client: - try: - cached_response = redis_client.get(cache_key) - if cached_response: - return cached_response.decode("utf-8") - except redis.ConnectionError as e: - logger.error(f"Redis connection error: {e}") - - result = func(self, model, messages, stream, tools, *args, **kwargs) - if redis_client and isinstance(result, str): - try: - redis_client.set(cache_key, result, ex=1800) - except redis.ConnectionError as e: - logger.error(f"Redis connection error: {e}") - - return result except ValueError as e: - logger.error(e) - return "Error: No user message found in the conversation to generate a cache key." + logger.error(f"Cache key generation failed: {e}") + return func(self, model, messages, stream, tools, *args, **kwargs) + + redis_client = get_redis_instance() + if redis_client: + try: + cached_response = redis_client.get(cache_key) + if cached_response: + return cached_response.decode("utf-8") + except Exception as e: + logger.error(f"Error getting cached response: {e}") + + result = func(self, model, messages, stream, tools, *args, **kwargs) + if redis_client and isinstance(result, str): + try: + redis_client.set(cache_key, result, ex=1800) + except Exception as e: + logger.error(f"Error setting cache: {e}") + + return result return wrapper def stream_cache(func): def wrapper(self, model, messages, stream, tools=None, *args, **kwargs): - cache_key = gen_cache_key(messages, model, tools) - logger.info(f"Stream cache key: {cache_key}") + if tools is not None: + yield from func(self, model, messages, stream, tools, *args, **kwargs) + return + + try: + cache_key = gen_cache_key(messages, model, tools) + except ValueError as e: + logger.error(f"Cache key generation failed: {e}") + yield from func(self, model, messages, stream, tools, *args, **kwargs) + return redis_client = get_redis_instance() if redis_client: @@ -81,23 +97,21 @@ def stream_cache(func): cached_response = json.loads(cached_response.decode("utf-8")) for chunk in cached_response: yield chunk - time.sleep(0.03) + time.sleep(0.03) # Simulate streaming delay return - except redis.ConnectionError as e: - logger.error(f"Redis connection error: {e}") + except Exception as e: + logger.error(f"Error getting cached stream: {e}") - result = func(self, model, messages, stream, tools=tools, *args, **kwargs) stream_cache_data = [] - - for chunk in result: - stream_cache_data.append(chunk) + for chunk in func(self, model, messages, stream, tools, *args, **kwargs): yield chunk + stream_cache_data.append(str(chunk)) if redis_client: try: redis_client.set(cache_key, json.dumps(stream_cache_data), ex=1800) logger.info(f"Stream cache saved for key: {cache_key}") - except redis.ConnectionError as e: - logger.error(f"Redis connection error: {e}") + except Exception as e: + logger.error(f"Error setting stream cache: {e}") return wrapper diff --git a/application/llm/docsgpt_provider.py b/application/llm/docsgpt_provider.py index bb23d824..001035c4 100644 --- a/application/llm/docsgpt_provider.py +++ b/application/llm/docsgpt_provider.py @@ -1,34 +1,131 @@ -from application.llm.base import BaseLLM import json -import requests + +from application.core.settings import settings +from application.llm.base import BaseLLM class DocsGPTAPILLM(BaseLLM): def __init__(self, api_key=None, user_api_key=None, *args, **kwargs): + from openai import OpenAI + super().__init__(*args, **kwargs) - self.api_key = api_key + self.client = OpenAI(api_key="sk-docsgpt-public", base_url="https://oai.arc53.com") self.user_api_key = user_api_key - self.endpoint = "https://llm.arc53.com" + self.api_key = api_key - def _raw_gen(self, baseself, model, messages, stream=False, *args, **kwargs): - response = requests.post( - f"{self.endpoint}/answer", json={"messages": messages, "max_new_tokens": 30} - ) - response_clean = response.json()["a"].replace("###", "") + def _clean_messages_openai(self, messages): + cleaned_messages = [] + for message in messages: + role = message.get("role") + content = message.get("content") - return response_clean + if role == "model": + role = "assistant" - def _raw_gen_stream(self, baseself, model, messages, stream=True, *args, **kwargs): - response = requests.post( - f"{self.endpoint}/stream", - json={"messages": messages, "max_new_tokens": 256}, - stream=True, - ) + if role and content is not None: + if isinstance(content, str): + cleaned_messages.append({"role": role, "content": content}) + elif isinstance(content, list): + for item in content: + if "text" in item: + cleaned_messages.append( + {"role": role, "content": item["text"]} + ) + elif "function_call" in item: + tool_call = { + "id": item["function_call"]["call_id"], + "type": "function", + "function": { + "name": item["function_call"]["name"], + "arguments": json.dumps( + item["function_call"]["args"] + ), + }, + } + cleaned_messages.append( + { + "role": "assistant", + "content": None, + "tool_calls": [tool_call], + } + ) + elif "function_response" in item: + cleaned_messages.append( + { + "role": "tool", + "tool_call_id": item["function_response"][ + "call_id" + ], + "content": json.dumps( + item["function_response"]["response"]["result"] + ), + } + ) + else: + raise ValueError( + f"Unexpected content dictionary format: {item}" + ) + else: + raise ValueError(f"Unexpected content type: {type(content)}") - for line in response.iter_lines(): - if line: - data_str = line.decode("utf-8") - if data_str.startswith("data: "): - data = json.loads(data_str[6:]) - yield data["a"] + return cleaned_messages + + def _raw_gen( + self, + baseself, + model, + messages, + stream=False, + tools=None, + engine=settings.AZURE_DEPLOYMENT_NAME, + **kwargs, + ): + messages = self._clean_messages_openai(messages) + if tools: + response = self.client.chat.completions.create( + model="docsgpt", + messages=messages, + stream=stream, + tools=tools, + **kwargs, + ) + return response.choices[0] + else: + response = self.client.chat.completions.create( + model="docsgpt", messages=messages, stream=stream, **kwargs + ) + return response.choices[0].message.content + + def _raw_gen_stream( + self, + baseself, + model, + messages, + stream=True, + tools=None, + engine=settings.AZURE_DEPLOYMENT_NAME, + **kwargs, + ): + messages = self._clean_messages_openai(messages) + if tools: + response = self.client.chat.completions.create( + model="docsgpt", + messages=messages, + stream=stream, + tools=tools, + **kwargs, + ) + else: + response = self.client.chat.completions.create( + model="docsgpt", messages=messages, stream=stream, **kwargs + ) + + for line in response: + if len(line.choices) > 0 and line.choices[0].delta.content is not None and len(line.choices[0].delta.content) > 0: + yield line.choices[0].delta.content + elif len(line.choices) > 0: + yield line.choices[0] + + def _supports_tools(self): + return True \ No newline at end of file diff --git a/application/llm/llm_creator.py b/application/llm/llm_creator.py index f32089de..9f1305ba 100644 --- a/application/llm/llm_creator.py +++ b/application/llm/llm_creator.py @@ -7,7 +7,7 @@ from application.llm.anthropic import AnthropicLLM from application.llm.docsgpt_provider import DocsGPTAPILLM from application.llm.premai import PremAILLM from application.llm.google_ai import GoogleLLM - +from application.llm.novita import NovitaLLM class LLMCreator: llms = { @@ -20,7 +20,8 @@ class LLMCreator: "docsgpt": DocsGPTAPILLM, "premai": PremAILLM, "groq": GroqLLM, - "google": GoogleLLM + "google": GoogleLLM, + "novita": NovitaLLM } @classmethod diff --git a/application/llm/novita.py b/application/llm/novita.py new file mode 100644 index 00000000..8d6ac042 --- /dev/null +++ b/application/llm/novita.py @@ -0,0 +1,32 @@ +from application.llm.base import BaseLLM +from openai import OpenAI + + +class NovitaLLM(BaseLLM): + def __init__(self, api_key=None, user_api_key=None, *args, **kwargs): + super().__init__(*args, **kwargs) + self.client = OpenAI(api_key=api_key, base_url="https://api.novita.ai/v3/openai") + self.api_key = api_key + self.user_api_key = user_api_key + + def _raw_gen(self, baseself, model, messages, stream=False, tools=None, **kwargs): + if tools: + response = self.client.chat.completions.create( + model=model, messages=messages, stream=stream, tools=tools, **kwargs + ) + return response.choices[0] + else: + response = self.client.chat.completions.create( + model=model, messages=messages, stream=stream, **kwargs + ) + return response.choices[0].message.content + + def _raw_gen_stream( + self, baseself, model, messages, stream=True, tools=None, **kwargs + ): + response = self.client.chat.completions.create( + model=model, messages=messages, stream=stream, **kwargs + ) + for line in response: + if line.choices[0].delta.content is not None: + yield line.choices[0].delta.content diff --git a/application/llm/openai.py b/application/llm/openai.py index 938de523..f8a38ed0 100644 --- a/application/llm/openai.py +++ b/application/llm/openai.py @@ -125,9 +125,9 @@ class OpenAILLM(BaseLLM): ) for line in response: - if line.choices[0].delta.content is not None: + if len(line.choices) > 0 and line.choices[0].delta.content is not None and len(line.choices[0].delta.content) > 0: yield line.choices[0].delta.content - else: + elif len(line.choices) > 0: yield line.choices[0] def _supports_tools(self): diff --git a/setup.sh b/setup.sh index 7775e24e..31ed3e42 100755 --- a/setup.sh +++ b/setup.sh @@ -148,6 +148,7 @@ prompt_cloud_api_provider_options() { echo -e "${YELLOW}4) Groq${NC}" echo -e "${YELLOW}5) HuggingFace Inference API${NC}" echo -e "${YELLOW}6) Azure OpenAI${NC}" + echo -e "${YELLOW}7) Novita${NC}" echo -e "${YELLOW}b) Back to Main Menu${NC}" echo read -p "$(echo -e "${DEFAULT_FG}Choose option (1-6, or b): ${NC}")" provider_choice @@ -428,6 +429,12 @@ connect_cloud_api_provider() { model_name="gpt-4o" get_api_key break ;; + 7) # Novita + provider_name="Novita" + llm_name="novita" + model_name="deepseek/deepseek-r1" + get_api_key + break ;; b|B) clear; return ;; # Clear screen and Back to Main Menu *) echo -e "\n${RED}Invalid choice. Please choose 1-6, or b.${NC}" ; sleep 1 ;; esac diff --git a/tests/llm/test_anthropic.py b/tests/llm/test_anthropic.py index 50ddbe29..867b6923 100644 --- a/tests/llm/test_anthropic.py +++ b/tests/llm/test_anthropic.py @@ -64,7 +64,5 @@ class TestAnthropicLLM(unittest.TestCase): max_tokens_to_sample=300, stream=True ) - mock_redis_instance.set.assert_called_once() - if __name__ == "__main__": unittest.main()