diff --git a/tests/llm/test_anthropic.py b/tests/llm/test_anthropic.py index ee4ba15f..f4c84afa 100644 --- a/tests/llm/test_anthropic.py +++ b/tests/llm/test_anthropic.py @@ -22,17 +22,23 @@ class TestAnthropicLLM(unittest.TestCase): mock_response = Mock() mock_response.completion = "test completion" - with patch.object(self.llm.anthropic.completions, "create", return_value=mock_response) as mock_create: - response = self.llm.gen("test_model", messages) - self.assertEqual(response, "test completion") + with patch("application.cache.make_redis") as mock_make_redis: + mock_redis_instance = mock_make_redis.return_value + mock_redis_instance.get.return_value = None + mock_redis_instance.set = Mock() - prompt_expected = "### Context \n context \n ### Question \n question" - mock_create.assert_called_with( - model="test_model", - max_tokens_to_sample=300, - stream=False, - prompt=f"{self.llm.HUMAN_PROMPT} {prompt_expected}{self.llm.AI_PROMPT}" - ) + with patch.object(self.llm.anthropic.completions, "create", return_value=mock_response) as mock_create: + response = self.llm.gen("test_model", messages) + self.assertEqual(response, "test completion") + + prompt_expected = "### Context \n context \n ### Question \n question" + mock_create.assert_called_with( + model="test_model", + max_tokens_to_sample=300, + stream=False, + prompt=f"{self.llm.HUMAN_PROMPT} {prompt_expected}{self.llm.AI_PROMPT}" + ) + mock_redis_instance.set.assert_called_once() def test_gen_stream(self): messages = [ @@ -41,17 +47,23 @@ class TestAnthropicLLM(unittest.TestCase): ] mock_responses = [Mock(completion="response_1"), Mock(completion="response_2")] - with patch.object(self.llm.anthropic.completions, "create", return_value=iter(mock_responses)) as mock_create: - responses = list(self.llm.gen_stream("test_model", messages)) - self.assertListEqual(responses, ["response_1", "response_2"]) + with patch("application.cache.make_redis") as mock_make_redis: + mock_redis_instance = mock_make_redis.return_value + mock_redis_instance.get.return_value = None + mock_redis_instance.set = Mock() - prompt_expected = "### Context \n context \n ### Question \n question" - mock_create.assert_called_with( - model="test_model", - prompt=f"{self.llm.HUMAN_PROMPT} {prompt_expected}{self.llm.AI_PROMPT}", - max_tokens_to_sample=300, - stream=True - ) + with patch.object(self.llm.anthropic.completions, "create", return_value=iter(mock_responses)) as mock_create: + responses = list(self.llm.gen_stream("test_model", messages)) + self.assertListEqual(responses, ["response_1", "response_2"]) + + prompt_expected = "### Context \n context \n ### Question \n question" + mock_create.assert_called_with( + model="test_model", + prompt=f"{self.llm.HUMAN_PROMPT} {prompt_expected}{self.llm.AI_PROMPT}", + max_tokens_to_sample=300, + stream=True + ) + mock_redis_instance.set.assert_called_once() if __name__ == "__main__": unittest.main() diff --git a/tests/llm/test_sagemaker.py b/tests/llm/test_sagemaker.py index 0602f597..e45d4c0d 100644 --- a/tests/llm/test_sagemaker.py +++ b/tests/llm/test_sagemaker.py @@ -52,28 +52,38 @@ class TestSagemakerAPILLM(unittest.TestCase): self.response['Body'].read.return_value.decode.return_value = json.dumps(self.result) def test_gen(self): - with patch.object(self.sagemaker.runtime, 'invoke_endpoint', - return_value=self.response) as mock_invoke_endpoint: - output = self.sagemaker.gen(None, self.messages) - mock_invoke_endpoint.assert_called_once_with( - EndpointName=self.sagemaker.endpoint, - ContentType='application/json', - Body=self.body_bytes - ) - self.assertEqual(output, - self.result[0]['generated_text'][len(self.prompt):]) + with patch('application.cache.make_redis') as mock_make_redis: + mock_redis_instance = mock_make_redis.return_value + mock_redis_instance.get.return_value = None + + with patch.object(self.sagemaker.runtime, 'invoke_endpoint', + return_value=self.response) as mock_invoke_endpoint: + output = self.sagemaker.gen(None, self.messages) + mock_invoke_endpoint.assert_called_once_with( + EndpointName=self.sagemaker.endpoint, + ContentType='application/json', + Body=self.body_bytes + ) + self.assertEqual(output, + self.result[0]['generated_text'][len(self.prompt):]) + mock_make_redis.assert_called_once() + mock_redis_instance.set.assert_called_once() def test_gen_stream(self): - with patch.object(self.sagemaker.runtime, 'invoke_endpoint_with_response_stream', - return_value=self.response) as mock_invoke_endpoint: - output = list(self.sagemaker.gen_stream(None, self.messages)) - mock_invoke_endpoint.assert_called_once_with( - EndpointName=self.sagemaker.endpoint, - ContentType='application/json', - Body=self.body_bytes_stream - ) - self.assertEqual(output, []) - + with patch('application.cache.make_redis') as mock_make_redis: + mock_redis_instance = mock_make_redis.return_value + mock_redis_instance.get.return_value = None + + with patch.object(self.sagemaker.runtime, 'invoke_endpoint_with_response_stream', + return_value=self.response) as mock_invoke_endpoint: + output = list(self.sagemaker.gen_stream(None, self.messages)) + mock_invoke_endpoint.assert_called_once_with( + EndpointName=self.sagemaker.endpoint, + ContentType='application/json', + Body=self.body_bytes_stream + ) + self.assertEqual(output, []) + mock_redis_instance.set.assert_called_once() class TestLineIterator(unittest.TestCase): def setUp(self):