mirror of
https://github.com/arc53/DocsGPT.git
synced 2025-11-29 08:33:20 +00:00
(tests:llm) llms, handlers
This commit is contained in:
232
tests/llm/handlers/test_base.py
Normal file
232
tests/llm/handlers/test_base.py
Normal file
@@ -0,0 +1,232 @@
|
|||||||
|
import pytest
|
||||||
|
from unittest.mock import Mock, patch
|
||||||
|
from typing import Any, Dict, Generator
|
||||||
|
|
||||||
|
from application.llm.handlers.base import LLMHandler, LLMResponse, ToolCall
|
||||||
|
|
||||||
|
|
||||||
|
class TestToolCall:
|
||||||
|
"""Test ToolCall dataclass."""
|
||||||
|
|
||||||
|
def test_tool_call_creation(self):
|
||||||
|
"""Test basic ToolCall creation."""
|
||||||
|
tool_call = ToolCall(
|
||||||
|
id="test_id",
|
||||||
|
name="test_function",
|
||||||
|
arguments={"arg1": "value1"},
|
||||||
|
index=0
|
||||||
|
)
|
||||||
|
assert tool_call.id == "test_id"
|
||||||
|
assert tool_call.name == "test_function"
|
||||||
|
assert tool_call.arguments == {"arg1": "value1"}
|
||||||
|
assert tool_call.index == 0
|
||||||
|
|
||||||
|
def test_tool_call_from_dict(self):
|
||||||
|
"""Test ToolCall creation from dictionary."""
|
||||||
|
data = {
|
||||||
|
"id": "call_123",
|
||||||
|
"name": "get_weather",
|
||||||
|
"arguments": {"location": "New York"},
|
||||||
|
"index": 1
|
||||||
|
}
|
||||||
|
tool_call = ToolCall.from_dict(data)
|
||||||
|
assert tool_call.id == "call_123"
|
||||||
|
assert tool_call.name == "get_weather"
|
||||||
|
assert tool_call.arguments == {"location": "New York"}
|
||||||
|
assert tool_call.index == 1
|
||||||
|
|
||||||
|
def test_tool_call_from_dict_missing_fields(self):
|
||||||
|
"""Test ToolCall creation with missing fields."""
|
||||||
|
data = {"name": "test_func"}
|
||||||
|
tool_call = ToolCall.from_dict(data)
|
||||||
|
assert tool_call.id == ""
|
||||||
|
assert tool_call.name == "test_func"
|
||||||
|
assert tool_call.arguments == {}
|
||||||
|
assert tool_call.index is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestLLMResponse:
|
||||||
|
"""Test LLMResponse dataclass."""
|
||||||
|
|
||||||
|
def test_llm_response_creation(self):
|
||||||
|
"""Test basic LLMResponse creation."""
|
||||||
|
tool_calls = [ToolCall(id="1", name="func", arguments={})]
|
||||||
|
response = LLMResponse(
|
||||||
|
content="Hello",
|
||||||
|
tool_calls=tool_calls,
|
||||||
|
finish_reason="tool_calls",
|
||||||
|
raw_response={"test": "data"}
|
||||||
|
)
|
||||||
|
assert response.content == "Hello"
|
||||||
|
assert len(response.tool_calls) == 1
|
||||||
|
assert response.finish_reason == "tool_calls"
|
||||||
|
assert response.raw_response == {"test": "data"}
|
||||||
|
|
||||||
|
def test_requires_tool_call_true(self):
|
||||||
|
"""Test requires_tool_call property when tool calls are needed."""
|
||||||
|
tool_calls = [ToolCall(id="1", name="func", arguments={})]
|
||||||
|
response = LLMResponse(
|
||||||
|
content="",
|
||||||
|
tool_calls=tool_calls,
|
||||||
|
finish_reason="tool_calls",
|
||||||
|
raw_response={}
|
||||||
|
)
|
||||||
|
assert response.requires_tool_call is True
|
||||||
|
|
||||||
|
def test_requires_tool_call_false_no_tools(self):
|
||||||
|
"""Test requires_tool_call property when no tool calls."""
|
||||||
|
response = LLMResponse(
|
||||||
|
content="Hello",
|
||||||
|
tool_calls=[],
|
||||||
|
finish_reason="stop",
|
||||||
|
raw_response={}
|
||||||
|
)
|
||||||
|
assert response.requires_tool_call is False
|
||||||
|
|
||||||
|
def test_requires_tool_call_false_wrong_finish_reason(self):
|
||||||
|
"""Test requires_tool_call property with tools but wrong finish reason."""
|
||||||
|
tool_calls = [ToolCall(id="1", name="func", arguments={})]
|
||||||
|
response = LLMResponse(
|
||||||
|
content="Hello",
|
||||||
|
tool_calls=tool_calls,
|
||||||
|
finish_reason="stop",
|
||||||
|
raw_response={}
|
||||||
|
)
|
||||||
|
assert response.requires_tool_call is False
|
||||||
|
|
||||||
|
|
||||||
|
class ConcreteHandler(LLMHandler):
|
||||||
|
"""Concrete implementation for testing abstract base class."""
|
||||||
|
|
||||||
|
def parse_response(self, response: Any) -> LLMResponse:
|
||||||
|
return LLMResponse(
|
||||||
|
content=str(response),
|
||||||
|
tool_calls=[],
|
||||||
|
finish_reason="stop",
|
||||||
|
raw_response=response
|
||||||
|
)
|
||||||
|
|
||||||
|
def create_tool_message(self, tool_call: ToolCall, result: Any) -> Dict:
|
||||||
|
return {
|
||||||
|
"role": "tool",
|
||||||
|
"content": str(result),
|
||||||
|
"tool_call_id": tool_call.id
|
||||||
|
}
|
||||||
|
|
||||||
|
def _iterate_stream(self, response: Any) -> Generator:
|
||||||
|
for chunk in response:
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
|
||||||
|
class TestLLMHandler:
|
||||||
|
"""Test LLMHandler base class."""
|
||||||
|
|
||||||
|
def test_handler_initialization(self):
|
||||||
|
"""Test handler initialization."""
|
||||||
|
handler = ConcreteHandler()
|
||||||
|
assert handler.llm_calls == []
|
||||||
|
assert handler.tool_calls == []
|
||||||
|
|
||||||
|
def test_prepare_messages_no_attachments(self):
|
||||||
|
"""Test prepare_messages with no attachments."""
|
||||||
|
handler = ConcreteHandler()
|
||||||
|
messages = [{"role": "user", "content": "Hello"}]
|
||||||
|
|
||||||
|
mock_agent = Mock()
|
||||||
|
result = handler.prepare_messages(mock_agent, messages, None)
|
||||||
|
assert result == messages
|
||||||
|
|
||||||
|
def test_prepare_messages_with_supported_attachments(self):
|
||||||
|
"""Test prepare_messages with supported attachments."""
|
||||||
|
handler = ConcreteHandler()
|
||||||
|
messages = [{"role": "user", "content": "Hello"}]
|
||||||
|
attachments = [{"mime_type": "image/png", "path": "/test.png"}]
|
||||||
|
|
||||||
|
mock_agent = Mock()
|
||||||
|
mock_agent.llm.get_supported_attachment_types.return_value = ["image/png"]
|
||||||
|
mock_agent.llm.prepare_messages_with_attachments.return_value = messages
|
||||||
|
|
||||||
|
result = handler.prepare_messages(mock_agent, messages, attachments)
|
||||||
|
mock_agent.llm.prepare_messages_with_attachments.assert_called_once_with(
|
||||||
|
messages, attachments
|
||||||
|
)
|
||||||
|
assert result == messages
|
||||||
|
|
||||||
|
@patch('application.llm.handlers.base.logger')
|
||||||
|
def test_prepare_messages_with_unsupported_attachments(self, mock_logger):
|
||||||
|
"""Test prepare_messages with unsupported attachments."""
|
||||||
|
handler = ConcreteHandler()
|
||||||
|
messages = [{"role": "user", "content": "Hello"}]
|
||||||
|
attachments = [{"mime_type": "text/plain", "path": "/test.txt"}]
|
||||||
|
|
||||||
|
mock_agent = Mock()
|
||||||
|
mock_agent.llm.get_supported_attachment_types.return_value = ["image/png"]
|
||||||
|
|
||||||
|
with patch.object(handler, '_append_unsupported_attachments', return_value=messages) as mock_append:
|
||||||
|
result = handler.prepare_messages(mock_agent, messages, attachments)
|
||||||
|
mock_append.assert_called_once_with(messages, attachments)
|
||||||
|
assert result == messages
|
||||||
|
|
||||||
|
def test_prepare_messages_mixed_attachments(self):
|
||||||
|
"""Test prepare_messages with both supported and unsupported attachments."""
|
||||||
|
handler = ConcreteHandler()
|
||||||
|
messages = [{"role": "user", "content": "Hello"}]
|
||||||
|
attachments = [
|
||||||
|
{"mime_type": "image/png", "path": "/test.png"},
|
||||||
|
{"mime_type": "text/plain", "path": "/test.txt"}
|
||||||
|
]
|
||||||
|
|
||||||
|
mock_agent = Mock()
|
||||||
|
mock_agent.llm.get_supported_attachment_types.return_value = ["image/png"]
|
||||||
|
mock_agent.llm.prepare_messages_with_attachments.return_value = messages
|
||||||
|
|
||||||
|
with patch.object(handler, '_append_unsupported_attachments', return_value=messages) as mock_append:
|
||||||
|
result = handler.prepare_messages(mock_agent, messages, attachments)
|
||||||
|
|
||||||
|
# Should call both methods
|
||||||
|
mock_agent.llm.prepare_messages_with_attachments.assert_called_once()
|
||||||
|
mock_append.assert_called_once()
|
||||||
|
assert result == messages
|
||||||
|
|
||||||
|
def test_process_message_flow_non_streaming(self):
|
||||||
|
"""Test process_message_flow for non-streaming."""
|
||||||
|
handler = ConcreteHandler()
|
||||||
|
mock_agent = Mock()
|
||||||
|
initial_response = "test response"
|
||||||
|
tools_dict = {}
|
||||||
|
messages = [{"role": "user", "content": "Hello"}]
|
||||||
|
|
||||||
|
with patch.object(handler, 'prepare_messages', return_value=messages) as mock_prepare:
|
||||||
|
with patch.object(handler, 'handle_non_streaming', return_value="final") as mock_handle:
|
||||||
|
result = handler.process_message_flow(
|
||||||
|
mock_agent, initial_response, tools_dict, messages, stream=False
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_prepare.assert_called_once_with(mock_agent, messages, None)
|
||||||
|
mock_handle.assert_called_once_with(mock_agent, initial_response, tools_dict, messages)
|
||||||
|
assert result == "final"
|
||||||
|
|
||||||
|
def test_process_message_flow_streaming(self):
|
||||||
|
"""Test process_message_flow for streaming."""
|
||||||
|
handler = ConcreteHandler()
|
||||||
|
mock_agent = Mock()
|
||||||
|
initial_response = "test response"
|
||||||
|
tools_dict = {}
|
||||||
|
messages = [{"role": "user", "content": "Hello"}]
|
||||||
|
|
||||||
|
def mock_generator():
|
||||||
|
yield "chunk1"
|
||||||
|
yield "chunk2"
|
||||||
|
|
||||||
|
with patch.object(handler, 'prepare_messages', return_value=messages) as mock_prepare:
|
||||||
|
with patch.object(handler, 'handle_streaming', return_value=mock_generator()) as mock_handle:
|
||||||
|
result = handler.process_message_flow(
|
||||||
|
mock_agent, initial_response, tools_dict, messages, stream=True
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_prepare.assert_called_once_with(mock_agent, messages, None)
|
||||||
|
mock_handle.assert_called_once_with(mock_agent, initial_response, tools_dict, messages)
|
||||||
|
|
||||||
|
# Verify it's a generator
|
||||||
|
chunks = list(result)
|
||||||
|
assert chunks == ["chunk1", "chunk2"]
|
||||||
271
tests/llm/handlers/test_google.py
Normal file
271
tests/llm/handlers/test_google.py
Normal file
@@ -0,0 +1,271 @@
|
|||||||
|
import pytest
|
||||||
|
from unittest.mock import Mock, patch
|
||||||
|
from types import SimpleNamespace
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
from application.llm.handlers.google import GoogleLLMHandler
|
||||||
|
from application.llm.handlers.base import ToolCall, LLMResponse
|
||||||
|
|
||||||
|
|
||||||
|
class TestGoogleLLMHandler:
|
||||||
|
"""Test GoogleLLMHandler class."""
|
||||||
|
|
||||||
|
def test_handler_initialization(self):
|
||||||
|
"""Test handler initialization."""
|
||||||
|
handler = GoogleLLMHandler()
|
||||||
|
assert handler.llm_calls == []
|
||||||
|
assert handler.tool_calls == []
|
||||||
|
|
||||||
|
def test_parse_response_string_input(self):
|
||||||
|
"""Test parsing string response."""
|
||||||
|
handler = GoogleLLMHandler()
|
||||||
|
response = "Hello from Google!"
|
||||||
|
|
||||||
|
result = handler.parse_response(response)
|
||||||
|
|
||||||
|
assert isinstance(result, LLMResponse)
|
||||||
|
assert result.content == "Hello from Google!"
|
||||||
|
assert result.tool_calls == []
|
||||||
|
assert result.finish_reason == "stop"
|
||||||
|
assert result.raw_response == "Hello from Google!"
|
||||||
|
|
||||||
|
def test_parse_response_with_candidates_text_only(self):
|
||||||
|
"""Test parsing response with candidates containing only text."""
|
||||||
|
handler = GoogleLLMHandler()
|
||||||
|
|
||||||
|
mock_part = SimpleNamespace(text="Google response text")
|
||||||
|
mock_content = SimpleNamespace(parts=[mock_part])
|
||||||
|
mock_candidate = SimpleNamespace(content=mock_content)
|
||||||
|
mock_response = SimpleNamespace(candidates=[mock_candidate])
|
||||||
|
|
||||||
|
result = handler.parse_response(mock_response)
|
||||||
|
|
||||||
|
assert result.content == "Google response text"
|
||||||
|
assert result.tool_calls == []
|
||||||
|
assert result.finish_reason == "stop"
|
||||||
|
assert result.raw_response == mock_response
|
||||||
|
|
||||||
|
def test_parse_response_with_multiple_text_parts(self):
|
||||||
|
"""Test parsing response with multiple text parts."""
|
||||||
|
handler = GoogleLLMHandler()
|
||||||
|
|
||||||
|
mock_part1 = SimpleNamespace(text="First part")
|
||||||
|
mock_part2 = SimpleNamespace(text="Second part")
|
||||||
|
mock_content = SimpleNamespace(parts=[mock_part1, mock_part2])
|
||||||
|
mock_candidate = SimpleNamespace(content=mock_content)
|
||||||
|
mock_response = SimpleNamespace(candidates=[mock_candidate])
|
||||||
|
|
||||||
|
result = handler.parse_response(mock_response)
|
||||||
|
|
||||||
|
assert result.content == "First part Second part"
|
||||||
|
assert result.tool_calls == []
|
||||||
|
assert result.finish_reason == "stop"
|
||||||
|
|
||||||
|
@patch('uuid.uuid4')
|
||||||
|
def test_parse_response_with_function_call(self, mock_uuid):
|
||||||
|
"""Test parsing response with function call."""
|
||||||
|
mock_uuid.return_value = Mock(spec=uuid.UUID)
|
||||||
|
mock_uuid.return_value.__str__ = Mock(return_value="test-uuid-123")
|
||||||
|
|
||||||
|
handler = GoogleLLMHandler()
|
||||||
|
|
||||||
|
mock_function_call = SimpleNamespace(
|
||||||
|
name="get_weather",
|
||||||
|
args={"location": "San Francisco"}
|
||||||
|
)
|
||||||
|
mock_part = SimpleNamespace(function_call=mock_function_call)
|
||||||
|
mock_content = SimpleNamespace(parts=[mock_part])
|
||||||
|
mock_candidate = SimpleNamespace(content=mock_content)
|
||||||
|
mock_response = SimpleNamespace(candidates=[mock_candidate])
|
||||||
|
|
||||||
|
result = handler.parse_response(mock_response)
|
||||||
|
|
||||||
|
assert result.content == ""
|
||||||
|
assert len(result.tool_calls) == 1
|
||||||
|
assert result.tool_calls[0].id == "test-uuid-123"
|
||||||
|
assert result.tool_calls[0].name == "get_weather"
|
||||||
|
assert result.tool_calls[0].arguments == {"location": "San Francisco"}
|
||||||
|
assert result.finish_reason == "tool_calls"
|
||||||
|
|
||||||
|
@patch('uuid.uuid4')
|
||||||
|
def test_parse_response_with_mixed_parts(self, mock_uuid):
|
||||||
|
"""Test parsing response with both text and function call parts."""
|
||||||
|
mock_uuid.return_value = Mock(spec=uuid.UUID)
|
||||||
|
mock_uuid.return_value.__str__ = Mock(return_value="test-uuid-456")
|
||||||
|
|
||||||
|
handler = GoogleLLMHandler()
|
||||||
|
|
||||||
|
mock_text_part = SimpleNamespace(text="I'll check the weather for you.")
|
||||||
|
mock_function_call = SimpleNamespace(
|
||||||
|
name="get_weather",
|
||||||
|
args={"location": "NYC"}
|
||||||
|
)
|
||||||
|
mock_function_part = SimpleNamespace(function_call=mock_function_call)
|
||||||
|
|
||||||
|
mock_content = SimpleNamespace(parts=[mock_text_part, mock_function_part])
|
||||||
|
mock_candidate = SimpleNamespace(content=mock_content)
|
||||||
|
mock_response = SimpleNamespace(candidates=[mock_candidate])
|
||||||
|
|
||||||
|
result = handler.parse_response(mock_response)
|
||||||
|
|
||||||
|
assert result.content == "I'll check the weather for you."
|
||||||
|
assert len(result.tool_calls) == 1
|
||||||
|
assert result.tool_calls[0].name == "get_weather"
|
||||||
|
assert result.finish_reason == "tool_calls"
|
||||||
|
|
||||||
|
def test_parse_response_empty_candidates(self):
|
||||||
|
"""Test parsing response with empty candidates."""
|
||||||
|
handler = GoogleLLMHandler()
|
||||||
|
|
||||||
|
mock_response = SimpleNamespace(candidates=[])
|
||||||
|
|
||||||
|
result = handler.parse_response(mock_response)
|
||||||
|
|
||||||
|
assert result.content == ""
|
||||||
|
assert result.tool_calls == []
|
||||||
|
assert result.finish_reason == "stop"
|
||||||
|
|
||||||
|
def test_parse_response_parts_with_none_text(self):
|
||||||
|
"""Test parsing response with parts that have None text."""
|
||||||
|
handler = GoogleLLMHandler()
|
||||||
|
|
||||||
|
mock_part1 = SimpleNamespace(text=None)
|
||||||
|
mock_part2 = SimpleNamespace(text="Valid text")
|
||||||
|
mock_content = SimpleNamespace(parts=[mock_part1, mock_part2])
|
||||||
|
mock_candidate = SimpleNamespace(content=mock_content)
|
||||||
|
mock_response = SimpleNamespace(candidates=[mock_candidate])
|
||||||
|
|
||||||
|
result = handler.parse_response(mock_response)
|
||||||
|
|
||||||
|
assert result.content == "Valid text"
|
||||||
|
|
||||||
|
def test_parse_response_parts_without_text_attribute(self):
|
||||||
|
"""Test parsing response with parts missing text attribute."""
|
||||||
|
handler = GoogleLLMHandler()
|
||||||
|
|
||||||
|
mock_part1 = SimpleNamespace()
|
||||||
|
mock_part2 = SimpleNamespace(text="Valid text")
|
||||||
|
mock_content = SimpleNamespace(parts=[mock_part1, mock_part2])
|
||||||
|
mock_candidate = SimpleNamespace(content=mock_content)
|
||||||
|
mock_response = SimpleNamespace(candidates=[mock_candidate])
|
||||||
|
|
||||||
|
result = handler.parse_response(mock_response)
|
||||||
|
|
||||||
|
assert result.content == "Valid text"
|
||||||
|
|
||||||
|
@patch('uuid.uuid4')
|
||||||
|
def test_parse_response_direct_function_call(self, mock_uuid):
|
||||||
|
"""Test parsing response with direct function call (not in candidates)."""
|
||||||
|
mock_uuid.return_value = Mock(spec=uuid.UUID)
|
||||||
|
mock_uuid.return_value.__str__ = Mock(return_value="direct-uuid-789")
|
||||||
|
|
||||||
|
handler = GoogleLLMHandler()
|
||||||
|
|
||||||
|
mock_function_call = SimpleNamespace(
|
||||||
|
name="calculate",
|
||||||
|
args={"expression": "2+2"}
|
||||||
|
)
|
||||||
|
mock_response = SimpleNamespace(
|
||||||
|
function_call=mock_function_call,
|
||||||
|
text="The calculation result is:"
|
||||||
|
)
|
||||||
|
|
||||||
|
result = handler.parse_response(mock_response)
|
||||||
|
|
||||||
|
assert result.content == "The calculation result is:"
|
||||||
|
assert len(result.tool_calls) == 1
|
||||||
|
assert result.tool_calls[0].id == "direct-uuid-789"
|
||||||
|
assert result.tool_calls[0].name == "calculate"
|
||||||
|
assert result.tool_calls[0].arguments == {"expression": "2+2"}
|
||||||
|
assert result.finish_reason == "tool_calls"
|
||||||
|
|
||||||
|
def test_parse_response_direct_function_call_no_text(self):
|
||||||
|
"""Test parsing response with direct function call and no text."""
|
||||||
|
handler = GoogleLLMHandler()
|
||||||
|
|
||||||
|
mock_function_call = SimpleNamespace(
|
||||||
|
name="get_data",
|
||||||
|
args={"id": 123}
|
||||||
|
)
|
||||||
|
mock_response = SimpleNamespace(function_call=mock_function_call)
|
||||||
|
|
||||||
|
result = handler.parse_response(mock_response)
|
||||||
|
|
||||||
|
assert result.content == ""
|
||||||
|
assert len(result.tool_calls) == 1
|
||||||
|
assert result.tool_calls[0].name == "get_data"
|
||||||
|
assert result.finish_reason == "tool_calls"
|
||||||
|
|
||||||
|
def test_create_tool_message(self):
|
||||||
|
"""Test creating tool message."""
|
||||||
|
handler = GoogleLLMHandler()
|
||||||
|
|
||||||
|
tool_call = ToolCall(
|
||||||
|
id="call_123",
|
||||||
|
name="get_weather",
|
||||||
|
arguments={"location": "Tokyo"},
|
||||||
|
index=0
|
||||||
|
)
|
||||||
|
result = {"temperature": "25C", "condition": "cloudy"}
|
||||||
|
|
||||||
|
message = handler.create_tool_message(tool_call, result)
|
||||||
|
|
||||||
|
expected = {
|
||||||
|
"role": "model",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"function_response": {
|
||||||
|
"name": "get_weather",
|
||||||
|
"response": {"result": result},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
assert message == expected
|
||||||
|
|
||||||
|
def test_create_tool_message_string_result(self):
|
||||||
|
"""Test creating tool message with string result."""
|
||||||
|
handler = GoogleLLMHandler()
|
||||||
|
|
||||||
|
tool_call = ToolCall(id="call_456", name="get_time", arguments={})
|
||||||
|
result = "2023-12-01 15:30:00 JST"
|
||||||
|
|
||||||
|
message = handler.create_tool_message(tool_call, result)
|
||||||
|
|
||||||
|
assert message["role"] == "model"
|
||||||
|
assert message["content"][0]["function_response"]["response"]["result"] == result
|
||||||
|
assert message["content"][0]["function_response"]["name"] == "get_time"
|
||||||
|
|
||||||
|
def test_iterate_stream(self):
|
||||||
|
"""Test stream iteration."""
|
||||||
|
handler = GoogleLLMHandler()
|
||||||
|
|
||||||
|
mock_chunks = ["chunk1", "chunk2", "chunk3"]
|
||||||
|
|
||||||
|
result = list(handler._iterate_stream(mock_chunks))
|
||||||
|
|
||||||
|
assert result == mock_chunks
|
||||||
|
|
||||||
|
def test_iterate_stream_empty(self):
|
||||||
|
"""Test stream iteration with empty response."""
|
||||||
|
handler = GoogleLLMHandler()
|
||||||
|
|
||||||
|
result = list(handler._iterate_stream([]))
|
||||||
|
|
||||||
|
assert result == []
|
||||||
|
|
||||||
|
def test_parse_response_parts_without_function_call_attribute(self):
|
||||||
|
"""Test parsing response with parts missing function_call attribute."""
|
||||||
|
handler = GoogleLLMHandler()
|
||||||
|
|
||||||
|
mock_part = SimpleNamespace(text="Normal text")
|
||||||
|
mock_content = SimpleNamespace(parts=[mock_part])
|
||||||
|
mock_candidate = SimpleNamespace(content=mock_content)
|
||||||
|
mock_response = SimpleNamespace(candidates=[mock_candidate])
|
||||||
|
|
||||||
|
result = handler.parse_response(mock_response)
|
||||||
|
|
||||||
|
assert result.content == "Normal text"
|
||||||
|
assert result.tool_calls == []
|
||||||
|
assert result.finish_reason == "stop"
|
||||||
126
tests/llm/handlers/test_handler_creator.py
Normal file
126
tests/llm/handlers/test_handler_creator.py
Normal file
@@ -0,0 +1,126 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
from application.llm.handlers.handler_creator import LLMHandlerCreator
|
||||||
|
from application.llm.handlers.base import LLMHandler
|
||||||
|
from application.llm.handlers.openai import OpenAILLMHandler
|
||||||
|
from application.llm.handlers.google import GoogleLLMHandler
|
||||||
|
|
||||||
|
|
||||||
|
class TestLLMHandlerCreator:
|
||||||
|
"""Test LLMHandlerCreator class."""
|
||||||
|
|
||||||
|
def test_create_openai_handler(self):
|
||||||
|
"""Test creating OpenAI handler."""
|
||||||
|
handler = LLMHandlerCreator.create_handler("openai")
|
||||||
|
|
||||||
|
assert isinstance(handler, OpenAILLMHandler)
|
||||||
|
assert isinstance(handler, LLMHandler)
|
||||||
|
|
||||||
|
def test_create_openai_handler_case_insensitive(self):
|
||||||
|
"""Test creating OpenAI handler with different cases."""
|
||||||
|
handler_upper = LLMHandlerCreator.create_handler("OPENAI")
|
||||||
|
handler_mixed = LLMHandlerCreator.create_handler("OpenAI")
|
||||||
|
|
||||||
|
assert isinstance(handler_upper, OpenAILLMHandler)
|
||||||
|
assert isinstance(handler_mixed, OpenAILLMHandler)
|
||||||
|
|
||||||
|
def test_create_google_handler(self):
|
||||||
|
"""Test creating Google handler."""
|
||||||
|
handler = LLMHandlerCreator.create_handler("google")
|
||||||
|
|
||||||
|
assert isinstance(handler, GoogleLLMHandler)
|
||||||
|
assert isinstance(handler, LLMHandler)
|
||||||
|
|
||||||
|
def test_create_google_handler_case_insensitive(self):
|
||||||
|
"""Test creating Google handler with different cases."""
|
||||||
|
handler_upper = LLMHandlerCreator.create_handler("GOOGLE")
|
||||||
|
handler_mixed = LLMHandlerCreator.create_handler("Google")
|
||||||
|
|
||||||
|
assert isinstance(handler_upper, GoogleLLMHandler)
|
||||||
|
assert isinstance(handler_mixed, GoogleLLMHandler)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_default_handler(self):
|
||||||
|
"""Test creating default handler."""
|
||||||
|
handler = LLMHandlerCreator.create_handler("default")
|
||||||
|
|
||||||
|
assert isinstance(handler, OpenAILLMHandler)
|
||||||
|
|
||||||
|
def test_create_unknown_handler_fallback(self):
|
||||||
|
"""Test creating handler for unknown type falls back to OpenAI."""
|
||||||
|
handler = LLMHandlerCreator.create_handler("unknown_provider")
|
||||||
|
|
||||||
|
assert isinstance(handler, OpenAILLMHandler)
|
||||||
|
|
||||||
|
def test_create_anthropic_handler_fallback(self):
|
||||||
|
"""Test creating Anthropic handler falls back to OpenAI (not supported in handlers)."""
|
||||||
|
handler = LLMHandlerCreator.create_handler("anthropic")
|
||||||
|
|
||||||
|
assert isinstance(handler, OpenAILLMHandler)
|
||||||
|
|
||||||
|
def test_create_empty_string_handler_fallback(self):
|
||||||
|
"""Test creating handler with empty string falls back to OpenAI."""
|
||||||
|
handler = LLMHandlerCreator.create_handler("")
|
||||||
|
|
||||||
|
assert isinstance(handler, OpenAILLMHandler)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def test_handlers_registry(self):
|
||||||
|
"""Test the handlers registry contains expected mappings."""
|
||||||
|
expected_handlers = {
|
||||||
|
"openai": OpenAILLMHandler,
|
||||||
|
"google": GoogleLLMHandler,
|
||||||
|
"default": OpenAILLMHandler,
|
||||||
|
}
|
||||||
|
|
||||||
|
assert LLMHandlerCreator.handlers == expected_handlers
|
||||||
|
|
||||||
|
def test_create_handler_with_args(self):
|
||||||
|
"""Test creating handler with additional arguments."""
|
||||||
|
handler = LLMHandlerCreator.create_handler("openai")
|
||||||
|
|
||||||
|
assert isinstance(handler, OpenAILLMHandler)
|
||||||
|
assert handler.llm_calls == []
|
||||||
|
assert handler.tool_calls == []
|
||||||
|
|
||||||
|
def test_create_handler_with_kwargs(self):
|
||||||
|
"""Test creating handler with keyword arguments."""
|
||||||
|
handler = LLMHandlerCreator.create_handler("google")
|
||||||
|
|
||||||
|
assert isinstance(handler, GoogleLLMHandler)
|
||||||
|
assert handler.llm_calls == []
|
||||||
|
assert handler.tool_calls == []
|
||||||
|
|
||||||
|
def test_all_registered_handlers_are_valid(self):
|
||||||
|
"""Test that all registered handlers can be instantiated."""
|
||||||
|
for handler_type in LLMHandlerCreator.handlers.keys():
|
||||||
|
handler = LLMHandlerCreator.create_handler(handler_type)
|
||||||
|
assert isinstance(handler, LLMHandler)
|
||||||
|
assert hasattr(handler, 'parse_response')
|
||||||
|
assert hasattr(handler, 'create_tool_message')
|
||||||
|
assert hasattr(handler, '_iterate_stream')
|
||||||
|
|
||||||
|
def test_handler_inheritance(self):
|
||||||
|
"""Test that all created handlers inherit from LLMHandler."""
|
||||||
|
test_types = ["openai", "google", "default", "unknown"]
|
||||||
|
|
||||||
|
for handler_type in test_types:
|
||||||
|
handler = LLMHandlerCreator.create_handler(handler_type)
|
||||||
|
assert isinstance(handler, LLMHandler)
|
||||||
|
|
||||||
|
assert callable(getattr(handler, 'parse_response'))
|
||||||
|
assert callable(getattr(handler, 'create_tool_message'))
|
||||||
|
assert callable(getattr(handler, '_iterate_stream'))
|
||||||
|
|
||||||
|
def test_create_handler_preserves_handler_state(self):
|
||||||
|
"""Test that each created handler has independent state."""
|
||||||
|
handler1 = LLMHandlerCreator.create_handler("openai")
|
||||||
|
handler2 = LLMHandlerCreator.create_handler("openai")
|
||||||
|
|
||||||
|
handler1.llm_calls.append("test_call")
|
||||||
|
|
||||||
|
assert len(handler1.llm_calls) == 1
|
||||||
|
assert len(handler2.llm_calls) == 0
|
||||||
|
assert handler1 is not handler2
|
||||||
210
tests/llm/handlers/test_openai.py
Normal file
210
tests/llm/handlers/test_openai.py
Normal file
@@ -0,0 +1,210 @@
|
|||||||
|
import pytest
|
||||||
|
from unittest.mock import Mock
|
||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
from application.llm.handlers.openai import OpenAILLMHandler
|
||||||
|
from application.llm.handlers.base import ToolCall, LLMResponse
|
||||||
|
|
||||||
|
|
||||||
|
class TestOpenAILLMHandler:
|
||||||
|
"""Test OpenAILLMHandler class."""
|
||||||
|
|
||||||
|
def test_handler_initialization(self):
|
||||||
|
"""Test handler initialization."""
|
||||||
|
handler = OpenAILLMHandler()
|
||||||
|
assert handler.llm_calls == []
|
||||||
|
assert handler.tool_calls == []
|
||||||
|
|
||||||
|
def test_parse_response_string_input(self):
|
||||||
|
"""Test parsing string response."""
|
||||||
|
handler = OpenAILLMHandler()
|
||||||
|
response = "Hello, world!"
|
||||||
|
|
||||||
|
result = handler.parse_response(response)
|
||||||
|
|
||||||
|
assert isinstance(result, LLMResponse)
|
||||||
|
assert result.content == "Hello, world!"
|
||||||
|
assert result.tool_calls == []
|
||||||
|
assert result.finish_reason == "stop"
|
||||||
|
assert result.raw_response == "Hello, world!"
|
||||||
|
|
||||||
|
def test_parse_response_with_message_content(self):
|
||||||
|
"""Test parsing response with message content."""
|
||||||
|
handler = OpenAILLMHandler()
|
||||||
|
|
||||||
|
# Mock OpenAI response structure
|
||||||
|
mock_message = SimpleNamespace(content="Test content", tool_calls=None)
|
||||||
|
mock_response = SimpleNamespace(message=mock_message, finish_reason="stop")
|
||||||
|
|
||||||
|
result = handler.parse_response(mock_response)
|
||||||
|
|
||||||
|
assert result.content == "Test content"
|
||||||
|
assert result.tool_calls == []
|
||||||
|
assert result.finish_reason == "stop"
|
||||||
|
assert result.raw_response == mock_response
|
||||||
|
|
||||||
|
def test_parse_response_with_delta_content(self):
|
||||||
|
"""Test parsing response with delta content (streaming)."""
|
||||||
|
handler = OpenAILLMHandler()
|
||||||
|
|
||||||
|
# Mock streaming response structure
|
||||||
|
mock_delta = SimpleNamespace(content="Stream chunk", tool_calls=None)
|
||||||
|
mock_response = SimpleNamespace(delta=mock_delta, finish_reason="")
|
||||||
|
|
||||||
|
result = handler.parse_response(mock_response)
|
||||||
|
|
||||||
|
assert result.content == "Stream chunk"
|
||||||
|
assert result.tool_calls == []
|
||||||
|
assert result.finish_reason == ""
|
||||||
|
assert result.raw_response == mock_response
|
||||||
|
|
||||||
|
def test_parse_response_with_tool_calls(self):
|
||||||
|
"""Test parsing response with tool calls."""
|
||||||
|
handler = OpenAILLMHandler()
|
||||||
|
|
||||||
|
# Mock tool call structure
|
||||||
|
mock_function = SimpleNamespace(name="get_weather", arguments='{"location": "NYC"}')
|
||||||
|
mock_tool_call = SimpleNamespace(
|
||||||
|
id="call_123",
|
||||||
|
function=mock_function,
|
||||||
|
index=0
|
||||||
|
)
|
||||||
|
mock_message = SimpleNamespace(content="", tool_calls=[mock_tool_call])
|
||||||
|
mock_response = SimpleNamespace(message=mock_message, finish_reason="tool_calls")
|
||||||
|
|
||||||
|
result = handler.parse_response(mock_response)
|
||||||
|
|
||||||
|
assert result.content == ""
|
||||||
|
assert len(result.tool_calls) == 1
|
||||||
|
assert result.tool_calls[0].id == "call_123"
|
||||||
|
assert result.tool_calls[0].name == "get_weather"
|
||||||
|
assert result.tool_calls[0].arguments == '{"location": "NYC"}'
|
||||||
|
assert result.tool_calls[0].index == 0
|
||||||
|
assert result.finish_reason == "tool_calls"
|
||||||
|
|
||||||
|
def test_parse_response_with_multiple_tool_calls(self):
|
||||||
|
"""Test parsing response with multiple tool calls."""
|
||||||
|
handler = OpenAILLMHandler()
|
||||||
|
|
||||||
|
# Mock multiple tool calls
|
||||||
|
mock_function1 = SimpleNamespace(name="get_weather", arguments='{"location": "NYC"}')
|
||||||
|
mock_function2 = SimpleNamespace(name="get_time", arguments='{"timezone": "UTC"}')
|
||||||
|
|
||||||
|
mock_tool_call1 = SimpleNamespace(id="call_1", function=mock_function1, index=0)
|
||||||
|
mock_tool_call2 = SimpleNamespace(id="call_2", function=mock_function2, index=1)
|
||||||
|
|
||||||
|
mock_message = SimpleNamespace(content="", tool_calls=[mock_tool_call1, mock_tool_call2])
|
||||||
|
mock_response = SimpleNamespace(message=mock_message, finish_reason="tool_calls")
|
||||||
|
|
||||||
|
result = handler.parse_response(mock_response)
|
||||||
|
|
||||||
|
assert len(result.tool_calls) == 2
|
||||||
|
assert result.tool_calls[0].name == "get_weather"
|
||||||
|
assert result.tool_calls[1].name == "get_time"
|
||||||
|
|
||||||
|
def test_parse_response_empty_tool_calls(self):
|
||||||
|
"""Test parsing response with empty tool_calls."""
|
||||||
|
handler = OpenAILLMHandler()
|
||||||
|
|
||||||
|
mock_message = SimpleNamespace(content="No tools needed", tool_calls=None)
|
||||||
|
mock_response = SimpleNamespace(message=mock_message, finish_reason="stop")
|
||||||
|
|
||||||
|
result = handler.parse_response(mock_response)
|
||||||
|
|
||||||
|
assert result.content == "No tools needed"
|
||||||
|
assert result.tool_calls == []
|
||||||
|
assert result.finish_reason == "stop"
|
||||||
|
|
||||||
|
def test_parse_response_missing_attributes(self):
|
||||||
|
"""Test parsing response with missing attributes."""
|
||||||
|
handler = OpenAILLMHandler()
|
||||||
|
|
||||||
|
# Mock response with missing attributes
|
||||||
|
mock_message = SimpleNamespace() # No content or tool_calls
|
||||||
|
mock_response = SimpleNamespace(message=mock_message) # No finish_reason
|
||||||
|
|
||||||
|
result = handler.parse_response(mock_response)
|
||||||
|
|
||||||
|
assert result.content == ""
|
||||||
|
assert result.tool_calls == []
|
||||||
|
assert result.finish_reason == ""
|
||||||
|
|
||||||
|
def test_create_tool_message(self):
|
||||||
|
"""Test creating tool message."""
|
||||||
|
handler = OpenAILLMHandler()
|
||||||
|
|
||||||
|
tool_call = ToolCall(
|
||||||
|
id="call_123",
|
||||||
|
name="get_weather",
|
||||||
|
arguments={"location": "NYC"},
|
||||||
|
index=0
|
||||||
|
)
|
||||||
|
result = {"temperature": "72F", "condition": "sunny"}
|
||||||
|
|
||||||
|
message = handler.create_tool_message(tool_call, result)
|
||||||
|
|
||||||
|
expected = {
|
||||||
|
"role": "tool",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"function_response": {
|
||||||
|
"name": "get_weather",
|
||||||
|
"response": {"result": result},
|
||||||
|
"call_id": "call_123",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
assert message == expected
|
||||||
|
|
||||||
|
def test_create_tool_message_string_result(self):
|
||||||
|
"""Test creating tool message with string result."""
|
||||||
|
handler = OpenAILLMHandler()
|
||||||
|
|
||||||
|
tool_call = ToolCall(id="call_456", name="get_time", arguments={})
|
||||||
|
result = "2023-12-01 10:30:00"
|
||||||
|
|
||||||
|
message = handler.create_tool_message(tool_call, result)
|
||||||
|
|
||||||
|
assert message["role"] == "tool"
|
||||||
|
assert message["content"][0]["function_response"]["response"]["result"] == result
|
||||||
|
assert message["content"][0]["function_response"]["call_id"] == "call_456"
|
||||||
|
|
||||||
|
def test_iterate_stream(self):
|
||||||
|
"""Test stream iteration."""
|
||||||
|
handler = OpenAILLMHandler()
|
||||||
|
|
||||||
|
# Mock streaming response
|
||||||
|
mock_chunks = ["chunk1", "chunk2", "chunk3"]
|
||||||
|
|
||||||
|
result = list(handler._iterate_stream(mock_chunks))
|
||||||
|
|
||||||
|
assert result == mock_chunks
|
||||||
|
|
||||||
|
def test_iterate_stream_empty(self):
|
||||||
|
"""Test stream iteration with empty response."""
|
||||||
|
handler = OpenAILLMHandler()
|
||||||
|
|
||||||
|
result = list(handler._iterate_stream([]))
|
||||||
|
|
||||||
|
assert result == []
|
||||||
|
|
||||||
|
def test_parse_response_tool_call_missing_attributes(self):
|
||||||
|
"""Test parsing tool calls with missing attributes."""
|
||||||
|
handler = OpenAILLMHandler()
|
||||||
|
|
||||||
|
# Mock tool call with missing attributes
|
||||||
|
mock_function = SimpleNamespace() # No name or arguments
|
||||||
|
mock_tool_call = SimpleNamespace(function=mock_function) # No id or index
|
||||||
|
|
||||||
|
mock_message = SimpleNamespace(content="", tool_calls=[mock_tool_call])
|
||||||
|
mock_response = SimpleNamespace(message=mock_message, finish_reason="tool_calls")
|
||||||
|
|
||||||
|
result = handler.parse_response(mock_response)
|
||||||
|
|
||||||
|
assert len(result.tool_calls) == 1
|
||||||
|
assert result.tool_calls[0].id == ""
|
||||||
|
assert result.tool_calls[0].name == ""
|
||||||
|
assert result.tool_calls[0].arguments == ""
|
||||||
|
assert result.tool_calls[0].index is None
|
||||||
@@ -1,68 +0,0 @@
|
|||||||
import unittest
|
|
||||||
from unittest.mock import patch, Mock
|
|
||||||
from application.llm.anthropic import AnthropicLLM
|
|
||||||
|
|
||||||
class TestAnthropicLLM(unittest.TestCase):
|
|
||||||
|
|
||||||
def setUp(self):
|
|
||||||
self.api_key = "TEST_API_KEY"
|
|
||||||
self.llm = AnthropicLLM(api_key=self.api_key)
|
|
||||||
|
|
||||||
@patch("application.llm.anthropic.settings")
|
|
||||||
def test_init_default_api_key(self, mock_settings):
|
|
||||||
mock_settings.ANTHROPIC_API_KEY = "DEFAULT_API_KEY"
|
|
||||||
llm = AnthropicLLM()
|
|
||||||
self.assertEqual(llm.api_key, "DEFAULT_API_KEY")
|
|
||||||
|
|
||||||
def test_gen(self):
|
|
||||||
messages = [
|
|
||||||
{"content": "context"},
|
|
||||||
{"content": "question"}
|
|
||||||
]
|
|
||||||
mock_response = Mock()
|
|
||||||
mock_response.completion = "test completion"
|
|
||||||
|
|
||||||
with patch("application.cache.get_redis_instance") as mock_make_redis:
|
|
||||||
mock_redis_instance = mock_make_redis.return_value
|
|
||||||
mock_redis_instance.get.return_value = None
|
|
||||||
mock_redis_instance.set = Mock()
|
|
||||||
|
|
||||||
with patch.object(self.llm.anthropic.completions, "create", return_value=mock_response) as mock_create:
|
|
||||||
response = self.llm.gen("test_model", messages)
|
|
||||||
self.assertEqual(response, "test completion")
|
|
||||||
|
|
||||||
prompt_expected = "### Context \n context \n ### Question \n question"
|
|
||||||
mock_create.assert_called_with(
|
|
||||||
model="test_model",
|
|
||||||
max_tokens_to_sample=300,
|
|
||||||
stream=False,
|
|
||||||
prompt=f"{self.llm.HUMAN_PROMPT} {prompt_expected}{self.llm.AI_PROMPT}"
|
|
||||||
)
|
|
||||||
mock_redis_instance.set.assert_called_once()
|
|
||||||
|
|
||||||
def test_gen_stream(self):
|
|
||||||
messages = [
|
|
||||||
{"content": "context"},
|
|
||||||
{"content": "question"}
|
|
||||||
]
|
|
||||||
mock_responses = [Mock(completion="response_1"), Mock(completion="response_2")]
|
|
||||||
mock_tools = Mock()
|
|
||||||
|
|
||||||
with patch("application.cache.get_redis_instance") as mock_make_redis:
|
|
||||||
mock_redis_instance = mock_make_redis.return_value
|
|
||||||
mock_redis_instance.get.return_value = None
|
|
||||||
mock_redis_instance.set = Mock()
|
|
||||||
|
|
||||||
with patch.object(self.llm.anthropic.completions, "create", return_value=iter(mock_responses)) as mock_create:
|
|
||||||
responses = list(self.llm.gen_stream("test_model", messages, tools=mock_tools))
|
|
||||||
self.assertListEqual(responses, ["response_1", "response_2"])
|
|
||||||
|
|
||||||
prompt_expected = "### Context \n context \n ### Question \n question"
|
|
||||||
mock_create.assert_called_with(
|
|
||||||
model="test_model",
|
|
||||||
prompt=f"{self.llm.HUMAN_PROMPT} {prompt_expected}{self.llm.AI_PROMPT}",
|
|
||||||
max_tokens_to_sample=300,
|
|
||||||
stream=True
|
|
||||||
)
|
|
||||||
if __name__ == "__main__":
|
|
||||||
unittest.main()
|
|
||||||
65
tests/llm/test_anthropic_llm.py
Normal file
65
tests/llm/test_anthropic_llm.py
Normal file
@@ -0,0 +1,65 @@
|
|||||||
|
import sys
|
||||||
|
import types
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
class _FakeCompletion:
|
||||||
|
def __init__(self, text):
|
||||||
|
self.completion = text
|
||||||
|
|
||||||
|
class _FakeCompletions:
|
||||||
|
def __init__(self):
|
||||||
|
self.last_kwargs = None
|
||||||
|
self._stream = [_FakeCompletion("s1"), _FakeCompletion("s2")]
|
||||||
|
|
||||||
|
def create(self, **kwargs):
|
||||||
|
self.last_kwargs = kwargs
|
||||||
|
if kwargs.get("stream"):
|
||||||
|
return self._stream
|
||||||
|
return _FakeCompletion("final")
|
||||||
|
|
||||||
|
class _FakeAnthropic:
|
||||||
|
def __init__(self, api_key=None):
|
||||||
|
self.api_key = api_key
|
||||||
|
self.completions = _FakeCompletions()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def patch_anthropic(monkeypatch):
|
||||||
|
fake = types.ModuleType("anthropic")
|
||||||
|
fake.Anthropic = _FakeAnthropic
|
||||||
|
fake.HUMAN_PROMPT = "<HUMAN>"
|
||||||
|
fake.AI_PROMPT = "<AI>"
|
||||||
|
sys.modules["anthropic"] = fake
|
||||||
|
yield
|
||||||
|
sys.modules.pop("anthropic", None)
|
||||||
|
|
||||||
|
|
||||||
|
def test_anthropic_raw_gen_builds_prompt_and_returns_completion():
|
||||||
|
from application.llm.anthropic import AnthropicLLM
|
||||||
|
|
||||||
|
llm = AnthropicLLM(api_key="k")
|
||||||
|
msgs = [
|
||||||
|
{"content": "ctx"},
|
||||||
|
{"content": "q"},
|
||||||
|
]
|
||||||
|
out = llm._raw_gen(llm, model="claude-2", messages=msgs, stream=False, max_tokens=55)
|
||||||
|
assert out == "final"
|
||||||
|
last = llm.anthropic.completions.last_kwargs
|
||||||
|
assert last["model"] == "claude-2"
|
||||||
|
assert last["max_tokens_to_sample"] == 55
|
||||||
|
assert last["prompt"].startswith("<HUMAN>") and last["prompt"].endswith("<AI>")
|
||||||
|
assert "### Context" in last["prompt"] and "### Question" in last["prompt"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_anthropic_raw_gen_stream_yields_chunks():
|
||||||
|
from application.llm.anthropic import AnthropicLLM
|
||||||
|
|
||||||
|
llm = AnthropicLLM(api_key="k")
|
||||||
|
msgs = [
|
||||||
|
{"content": "ctx"},
|
||||||
|
{"content": "q"},
|
||||||
|
]
|
||||||
|
gen = llm._raw_gen_stream(llm, model="claude", messages=msgs, stream=True, max_tokens=10)
|
||||||
|
chunks = list(gen)
|
||||||
|
assert chunks == ["s1", "s2"]
|
||||||
|
|
||||||
152
tests/llm/test_google_llm.py
Normal file
152
tests/llm/test_google_llm.py
Normal file
@@ -0,0 +1,152 @@
|
|||||||
|
import sys
|
||||||
|
import types
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from application.llm.google_ai import GoogleLLM
|
||||||
|
|
||||||
|
class _FakePart:
|
||||||
|
def __init__(self, text=None, function_call=None, file_data=None):
|
||||||
|
self.text = text
|
||||||
|
self.function_call = function_call
|
||||||
|
self.file_data = file_data
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_text(text):
|
||||||
|
return _FakePart(text=text)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_function_call(name, args):
|
||||||
|
return _FakePart(function_call=types.SimpleNamespace(name=name, args=args))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_function_response(name, response):
|
||||||
|
# not used in assertions but present for completeness
|
||||||
|
return _FakePart(function_call=None, text=str(response))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_uri(file_uri, mime_type):
|
||||||
|
# mimic presence of file data for streaming detection
|
||||||
|
return _FakePart(file_data=types.SimpleNamespace(file_uri=file_uri, mime_type=mime_type))
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeContent:
|
||||||
|
def __init__(self, role, parts):
|
||||||
|
self.role = role
|
||||||
|
self.parts = parts
|
||||||
|
|
||||||
|
|
||||||
|
class FakeTypesModule:
|
||||||
|
Part = _FakePart
|
||||||
|
Content = _FakeContent
|
||||||
|
|
||||||
|
class GenerateContentConfig:
|
||||||
|
def __init__(self):
|
||||||
|
self.system_instruction = None
|
||||||
|
self.tools = None
|
||||||
|
self.response_schema = None
|
||||||
|
self.response_mime_type = None
|
||||||
|
|
||||||
|
|
||||||
|
class FakeModels:
|
||||||
|
def __init__(self):
|
||||||
|
self.last_args = None
|
||||||
|
self.last_kwargs = None
|
||||||
|
|
||||||
|
class _Resp:
|
||||||
|
def __init__(self, text=None, candidates=None):
|
||||||
|
self.text = text
|
||||||
|
self.candidates = candidates or []
|
||||||
|
|
||||||
|
def generate_content(self, *args, **kwargs):
|
||||||
|
self.last_args, self.last_kwargs = args, kwargs
|
||||||
|
return FakeModels._Resp(text="ok")
|
||||||
|
|
||||||
|
def generate_content_stream(self, *args, **kwargs):
|
||||||
|
self.last_args, self.last_kwargs = args, kwargs
|
||||||
|
# Simulate stream of text parts
|
||||||
|
part1 = types.SimpleNamespace(text="a", candidates=None)
|
||||||
|
part2 = types.SimpleNamespace(text="b", candidates=None)
|
||||||
|
return [part1, part2]
|
||||||
|
|
||||||
|
|
||||||
|
class FakeClient:
|
||||||
|
def __init__(self, *_, **__):
|
||||||
|
self.models = FakeModels()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def patch_google_modules(monkeypatch):
|
||||||
|
# Patch the types module used by GoogleLLM
|
||||||
|
import application.llm.google_ai as gmod
|
||||||
|
monkeypatch.setattr(gmod, "types", FakeTypesModule)
|
||||||
|
monkeypatch.setattr(gmod.genai, "Client", FakeClient)
|
||||||
|
|
||||||
|
|
||||||
|
def test_clean_messages_google_basic():
|
||||||
|
llm = GoogleLLM(api_key="key")
|
||||||
|
msgs = [
|
||||||
|
{"role": "assistant", "content": "hi"},
|
||||||
|
{"role": "user", "content": [
|
||||||
|
{"text": "hello"},
|
||||||
|
{"files": [{"file_uri": "gs://x", "mime_type": "image/png"}]},
|
||||||
|
{"function_call": {"name": "fn", "args": {"a": 1}}},
|
||||||
|
]},
|
||||||
|
]
|
||||||
|
cleaned = llm._clean_messages_google(msgs)
|
||||||
|
|
||||||
|
assert all(hasattr(c, "role") and hasattr(c, "parts") for c in cleaned)
|
||||||
|
assert any(c.role == "model" for c in cleaned)
|
||||||
|
assert any(hasattr(p, "text") for c in cleaned for p in c.parts)
|
||||||
|
|
||||||
|
|
||||||
|
def test_raw_gen_calls_google_client_and_returns_text():
|
||||||
|
llm = GoogleLLM(api_key="key")
|
||||||
|
msgs = [{"role": "user", "content": "hello"}]
|
||||||
|
out = llm._raw_gen(llm, model="gemini-2.0", messages=msgs, stream=False)
|
||||||
|
assert out == "ok"
|
||||||
|
|
||||||
|
|
||||||
|
def test_raw_gen_stream_yields_chunks():
|
||||||
|
llm = GoogleLLM(api_key="key")
|
||||||
|
msgs = [{"role": "user", "content": "hello"}]
|
||||||
|
gen = llm._raw_gen_stream(llm, model="gemini", messages=msgs, stream=True)
|
||||||
|
assert list(gen) == ["a", "b"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_prepare_structured_output_format_type_mapping():
|
||||||
|
llm = GoogleLLM(api_key="key")
|
||||||
|
schema = {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"a": {"type": "string"},
|
||||||
|
"b": {"type": "array", "items": {"type": "integer"}},
|
||||||
|
},
|
||||||
|
"required": ["a"],
|
||||||
|
}
|
||||||
|
out = llm.prepare_structured_output_format(schema)
|
||||||
|
assert out["type"] == "OBJECT"
|
||||||
|
assert out["properties"]["a"]["type"] == "STRING"
|
||||||
|
assert out["properties"]["b"]["type"] == "ARRAY"
|
||||||
|
|
||||||
|
|
||||||
|
def test_prepare_messages_with_attachments_appends_files(monkeypatch):
|
||||||
|
llm = GoogleLLM(api_key="key")
|
||||||
|
llm.storage = types.SimpleNamespace(
|
||||||
|
file_exists=lambda path: True,
|
||||||
|
process_file=lambda path, processor_func, **kwargs: "gs://file_uri"
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(llm, "_upload_file_to_google", lambda att: "gs://file_uri")
|
||||||
|
|
||||||
|
messages = [{"role": "user", "content": "Hi"}]
|
||||||
|
attachments = [
|
||||||
|
{"path": "/tmp/img.png", "mime_type": "image/png"},
|
||||||
|
{"path": "/tmp/doc.pdf", "mime_type": "application/pdf"},
|
||||||
|
]
|
||||||
|
|
||||||
|
out = llm.prepare_messages_with_attachments(messages, attachments)
|
||||||
|
user_msg = next(m for m in out if m["role"] == "user")
|
||||||
|
assert isinstance(user_msg["content"], list)
|
||||||
|
files_entry = next((p for p in user_msg["content"] if isinstance(p, dict) and "files" in p), None)
|
||||||
|
assert files_entry is not None
|
||||||
|
assert isinstance(files_entry["files"], list) and len(files_entry["files"]) == 2
|
||||||
|
|
||||||
@@ -1,11 +0,0 @@
|
|||||||
import unittest
|
|
||||||
from application.llm.openai import OpenAILLM
|
|
||||||
|
|
||||||
class TestOpenAILLM(unittest.TestCase):
|
|
||||||
|
|
||||||
def setUp(self):
|
|
||||||
self.api_key = "test_api_key"
|
|
||||||
self.llm = OpenAILLM(self.api_key)
|
|
||||||
|
|
||||||
def test_init(self):
|
|
||||||
self.assertEqual(self.llm.api_key, self.api_key)
|
|
||||||
158
tests/llm/test_openai_llm.py
Normal file
158
tests/llm/test_openai_llm.py
Normal file
@@ -0,0 +1,158 @@
|
|||||||
|
import json
|
||||||
|
import types
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from application.llm.openai import OpenAILLM
|
||||||
|
|
||||||
|
|
||||||
|
class FakeChatCompletions:
|
||||||
|
def __init__(self):
|
||||||
|
self.last_kwargs = None
|
||||||
|
|
||||||
|
class _Msg:
|
||||||
|
def __init__(self, content=None, tool_calls=None):
|
||||||
|
self.content = content
|
||||||
|
self.tool_calls = tool_calls
|
||||||
|
|
||||||
|
class _Delta:
|
||||||
|
def __init__(self, content=None):
|
||||||
|
self.content = content
|
||||||
|
|
||||||
|
class _Choice:
|
||||||
|
def __init__(self, content=None, delta=None, finish_reason="stop"):
|
||||||
|
self.message = FakeChatCompletions._Msg(content=content)
|
||||||
|
self.delta = FakeChatCompletions._Delta(content=delta)
|
||||||
|
self.finish_reason = finish_reason
|
||||||
|
|
||||||
|
class _StreamLine:
|
||||||
|
def __init__(self, deltas):
|
||||||
|
self.choices = [FakeChatCompletions._Choice(delta=d) for d in deltas]
|
||||||
|
|
||||||
|
class _Response:
|
||||||
|
def __init__(self, choices=None, lines=None):
|
||||||
|
self._choices = choices or []
|
||||||
|
self._lines = lines or []
|
||||||
|
|
||||||
|
@property
|
||||||
|
def choices(self):
|
||||||
|
return self._choices
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
for line in self._lines:
|
||||||
|
yield line
|
||||||
|
|
||||||
|
def create(self, **kwargs):
|
||||||
|
self.last_kwargs = kwargs
|
||||||
|
# default non-streaming: return content
|
||||||
|
if not kwargs.get("stream"):
|
||||||
|
return FakeChatCompletions._Response(choices=[
|
||||||
|
FakeChatCompletions._Choice(content="hello world")
|
||||||
|
])
|
||||||
|
# streaming: yield line objects each with choices[0].delta.content
|
||||||
|
return FakeChatCompletions._Response(lines=[
|
||||||
|
FakeChatCompletions._StreamLine(["part1"]),
|
||||||
|
FakeChatCompletions._StreamLine(["part2"]),
|
||||||
|
])
|
||||||
|
|
||||||
|
|
||||||
|
class FakeClient:
|
||||||
|
def __init__(self):
|
||||||
|
self.chat = types.SimpleNamespace(completions=FakeChatCompletions())
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def openai_llm(monkeypatch):
|
||||||
|
llm = OpenAILLM(api_key="sk-test", user_api_key=None)
|
||||||
|
llm.storage = types.SimpleNamespace(
|
||||||
|
get_file=lambda path: types.SimpleNamespace(read=lambda: b"img"),
|
||||||
|
file_exists=lambda path: True,
|
||||||
|
process_file=lambda path, processor_func, **kwargs: "file_id_123",
|
||||||
|
)
|
||||||
|
llm.client = FakeClient()
|
||||||
|
return llm
|
||||||
|
|
||||||
|
|
||||||
|
def test_clean_messages_openai_variants(openai_llm):
|
||||||
|
messages = [
|
||||||
|
{"role": "system", "content": "sys"},
|
||||||
|
{"role": "model", "content": "asst"},
|
||||||
|
{"role": "user", "content": [
|
||||||
|
{"text": "hello"},
|
||||||
|
{"function_call": {"call_id": "c1", "name": "fn", "args": {"a": 1}}},
|
||||||
|
{"function_response": {"call_id": "c1", "name": "fn", "response": {"result": 42}}},
|
||||||
|
{"type": "image_url", "image_url": {"url": ""}},
|
||||||
|
]},
|
||||||
|
]
|
||||||
|
|
||||||
|
cleaned = openai_llm._clean_messages_openai(messages)
|
||||||
|
|
||||||
|
roles = [m["role"] for m in cleaned]
|
||||||
|
assert roles.count("assistant") >= 1
|
||||||
|
assert any(m["role"] == "tool" for m in cleaned)
|
||||||
|
|
||||||
|
assert any(isinstance(m["content"], list) and any(
|
||||||
|
part.get("type") == "image_url" for part in m["content"] if isinstance(part, dict)
|
||||||
|
) for m in cleaned if m["role"] == "user")
|
||||||
|
|
||||||
|
|
||||||
|
def test_raw_gen_calls_openai_client_and_returns_content(openai_llm):
|
||||||
|
msgs = [
|
||||||
|
{"role": "system", "content": "sys"},
|
||||||
|
{"role": "user", "content": "hello"},
|
||||||
|
]
|
||||||
|
content = openai_llm._raw_gen(openai_llm, model="gpt-4o", messages=msgs, stream=False)
|
||||||
|
assert content == "hello world"
|
||||||
|
|
||||||
|
passed = openai_llm.client.chat.completions.last_kwargs
|
||||||
|
assert passed["model"] == "gpt-4o"
|
||||||
|
assert isinstance(passed["messages"], list)
|
||||||
|
assert passed["stream"] is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_raw_gen_stream_yields_chunks(openai_llm):
|
||||||
|
msgs = [
|
||||||
|
{"role": "user", "content": "hi"},
|
||||||
|
]
|
||||||
|
gen = openai_llm._raw_gen_stream(openai_llm, model="gpt", messages=msgs, stream=True)
|
||||||
|
chunks = list(gen)
|
||||||
|
assert "part1" in "".join(chunks)
|
||||||
|
assert "part2" in "".join(chunks)
|
||||||
|
|
||||||
|
|
||||||
|
def test_prepare_structured_output_format_enforces_required_and_strict(openai_llm):
|
||||||
|
schema = {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"a": {"type": "string"},
|
||||||
|
"b": {"type": "number"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
result = openai_llm.prepare_structured_output_format(schema)
|
||||||
|
assert result["type"] == "json_schema"
|
||||||
|
js = result["json_schema"]
|
||||||
|
assert js["strict"] is True
|
||||||
|
assert set(js["schema"]["required"]) == {"a", "b"}
|
||||||
|
assert js["schema"]["additionalProperties"] is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_prepare_messages_with_attachments_image_and_pdf(openai_llm, monkeypatch):
|
||||||
|
|
||||||
|
monkeypatch.setattr(openai_llm, "_get_base64_image", lambda att: "AAA=")
|
||||||
|
monkeypatch.setattr(openai_llm, "_upload_file_to_openai", lambda att: "file_xyz")
|
||||||
|
|
||||||
|
messages = [{"role": "user", "content": "Hi"}]
|
||||||
|
attachments = [
|
||||||
|
{"path": "/tmp/img.png", "mime_type": "image/png"},
|
||||||
|
{"path": "/tmp/doc.pdf", "mime_type": "application/pdf"},
|
||||||
|
]
|
||||||
|
out = openai_llm.prepare_messages_with_attachments(messages, attachments)
|
||||||
|
|
||||||
|
# last user message should have list content with text and two attachments
|
||||||
|
user_msg = next(m for m in out if m["role"] == "user")
|
||||||
|
assert isinstance(user_msg["content"], list)
|
||||||
|
types_in_content = [p.get("type") for p in user_msg["content"] if isinstance(p, dict)]
|
||||||
|
assert "image_url" in types_in_content or any(
|
||||||
|
isinstance(p, dict) and p.get("image_url") for p in user_msg["content"]
|
||||||
|
)
|
||||||
|
assert any(isinstance(p, dict) and p.get("file", {}).get("file_id") == "file_xyz" for p in user_msg["content"])
|
||||||
|
|
||||||
Reference in New Issue
Block a user