diff --git a/application/cache.py b/application/cache.py index 7239abac..76b594c9 100644 --- a/application/cache.py +++ b/application/cache.py @@ -1,29 +1,34 @@ -import redis -import time import json import logging +import time from threading import Lock + +import redis + from application.core.settings import settings from application.utils import get_hash -import sys logger = logging.getLogger(__name__) _redis_instance = None _instance_lock = Lock() + def get_redis_instance(): global _redis_instance if _redis_instance is None: with _instance_lock: if _redis_instance is None: try: - _redis_instance = redis.Redis.from_url(settings.CACHE_REDIS_URL, socket_connect_timeout=2) + _redis_instance = redis.Redis.from_url( + settings.CACHE_REDIS_URL, socket_connect_timeout=2 + ) except redis.ConnectionError as e: logger.error(f"Redis connection error: {e}") _redis_instance = None return _redis_instance + def gen_cache_key(messages, model="docgpt", tools=None): if not all(isinstance(msg, dict) for msg in messages): raise ValueError("All messages must be dictionaries.") @@ -33,6 +38,7 @@ def gen_cache_key(messages, model="docgpt", tools=None): cache_key = get_hash(combined) return cache_key + def gen_cache(func): def wrapper(self, model, messages, stream, tools=None, *args, **kwargs): try: @@ -42,7 +48,7 @@ def gen_cache(func): try: cached_response = redis_client.get(cache_key) if cached_response: - return cached_response.decode('utf-8') + return cached_response.decode("utf-8") except redis.ConnectionError as e: logger.error(f"Redis connection error: {e}") @@ -57,20 +63,22 @@ def gen_cache(func): except ValueError as e: logger.error(e) return "Error: No user message found in the conversation to generate a cache key." + return wrapper + def stream_cache(func): def wrapper(self, model, messages, stream, *args, **kwargs): cache_key = gen_cache_key(messages) logger.info(f"Stream cache key: {cache_key}") - + redis_client = get_redis_instance() if redis_client: try: cached_response = redis_client.get(cache_key) if cached_response: logger.info(f"Cache hit for stream key: {cache_key}") - cached_response = json.loads(cached_response.decode('utf-8')) + cached_response = json.loads(cached_response.decode("utf-8")) for chunk in cached_response: yield chunk time.sleep(0.03) @@ -80,16 +88,16 @@ def stream_cache(func): result = func(self, model, messages, stream, *args, **kwargs) stream_cache_data = [] - + for chunk in result: stream_cache_data.append(chunk) yield chunk - + if redis_client: try: redis_client.set(cache_key, json.dumps(stream_cache_data), ex=1800) logger.info(f"Stream cache saved for key: {cache_key}") except redis.ConnectionError as e: logger.error(f"Redis connection error: {e}") - - return wrapper \ No newline at end of file + + return wrapper diff --git a/application/retriever/classic_rag.py b/application/retriever/classic_rag.py index 4ac52bc5..81a5985b 100644 --- a/application/retriever/classic_rag.py +++ b/application/retriever/classic_rag.py @@ -1,10 +1,9 @@ -from application.retriever.base import BaseRetriever from application.core.settings import settings -from application.vectorstore.vector_creator import VectorCreator -from application.llm.llm_creator import LLMCreator +from application.retriever.base import BaseRetriever from application.tools.agent import Agent from application.utils import num_tokens_from_string +from application.vectorstore.vector_creator import VectorCreator class ClassicRAG(BaseRetriever): @@ -21,7 +20,7 @@ class ClassicRAG(BaseRetriever): user_api_key=None, ): self.question = question - self.vectorstore = source['active_docs'] if 'active_docs' in source else None + self.vectorstore = source["active_docs"] if "active_docs" in source else None self.chat_history = chat_history self.prompt = prompt self.chunks = chunks @@ -78,9 +77,9 @@ class ClassicRAG(BaseRetriever): # count tokens in history for i in self.chat_history: if "prompt" in i and "response" in i: - tokens_batch = num_tokens_from_string(i["prompt"]) + num_tokens_from_string( - i["response"] - ) + tokens_batch = num_tokens_from_string( + i["prompt"] + ) + num_tokens_from_string(i["response"]) if tokens_current_history + tokens_batch < self.token_limit: tokens_current_history += tokens_batch messages_combine.append( @@ -95,14 +94,19 @@ class ClassicRAG(BaseRetriever): # settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=self.user_api_key # ) # completion = llm.gen_stream(model=self.gpt_model, messages=messages_combine) - agent = Agent(llm_name=settings.LLM_NAME,gpt_model=self.gpt_model, api_key=settings.API_KEY, user_api_key=self.user_api_key) + agent = Agent( + llm_name=settings.LLM_NAME, + gpt_model=self.gpt_model, + api_key=settings.API_KEY, + user_api_key=self.user_api_key, + ) completion = agent.gen(messages_combine) for line in completion: yield {"answer": str(line)} def search(self): return self._get_data() - + def get_params(self): return { "question": self.question, @@ -112,5 +116,5 @@ class ClassicRAG(BaseRetriever): "chunks": self.chunks, "token_limit": self.token_limit, "gpt_model": self.gpt_model, - "user_api_key": self.user_api_key + "user_api_key": self.user_api_key, } diff --git a/application/tools/agent.py b/application/tools/agent.py index ffd14770..e02c40f7 100644 --- a/application/tools/agent.py +++ b/application/tools/agent.py @@ -1,7 +1,6 @@ import json from application.core.mongo_db import MongoDB -from application.core.settings import settings from application.llm.llm_creator import LLMCreator from application.tools.tool_manager import ToolManager diff --git a/application/tools/tool_manager.py b/application/tools/tool_manager.py index cc9a055a..3e0766cf 100644 --- a/application/tools/tool_manager.py +++ b/application/tools/tool_manager.py @@ -27,7 +27,6 @@ class ToolManager: def load_tool(self, tool_name, tool_config): self.config[tool_name] = tool_config - tools_dir = os.path.join(os.path.dirname(__file__), "implementations") module = importlib.import_module( f"application.tools.implementations.{tool_name}" )