fix: tests

This commit is contained in:
Alex
2024-12-20 18:13:37 +00:00
parent c2a95b5bec
commit 1f75f0c082
5 changed files with 21 additions and 17 deletions

View File

@@ -12,18 +12,21 @@ def test_make_gen_cache_key():
{'role': 'system', 'content': 'test_system_message'},
]
model = "test_docgpt"
tools = None
# Manually calculate the expected hash
expected_combined = f"{model}_{json.dumps(messages, sort_keys=True)}"
messages_str = json.dumps(messages)
tools_str = json.dumps(tools) if tools else ""
expected_combined = f"{model}_{messages_str}_{tools_str}"
expected_hash = get_hash(expected_combined)
cache_key = gen_cache_key(*messages, model=model)
cache_key = gen_cache_key(messages, model=model, tools=None)
assert cache_key == expected_hash
def test_gen_cache_key_invalid_message_format():
# Test when messages is not a list
with unittest.TestCase.assertRaises(unittest.TestCase, ValueError) as context:
gen_cache_key("This is not a list", model="docgpt")
gen_cache_key("This is not a list", model="docgpt", tools=None)
assert str(context.exception) == "All messages must be dictionaries."
# Test for gen_cache decorator
@@ -35,14 +38,14 @@ def test_gen_cache_hit(mock_make_redis):
mock_redis_instance.get.return_value = b"cached_result" # Simulate a cache hit
@gen_cache
def mock_function(self, model, messages):
def mock_function(self, model, messages, stream, tools):
return "new_result"
messages = [{'role': 'user', 'content': 'test_user_message'}]
model = "test_docgpt"
# Act
result = mock_function(None, model, messages)
result = mock_function(None, model, messages, stream=False, tools=None)
# Assert
assert result == "cached_result" # Should return cached result
@@ -58,7 +61,7 @@ def test_gen_cache_miss(mock_make_redis):
mock_redis_instance.get.return_value = None # Simulate a cache miss
@gen_cache
def mock_function(self, model, messages):
def mock_function(self, model, messages, steam, tools):
return "new_result"
messages = [
@@ -67,7 +70,7 @@ def test_gen_cache_miss(mock_make_redis):
]
model = "test_docgpt"
# Act
result = mock_function(None, model, messages)
result = mock_function(None, model, messages, stream=False, tools=None)
# Assert
assert result == "new_result"
@@ -83,14 +86,14 @@ def test_stream_cache_hit(mock_make_redis):
mock_redis_instance.get.return_value = cached_chunk
@stream_cache
def mock_function(self, model, messages, stream):
def mock_function(self, model, messages, stream, tools):
yield "new_chunk"
messages = [{'role': 'user', 'content': 'test_user_message'}]
model = "test_docgpt"
# Act
result = list(mock_function(None, model, messages, stream=True))
result = list(mock_function(None, model, messages, stream=True, tools=None))
# Assert
assert result == ["chunk1", "chunk2"] # Should return cached chunks
@@ -106,7 +109,7 @@ def test_stream_cache_miss(mock_make_redis):
mock_redis_instance.get.return_value = None # Simulate a cache miss
@stream_cache
def mock_function(self, model, messages, stream):
def mock_function(self, model, messages, stream, tools):
yield "new_chunk"
messages = [
@@ -117,7 +120,7 @@ def test_stream_cache_miss(mock_make_redis):
model = "test_docgpt"
# Act
result = list(mock_function(None, model, messages, stream=True))
result = list(mock_function(None, model, messages, stream=True, tools=None))
# Assert
assert result == ["new_chunk"]