diff --git a/application/cache.py b/application/cache.py new file mode 100644 index 00000000..7ea363e8 --- /dev/null +++ b/application/cache.py @@ -0,0 +1,107 @@ +import sys +import redis +from datetime import datetime +from application.core.settings import settings +from application.utils import get_hash + + +# Initialize Redis client +redis_client = redis.Redis( + host=settings.REDIS_HOST, + port=settings.REDIS_PORT, + db=settings.REDIS_DB, +) + +def gen_cache_key(model, *args): + """ + Generate a unique cache key using the model and input arguments. + + Args: + model (str): The name or identifier of the LLM model being used. + *args: Additional arguments that should contribute to the uniqueness of the cache key. + + Returns: + str: A unique cache key generated by hashing the model name and arguments. + """ + # Combine the model name and args into a single string to ensure uniqueness + key_base = f"{model}_" + "_".join([str(arg) for arg in args]) + + # Use the get_hash utility to hash the key for consistent length and uniqueness + cache_key = get_hash(key_base) + + return cache_key + + +def gen_cache(func): + """ + Decorator to cache the response of a function that generates a response using an LLM. + + This decorator first checks if a response is cached for the given input (model and messages). + If a cached response is found, it returns that. If not, it generates the response, + caches it, and returns the generated response. + + Args: + func (function): The function to be decorated. + + Returns: + function: The wrapped function that handles caching and LLM response generation. + """ + def wrapper(self, model, messages, *args, **kwargs): + # Generate a cache key based on the model and message contents + cache_key = gen_cache_key(model, *[msg['content'] for msg in messages]) + + # Check for cached response + cached_response = redis_client.get(cache_key) + if cached_response: + print(f"Cache hit for key: {cache_key}") + return cached_response.decode('utf-8') # Redis stores bytes, so decode to string + + # No cached response, generate the LLM result + result = func(self, model, messages, *args, **kwargs) + + # Cache the result for future use (expires in 3600 seconds = 1 hour) + redis_client.set(cache_key, result, ex=3600) + print(f"Cache saved for key: {cache_key}") + + return result + return wrapper + + +def stream_cache(func): + """ + Decorator to cache the streamed response of an LLM function. + + This decorator first checks if a streamed response is cached for the given input (model and messages). + If a cached response is found, it yields that. If not, it streams the response, caches it, + and then yields the response. + + Args: + func (function): The function to be decorated. + + Returns: + function: The wrapped function that handles caching and streaming LLM responses. + """ + def wrapper(self, model, messages, *args, **kwargs): + # Generate a cache key based on the model and message contents + cache_key = gen_cache_key(model, *[msg['content'] for msg in messages]) + + # Check for cached streamed response + cached_response = redis_client.get(cache_key) + if cached_response: + print(f"Cache hit for stream key: {cache_key}") + # Yield the cached response in chunks (split by a delimiter if necessary) + yield cached_response.decode('utf-8') + return + + # No cached response, proceed with streaming the response + batch = [] + result = func(self, model, messages, *args, **kwargs) + for chunk in result: + batch.append(chunk) + yield chunk # Yield each chunk of the response to the caller + + # After streaming is complete, save the full response to the cache + full_response = ''.join(batch) # Join chunks into a full response + redis_client.set(cache_key, full_response, ex=3600) + print(f"Stream cache saved for key: {cache_key}") + return wrapper \ No newline at end of file