fix: python lint errors

This commit is contained in:
Siddhant Rai
2024-12-19 10:06:06 +05:30
parent c3f538c2f6
commit daa332aa20
4 changed files with 33 additions and 23 deletions

View File

@@ -1,29 +1,34 @@
import redis
import time
import json
import logging
import time
from threading import Lock
import redis
from application.core.settings import settings
from application.utils import get_hash
import sys
logger = logging.getLogger(__name__)
_redis_instance = None
_instance_lock = Lock()
def get_redis_instance():
global _redis_instance
if _redis_instance is None:
with _instance_lock:
if _redis_instance is None:
try:
_redis_instance = redis.Redis.from_url(settings.CACHE_REDIS_URL, socket_connect_timeout=2)
_redis_instance = redis.Redis.from_url(
settings.CACHE_REDIS_URL, socket_connect_timeout=2
)
except redis.ConnectionError as e:
logger.error(f"Redis connection error: {e}")
_redis_instance = None
return _redis_instance
def gen_cache_key(messages, model="docgpt", tools=None):
if not all(isinstance(msg, dict) for msg in messages):
raise ValueError("All messages must be dictionaries.")
@@ -33,6 +38,7 @@ def gen_cache_key(messages, model="docgpt", tools=None):
cache_key = get_hash(combined)
return cache_key
def gen_cache(func):
def wrapper(self, model, messages, stream, tools=None, *args, **kwargs):
try:
@@ -42,7 +48,7 @@ def gen_cache(func):
try:
cached_response = redis_client.get(cache_key)
if cached_response:
return cached_response.decode('utf-8')
return cached_response.decode("utf-8")
except redis.ConnectionError as e:
logger.error(f"Redis connection error: {e}")
@@ -57,20 +63,22 @@ def gen_cache(func):
except ValueError as e:
logger.error(e)
return "Error: No user message found in the conversation to generate a cache key."
return wrapper
def stream_cache(func):
def wrapper(self, model, messages, stream, *args, **kwargs):
cache_key = gen_cache_key(messages)
logger.info(f"Stream cache key: {cache_key}")
redis_client = get_redis_instance()
if redis_client:
try:
cached_response = redis_client.get(cache_key)
if cached_response:
logger.info(f"Cache hit for stream key: {cache_key}")
cached_response = json.loads(cached_response.decode('utf-8'))
cached_response = json.loads(cached_response.decode("utf-8"))
for chunk in cached_response:
yield chunk
time.sleep(0.03)
@@ -80,16 +88,16 @@ def stream_cache(func):
result = func(self, model, messages, stream, *args, **kwargs)
stream_cache_data = []
for chunk in result:
stream_cache_data.append(chunk)
yield chunk
if redis_client:
try:
redis_client.set(cache_key, json.dumps(stream_cache_data), ex=1800)
logger.info(f"Stream cache saved for key: {cache_key}")
except redis.ConnectionError as e:
logger.error(f"Redis connection error: {e}")
return wrapper
return wrapper

View File

@@ -1,10 +1,9 @@
from application.retriever.base import BaseRetriever
from application.core.settings import settings
from application.vectorstore.vector_creator import VectorCreator
from application.llm.llm_creator import LLMCreator
from application.retriever.base import BaseRetriever
from application.tools.agent import Agent
from application.utils import num_tokens_from_string
from application.vectorstore.vector_creator import VectorCreator
class ClassicRAG(BaseRetriever):
@@ -21,7 +20,7 @@ class ClassicRAG(BaseRetriever):
user_api_key=None,
):
self.question = question
self.vectorstore = source['active_docs'] if 'active_docs' in source else None
self.vectorstore = source["active_docs"] if "active_docs" in source else None
self.chat_history = chat_history
self.prompt = prompt
self.chunks = chunks
@@ -78,9 +77,9 @@ class ClassicRAG(BaseRetriever):
# count tokens in history
for i in self.chat_history:
if "prompt" in i and "response" in i:
tokens_batch = num_tokens_from_string(i["prompt"]) + num_tokens_from_string(
i["response"]
)
tokens_batch = num_tokens_from_string(
i["prompt"]
) + num_tokens_from_string(i["response"])
if tokens_current_history + tokens_batch < self.token_limit:
tokens_current_history += tokens_batch
messages_combine.append(
@@ -95,14 +94,19 @@ class ClassicRAG(BaseRetriever):
# settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=self.user_api_key
# )
# completion = llm.gen_stream(model=self.gpt_model, messages=messages_combine)
agent = Agent(llm_name=settings.LLM_NAME,gpt_model=self.gpt_model, api_key=settings.API_KEY, user_api_key=self.user_api_key)
agent = Agent(
llm_name=settings.LLM_NAME,
gpt_model=self.gpt_model,
api_key=settings.API_KEY,
user_api_key=self.user_api_key,
)
completion = agent.gen(messages_combine)
for line in completion:
yield {"answer": str(line)}
def search(self):
return self._get_data()
def get_params(self):
return {
"question": self.question,
@@ -112,5 +116,5 @@ class ClassicRAG(BaseRetriever):
"chunks": self.chunks,
"token_limit": self.token_limit,
"gpt_model": self.gpt_model,
"user_api_key": self.user_api_key
"user_api_key": self.user_api_key,
}

View File

@@ -1,7 +1,6 @@
import json
from application.core.mongo_db import MongoDB
from application.core.settings import settings
from application.llm.llm_creator import LLMCreator
from application.tools.tool_manager import ToolManager

View File

@@ -27,7 +27,6 @@ class ToolManager:
def load_tool(self, tool_name, tool_config):
self.config[tool_name] = tool_config
tools_dir = os.path.join(os.path.dirname(__file__), "implementations")
module = importlib.import_module(
f"application.tools.implementations.{tool_name}"
)