From 3c43b87e9f16d431efaac64e7fec0ad2ba470116 Mon Sep 17 00:00:00 2001 From: fadingNA Date: Sat, 12 Oct 2024 15:25:06 -0400 Subject: [PATCH] adding time for perform streaming response --- application/cache.py | 76 ++++++++++++++++++-------------------------- 1 file changed, 31 insertions(+), 45 deletions(-) diff --git a/application/cache.py b/application/cache.py index 7ea363e8..9f8cc896 100644 --- a/application/cache.py +++ b/application/cache.py @@ -1,5 +1,6 @@ import sys import redis +import time from datetime import datetime from application.core.settings import settings from application.utils import get_hash @@ -12,23 +13,14 @@ redis_client = redis.Redis( db=settings.REDIS_DB, ) -def gen_cache_key(model, *args): +def gen_cache_key(messages): """ - Generate a unique cache key using the model and input arguments. - - Args: - model (str): The name or identifier of the LLM model being used. - *args: Additional arguments that should contribute to the uniqueness of the cache key. - - Returns: - str: A unique cache key generated by hashing the model name and arguments. + Generate a unique cache key using the latest user prompt. """ - # Combine the model name and args into a single string to ensure uniqueness - key_base = f"{model}_" + "_".join([str(arg) for arg in args]) - - # Use the get_hash utility to hash the key for consistent length and uniqueness - cache_key = get_hash(key_base) - + latest_user_prompt = next((msg['content'] for msg in reversed(messages) if msg['role'] == 'user'), None) + if latest_user_prompt is None: + raise ValueError("No user message found in the conversation to generate a cache key.") + cache_key = get_hash(latest_user_prompt) return cache_key @@ -39,34 +31,23 @@ def gen_cache(func): This decorator first checks if a response is cached for the given input (model and messages). If a cached response is found, it returns that. If not, it generates the response, caches it, and returns the generated response. - Args: func (function): The function to be decorated. - Returns: function: The wrapped function that handles caching and LLM response generation. """ def wrapper(self, model, messages, *args, **kwargs): - # Generate a cache key based on the model and message contents - cache_key = gen_cache_key(model, *[msg['content'] for msg in messages]) - - # Check for cached response + cache_key = gen_cache_key(messages=messages) cached_response = redis_client.get(cache_key) if cached_response: print(f"Cache hit for key: {cache_key}") - return cached_response.decode('utf-8') # Redis stores bytes, so decode to string - - # No cached response, generate the LLM result + return cached_response.decode('utf-8') result = func(self, model, messages, *args, **kwargs) - - # Cache the result for future use (expires in 3600 seconds = 1 hour) redis_client.set(cache_key, result, ex=3600) print(f"Cache saved for key: {cache_key}") - return result return wrapper - def stream_cache(func): """ Decorator to cache the streamed response of an LLM function. @@ -74,34 +55,39 @@ def stream_cache(func): This decorator first checks if a streamed response is cached for the given input (model and messages). If a cached response is found, it yields that. If not, it streams the response, caches it, and then yields the response. - + Args: func (function): The function to be decorated. - + Returns: function: The wrapped function that handles caching and streaming LLM responses. """ def wrapper(self, model, messages, *args, **kwargs): - # Generate a cache key based on the model and message contents - cache_key = gen_cache_key(model, *[msg['content'] for msg in messages]) + cache_key = gen_cache_key(messages=messages) - # Check for cached streamed response - cached_response = redis_client.get(cache_key) + # we are using lrange and rpush to simulate streaming + cached_response = redis_client.lrange(cache_key, 0, -1) if cached_response: print(f"Cache hit for stream key: {cache_key}") - # Yield the cached response in chunks (split by a delimiter if necessary) - yield cached_response.decode('utf-8') + for chunk in cached_response: + print(f"Streaming cached chunk: {chunk.decode('utf-8')}") + yield chunk.decode('utf-8') + # need to slow down the response to simulate streaming + # because the cached response is instantaneous + # and redis is using in-memory storage + time.sleep(0.07) return - # No cached response, proceed with streaming the response - batch = [] result = func(self, model, messages, *args, **kwargs) + for chunk in result: - batch.append(chunk) - yield chunk # Yield each chunk of the response to the caller - - # After streaming is complete, save the full response to the cache - full_response = ''.join(batch) # Join chunks into a full response - redis_client.set(cache_key, full_response, ex=3600) + print(f"Streaming live chunk: {chunk}") + redis_client.rpush(cache_key, chunk) + yield chunk + + # expire the cache after 30 minutes + redis_client.expire(cache_key, 1800) print(f"Stream cache saved for key: {cache_key}") - return wrapper \ No newline at end of file + + return wrapper +