mirror of
https://github.com/arc53/DocsGPT.git
synced 2026-05-21 21:05:05 +00:00
Without health_check_interval, a half-open TCP socket (NAT silently dropped state, ELB idle-close) can leave pubsub.get_message hanging past the SSE generator's keepalive cadence — the kernel never surfaces the dead socket because no payload is in flight. Setting health_check_interval=10 makes redis-py ping every 10s when otherwise idle, so the next get_message after the dead window raises and the SSE loop falls into its reconnect path instead of silently freezing on the user.
136 lines
5.2 KiB
Python
136 lines
5.2 KiB
Python
import hashlib
|
|
import json
|
|
import logging
|
|
import time
|
|
from threading import Lock
|
|
|
|
import redis
|
|
|
|
from application.core.settings import settings
|
|
from application.utils import get_hash
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def _cache_default(value):
|
|
# Image attachments arrive inline as bytes (see GoogleLLM.prepare_messages_with_attachments);
|
|
# hash so the cache key stays bounded in size and stable across identical content.
|
|
if isinstance(value, (bytes, bytearray, memoryview)):
|
|
return f"<bytes:sha256:{hashlib.sha256(bytes(value)).hexdigest()}>"
|
|
return repr(value)
|
|
|
|
_redis_instance = None
|
|
_redis_creation_failed = False
|
|
_instance_lock = Lock()
|
|
|
|
def get_redis_instance():
|
|
global _redis_instance, _redis_creation_failed
|
|
if _redis_instance is None and not _redis_creation_failed:
|
|
with _instance_lock:
|
|
if _redis_instance is None and not _redis_creation_failed:
|
|
try:
|
|
# ``health_check_interval`` makes redis-py ping the
|
|
# connection every N seconds when otherwise idle.
|
|
# Without it, a half-open TCP (NAT silently dropped
|
|
# state, ELB idle-close) can hang the SSE generator
|
|
# in ``pubsub.get_message`` past its keepalive
|
|
# cadence — the kernel never surfaces the dead
|
|
# socket because no payload is in flight.
|
|
_redis_instance = redis.Redis.from_url(
|
|
settings.CACHE_REDIS_URL,
|
|
socket_connect_timeout=2,
|
|
health_check_interval=10,
|
|
)
|
|
except ValueError as e:
|
|
logger.error(f"Invalid Redis URL: {e}")
|
|
_redis_creation_failed = True # Stop future attempts
|
|
_redis_instance = None
|
|
except redis.ConnectionError as e:
|
|
logger.error(f"Redis connection error: {e}")
|
|
_redis_instance = None # Keep trying for connection errors
|
|
return _redis_instance
|
|
|
|
|
|
def gen_cache_key(messages, model="docgpt", tools=None):
|
|
if not all(isinstance(msg, dict) for msg in messages):
|
|
raise ValueError("All messages must be dictionaries.")
|
|
messages_str = json.dumps(messages, default=_cache_default)
|
|
tools_str = json.dumps(str(tools)) if tools else ""
|
|
combined = f"{model}_{messages_str}_{tools_str}"
|
|
cache_key = get_hash(combined)
|
|
return cache_key
|
|
|
|
|
|
def gen_cache(func):
|
|
def wrapper(self, model, messages, stream, tools=None, *args, **kwargs):
|
|
if tools is not None:
|
|
return func(self, model, messages, stream, tools, *args, **kwargs)
|
|
|
|
try:
|
|
cache_key = gen_cache_key(messages, model, tools)
|
|
except ValueError as e:
|
|
logger.error(f"Cache key generation failed: {e}")
|
|
return func(self, model, messages, stream, tools, *args, **kwargs)
|
|
|
|
redis_client = get_redis_instance()
|
|
if redis_client:
|
|
try:
|
|
cached_response = redis_client.get(cache_key)
|
|
if cached_response:
|
|
return cached_response.decode("utf-8")
|
|
except Exception as e:
|
|
logger.error(f"Error getting cached response: {e}", exc_info=True)
|
|
|
|
result = func(self, model, messages, stream, tools, *args, **kwargs)
|
|
if redis_client and isinstance(result, str):
|
|
try:
|
|
redis_client.set(cache_key, result, ex=1800)
|
|
except Exception as e:
|
|
logger.error(f"Error setting cache: {e}", exc_info=True)
|
|
|
|
return result
|
|
|
|
return wrapper
|
|
|
|
|
|
def stream_cache(func):
|
|
def wrapper(self, model, messages, stream, tools=None, *args, **kwargs):
|
|
if tools is not None:
|
|
yield from func(self, model, messages, stream, tools, *args, **kwargs)
|
|
return
|
|
|
|
try:
|
|
cache_key = gen_cache_key(messages, model, tools)
|
|
except ValueError as e:
|
|
logger.error(f"Cache key generation failed: {e}")
|
|
yield from func(self, model, messages, stream, tools, *args, **kwargs)
|
|
return
|
|
|
|
redis_client = get_redis_instance()
|
|
if redis_client:
|
|
try:
|
|
cached_response = redis_client.get(cache_key)
|
|
if cached_response:
|
|
logger.info(f"Cache hit for stream key: {cache_key}")
|
|
cached_response = json.loads(cached_response.decode("utf-8"))
|
|
for chunk in cached_response:
|
|
yield chunk
|
|
time.sleep(0.03) # Simulate streaming delay
|
|
return
|
|
except Exception as e:
|
|
logger.error(f"Error getting cached stream: {e}", exc_info=True)
|
|
|
|
stream_cache_data = []
|
|
for chunk in func(self, model, messages, stream, tools, *args, **kwargs):
|
|
yield chunk
|
|
stream_cache_data.append(str(chunk))
|
|
|
|
if redis_client:
|
|
try:
|
|
redis_client.set(cache_key, json.dumps(stream_cache_data), ex=1800)
|
|
logger.info(f"Stream cache saved for key: {cache_key}")
|
|
except Exception as e:
|
|
logger.error(f"Error setting stream cache: {e}", exc_info=True)
|
|
|
|
return wrapper
|