diff --git a/application/cache.py b/application/cache.py index 9f8cc896..db44a1e6 100644 --- a/application/cache.py +++ b/application/cache.py @@ -6,23 +6,53 @@ from application.core.settings import settings from application.utils import get_hash -# Initialize Redis client -redis_client = redis.Redis( - host=settings.REDIS_HOST, - port=settings.REDIS_PORT, - db=settings.REDIS_DB, -) +def make_redis(): + """ + Initialize a Redis client using the settings provided in the application settings. -def gen_cache_key(messages): + Returns: + redis.Redis: A Redis client instance. """ - Generate a unique cache key using the latest user prompt. + return redis.Redis( + host=settings.REDIS_HOST, + port=settings.REDIS_PORT, + db=settings.REDIS_DB, + ) + +def gen_cache_key(messages, model="docgpt"): """ - latest_user_prompt = next((msg['content'] for msg in reversed(messages) if msg['role'] == 'user'), None) + Generate a unique cache key based on the latest user message and model. + + This function extracts the content of the latest user message from the `messages` + list and combines it with the model name to generate a unique cache key using a hash function. + This key can be used for caching responses in the system. + + Args: + messages (list): A list of dictionaries representing the conversation messages. + Each dictionary should contain at least a 'content' field and a 'role' field. + model (str, optional): The model name or identifier. Defaults to "docgpt". + + Raises: + ValueError: I3messages are provided. + ValueError: If `messages` is not a list. + ValueError: If no user message is found in the conversation. + + Returns: + str: A unique cache key generated by hashing the combined model name and latest user message. + """ + + if not messages: + raise ValueError("No messages found in the conversation to generate a cache key.") + if not isinstance(messages, list): + raise ValueError("Messages must be a list of dictionaries.") + + latest_user_prompt = next((msg['content'] for msg in reversed(messages) if msg.get('role') == 'user'), None) if latest_user_prompt is None: raise ValueError("No user message found in the conversation to generate a cache key.") - cache_key = get_hash(latest_user_prompt) - return cache_key + combined = f"{model}_{latest_user_prompt}" + cache_key = get_hash(combined) + return cache_key def gen_cache(func): """ @@ -36,16 +66,22 @@ def gen_cache(func): Returns: function: The wrapped function that handles caching and LLM response generation. """ + def wrapper(self, model, messages, *args, **kwargs): - cache_key = gen_cache_key(messages=messages) - cached_response = redis_client.get(cache_key) - if cached_response: - print(f"Cache hit for key: {cache_key}") - return cached_response.decode('utf-8') - result = func(self, model, messages, *args, **kwargs) - redis_client.set(cache_key, result, ex=3600) - print(f"Cache saved for key: {cache_key}") - return result + try: + cache_key = gen_cache_key(messages=messages) + redis_client = make_redis() + cached_response = redis_client.get(cache_key) + if cached_response: + print(f"Cache hit for key: {cache_key}") + return cached_response.decode('utf-8') + result = func(self, model, messages, *args, **kwargs) + redis_client.set(cache_key, result, ex=3600) + print(f"Cache saved for key: {cache_key}") + return result + except ValueError as e: + print(e) + return "Error: No user message found in the conversation to generate a cache key." return wrapper def stream_cache(func): @@ -61,33 +97,37 @@ def stream_cache(func): Returns: function: The wrapped function that handles caching and streaming LLM responses. + (self._raw_gen, decorators=decorators, model=model, messages=messages, stream=stream, *args, **kwargs """ - def wrapper(self, model, messages, *args, **kwargs): + def wrapper(self, model, messages, stream, *args, **kwargs): cache_key = gen_cache_key(messages=messages) - # we are using lrange and rpush to simulate streaming - cached_response = redis_client.lrange(cache_key, 0, -1) - if cached_response: - print(f"Cache hit for stream key: {cache_key}") - for chunk in cached_response: - print(f"Streaming cached chunk: {chunk.decode('utf-8')}") - yield chunk.decode('utf-8') - # need to slow down the response to simulate streaming - # because the cached response is instantaneous - # and redis is using in-memory storage - time.sleep(0.07) - return + try: + # we are using lrange and rpush to simulate streaming + redis_client = make_redis() + cached_response = redis_client.lrange(cache_key, 0, -1) + if cached_response: + #print(f"Cache hit for stream key: {cache_key}") + for chunk in cached_response: + yield chunk.decode('utf-8') + # need to slow down the response to simulate streaming + # because the cached response is instantaneous + # and redis is using in-memory storage + time.sleep(0.07) + return - result = func(self, model, messages, *args, **kwargs) - - for chunk in result: - print(f"Streaming live chunk: {chunk}") - redis_client.rpush(cache_key, chunk) - yield chunk - - # expire the cache after 30 minutes - redis_client.expire(cache_key, 1800) - print(f"Stream cache saved for key: {cache_key}") + result = func(self, model, messages, stream, *args, **kwargs) + + for chunk in result: + redis_client.rpush(cache_key, chunk) + yield chunk + + # expire the cache after 30 minutes + redis_client.expire(cache_key, 1800) + print(f"Stream cache saved for key: {cache_key}") + except ValueError as e: + print(e) + yield "Error: No user message found in the conversation to generate a cache key." return wrapper