mirror of
https://github.com/arc53/DocsGPT.git
synced 2026-03-04 21:03:41 +00:00
resolved merge conflicts
This commit is contained in:
@@ -22,17 +22,23 @@ class TestAnthropicLLM(unittest.TestCase):
|
||||
mock_response = Mock()
|
||||
mock_response.completion = "test completion"
|
||||
|
||||
with patch.object(self.llm.anthropic.completions, "create", return_value=mock_response) as mock_create:
|
||||
response = self.llm.gen("test_model", messages)
|
||||
self.assertEqual(response, "test completion")
|
||||
with patch("application.cache.get_redis_instance") as mock_make_redis:
|
||||
mock_redis_instance = mock_make_redis.return_value
|
||||
mock_redis_instance.get.return_value = None
|
||||
mock_redis_instance.set = Mock()
|
||||
|
||||
prompt_expected = "### Context \n context \n ### Question \n question"
|
||||
mock_create.assert_called_with(
|
||||
model="test_model",
|
||||
max_tokens_to_sample=300,
|
||||
stream=False,
|
||||
prompt=f"{self.llm.HUMAN_PROMPT} {prompt_expected}{self.llm.AI_PROMPT}"
|
||||
)
|
||||
with patch.object(self.llm.anthropic.completions, "create", return_value=mock_response) as mock_create:
|
||||
response = self.llm.gen("test_model", messages)
|
||||
self.assertEqual(response, "test completion")
|
||||
|
||||
prompt_expected = "### Context \n context \n ### Question \n question"
|
||||
mock_create.assert_called_with(
|
||||
model="test_model",
|
||||
max_tokens_to_sample=300,
|
||||
stream=False,
|
||||
prompt=f"{self.llm.HUMAN_PROMPT} {prompt_expected}{self.llm.AI_PROMPT}"
|
||||
)
|
||||
mock_redis_instance.set.assert_called_once()
|
||||
|
||||
def test_gen_stream(self):
|
||||
messages = [
|
||||
@@ -41,17 +47,23 @@ class TestAnthropicLLM(unittest.TestCase):
|
||||
]
|
||||
mock_responses = [Mock(completion="response_1"), Mock(completion="response_2")]
|
||||
|
||||
with patch.object(self.llm.anthropic.completions, "create", return_value=iter(mock_responses)) as mock_create:
|
||||
responses = list(self.llm.gen_stream("test_model", messages))
|
||||
self.assertListEqual(responses, ["response_1", "response_2"])
|
||||
with patch("application.cache.get_redis_instance") as mock_make_redis:
|
||||
mock_redis_instance = mock_make_redis.return_value
|
||||
mock_redis_instance.get.return_value = None
|
||||
mock_redis_instance.set = Mock()
|
||||
|
||||
prompt_expected = "### Context \n context \n ### Question \n question"
|
||||
mock_create.assert_called_with(
|
||||
model="test_model",
|
||||
prompt=f"{self.llm.HUMAN_PROMPT} {prompt_expected}{self.llm.AI_PROMPT}",
|
||||
max_tokens_to_sample=300,
|
||||
stream=True
|
||||
)
|
||||
with patch.object(self.llm.anthropic.completions, "create", return_value=iter(mock_responses)) as mock_create:
|
||||
responses = list(self.llm.gen_stream("test_model", messages))
|
||||
self.assertListEqual(responses, ["response_1", "response_2"])
|
||||
|
||||
prompt_expected = "### Context \n context \n ### Question \n question"
|
||||
mock_create.assert_called_with(
|
||||
model="test_model",
|
||||
prompt=f"{self.llm.HUMAN_PROMPT} {prompt_expected}{self.llm.AI_PROMPT}",
|
||||
max_tokens_to_sample=300,
|
||||
stream=True
|
||||
)
|
||||
mock_redis_instance.set.assert_called_once()
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -52,28 +52,38 @@ class TestSagemakerAPILLM(unittest.TestCase):
|
||||
self.response['Body'].read.return_value.decode.return_value = json.dumps(self.result)
|
||||
|
||||
def test_gen(self):
|
||||
with patch.object(self.sagemaker.runtime, 'invoke_endpoint',
|
||||
return_value=self.response) as mock_invoke_endpoint:
|
||||
output = self.sagemaker.gen(None, self.messages)
|
||||
mock_invoke_endpoint.assert_called_once_with(
|
||||
EndpointName=self.sagemaker.endpoint,
|
||||
ContentType='application/json',
|
||||
Body=self.body_bytes
|
||||
)
|
||||
self.assertEqual(output,
|
||||
self.result[0]['generated_text'][len(self.prompt):])
|
||||
with patch('application.cache.get_redis_instance') as mock_make_redis:
|
||||
mock_redis_instance = mock_make_redis.return_value
|
||||
mock_redis_instance.get.return_value = None
|
||||
|
||||
with patch.object(self.sagemaker.runtime, 'invoke_endpoint',
|
||||
return_value=self.response) as mock_invoke_endpoint:
|
||||
output = self.sagemaker.gen(None, self.messages)
|
||||
mock_invoke_endpoint.assert_called_once_with(
|
||||
EndpointName=self.sagemaker.endpoint,
|
||||
ContentType='application/json',
|
||||
Body=self.body_bytes
|
||||
)
|
||||
self.assertEqual(output,
|
||||
self.result[0]['generated_text'][len(self.prompt):])
|
||||
mock_make_redis.assert_called_once()
|
||||
mock_redis_instance.set.assert_called_once()
|
||||
|
||||
def test_gen_stream(self):
|
||||
with patch.object(self.sagemaker.runtime, 'invoke_endpoint_with_response_stream',
|
||||
return_value=self.response) as mock_invoke_endpoint:
|
||||
output = list(self.sagemaker.gen_stream(None, self.messages))
|
||||
mock_invoke_endpoint.assert_called_once_with(
|
||||
EndpointName=self.sagemaker.endpoint,
|
||||
ContentType='application/json',
|
||||
Body=self.body_bytes_stream
|
||||
)
|
||||
self.assertEqual(output, [])
|
||||
|
||||
with patch('application.cache.get_redis_instance') as mock_make_redis:
|
||||
mock_redis_instance = mock_make_redis.return_value
|
||||
mock_redis_instance.get.return_value = None
|
||||
|
||||
with patch.object(self.sagemaker.runtime, 'invoke_endpoint_with_response_stream',
|
||||
return_value=self.response) as mock_invoke_endpoint:
|
||||
output = list(self.sagemaker.gen_stream(None, self.messages))
|
||||
mock_invoke_endpoint.assert_called_once_with(
|
||||
EndpointName=self.sagemaker.endpoint,
|
||||
ContentType='application/json',
|
||||
Body=self.body_bytes_stream
|
||||
)
|
||||
self.assertEqual(output, [])
|
||||
mock_redis_instance.set.assert_called_once()
|
||||
class TestLineIterator(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
|
||||
131
tests/test_cache.py
Normal file
131
tests/test_cache.py
Normal file
@@ -0,0 +1,131 @@
|
||||
import unittest
|
||||
import json
|
||||
from unittest.mock import patch, MagicMock
|
||||
from application.cache import gen_cache_key, stream_cache, gen_cache
|
||||
from application.utils import get_hash
|
||||
|
||||
|
||||
# Test for gen_cache_key function
|
||||
def test_make_gen_cache_key():
|
||||
messages = [
|
||||
{'role': 'user', 'content': 'test_user_message'},
|
||||
{'role': 'system', 'content': 'test_system_message'},
|
||||
]
|
||||
model = "test_docgpt"
|
||||
|
||||
# Manually calculate the expected hash
|
||||
expected_combined = f"{model}_{json.dumps(messages, sort_keys=True)}"
|
||||
expected_hash = get_hash(expected_combined)
|
||||
cache_key = gen_cache_key(*messages, model=model)
|
||||
|
||||
assert cache_key == expected_hash
|
||||
|
||||
def test_gen_cache_key_invalid_message_format():
|
||||
# Test when messages is not a list
|
||||
with unittest.TestCase.assertRaises(unittest.TestCase, ValueError) as context:
|
||||
gen_cache_key("This is not a list", model="docgpt")
|
||||
assert str(context.exception) == "All messages must be dictionaries."
|
||||
|
||||
# Test for gen_cache decorator
|
||||
@patch('application.cache.get_redis_instance') # Mock the Redis client
|
||||
def test_gen_cache_hit(mock_make_redis):
|
||||
# Arrange
|
||||
mock_redis_instance = MagicMock()
|
||||
mock_make_redis.return_value = mock_redis_instance
|
||||
mock_redis_instance.get.return_value = b"cached_result" # Simulate a cache hit
|
||||
|
||||
@gen_cache
|
||||
def mock_function(self, model, messages):
|
||||
return "new_result"
|
||||
|
||||
messages = [{'role': 'user', 'content': 'test_user_message'}]
|
||||
model = "test_docgpt"
|
||||
|
||||
# Act
|
||||
result = mock_function(None, model, messages)
|
||||
|
||||
# Assert
|
||||
assert result == "cached_result" # Should return cached result
|
||||
mock_redis_instance.get.assert_called_once() # Ensure Redis get was called
|
||||
mock_redis_instance.set.assert_not_called() # Ensure the function result is not cached again
|
||||
|
||||
|
||||
@patch('application.cache.get_redis_instance') # Mock the Redis client
|
||||
def test_gen_cache_miss(mock_make_redis):
|
||||
# Arrange
|
||||
mock_redis_instance = MagicMock()
|
||||
mock_make_redis.return_value = mock_redis_instance
|
||||
mock_redis_instance.get.return_value = None # Simulate a cache miss
|
||||
|
||||
@gen_cache
|
||||
def mock_function(self, model, messages):
|
||||
return "new_result"
|
||||
|
||||
messages = [
|
||||
{'role': 'user', 'content': 'test_user_message'},
|
||||
{'role': 'system', 'content': 'test_system_message'},
|
||||
]
|
||||
model = "test_docgpt"
|
||||
# Act
|
||||
result = mock_function(None, model, messages)
|
||||
|
||||
# Assert
|
||||
assert result == "new_result"
|
||||
mock_redis_instance.get.assert_called_once()
|
||||
|
||||
@patch('application.cache.get_redis_instance')
|
||||
def test_stream_cache_hit(mock_make_redis):
|
||||
# Arrange
|
||||
mock_redis_instance = MagicMock()
|
||||
mock_make_redis.return_value = mock_redis_instance
|
||||
|
||||
cached_chunk = json.dumps(["chunk1", "chunk2"]).encode('utf-8')
|
||||
mock_redis_instance.get.return_value = cached_chunk
|
||||
|
||||
@stream_cache
|
||||
def mock_function(self, model, messages, stream):
|
||||
yield "new_chunk"
|
||||
|
||||
messages = [{'role': 'user', 'content': 'test_user_message'}]
|
||||
model = "test_docgpt"
|
||||
|
||||
# Act
|
||||
result = list(mock_function(None, model, messages, stream=True))
|
||||
|
||||
# Assert
|
||||
assert result == ["chunk1", "chunk2"] # Should return cached chunks
|
||||
mock_redis_instance.get.assert_called_once()
|
||||
mock_redis_instance.set.assert_not_called()
|
||||
|
||||
|
||||
@patch('application.cache.get_redis_instance')
|
||||
def test_stream_cache_miss(mock_make_redis):
|
||||
# Arrange
|
||||
mock_redis_instance = MagicMock()
|
||||
mock_make_redis.return_value = mock_redis_instance
|
||||
mock_redis_instance.get.return_value = None # Simulate a cache miss
|
||||
|
||||
@stream_cache
|
||||
def mock_function(self, model, messages, stream):
|
||||
yield "new_chunk"
|
||||
|
||||
messages = [
|
||||
{'role': 'user', 'content': 'This is the context'},
|
||||
{'role': 'system', 'content': 'Some other message'},
|
||||
{'role': 'user', 'content': 'What is the answer?'}
|
||||
]
|
||||
model = "test_docgpt"
|
||||
|
||||
# Act
|
||||
result = list(mock_function(None, model, messages, stream=True))
|
||||
|
||||
# Assert
|
||||
assert result == ["new_chunk"]
|
||||
mock_redis_instance.get.assert_called_once()
|
||||
mock_redis_instance.set.assert_called_once()
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user