diff --git a/application/api/answer/routes.py b/application/api/answer/routes.py index 9a22db84..17eb5cc3 100644 --- a/application/api/answer/routes.py +++ b/application/api/answer/routes.py @@ -292,6 +292,7 @@ class Stream(Resource): def post(self): data = request.get_json() required_fields = ["question"] + missing_fields = check_required_fields(data, required_fields) if missing_fields: return missing_fields @@ -422,7 +423,7 @@ class Answer(Resource): @api.doc(description="Provide an answer based on the question and retriever") def post(self): data = request.get_json() - required_fields = ["question"] + required_fields = ["question"] missing_fields = check_required_fields(data, required_fields) if missing_fields: return missing_fields diff --git a/application/cache.py b/application/cache.py new file mode 100644 index 00000000..33022e45 --- /dev/null +++ b/application/cache.py @@ -0,0 +1,93 @@ +import redis +import time +import json +import logging +from threading import Lock +from application.core.settings import settings +from application.utils import get_hash + +logger = logging.getLogger(__name__) + +_redis_instance = None +_instance_lock = Lock() + +def get_redis_instance(): + global _redis_instance + if _redis_instance is None: + with _instance_lock: + if _redis_instance is None: + try: + _redis_instance = redis.Redis.from_url(settings.CACHE_REDIS_URL, socket_connect_timeout=2) + except redis.ConnectionError as e: + logger.error(f"Redis connection error: {e}") + _redis_instance = None + return _redis_instance + +def gen_cache_key(*messages, model="docgpt"): + if not all(isinstance(msg, dict) for msg in messages): + raise ValueError("All messages must be dictionaries.") + messages_str = json.dumps(list(messages), sort_keys=True) + combined = f"{model}_{messages_str}" + cache_key = get_hash(combined) + return cache_key + +def gen_cache(func): + def wrapper(self, model, messages, *args, **kwargs): + try: + cache_key = gen_cache_key(*messages) + redis_client = get_redis_instance() + if redis_client: + try: + cached_response = redis_client.get(cache_key) + if cached_response: + return cached_response.decode('utf-8') + except redis.ConnectionError as e: + logger.error(f"Redis connection error: {e}") + + result = func(self, model, messages, *args, **kwargs) + if redis_client: + try: + redis_client.set(cache_key, result, ex=1800) + except redis.ConnectionError as e: + logger.error(f"Redis connection error: {e}") + + return result + except ValueError as e: + logger.error(e) + return "Error: No user message found in the conversation to generate a cache key." + return wrapper + +def stream_cache(func): + def wrapper(self, model, messages, stream, *args, **kwargs): + cache_key = gen_cache_key(*messages) + logger.info(f"Stream cache key: {cache_key}") + + redis_client = get_redis_instance() + if redis_client: + try: + cached_response = redis_client.get(cache_key) + if cached_response: + logger.info(f"Cache hit for stream key: {cache_key}") + cached_response = json.loads(cached_response.decode('utf-8')) + for chunk in cached_response: + yield chunk + time.sleep(0.03) + return + except redis.ConnectionError as e: + logger.error(f"Redis connection error: {e}") + + result = func(self, model, messages, stream, *args, **kwargs) + stream_cache_data = [] + + for chunk in result: + stream_cache_data.append(chunk) + yield chunk + + if redis_client: + try: + redis_client.set(cache_key, json.dumps(stream_cache_data), ex=1800) + logger.info(f"Stream cache saved for key: {cache_key}") + except redis.ConnectionError as e: + logger.error(f"Redis connection error: {e}") + + return wrapper \ No newline at end of file diff --git a/application/core/settings.py b/application/core/settings.py index e6173be4..7346da08 100644 --- a/application/core/settings.py +++ b/application/core/settings.py @@ -21,6 +21,9 @@ class Settings(BaseSettings): VECTOR_STORE: str = "faiss" # "faiss" or "elasticsearch" or "qdrant" or "milvus" RETRIEVERS_ENABLED: list = ["classic_rag", "duckduck_search"] # also brave_search + # LLM Cache + CACHE_REDIS_URL: str = "redis://localhost:6379/2" + API_URL: str = "http://localhost:7091" # backend url for celery worker API_KEY: Optional[str] = None # LLM api key diff --git a/application/llm/base.py b/application/llm/base.py index 475b7937..1caab5d3 100644 --- a/application/llm/base.py +++ b/application/llm/base.py @@ -1,28 +1,29 @@ from abc import ABC, abstractmethod from application.usage import gen_token_usage, stream_token_usage +from application.cache import stream_cache, gen_cache class BaseLLM(ABC): def __init__(self): self.token_usage = {"prompt_tokens": 0, "generated_tokens": 0} - def _apply_decorator(self, method, decorator, *args, **kwargs): - return decorator(method, *args, **kwargs) + def _apply_decorator(self, method, decorators, *args, **kwargs): + for decorator in decorators: + method = decorator(method) + return method(self, *args, **kwargs) @abstractmethod def _raw_gen(self, model, messages, stream, *args, **kwargs): pass def gen(self, model, messages, stream=False, *args, **kwargs): - return self._apply_decorator(self._raw_gen, gen_token_usage)( - self, model=model, messages=messages, stream=stream, *args, **kwargs - ) + decorators = [gen_token_usage, gen_cache] + return self._apply_decorator(self._raw_gen, decorators=decorators, model=model, messages=messages, stream=stream, *args, **kwargs) @abstractmethod def _raw_gen_stream(self, model, messages, stream, *args, **kwargs): pass def gen_stream(self, model, messages, stream=True, *args, **kwargs): - return self._apply_decorator(self._raw_gen_stream, stream_token_usage)( - self, model=model, messages=messages, stream=stream, *args, **kwargs - ) + decorators = [stream_cache, stream_token_usage] + return self._apply_decorator(self._raw_gen_stream, decorators=decorators, model=model, messages=messages, stream=stream, *args, **kwargs) \ No newline at end of file diff --git a/application/utils.py b/application/utils.py index f0802c39..1fc9e329 100644 --- a/application/utils.py +++ b/application/utils.py @@ -1,6 +1,8 @@ import tiktoken +import hashlib from flask import jsonify, make_response + _encoding = None @@ -39,3 +41,8 @@ def check_required_fields(data, required_fields): 400, ) return None + + +def get_hash(data): + return hashlib.md5(data.encode()).hexdigest() + diff --git a/docker-compose.yaml b/docker-compose.yaml index f3b8a363..d3f3421a 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -20,6 +20,7 @@ services: - CELERY_BROKER_URL=redis://redis:6379/0 - CELERY_RESULT_BACKEND=redis://redis:6379/1 - MONGO_URI=mongodb://mongo:27017/docsgpt + - CACHE_REDIS_URL=redis://redis:6379/2 ports: - "7091:7091" volumes: @@ -41,6 +42,7 @@ services: - CELERY_RESULT_BACKEND=redis://redis:6379/1 - MONGO_URI=mongodb://mongo:27017/docsgpt - API_URL=http://backend:7091 + - CACHE_REDIS_URL=redis://redis:6379/2 depends_on: - redis - mongo diff --git a/frontend/package-lock.json b/frontend/package-lock.json index 4087e4f5..75f4ea8e 100644 --- a/frontend/package-lock.json +++ b/frontend/package-lock.json @@ -1675,7 +1675,7 @@ "version": "18.3.0", "resolved": "https://registry.npmjs.org/@types/react-dom/-/react-dom-18.3.0.tgz", "integrity": "sha512-EhwApuTmMBmXuFOikhQLIBUn6uFg81SwLMOAUgodJF14SOBOCMdU04gDoYi0WOJJHD144TL32z4yDqCW3dnkQg==", - "devOptional": true, + "dev": true, "dependencies": { "@types/react": "*" } diff --git a/tests/llm/test_anthropic.py b/tests/llm/test_anthropic.py index ee4ba15f..689013c0 100644 --- a/tests/llm/test_anthropic.py +++ b/tests/llm/test_anthropic.py @@ -22,17 +22,23 @@ class TestAnthropicLLM(unittest.TestCase): mock_response = Mock() mock_response.completion = "test completion" - with patch.object(self.llm.anthropic.completions, "create", return_value=mock_response) as mock_create: - response = self.llm.gen("test_model", messages) - self.assertEqual(response, "test completion") + with patch("application.cache.get_redis_instance") as mock_make_redis: + mock_redis_instance = mock_make_redis.return_value + mock_redis_instance.get.return_value = None + mock_redis_instance.set = Mock() - prompt_expected = "### Context \n context \n ### Question \n question" - mock_create.assert_called_with( - model="test_model", - max_tokens_to_sample=300, - stream=False, - prompt=f"{self.llm.HUMAN_PROMPT} {prompt_expected}{self.llm.AI_PROMPT}" - ) + with patch.object(self.llm.anthropic.completions, "create", return_value=mock_response) as mock_create: + response = self.llm.gen("test_model", messages) + self.assertEqual(response, "test completion") + + prompt_expected = "### Context \n context \n ### Question \n question" + mock_create.assert_called_with( + model="test_model", + max_tokens_to_sample=300, + stream=False, + prompt=f"{self.llm.HUMAN_PROMPT} {prompt_expected}{self.llm.AI_PROMPT}" + ) + mock_redis_instance.set.assert_called_once() def test_gen_stream(self): messages = [ @@ -41,17 +47,23 @@ class TestAnthropicLLM(unittest.TestCase): ] mock_responses = [Mock(completion="response_1"), Mock(completion="response_2")] - with patch.object(self.llm.anthropic.completions, "create", return_value=iter(mock_responses)) as mock_create: - responses = list(self.llm.gen_stream("test_model", messages)) - self.assertListEqual(responses, ["response_1", "response_2"]) + with patch("application.cache.get_redis_instance") as mock_make_redis: + mock_redis_instance = mock_make_redis.return_value + mock_redis_instance.get.return_value = None + mock_redis_instance.set = Mock() - prompt_expected = "### Context \n context \n ### Question \n question" - mock_create.assert_called_with( - model="test_model", - prompt=f"{self.llm.HUMAN_PROMPT} {prompt_expected}{self.llm.AI_PROMPT}", - max_tokens_to_sample=300, - stream=True - ) + with patch.object(self.llm.anthropic.completions, "create", return_value=iter(mock_responses)) as mock_create: + responses = list(self.llm.gen_stream("test_model", messages)) + self.assertListEqual(responses, ["response_1", "response_2"]) + + prompt_expected = "### Context \n context \n ### Question \n question" + mock_create.assert_called_with( + model="test_model", + prompt=f"{self.llm.HUMAN_PROMPT} {prompt_expected}{self.llm.AI_PROMPT}", + max_tokens_to_sample=300, + stream=True + ) + mock_redis_instance.set.assert_called_once() if __name__ == "__main__": unittest.main() diff --git a/tests/llm/test_sagemaker.py b/tests/llm/test_sagemaker.py index 0602f597..d659d498 100644 --- a/tests/llm/test_sagemaker.py +++ b/tests/llm/test_sagemaker.py @@ -52,28 +52,38 @@ class TestSagemakerAPILLM(unittest.TestCase): self.response['Body'].read.return_value.decode.return_value = json.dumps(self.result) def test_gen(self): - with patch.object(self.sagemaker.runtime, 'invoke_endpoint', - return_value=self.response) as mock_invoke_endpoint: - output = self.sagemaker.gen(None, self.messages) - mock_invoke_endpoint.assert_called_once_with( - EndpointName=self.sagemaker.endpoint, - ContentType='application/json', - Body=self.body_bytes - ) - self.assertEqual(output, - self.result[0]['generated_text'][len(self.prompt):]) + with patch('application.cache.get_redis_instance') as mock_make_redis: + mock_redis_instance = mock_make_redis.return_value + mock_redis_instance.get.return_value = None + + with patch.object(self.sagemaker.runtime, 'invoke_endpoint', + return_value=self.response) as mock_invoke_endpoint: + output = self.sagemaker.gen(None, self.messages) + mock_invoke_endpoint.assert_called_once_with( + EndpointName=self.sagemaker.endpoint, + ContentType='application/json', + Body=self.body_bytes + ) + self.assertEqual(output, + self.result[0]['generated_text'][len(self.prompt):]) + mock_make_redis.assert_called_once() + mock_redis_instance.set.assert_called_once() def test_gen_stream(self): - with patch.object(self.sagemaker.runtime, 'invoke_endpoint_with_response_stream', - return_value=self.response) as mock_invoke_endpoint: - output = list(self.sagemaker.gen_stream(None, self.messages)) - mock_invoke_endpoint.assert_called_once_with( - EndpointName=self.sagemaker.endpoint, - ContentType='application/json', - Body=self.body_bytes_stream - ) - self.assertEqual(output, []) - + with patch('application.cache.get_redis_instance') as mock_make_redis: + mock_redis_instance = mock_make_redis.return_value + mock_redis_instance.get.return_value = None + + with patch.object(self.sagemaker.runtime, 'invoke_endpoint_with_response_stream', + return_value=self.response) as mock_invoke_endpoint: + output = list(self.sagemaker.gen_stream(None, self.messages)) + mock_invoke_endpoint.assert_called_once_with( + EndpointName=self.sagemaker.endpoint, + ContentType='application/json', + Body=self.body_bytes_stream + ) + self.assertEqual(output, []) + mock_redis_instance.set.assert_called_once() class TestLineIterator(unittest.TestCase): def setUp(self): diff --git a/tests/test_cache.py b/tests/test_cache.py new file mode 100644 index 00000000..4270a181 --- /dev/null +++ b/tests/test_cache.py @@ -0,0 +1,131 @@ +import unittest +import json +from unittest.mock import patch, MagicMock +from application.cache import gen_cache_key, stream_cache, gen_cache +from application.utils import get_hash + + +# Test for gen_cache_key function +def test_make_gen_cache_key(): + messages = [ + {'role': 'user', 'content': 'test_user_message'}, + {'role': 'system', 'content': 'test_system_message'}, + ] + model = "test_docgpt" + + # Manually calculate the expected hash + expected_combined = f"{model}_{json.dumps(messages, sort_keys=True)}" + expected_hash = get_hash(expected_combined) + cache_key = gen_cache_key(*messages, model=model) + + assert cache_key == expected_hash + +def test_gen_cache_key_invalid_message_format(): + # Test when messages is not a list + with unittest.TestCase.assertRaises(unittest.TestCase, ValueError) as context: + gen_cache_key("This is not a list", model="docgpt") + assert str(context.exception) == "All messages must be dictionaries." + +# Test for gen_cache decorator +@patch('application.cache.get_redis_instance') # Mock the Redis client +def test_gen_cache_hit(mock_make_redis): + # Arrange + mock_redis_instance = MagicMock() + mock_make_redis.return_value = mock_redis_instance + mock_redis_instance.get.return_value = b"cached_result" # Simulate a cache hit + + @gen_cache + def mock_function(self, model, messages): + return "new_result" + + messages = [{'role': 'user', 'content': 'test_user_message'}] + model = "test_docgpt" + + # Act + result = mock_function(None, model, messages) + + # Assert + assert result == "cached_result" # Should return cached result + mock_redis_instance.get.assert_called_once() # Ensure Redis get was called + mock_redis_instance.set.assert_not_called() # Ensure the function result is not cached again + + +@patch('application.cache.get_redis_instance') # Mock the Redis client +def test_gen_cache_miss(mock_make_redis): + # Arrange + mock_redis_instance = MagicMock() + mock_make_redis.return_value = mock_redis_instance + mock_redis_instance.get.return_value = None # Simulate a cache miss + + @gen_cache + def mock_function(self, model, messages): + return "new_result" + + messages = [ + {'role': 'user', 'content': 'test_user_message'}, + {'role': 'system', 'content': 'test_system_message'}, + ] + model = "test_docgpt" + # Act + result = mock_function(None, model, messages) + + # Assert + assert result == "new_result" + mock_redis_instance.get.assert_called_once() + +@patch('application.cache.get_redis_instance') +def test_stream_cache_hit(mock_make_redis): + # Arrange + mock_redis_instance = MagicMock() + mock_make_redis.return_value = mock_redis_instance + + cached_chunk = json.dumps(["chunk1", "chunk2"]).encode('utf-8') + mock_redis_instance.get.return_value = cached_chunk + + @stream_cache + def mock_function(self, model, messages, stream): + yield "new_chunk" + + messages = [{'role': 'user', 'content': 'test_user_message'}] + model = "test_docgpt" + + # Act + result = list(mock_function(None, model, messages, stream=True)) + + # Assert + assert result == ["chunk1", "chunk2"] # Should return cached chunks + mock_redis_instance.get.assert_called_once() + mock_redis_instance.set.assert_not_called() + + +@patch('application.cache.get_redis_instance') +def test_stream_cache_miss(mock_make_redis): + # Arrange + mock_redis_instance = MagicMock() + mock_make_redis.return_value = mock_redis_instance + mock_redis_instance.get.return_value = None # Simulate a cache miss + + @stream_cache + def mock_function(self, model, messages, stream): + yield "new_chunk" + + messages = [ + {'role': 'user', 'content': 'This is the context'}, + {'role': 'system', 'content': 'Some other message'}, + {'role': 'user', 'content': 'What is the answer?'} + ] + model = "test_docgpt" + + # Act + result = list(mock_function(None, model, messages, stream=True)) + + # Assert + assert result == ["new_chunk"] + mock_redis_instance.get.assert_called_once() + mock_redis_instance.set.assert_called_once() + + + + + +