mirror of
https://github.com/arc53/DocsGPT.git
synced 2025-11-29 16:43:16 +00:00
fix: add singleton, logging, connection handle
This commit is contained in:
@@ -1,131 +1,93 @@
|
||||
import redis
|
||||
import time
|
||||
import json
|
||||
import logging
|
||||
from threading import Lock
|
||||
from application.core.settings import settings
|
||||
from application.utils import get_hash
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def make_redis():
|
||||
"""
|
||||
Initialize a Redis client using the settings provided in the application settings.
|
||||
_redis_instance = None
|
||||
_instance_lock = Lock()
|
||||
|
||||
Returns:
|
||||
redis.Redis: A Redis client instance.
|
||||
"""
|
||||
return redis.Redis(
|
||||
host=settings.REDIS_HOST,
|
||||
port=settings.REDIS_PORT,
|
||||
db=settings.REDIS_DB,
|
||||
)
|
||||
def get_redis_instance():
|
||||
global _redis_instance
|
||||
if _redis_instance is None:
|
||||
with _instance_lock:
|
||||
if _redis_instance is None:
|
||||
try:
|
||||
_redis_instance = redis.Redis.from_url(settings.CACHE_REDIS_URL, socket_connect_timeout=2)
|
||||
except redis.ConnectionError as e:
|
||||
logger.error(f"Redis connection error: {e}")
|
||||
_redis_instance = None
|
||||
return _redis_instance
|
||||
|
||||
def gen_cache_key(*messages, model="docgpt"):
|
||||
"""
|
||||
Generate a unique cache key based on the latest user message and model.
|
||||
|
||||
This function extracts the content of the latest user message from the `messages`
|
||||
list and combines it with the model name to generate a unique cache key using a hash function.
|
||||
This key can be used for caching responses in the system.
|
||||
|
||||
Args:
|
||||
messages (list): A list of dictionaries representing the conversation messages.
|
||||
Each dictionary should contain at least a 'content' field and a 'role' field.
|
||||
model (str, optional): The model name or identifier. Defaults to "docgpt".
|
||||
|
||||
Raises:
|
||||
ValueError: I3messages are provided.
|
||||
ValueError: If `messages` is not a list.
|
||||
ValueError: If no user message is found in the conversation.
|
||||
|
||||
Returns:
|
||||
str: A unique cache key generated by hashing the combined model name and latest user message.
|
||||
"""
|
||||
if not all(isinstance(msg, dict) for msg in messages):
|
||||
raise ValueError("All messages must be dictionaries.")
|
||||
|
||||
messages_str = json.dumps(list(messages), sort_keys=True)
|
||||
combined = f"{model}_{messages_str}"
|
||||
cache_key = get_hash(combined)
|
||||
return cache_key
|
||||
|
||||
def gen_cache(func):
|
||||
"""
|
||||
Decorator to cache the response of a function that generates a response using an LLM.
|
||||
|
||||
This decorator first checks if a response is cached for the given input (model and messages).
|
||||
If a cached response is found, it returns that. If not, it generates the response,
|
||||
caches it, and returns the generated response.
|
||||
Args:
|
||||
func (function): The function to be decorated.
|
||||
Returns:
|
||||
function: The wrapped function that handles caching and LLM response generation.
|
||||
"""
|
||||
|
||||
def wrapper(self, model, messages, *args, **kwargs):
|
||||
try:
|
||||
cache_key = gen_cache_key(*messages)
|
||||
redis_client = make_redis()
|
||||
cached_response = redis_client.get(cache_key)
|
||||
redis_client = get_redis_instance()
|
||||
if redis_client:
|
||||
try:
|
||||
cached_response = redis_client.get(cache_key)
|
||||
if cached_response:
|
||||
return cached_response.decode('utf-8')
|
||||
except redis.ConnectionError as e:
|
||||
logger.error(f"Redis connection error: {e}")
|
||||
|
||||
if cached_response:
|
||||
return cached_response.decode('utf-8')
|
||||
|
||||
result = func(self, model, messages, *args, **kwargs)
|
||||
|
||||
# expire the cache after 30 minutes
|
||||
# set time in seconds
|
||||
redis_client.set(cache_key, result, ex=1800)
|
||||
if redis_client:
|
||||
try:
|
||||
redis_client.set(cache_key, result, ex=1800)
|
||||
except redis.ConnectionError as e:
|
||||
logger.error(f"Redis connection error: {e}")
|
||||
|
||||
return result
|
||||
except ValueError as e:
|
||||
print(e)
|
||||
logger.error(e)
|
||||
return "Error: No user message found in the conversation to generate a cache key."
|
||||
return wrapper
|
||||
|
||||
def stream_cache(func):
|
||||
"""
|
||||
Decorator to cache the streamed response of an LLM function.
|
||||
|
||||
This decorator first checks if a streamed response is cached for the given input (model and messages).
|
||||
If a cached response is found, it yields that. If not, it streams the response, caches it,
|
||||
and then yields the response.
|
||||
|
||||
Args:
|
||||
func (function): The function to be decorated.
|
||||
|
||||
Returns:
|
||||
function: The wrapped function that handles caching and streaming LLM responses.
|
||||
(self._raw_gen, decorators=decorators, model=model, messages=messages, stream=stream, *args, **kwargs
|
||||
"""
|
||||
def wrapper(self, model, messages, stream, *args, **kwargs):
|
||||
cache_key = gen_cache_key(*messages)
|
||||
logger.info(f"Stream cache key: {cache_key}")
|
||||
|
||||
redis_client = get_redis_instance()
|
||||
if redis_client:
|
||||
try:
|
||||
cached_response = redis_client.get(cache_key)
|
||||
if cached_response:
|
||||
logger.info(f"Cache hit for stream key: {cache_key}")
|
||||
cached_response = json.loads(cached_response.decode('utf-8'))
|
||||
for chunk in cached_response:
|
||||
yield chunk
|
||||
time.sleep(0.03)
|
||||
return
|
||||
except redis.ConnectionError as e:
|
||||
logger.error(f"Redis connection error: {e}")
|
||||
|
||||
try:
|
||||
# we are using lrange and rpush to simulate streaming
|
||||
redis_client = make_redis()
|
||||
cached_response = redis_client.get(cache_key)
|
||||
if cached_response:
|
||||
print(f"Cache hit for stream key: {cache_key}")
|
||||
cached_response = json.loads(cached_response.decode('utf-8'))
|
||||
for chunk in cached_response:
|
||||
yield chunk
|
||||
# need to slow down the response to simulate streaming
|
||||
# because the cached response is instantaneous
|
||||
# and redis is using in-memory storage
|
||||
time.sleep(0.07)
|
||||
return
|
||||
|
||||
result = func(self, model, messages, stream, *args, **kwargs)
|
||||
stream_cache_data = []
|
||||
|
||||
for chunk in result:
|
||||
stream_cache_data.append(chunk)
|
||||
yield chunk
|
||||
|
||||
# expire the cache after 30 minutes
|
||||
redis_client.set(cache_key, json.dumps(stream_cache_data), ex=1800)
|
||||
print(f"Stream cache saved for key: {cache_key}")
|
||||
except ValueError as e:
|
||||
print(e)
|
||||
yield "Error: No user message found in the conversation to generate a cache key."
|
||||
result = func(self, model, messages, stream, *args, **kwargs)
|
||||
stream_cache_data = []
|
||||
|
||||
for chunk in result:
|
||||
stream_cache_data.append(chunk)
|
||||
yield chunk
|
||||
|
||||
if redis_client:
|
||||
try:
|
||||
redis_client.set(cache_key, json.dumps(stream_cache_data), ex=1800)
|
||||
logger.info(f"Stream cache saved for key: {cache_key}")
|
||||
except redis.ConnectionError as e:
|
||||
logger.error(f"Redis connection error: {e}")
|
||||
|
||||
return wrapper
|
||||
Reference in New Issue
Block a user