resolved merge conflicts

This commit is contained in:
Niharika Goulikar
2024-10-18 10:31:53 +00:00
parent 5854202f22
commit f4abed43ba
37 changed files with 1005 additions and 386 deletions

View File

@@ -292,6 +292,7 @@ class Stream(Resource):
def post(self):
data = request.get_json()
required_fields = ["question"]
missing_fields = check_required_fields(data, required_fields)
if missing_fields:
return missing_fields
@@ -422,7 +423,7 @@ class Answer(Resource):
@api.doc(description="Provide an answer based on the question and retriever")
def post(self):
data = request.get_json()
required_fields = ["question"]
required_fields = ["question"]
missing_fields = check_required_fields(data, required_fields)
if missing_fields:
return missing_fields

View File

@@ -7,7 +7,7 @@ from bson.binary import Binary, UuidRepresentation
from bson.dbref import DBRef
from bson.objectid import ObjectId
from flask import Blueprint, jsonify, make_response, request
from flask_restx import fields, Namespace, Resource
from flask_restx import inputs, fields, Namespace, Resource
from pymongo import MongoClient
from werkzeug.utils import secure_filename
@@ -802,7 +802,7 @@ class ShareConversation(Resource):
if missing_fields:
return missing_fields
is_promptable = request.args.get("isPromptable")
is_promptable = request.args.get("isPromptable", type=inputs.boolean)
if is_promptable is None:
return make_response(
jsonify({"success": False, "message": "isPromptable is required"}), 400
@@ -831,7 +831,7 @@ class ShareConversation(Resource):
uuid.uuid4(), UuidRepresentation.STANDARD
)
if is_promptable.lower() == "true":
if is_promptable:
prompt_id = data.get("prompt_id", "default")
chunks = data.get("chunks", "2")
@@ -859,7 +859,7 @@ class ShareConversation(Resource):
"conversation_id": DBRef(
"conversations", ObjectId(conversation_id)
),
"isPromptable": is_promptable.lower() == "true",
"isPromptable": is_promptable,
"first_n_queries": current_n_queries,
"user": user,
"api_key": api_uuid,
@@ -883,7 +883,7 @@ class ShareConversation(Resource):
"$ref": "conversations",
"$id": ObjectId(conversation_id),
},
"isPromptable": is_promptable.lower() == "true",
"isPromptable": is_promptable,
"first_n_queries": current_n_queries,
"user": user,
"api_key": api_uuid,
@@ -918,7 +918,7 @@ class ShareConversation(Resource):
"$ref": "conversations",
"$id": ObjectId(conversation_id),
},
"isPromptable": is_promptable.lower() == "true",
"isPromptable": is_promptable,
"first_n_queries": current_n_queries,
"user": user,
"api_key": api_uuid,
@@ -939,7 +939,7 @@ class ShareConversation(Resource):
"conversation_id": DBRef(
"conversations", ObjectId(conversation_id)
),
"isPromptable": is_promptable.lower() == "false",
"isPromptable": not is_promptable,
"first_n_queries": current_n_queries,
"user": user,
}
@@ -962,7 +962,7 @@ class ShareConversation(Resource):
"$ref": "conversations",
"$id": ObjectId(conversation_id),
},
"isPromptable": is_promptable.lower() == "false",
"isPromptable": not is_promptable,
"first_n_queries": current_n_queries,
"user": user,
}

93
application/cache.py Normal file
View File

@@ -0,0 +1,93 @@
import redis
import time
import json
import logging
from threading import Lock
from application.core.settings import settings
from application.utils import get_hash
logger = logging.getLogger(__name__)
_redis_instance = None
_instance_lock = Lock()
def get_redis_instance():
global _redis_instance
if _redis_instance is None:
with _instance_lock:
if _redis_instance is None:
try:
_redis_instance = redis.Redis.from_url(settings.CACHE_REDIS_URL, socket_connect_timeout=2)
except redis.ConnectionError as e:
logger.error(f"Redis connection error: {e}")
_redis_instance = None
return _redis_instance
def gen_cache_key(*messages, model="docgpt"):
if not all(isinstance(msg, dict) for msg in messages):
raise ValueError("All messages must be dictionaries.")
messages_str = json.dumps(list(messages), sort_keys=True)
combined = f"{model}_{messages_str}"
cache_key = get_hash(combined)
return cache_key
def gen_cache(func):
def wrapper(self, model, messages, *args, **kwargs):
try:
cache_key = gen_cache_key(*messages)
redis_client = get_redis_instance()
if redis_client:
try:
cached_response = redis_client.get(cache_key)
if cached_response:
return cached_response.decode('utf-8')
except redis.ConnectionError as e:
logger.error(f"Redis connection error: {e}")
result = func(self, model, messages, *args, **kwargs)
if redis_client:
try:
redis_client.set(cache_key, result, ex=1800)
except redis.ConnectionError as e:
logger.error(f"Redis connection error: {e}")
return result
except ValueError as e:
logger.error(e)
return "Error: No user message found in the conversation to generate a cache key."
return wrapper
def stream_cache(func):
def wrapper(self, model, messages, stream, *args, **kwargs):
cache_key = gen_cache_key(*messages)
logger.info(f"Stream cache key: {cache_key}")
redis_client = get_redis_instance()
if redis_client:
try:
cached_response = redis_client.get(cache_key)
if cached_response:
logger.info(f"Cache hit for stream key: {cache_key}")
cached_response = json.loads(cached_response.decode('utf-8'))
for chunk in cached_response:
yield chunk
time.sleep(0.03)
return
except redis.ConnectionError as e:
logger.error(f"Redis connection error: {e}")
result = func(self, model, messages, stream, *args, **kwargs)
stream_cache_data = []
for chunk in result:
stream_cache_data.append(chunk)
yield chunk
if redis_client:
try:
redis_client.set(cache_key, json.dumps(stream_cache_data), ex=1800)
logger.info(f"Stream cache saved for key: {cache_key}")
except redis.ConnectionError as e:
logger.error(f"Redis connection error: {e}")
return wrapper

View File

@@ -21,6 +21,9 @@ class Settings(BaseSettings):
VECTOR_STORE: str = "faiss" # "faiss" or "elasticsearch" or "qdrant" or "milvus"
RETRIEVERS_ENABLED: list = ["classic_rag", "duckduck_search"] # also brave_search
# LLM Cache
CACHE_REDIS_URL: str = "redis://localhost:6379/2"
API_URL: str = "http://localhost:7091" # backend url for celery worker
API_KEY: Optional[str] = None # LLM api key

View File

@@ -1,28 +1,29 @@
from abc import ABC, abstractmethod
from application.usage import gen_token_usage, stream_token_usage
from application.cache import stream_cache, gen_cache
class BaseLLM(ABC):
def __init__(self):
self.token_usage = {"prompt_tokens": 0, "generated_tokens": 0}
def _apply_decorator(self, method, decorator, *args, **kwargs):
return decorator(method, *args, **kwargs)
def _apply_decorator(self, method, decorators, *args, **kwargs):
for decorator in decorators:
method = decorator(method)
return method(self, *args, **kwargs)
@abstractmethod
def _raw_gen(self, model, messages, stream, *args, **kwargs):
pass
def gen(self, model, messages, stream=False, *args, **kwargs):
return self._apply_decorator(self._raw_gen, gen_token_usage)(
self, model=model, messages=messages, stream=stream, *args, **kwargs
)
decorators = [gen_token_usage, gen_cache]
return self._apply_decorator(self._raw_gen, decorators=decorators, model=model, messages=messages, stream=stream, *args, **kwargs)
@abstractmethod
def _raw_gen_stream(self, model, messages, stream, *args, **kwargs):
pass
def gen_stream(self, model, messages, stream=True, *args, **kwargs):
return self._apply_decorator(self._raw_gen_stream, stream_token_usage)(
self, model=model, messages=messages, stream=stream, *args, **kwargs
)
decorators = [stream_cache, stream_token_usage]
return self._apply_decorator(self._raw_gen_stream, decorators=decorators, model=model, messages=messages, stream=stream, *args, **kwargs)

View File

@@ -4,7 +4,7 @@ beautifulsoup4==4.12.3
celery==5.3.6
dataclasses-json==0.6.7
docx2txt==0.8
duckduckgo-search==6.2.6
duckduckgo-search==6.3.0
ebooklib==0.18
elastic-transport==8.15.0
elasticsearch==8.15.1
@@ -54,7 +54,7 @@ pathable==0.4.3
pillow==10.4.0
portalocker==2.10.1
prance==23.6.21.0
primp==0.6.2
primp==0.6.3
prompt-toolkit==3.0.47
protobuf==5.28.2
py==1.11.0

View File

@@ -1,6 +1,8 @@
import tiktoken
import hashlib
from flask import jsonify, make_response
_encoding = None
@@ -39,3 +41,8 @@ def check_required_fields(data, required_fields):
400,
)
return None
def get_hash(data):
return hashlib.md5(data.encode()).hexdigest()