From 1f75f0c0821ca413a6617404a3c94f9c1d86c5ce Mon Sep 17 00:00:00 2001 From: Alex Date: Fri, 20 Dec 2024 18:13:37 +0000 Subject: [PATCH] fix: tests --- application/llm/anthropic.py | 4 ++-- application/llm/sagemaker.py | 4 ++-- tests/llm/test_anthropic.py | 3 ++- tests/llm/test_sagemaker.py | 2 +- tests/test_cache.py | 25 ++++++++++++++----------- 5 files changed, 21 insertions(+), 17 deletions(-) diff --git a/application/llm/anthropic.py b/application/llm/anthropic.py index 4081bcd0..1fa3b5b2 100644 --- a/application/llm/anthropic.py +++ b/application/llm/anthropic.py @@ -17,7 +17,7 @@ class AnthropicLLM(BaseLLM): self.AI_PROMPT = AI_PROMPT def _raw_gen( - self, baseself, model, messages, stream=False, max_tokens=300, **kwargs + self, baseself, model, messages, stream=False, tools=None, max_tokens=300, **kwargs ): context = messages[0]["content"] user_question = messages[-1]["content"] @@ -34,7 +34,7 @@ class AnthropicLLM(BaseLLM): return completion.completion def _raw_gen_stream( - self, baseself, model, messages, stream=True, max_tokens=300, **kwargs + self, baseself, model, messages, stream=True, tools=None, max_tokens=300, **kwargs ): context = messages[0]["content"] user_question = messages[-1]["content"] diff --git a/application/llm/sagemaker.py b/application/llm/sagemaker.py index 63947430..aaf99a12 100644 --- a/application/llm/sagemaker.py +++ b/application/llm/sagemaker.py @@ -76,7 +76,7 @@ class SagemakerAPILLM(BaseLLM): self.endpoint = settings.SAGEMAKER_ENDPOINT self.runtime = runtime - def _raw_gen(self, baseself, model, messages, stream=False, **kwargs): + def _raw_gen(self, baseself, model, messages, stream=False, tools=None, **kwargs): context = messages[0]["content"] user_question = messages[-1]["content"] prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n" @@ -105,7 +105,7 @@ class SagemakerAPILLM(BaseLLM): print(result[0]["generated_text"], file=sys.stderr) return result[0]["generated_text"][len(prompt) :] - def _raw_gen_stream(self, baseself, model, messages, stream=True, **kwargs): + def _raw_gen_stream(self, baseself, model, messages, stream=True, tools=None, **kwargs): context = messages[0]["content"] user_question = messages[-1]["content"] prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n" diff --git a/tests/llm/test_anthropic.py b/tests/llm/test_anthropic.py index 689013c0..50ddbe29 100644 --- a/tests/llm/test_anthropic.py +++ b/tests/llm/test_anthropic.py @@ -46,6 +46,7 @@ class TestAnthropicLLM(unittest.TestCase): {"content": "question"} ] mock_responses = [Mock(completion="response_1"), Mock(completion="response_2")] + mock_tools = Mock() with patch("application.cache.get_redis_instance") as mock_make_redis: mock_redis_instance = mock_make_redis.return_value @@ -53,7 +54,7 @@ class TestAnthropicLLM(unittest.TestCase): mock_redis_instance.set = Mock() with patch.object(self.llm.anthropic.completions, "create", return_value=iter(mock_responses)) as mock_create: - responses = list(self.llm.gen_stream("test_model", messages)) + responses = list(self.llm.gen_stream("test_model", messages, tools=mock_tools)) self.assertListEqual(responses, ["response_1", "response_2"]) prompt_expected = "### Context \n context \n ### Question \n question" diff --git a/tests/llm/test_sagemaker.py b/tests/llm/test_sagemaker.py index d659d498..2b893a9a 100644 --- a/tests/llm/test_sagemaker.py +++ b/tests/llm/test_sagemaker.py @@ -76,7 +76,7 @@ class TestSagemakerAPILLM(unittest.TestCase): with patch.object(self.sagemaker.runtime, 'invoke_endpoint_with_response_stream', return_value=self.response) as mock_invoke_endpoint: - output = list(self.sagemaker.gen_stream(None, self.messages)) + output = list(self.sagemaker.gen_stream(None, self.messages, tools=None)) mock_invoke_endpoint.assert_called_once_with( EndpointName=self.sagemaker.endpoint, ContentType='application/json', diff --git a/tests/test_cache.py b/tests/test_cache.py index 4270a181..af2b5e00 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -12,18 +12,21 @@ def test_make_gen_cache_key(): {'role': 'system', 'content': 'test_system_message'}, ] model = "test_docgpt" + tools = None # Manually calculate the expected hash - expected_combined = f"{model}_{json.dumps(messages, sort_keys=True)}" + messages_str = json.dumps(messages) + tools_str = json.dumps(tools) if tools else "" + expected_combined = f"{model}_{messages_str}_{tools_str}" expected_hash = get_hash(expected_combined) - cache_key = gen_cache_key(*messages, model=model) + cache_key = gen_cache_key(messages, model=model, tools=None) assert cache_key == expected_hash def test_gen_cache_key_invalid_message_format(): # Test when messages is not a list with unittest.TestCase.assertRaises(unittest.TestCase, ValueError) as context: - gen_cache_key("This is not a list", model="docgpt") + gen_cache_key("This is not a list", model="docgpt", tools=None) assert str(context.exception) == "All messages must be dictionaries." # Test for gen_cache decorator @@ -35,14 +38,14 @@ def test_gen_cache_hit(mock_make_redis): mock_redis_instance.get.return_value = b"cached_result" # Simulate a cache hit @gen_cache - def mock_function(self, model, messages): + def mock_function(self, model, messages, stream, tools): return "new_result" messages = [{'role': 'user', 'content': 'test_user_message'}] model = "test_docgpt" # Act - result = mock_function(None, model, messages) + result = mock_function(None, model, messages, stream=False, tools=None) # Assert assert result == "cached_result" # Should return cached result @@ -58,7 +61,7 @@ def test_gen_cache_miss(mock_make_redis): mock_redis_instance.get.return_value = None # Simulate a cache miss @gen_cache - def mock_function(self, model, messages): + def mock_function(self, model, messages, steam, tools): return "new_result" messages = [ @@ -67,7 +70,7 @@ def test_gen_cache_miss(mock_make_redis): ] model = "test_docgpt" # Act - result = mock_function(None, model, messages) + result = mock_function(None, model, messages, stream=False, tools=None) # Assert assert result == "new_result" @@ -83,14 +86,14 @@ def test_stream_cache_hit(mock_make_redis): mock_redis_instance.get.return_value = cached_chunk @stream_cache - def mock_function(self, model, messages, stream): + def mock_function(self, model, messages, stream, tools): yield "new_chunk" messages = [{'role': 'user', 'content': 'test_user_message'}] model = "test_docgpt" # Act - result = list(mock_function(None, model, messages, stream=True)) + result = list(mock_function(None, model, messages, stream=True, tools=None)) # Assert assert result == ["chunk1", "chunk2"] # Should return cached chunks @@ -106,7 +109,7 @@ def test_stream_cache_miss(mock_make_redis): mock_redis_instance.get.return_value = None # Simulate a cache miss @stream_cache - def mock_function(self, model, messages, stream): + def mock_function(self, model, messages, stream, tools): yield "new_chunk" messages = [ @@ -117,7 +120,7 @@ def test_stream_cache_miss(mock_make_redis): model = "test_docgpt" # Act - result = list(mock_function(None, model, messages, stream=True)) + result = list(mock_function(None, model, messages, stream=True, tools=None)) # Assert assert result == ["new_chunk"]