diff --git a/application/cache.py b/application/cache.py index 1cc7b3f0..33022e45 100644 --- a/application/cache.py +++ b/application/cache.py @@ -1,131 +1,93 @@ import redis import time import json +import logging +from threading import Lock from application.core.settings import settings from application.utils import get_hash +logger = logging.getLogger(__name__) -def make_redis(): - """ - Initialize a Redis client using the settings provided in the application settings. +_redis_instance = None +_instance_lock = Lock() - Returns: - redis.Redis: A Redis client instance. - """ - return redis.Redis( - host=settings.REDIS_HOST, - port=settings.REDIS_PORT, - db=settings.REDIS_DB, - ) +def get_redis_instance(): + global _redis_instance + if _redis_instance is None: + with _instance_lock: + if _redis_instance is None: + try: + _redis_instance = redis.Redis.from_url(settings.CACHE_REDIS_URL, socket_connect_timeout=2) + except redis.ConnectionError as e: + logger.error(f"Redis connection error: {e}") + _redis_instance = None + return _redis_instance def gen_cache_key(*messages, model="docgpt"): - """ - Generate a unique cache key based on the latest user message and model. - - This function extracts the content of the latest user message from the `messages` - list and combines it with the model name to generate a unique cache key using a hash function. - This key can be used for caching responses in the system. - - Args: - messages (list): A list of dictionaries representing the conversation messages. - Each dictionary should contain at least a 'content' field and a 'role' field. - model (str, optional): The model name or identifier. Defaults to "docgpt". - - Raises: - ValueError: I3messages are provided. - ValueError: If `messages` is not a list. - ValueError: If no user message is found in the conversation. - - Returns: - str: A unique cache key generated by hashing the combined model name and latest user message. - """ if not all(isinstance(msg, dict) for msg in messages): raise ValueError("All messages must be dictionaries.") - messages_str = json.dumps(list(messages), sort_keys=True) combined = f"{model}_{messages_str}" cache_key = get_hash(combined) return cache_key def gen_cache(func): - """ - Decorator to cache the response of a function that generates a response using an LLM. - - This decorator first checks if a response is cached for the given input (model and messages). - If a cached response is found, it returns that. If not, it generates the response, - caches it, and returns the generated response. - Args: - func (function): The function to be decorated. - Returns: - function: The wrapped function that handles caching and LLM response generation. - """ - def wrapper(self, model, messages, *args, **kwargs): try: cache_key = gen_cache_key(*messages) - redis_client = make_redis() - cached_response = redis_client.get(cache_key) + redis_client = get_redis_instance() + if redis_client: + try: + cached_response = redis_client.get(cache_key) + if cached_response: + return cached_response.decode('utf-8') + except redis.ConnectionError as e: + logger.error(f"Redis connection error: {e}") - if cached_response: - return cached_response.decode('utf-8') - result = func(self, model, messages, *args, **kwargs) - - # expire the cache after 30 minutes - # set time in seconds - redis_client.set(cache_key, result, ex=1800) + if redis_client: + try: + redis_client.set(cache_key, result, ex=1800) + except redis.ConnectionError as e: + logger.error(f"Redis connection error: {e}") return result except ValueError as e: - print(e) + logger.error(e) return "Error: No user message found in the conversation to generate a cache key." return wrapper def stream_cache(func): - """ - Decorator to cache the streamed response of an LLM function. - - This decorator first checks if a streamed response is cached for the given input (model and messages). - If a cached response is found, it yields that. If not, it streams the response, caches it, - and then yields the response. - - Args: - func (function): The function to be decorated. - - Returns: - function: The wrapped function that handles caching and streaming LLM responses. - (self._raw_gen, decorators=decorators, model=model, messages=messages, stream=stream, *args, **kwargs - """ def wrapper(self, model, messages, stream, *args, **kwargs): cache_key = gen_cache_key(*messages) + logger.info(f"Stream cache key: {cache_key}") + + redis_client = get_redis_instance() + if redis_client: + try: + cached_response = redis_client.get(cache_key) + if cached_response: + logger.info(f"Cache hit for stream key: {cache_key}") + cached_response = json.loads(cached_response.decode('utf-8')) + for chunk in cached_response: + yield chunk + time.sleep(0.03) + return + except redis.ConnectionError as e: + logger.error(f"Redis connection error: {e}") - try: - # we are using lrange and rpush to simulate streaming - redis_client = make_redis() - cached_response = redis_client.get(cache_key) - if cached_response: - print(f"Cache hit for stream key: {cache_key}") - cached_response = json.loads(cached_response.decode('utf-8')) - for chunk in cached_response: - yield chunk - # need to slow down the response to simulate streaming - # because the cached response is instantaneous - # and redis is using in-memory storage - time.sleep(0.07) - return - - result = func(self, model, messages, stream, *args, **kwargs) - stream_cache_data = [] - - for chunk in result: - stream_cache_data.append(chunk) - yield chunk - - # expire the cache after 30 minutes - redis_client.set(cache_key, json.dumps(stream_cache_data), ex=1800) - print(f"Stream cache saved for key: {cache_key}") - except ValueError as e: - print(e) - yield "Error: No user message found in the conversation to generate a cache key." + result = func(self, model, messages, stream, *args, **kwargs) + stream_cache_data = [] + + for chunk in result: + stream_cache_data.append(chunk) + yield chunk + + if redis_client: + try: + redis_client.set(cache_key, json.dumps(stream_cache_data), ex=1800) + logger.info(f"Stream cache saved for key: {cache_key}") + except redis.ConnectionError as e: + logger.error(f"Redis connection error: {e}") return wrapper \ No newline at end of file diff --git a/application/core/settings.py b/application/core/settings.py index 0b81ac44..7346da08 100644 --- a/application/core/settings.py +++ b/application/core/settings.py @@ -22,9 +22,7 @@ class Settings(BaseSettings): RETRIEVERS_ENABLED: list = ["classic_rag", "duckduck_search"] # also brave_search # LLM Cache - REDIS_HOST: str = "localhost" - REDIS_PORT: int = 6379 - REDIS_DB: int = 0 + CACHE_REDIS_URL: str = "redis://localhost:6379/2" API_URL: str = "http://localhost:7091" # backend url for celery worker