add redis mock to anthropic and sagemaker

This commit is contained in:
fadingNA
2024-10-14 11:54:22 -04:00
parent 3e32724729
commit adb2cf35d4
2 changed files with 62 additions and 40 deletions

View File

@@ -22,17 +22,23 @@ class TestAnthropicLLM(unittest.TestCase):
mock_response = Mock()
mock_response.completion = "test completion"
with patch.object(self.llm.anthropic.completions, "create", return_value=mock_response) as mock_create:
response = self.llm.gen("test_model", messages)
self.assertEqual(response, "test completion")
with patch("application.cache.make_redis") as mock_make_redis:
mock_redis_instance = mock_make_redis.return_value
mock_redis_instance.get.return_value = None
mock_redis_instance.set = Mock()
prompt_expected = "### Context \n context \n ### Question \n question"
mock_create.assert_called_with(
model="test_model",
max_tokens_to_sample=300,
stream=False,
prompt=f"{self.llm.HUMAN_PROMPT} {prompt_expected}{self.llm.AI_PROMPT}"
)
with patch.object(self.llm.anthropic.completions, "create", return_value=mock_response) as mock_create:
response = self.llm.gen("test_model", messages)
self.assertEqual(response, "test completion")
prompt_expected = "### Context \n context \n ### Question \n question"
mock_create.assert_called_with(
model="test_model",
max_tokens_to_sample=300,
stream=False,
prompt=f"{self.llm.HUMAN_PROMPT} {prompt_expected}{self.llm.AI_PROMPT}"
)
mock_redis_instance.set.assert_called_once()
def test_gen_stream(self):
messages = [
@@ -41,17 +47,23 @@ class TestAnthropicLLM(unittest.TestCase):
]
mock_responses = [Mock(completion="response_1"), Mock(completion="response_2")]
with patch.object(self.llm.anthropic.completions, "create", return_value=iter(mock_responses)) as mock_create:
responses = list(self.llm.gen_stream("test_model", messages))
self.assertListEqual(responses, ["response_1", "response_2"])
with patch("application.cache.make_redis") as mock_make_redis:
mock_redis_instance = mock_make_redis.return_value
mock_redis_instance.get.return_value = None
mock_redis_instance.set = Mock()
prompt_expected = "### Context \n context \n ### Question \n question"
mock_create.assert_called_with(
model="test_model",
prompt=f"{self.llm.HUMAN_PROMPT} {prompt_expected}{self.llm.AI_PROMPT}",
max_tokens_to_sample=300,
stream=True
)
with patch.object(self.llm.anthropic.completions, "create", return_value=iter(mock_responses)) as mock_create:
responses = list(self.llm.gen_stream("test_model", messages))
self.assertListEqual(responses, ["response_1", "response_2"])
prompt_expected = "### Context \n context \n ### Question \n question"
mock_create.assert_called_with(
model="test_model",
prompt=f"{self.llm.HUMAN_PROMPT} {prompt_expected}{self.llm.AI_PROMPT}",
max_tokens_to_sample=300,
stream=True
)
mock_redis_instance.set.assert_called_once()
if __name__ == "__main__":
unittest.main()

View File

@@ -52,28 +52,38 @@ class TestSagemakerAPILLM(unittest.TestCase):
self.response['Body'].read.return_value.decode.return_value = json.dumps(self.result)
def test_gen(self):
with patch.object(self.sagemaker.runtime, 'invoke_endpoint',
return_value=self.response) as mock_invoke_endpoint:
output = self.sagemaker.gen(None, self.messages)
mock_invoke_endpoint.assert_called_once_with(
EndpointName=self.sagemaker.endpoint,
ContentType='application/json',
Body=self.body_bytes
)
self.assertEqual(output,
self.result[0]['generated_text'][len(self.prompt):])
with patch('application.cache.make_redis') as mock_make_redis:
mock_redis_instance = mock_make_redis.return_value
mock_redis_instance.get.return_value = None
with patch.object(self.sagemaker.runtime, 'invoke_endpoint',
return_value=self.response) as mock_invoke_endpoint:
output = self.sagemaker.gen(None, self.messages)
mock_invoke_endpoint.assert_called_once_with(
EndpointName=self.sagemaker.endpoint,
ContentType='application/json',
Body=self.body_bytes
)
self.assertEqual(output,
self.result[0]['generated_text'][len(self.prompt):])
mock_make_redis.assert_called_once()
mock_redis_instance.set.assert_called_once()
def test_gen_stream(self):
with patch.object(self.sagemaker.runtime, 'invoke_endpoint_with_response_stream',
return_value=self.response) as mock_invoke_endpoint:
output = list(self.sagemaker.gen_stream(None, self.messages))
mock_invoke_endpoint.assert_called_once_with(
EndpointName=self.sagemaker.endpoint,
ContentType='application/json',
Body=self.body_bytes_stream
)
self.assertEqual(output, [])
with patch('application.cache.make_redis') as mock_make_redis:
mock_redis_instance = mock_make_redis.return_value
mock_redis_instance.get.return_value = None
with patch.object(self.sagemaker.runtime, 'invoke_endpoint_with_response_stream',
return_value=self.response) as mock_invoke_endpoint:
output = list(self.sagemaker.gen_stream(None, self.messages))
mock_invoke_endpoint.assert_called_once_with(
EndpointName=self.sagemaker.endpoint,
ContentType='application/json',
Body=self.body_bytes_stream
)
self.assertEqual(output, [])
mock_redis_instance.set.assert_called_once()
class TestLineIterator(unittest.TestCase):
def setUp(self):