Merge pull request #1308 from fadingNA/caching-docsgpt

Caching docsgpt
This commit is contained in:
Alex
2024-10-15 11:56:30 +01:00
committed by GitHub
10 changed files with 310 additions and 50 deletions

View File

@@ -292,6 +292,7 @@ class Stream(Resource):
def post(self):
data = request.get_json()
required_fields = ["question"]
missing_fields = check_required_fields(data, required_fields)
if missing_fields:
return missing_fields
@@ -422,7 +423,7 @@ class Answer(Resource):
@api.doc(description="Provide an answer based on the question and retriever")
def post(self):
data = request.get_json()
required_fields = ["question"]
required_fields = ["question"]
missing_fields = check_required_fields(data, required_fields)
if missing_fields:
return missing_fields

93
application/cache.py Normal file
View File

@@ -0,0 +1,93 @@
import redis
import time
import json
import logging
from threading import Lock
from application.core.settings import settings
from application.utils import get_hash
logger = logging.getLogger(__name__)
_redis_instance = None
_instance_lock = Lock()
def get_redis_instance():
global _redis_instance
if _redis_instance is None:
with _instance_lock:
if _redis_instance is None:
try:
_redis_instance = redis.Redis.from_url(settings.CACHE_REDIS_URL, socket_connect_timeout=2)
except redis.ConnectionError as e:
logger.error(f"Redis connection error: {e}")
_redis_instance = None
return _redis_instance
def gen_cache_key(*messages, model="docgpt"):
if not all(isinstance(msg, dict) for msg in messages):
raise ValueError("All messages must be dictionaries.")
messages_str = json.dumps(list(messages), sort_keys=True)
combined = f"{model}_{messages_str}"
cache_key = get_hash(combined)
return cache_key
def gen_cache(func):
def wrapper(self, model, messages, *args, **kwargs):
try:
cache_key = gen_cache_key(*messages)
redis_client = get_redis_instance()
if redis_client:
try:
cached_response = redis_client.get(cache_key)
if cached_response:
return cached_response.decode('utf-8')
except redis.ConnectionError as e:
logger.error(f"Redis connection error: {e}")
result = func(self, model, messages, *args, **kwargs)
if redis_client:
try:
redis_client.set(cache_key, result, ex=1800)
except redis.ConnectionError as e:
logger.error(f"Redis connection error: {e}")
return result
except ValueError as e:
logger.error(e)
return "Error: No user message found in the conversation to generate a cache key."
return wrapper
def stream_cache(func):
def wrapper(self, model, messages, stream, *args, **kwargs):
cache_key = gen_cache_key(*messages)
logger.info(f"Stream cache key: {cache_key}")
redis_client = get_redis_instance()
if redis_client:
try:
cached_response = redis_client.get(cache_key)
if cached_response:
logger.info(f"Cache hit for stream key: {cache_key}")
cached_response = json.loads(cached_response.decode('utf-8'))
for chunk in cached_response:
yield chunk
time.sleep(0.03)
return
except redis.ConnectionError as e:
logger.error(f"Redis connection error: {e}")
result = func(self, model, messages, stream, *args, **kwargs)
stream_cache_data = []
for chunk in result:
stream_cache_data.append(chunk)
yield chunk
if redis_client:
try:
redis_client.set(cache_key, json.dumps(stream_cache_data), ex=1800)
logger.info(f"Stream cache saved for key: {cache_key}")
except redis.ConnectionError as e:
logger.error(f"Redis connection error: {e}")
return wrapper

View File

@@ -21,6 +21,9 @@ class Settings(BaseSettings):
VECTOR_STORE: str = "faiss" # "faiss" or "elasticsearch" or "qdrant" or "milvus"
RETRIEVERS_ENABLED: list = ["classic_rag", "duckduck_search"] # also brave_search
# LLM Cache
CACHE_REDIS_URL: str = "redis://localhost:6379/2"
API_URL: str = "http://localhost:7091" # backend url for celery worker
API_KEY: Optional[str] = None # LLM api key

View File

@@ -1,28 +1,29 @@
from abc import ABC, abstractmethod
from application.usage import gen_token_usage, stream_token_usage
from application.cache import stream_cache, gen_cache
class BaseLLM(ABC):
def __init__(self):
self.token_usage = {"prompt_tokens": 0, "generated_tokens": 0}
def _apply_decorator(self, method, decorator, *args, **kwargs):
return decorator(method, *args, **kwargs)
def _apply_decorator(self, method, decorators, *args, **kwargs):
for decorator in decorators:
method = decorator(method)
return method(self, *args, **kwargs)
@abstractmethod
def _raw_gen(self, model, messages, stream, *args, **kwargs):
pass
def gen(self, model, messages, stream=False, *args, **kwargs):
return self._apply_decorator(self._raw_gen, gen_token_usage)(
self, model=model, messages=messages, stream=stream, *args, **kwargs
)
decorators = [gen_token_usage, gen_cache]
return self._apply_decorator(self._raw_gen, decorators=decorators, model=model, messages=messages, stream=stream, *args, **kwargs)
@abstractmethod
def _raw_gen_stream(self, model, messages, stream, *args, **kwargs):
pass
def gen_stream(self, model, messages, stream=True, *args, **kwargs):
return self._apply_decorator(self._raw_gen_stream, stream_token_usage)(
self, model=model, messages=messages, stream=stream, *args, **kwargs
)
decorators = [stream_cache, stream_token_usage]
return self._apply_decorator(self._raw_gen_stream, decorators=decorators, model=model, messages=messages, stream=stream, *args, **kwargs)

View File

@@ -1,6 +1,8 @@
import tiktoken
import hashlib
from flask import jsonify, make_response
_encoding = None
@@ -39,3 +41,8 @@ def check_required_fields(data, required_fields):
400,
)
return None
def get_hash(data):
return hashlib.md5(data.encode()).hexdigest()

View File

@@ -20,6 +20,7 @@ services:
- CELERY_BROKER_URL=redis://redis:6379/0
- CELERY_RESULT_BACKEND=redis://redis:6379/1
- MONGO_URI=mongodb://mongo:27017/docsgpt
- CACHE_REDIS_URL=redis://redis:6379/2
ports:
- "7091:7091"
volumes:
@@ -41,6 +42,7 @@ services:
- CELERY_RESULT_BACKEND=redis://redis:6379/1
- MONGO_URI=mongodb://mongo:27017/docsgpt
- API_URL=http://backend:7091
- CACHE_REDIS_URL=redis://redis:6379/2
depends_on:
- redis
- mongo

View File

@@ -1675,7 +1675,7 @@
"version": "18.3.0",
"resolved": "https://registry.npmjs.org/@types/react-dom/-/react-dom-18.3.0.tgz",
"integrity": "sha512-EhwApuTmMBmXuFOikhQLIBUn6uFg81SwLMOAUgodJF14SOBOCMdU04gDoYi0WOJJHD144TL32z4yDqCW3dnkQg==",
"devOptional": true,
"dev": true,
"dependencies": {
"@types/react": "*"
}

View File

@@ -22,17 +22,23 @@ class TestAnthropicLLM(unittest.TestCase):
mock_response = Mock()
mock_response.completion = "test completion"
with patch.object(self.llm.anthropic.completions, "create", return_value=mock_response) as mock_create:
response = self.llm.gen("test_model", messages)
self.assertEqual(response, "test completion")
with patch("application.cache.get_redis_instance") as mock_make_redis:
mock_redis_instance = mock_make_redis.return_value
mock_redis_instance.get.return_value = None
mock_redis_instance.set = Mock()
prompt_expected = "### Context \n context \n ### Question \n question"
mock_create.assert_called_with(
model="test_model",
max_tokens_to_sample=300,
stream=False,
prompt=f"{self.llm.HUMAN_PROMPT} {prompt_expected}{self.llm.AI_PROMPT}"
)
with patch.object(self.llm.anthropic.completions, "create", return_value=mock_response) as mock_create:
response = self.llm.gen("test_model", messages)
self.assertEqual(response, "test completion")
prompt_expected = "### Context \n context \n ### Question \n question"
mock_create.assert_called_with(
model="test_model",
max_tokens_to_sample=300,
stream=False,
prompt=f"{self.llm.HUMAN_PROMPT} {prompt_expected}{self.llm.AI_PROMPT}"
)
mock_redis_instance.set.assert_called_once()
def test_gen_stream(self):
messages = [
@@ -41,17 +47,23 @@ class TestAnthropicLLM(unittest.TestCase):
]
mock_responses = [Mock(completion="response_1"), Mock(completion="response_2")]
with patch.object(self.llm.anthropic.completions, "create", return_value=iter(mock_responses)) as mock_create:
responses = list(self.llm.gen_stream("test_model", messages))
self.assertListEqual(responses, ["response_1", "response_2"])
with patch("application.cache.get_redis_instance") as mock_make_redis:
mock_redis_instance = mock_make_redis.return_value
mock_redis_instance.get.return_value = None
mock_redis_instance.set = Mock()
prompt_expected = "### Context \n context \n ### Question \n question"
mock_create.assert_called_with(
model="test_model",
prompt=f"{self.llm.HUMAN_PROMPT} {prompt_expected}{self.llm.AI_PROMPT}",
max_tokens_to_sample=300,
stream=True
)
with patch.object(self.llm.anthropic.completions, "create", return_value=iter(mock_responses)) as mock_create:
responses = list(self.llm.gen_stream("test_model", messages))
self.assertListEqual(responses, ["response_1", "response_2"])
prompt_expected = "### Context \n context \n ### Question \n question"
mock_create.assert_called_with(
model="test_model",
prompt=f"{self.llm.HUMAN_PROMPT} {prompt_expected}{self.llm.AI_PROMPT}",
max_tokens_to_sample=300,
stream=True
)
mock_redis_instance.set.assert_called_once()
if __name__ == "__main__":
unittest.main()

View File

@@ -52,28 +52,38 @@ class TestSagemakerAPILLM(unittest.TestCase):
self.response['Body'].read.return_value.decode.return_value = json.dumps(self.result)
def test_gen(self):
with patch.object(self.sagemaker.runtime, 'invoke_endpoint',
return_value=self.response) as mock_invoke_endpoint:
output = self.sagemaker.gen(None, self.messages)
mock_invoke_endpoint.assert_called_once_with(
EndpointName=self.sagemaker.endpoint,
ContentType='application/json',
Body=self.body_bytes
)
self.assertEqual(output,
self.result[0]['generated_text'][len(self.prompt):])
with patch('application.cache.get_redis_instance') as mock_make_redis:
mock_redis_instance = mock_make_redis.return_value
mock_redis_instance.get.return_value = None
with patch.object(self.sagemaker.runtime, 'invoke_endpoint',
return_value=self.response) as mock_invoke_endpoint:
output = self.sagemaker.gen(None, self.messages)
mock_invoke_endpoint.assert_called_once_with(
EndpointName=self.sagemaker.endpoint,
ContentType='application/json',
Body=self.body_bytes
)
self.assertEqual(output,
self.result[0]['generated_text'][len(self.prompt):])
mock_make_redis.assert_called_once()
mock_redis_instance.set.assert_called_once()
def test_gen_stream(self):
with patch.object(self.sagemaker.runtime, 'invoke_endpoint_with_response_stream',
return_value=self.response) as mock_invoke_endpoint:
output = list(self.sagemaker.gen_stream(None, self.messages))
mock_invoke_endpoint.assert_called_once_with(
EndpointName=self.sagemaker.endpoint,
ContentType='application/json',
Body=self.body_bytes_stream
)
self.assertEqual(output, [])
with patch('application.cache.get_redis_instance') as mock_make_redis:
mock_redis_instance = mock_make_redis.return_value
mock_redis_instance.get.return_value = None
with patch.object(self.sagemaker.runtime, 'invoke_endpoint_with_response_stream',
return_value=self.response) as mock_invoke_endpoint:
output = list(self.sagemaker.gen_stream(None, self.messages))
mock_invoke_endpoint.assert_called_once_with(
EndpointName=self.sagemaker.endpoint,
ContentType='application/json',
Body=self.body_bytes_stream
)
self.assertEqual(output, [])
mock_redis_instance.set.assert_called_once()
class TestLineIterator(unittest.TestCase):
def setUp(self):

131
tests/test_cache.py Normal file
View File

@@ -0,0 +1,131 @@
import unittest
import json
from unittest.mock import patch, MagicMock
from application.cache import gen_cache_key, stream_cache, gen_cache
from application.utils import get_hash
# Test for gen_cache_key function
def test_make_gen_cache_key():
messages = [
{'role': 'user', 'content': 'test_user_message'},
{'role': 'system', 'content': 'test_system_message'},
]
model = "test_docgpt"
# Manually calculate the expected hash
expected_combined = f"{model}_{json.dumps(messages, sort_keys=True)}"
expected_hash = get_hash(expected_combined)
cache_key = gen_cache_key(*messages, model=model)
assert cache_key == expected_hash
def test_gen_cache_key_invalid_message_format():
# Test when messages is not a list
with unittest.TestCase.assertRaises(unittest.TestCase, ValueError) as context:
gen_cache_key("This is not a list", model="docgpt")
assert str(context.exception) == "All messages must be dictionaries."
# Test for gen_cache decorator
@patch('application.cache.get_redis_instance') # Mock the Redis client
def test_gen_cache_hit(mock_make_redis):
# Arrange
mock_redis_instance = MagicMock()
mock_make_redis.return_value = mock_redis_instance
mock_redis_instance.get.return_value = b"cached_result" # Simulate a cache hit
@gen_cache
def mock_function(self, model, messages):
return "new_result"
messages = [{'role': 'user', 'content': 'test_user_message'}]
model = "test_docgpt"
# Act
result = mock_function(None, model, messages)
# Assert
assert result == "cached_result" # Should return cached result
mock_redis_instance.get.assert_called_once() # Ensure Redis get was called
mock_redis_instance.set.assert_not_called() # Ensure the function result is not cached again
@patch('application.cache.get_redis_instance') # Mock the Redis client
def test_gen_cache_miss(mock_make_redis):
# Arrange
mock_redis_instance = MagicMock()
mock_make_redis.return_value = mock_redis_instance
mock_redis_instance.get.return_value = None # Simulate a cache miss
@gen_cache
def mock_function(self, model, messages):
return "new_result"
messages = [
{'role': 'user', 'content': 'test_user_message'},
{'role': 'system', 'content': 'test_system_message'},
]
model = "test_docgpt"
# Act
result = mock_function(None, model, messages)
# Assert
assert result == "new_result"
mock_redis_instance.get.assert_called_once()
@patch('application.cache.get_redis_instance')
def test_stream_cache_hit(mock_make_redis):
# Arrange
mock_redis_instance = MagicMock()
mock_make_redis.return_value = mock_redis_instance
cached_chunk = json.dumps(["chunk1", "chunk2"]).encode('utf-8')
mock_redis_instance.get.return_value = cached_chunk
@stream_cache
def mock_function(self, model, messages, stream):
yield "new_chunk"
messages = [{'role': 'user', 'content': 'test_user_message'}]
model = "test_docgpt"
# Act
result = list(mock_function(None, model, messages, stream=True))
# Assert
assert result == ["chunk1", "chunk2"] # Should return cached chunks
mock_redis_instance.get.assert_called_once()
mock_redis_instance.set.assert_not_called()
@patch('application.cache.get_redis_instance')
def test_stream_cache_miss(mock_make_redis):
# Arrange
mock_redis_instance = MagicMock()
mock_make_redis.return_value = mock_redis_instance
mock_redis_instance.get.return_value = None # Simulate a cache miss
@stream_cache
def mock_function(self, model, messages, stream):
yield "new_chunk"
messages = [
{'role': 'user', 'content': 'This is the context'},
{'role': 'system', 'content': 'Some other message'},
{'role': 'user', 'content': 'What is the answer?'}
]
model = "test_docgpt"
# Act
result = list(mock_function(None, model, messages, stream=True))
# Assert
assert result == ["new_chunk"]
mock_redis_instance.get.assert_called_once()
mock_redis_instance.set.assert_called_once()