adding time for perform streaming response

This commit is contained in:
fadingNA
2024-10-12 15:25:06 -04:00
parent d7fe1150dc
commit 3c43b87e9f

View File

@@ -1,5 +1,6 @@
import sys import sys
import redis import redis
import time
from datetime import datetime from datetime import datetime
from application.core.settings import settings from application.core.settings import settings
from application.utils import get_hash from application.utils import get_hash
@@ -12,23 +13,14 @@ redis_client = redis.Redis(
db=settings.REDIS_DB, db=settings.REDIS_DB,
) )
def gen_cache_key(model, *args): def gen_cache_key(messages):
""" """
Generate a unique cache key using the model and input arguments. Generate a unique cache key using the latest user prompt.
Args:
model (str): The name or identifier of the LLM model being used.
*args: Additional arguments that should contribute to the uniqueness of the cache key.
Returns:
str: A unique cache key generated by hashing the model name and arguments.
""" """
# Combine the model name and args into a single string to ensure uniqueness latest_user_prompt = next((msg['content'] for msg in reversed(messages) if msg['role'] == 'user'), None)
key_base = f"{model}_" + "_".join([str(arg) for arg in args]) if latest_user_prompt is None:
raise ValueError("No user message found in the conversation to generate a cache key.")
# Use the get_hash utility to hash the key for consistent length and uniqueness cache_key = get_hash(latest_user_prompt)
cache_key = get_hash(key_base)
return cache_key return cache_key
@@ -39,34 +31,23 @@ def gen_cache(func):
This decorator first checks if a response is cached for the given input (model and messages). This decorator first checks if a response is cached for the given input (model and messages).
If a cached response is found, it returns that. If not, it generates the response, If a cached response is found, it returns that. If not, it generates the response,
caches it, and returns the generated response. caches it, and returns the generated response.
Args: Args:
func (function): The function to be decorated. func (function): The function to be decorated.
Returns: Returns:
function: The wrapped function that handles caching and LLM response generation. function: The wrapped function that handles caching and LLM response generation.
""" """
def wrapper(self, model, messages, *args, **kwargs): def wrapper(self, model, messages, *args, **kwargs):
# Generate a cache key based on the model and message contents cache_key = gen_cache_key(messages=messages)
cache_key = gen_cache_key(model, *[msg['content'] for msg in messages])
# Check for cached response
cached_response = redis_client.get(cache_key) cached_response = redis_client.get(cache_key)
if cached_response: if cached_response:
print(f"Cache hit for key: {cache_key}") print(f"Cache hit for key: {cache_key}")
return cached_response.decode('utf-8') # Redis stores bytes, so decode to string return cached_response.decode('utf-8')
# No cached response, generate the LLM result
result = func(self, model, messages, *args, **kwargs) result = func(self, model, messages, *args, **kwargs)
# Cache the result for future use (expires in 3600 seconds = 1 hour)
redis_client.set(cache_key, result, ex=3600) redis_client.set(cache_key, result, ex=3600)
print(f"Cache saved for key: {cache_key}") print(f"Cache saved for key: {cache_key}")
return result return result
return wrapper return wrapper
def stream_cache(func): def stream_cache(func):
""" """
Decorator to cache the streamed response of an LLM function. Decorator to cache the streamed response of an LLM function.
@@ -82,26 +63,31 @@ def stream_cache(func):
function: The wrapped function that handles caching and streaming LLM responses. function: The wrapped function that handles caching and streaming LLM responses.
""" """
def wrapper(self, model, messages, *args, **kwargs): def wrapper(self, model, messages, *args, **kwargs):
# Generate a cache key based on the model and message contents cache_key = gen_cache_key(messages=messages)
cache_key = gen_cache_key(model, *[msg['content'] for msg in messages])
# Check for cached streamed response # we are using lrange and rpush to simulate streaming
cached_response = redis_client.get(cache_key) cached_response = redis_client.lrange(cache_key, 0, -1)
if cached_response: if cached_response:
print(f"Cache hit for stream key: {cache_key}") print(f"Cache hit for stream key: {cache_key}")
# Yield the cached response in chunks (split by a delimiter if necessary) for chunk in cached_response:
yield cached_response.decode('utf-8') print(f"Streaming cached chunk: {chunk.decode('utf-8')}")
yield chunk.decode('utf-8')
# need to slow down the response to simulate streaming
# because the cached response is instantaneous
# and redis is using in-memory storage
time.sleep(0.07)
return return
# No cached response, proceed with streaming the response
batch = []
result = func(self, model, messages, *args, **kwargs) result = func(self, model, messages, *args, **kwargs)
for chunk in result:
batch.append(chunk)
yield chunk # Yield each chunk of the response to the caller
# After streaming is complete, save the full response to the cache for chunk in result:
full_response = ''.join(batch) # Join chunks into a full response print(f"Streaming live chunk: {chunk}")
redis_client.set(cache_key, full_response, ex=3600) redis_client.rpush(cache_key, chunk)
yield chunk
# expire the cache after 30 minutes
redis_client.expire(cache_key, 1800)
print(f"Stream cache saved for key: {cache_key}") print(f"Stream cache saved for key: {cache_key}")
return wrapper return wrapper