mirror of
https://github.com/arc53/DocsGPT.git
synced 2026-05-06 16:25:04 +00:00
chore: handlers tests
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
from unittest.mock import Mock
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from application.agents.classic_agent import ClassicAgent
|
||||
@@ -56,6 +56,73 @@ class TestBaseAgentInitialization:
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
assert agent.user == "user123"
|
||||
|
||||
def test_dependency_injection_llm(self, agent_base_params, mock_llm_handler_creator):
|
||||
"""When llm is provided, LLMCreator.create_llm is NOT called."""
|
||||
injected_llm = Mock()
|
||||
agent_base_params["llm"] = injected_llm
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
assert agent.llm is injected_llm
|
||||
|
||||
def test_dependency_injection_llm_handler(self, agent_base_params, mock_llm_creator):
|
||||
"""When llm_handler is provided, LLMHandlerCreator is NOT called."""
|
||||
injected_handler = Mock()
|
||||
agent_base_params["llm_handler"] = injected_handler
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
assert agent.llm_handler is injected_handler
|
||||
|
||||
def test_dependency_injection_tool_executor(
|
||||
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
|
||||
):
|
||||
"""When tool_executor is provided, a new one is NOT created."""
|
||||
injected_executor = Mock()
|
||||
injected_executor.tool_calls = []
|
||||
agent_base_params["tool_executor"] = injected_executor
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
assert agent.tool_executor is injected_executor
|
||||
|
||||
def test_json_schema_normalized(
|
||||
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
|
||||
):
|
||||
agent_base_params["json_schema"] = {"type": "object"}
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
assert agent.json_schema == {"type": "object"}
|
||||
|
||||
def test_json_schema_wrapped(
|
||||
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
|
||||
):
|
||||
agent_base_params["json_schema"] = {"schema": {"type": "string"}}
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
assert agent.json_schema == {"type": "string"}
|
||||
|
||||
def test_json_schema_invalid_ignored(
|
||||
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
|
||||
):
|
||||
agent_base_params["json_schema"] = {"bad": "no type"}
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
assert agent.json_schema is None
|
||||
|
||||
def test_retrieved_docs_defaults_to_empty(
|
||||
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
|
||||
):
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
assert agent.retrieved_docs == []
|
||||
|
||||
def test_attachments_defaults_to_empty(
|
||||
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
|
||||
):
|
||||
agent_base_params["attachments"] = None
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
assert agent.attachments == []
|
||||
|
||||
def test_limited_token_mode_defaults(
|
||||
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
|
||||
):
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
assert agent.limited_token_mode is False
|
||||
assert agent.limited_request_mode is False
|
||||
assert agent.current_token_count == 0
|
||||
assert agent.context_limit_reached is False
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestBaseAgentBuildMessages:
|
||||
@@ -602,3 +669,656 @@ class TestBaseAgentHandleResponse:
|
||||
assert len(results) == 2
|
||||
assert results[0]["type"] == "tool_call"
|
||||
assert results[1]["answer"] == "Final answer"
|
||||
|
||||
def test_handle_response_dict_event_passthrough(
|
||||
self,
|
||||
agent_base_params,
|
||||
mock_llm_handler,
|
||||
mock_llm_creator,
|
||||
mock_llm_handler_creator,
|
||||
log_context,
|
||||
):
|
||||
"""Dict events with 'type' key pass through without wrapping."""
|
||||
|
||||
def mock_process(*args):
|
||||
yield {"type": "info", "data": {"message": "processing"}}
|
||||
|
||||
mock_llm_handler.process_message_flow = Mock(side_effect=mock_process)
|
||||
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
response = Mock()
|
||||
response.message = None
|
||||
|
||||
results = list(agent._handle_response(response, {}, [], log_context))
|
||||
assert results == [{"type": "info", "data": {"message": "processing"}}]
|
||||
|
||||
def test_handle_response_message_object_from_handler(
|
||||
self,
|
||||
agent_base_params,
|
||||
mock_llm_handler,
|
||||
mock_llm_creator,
|
||||
mock_llm_handler_creator,
|
||||
log_context,
|
||||
):
|
||||
"""Response objects with .message.content from handler are unwrapped."""
|
||||
event = Mock()
|
||||
event.message = Mock()
|
||||
event.message.content = "from handler"
|
||||
|
||||
def mock_process(*args):
|
||||
yield event
|
||||
|
||||
mock_llm_handler.process_message_flow = Mock(side_effect=mock_process)
|
||||
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
response = Mock()
|
||||
response.message = None
|
||||
|
||||
results = list(agent._handle_response(response, {}, [], log_context))
|
||||
assert results[0]["answer"] == "from handler"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# gen() — the @log_activity decorated entry point
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestBaseAgentGen:
|
||||
|
||||
def test_gen_delegates_to_gen_inner(
|
||||
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
|
||||
):
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
|
||||
# ClassicAgent._gen_inner is abstract — we patch it
|
||||
with patch.object(agent, "_gen_inner") as mock_inner:
|
||||
mock_inner.return_value = iter([{"answer": "ok"}])
|
||||
results = list(agent.gen("hello"))
|
||||
|
||||
assert any(r.get("answer") == "ok" for r in results)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# tool_calls property
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestBaseAgentToolCallsProperty:
|
||||
|
||||
def test_getter(
|
||||
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
|
||||
):
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
agent.tool_executor.tool_calls = ["a", "b"]
|
||||
assert agent.tool_calls == ["a", "b"]
|
||||
|
||||
def test_setter(
|
||||
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
|
||||
):
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
agent.tool_calls = ["x"]
|
||||
assert agent.tool_executor.tool_calls == ["x"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _calculate_current_context_tokens
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestCalculateContextTokens:
|
||||
|
||||
def test_delegates_to_token_counter(
|
||||
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
|
||||
):
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
messages = [{"role": "user", "content": "hello"}]
|
||||
|
||||
with patch(
|
||||
"application.api.answer.services.compression.token_counter.TokenCounter"
|
||||
) as MockTC:
|
||||
MockTC.count_message_tokens.return_value = 42
|
||||
result = agent._calculate_current_context_tokens(messages)
|
||||
assert result == 42
|
||||
MockTC.count_message_tokens.assert_called_once_with(messages)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _check_context_limit
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestCheckContextLimit:
|
||||
|
||||
def _make_agent(self, agent_base_params, mock_llm_creator, mock_llm_handler_creator):
|
||||
return ClassicAgent(**agent_base_params)
|
||||
|
||||
def test_below_threshold_returns_false(
|
||||
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
|
||||
):
|
||||
agent = self._make_agent(
|
||||
agent_base_params, mock_llm_creator, mock_llm_handler_creator
|
||||
)
|
||||
messages = [{"role": "user", "content": "hi"}]
|
||||
|
||||
with patch.object(agent, "_calculate_current_context_tokens", return_value=100):
|
||||
with patch(
|
||||
"application.core.model_utils.get_token_limit", return_value=10000
|
||||
):
|
||||
result = agent._check_context_limit(messages)
|
||||
assert result is False
|
||||
assert agent.current_token_count == 100
|
||||
|
||||
def test_at_threshold_returns_true(
|
||||
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
|
||||
):
|
||||
agent = self._make_agent(
|
||||
agent_base_params, mock_llm_creator, mock_llm_handler_creator
|
||||
)
|
||||
messages = [{"role": "user", "content": "hi"}]
|
||||
|
||||
# threshold = 10000 * 0.8 = 8000; tokens = 8001 → True
|
||||
with patch.object(agent, "_calculate_current_context_tokens", return_value=8001):
|
||||
with patch(
|
||||
"application.core.model_utils.get_token_limit", return_value=10000
|
||||
):
|
||||
result = agent._check_context_limit(messages)
|
||||
assert result is True
|
||||
|
||||
def test_error_returns_false(
|
||||
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
|
||||
):
|
||||
agent = self._make_agent(
|
||||
agent_base_params, mock_llm_creator, mock_llm_handler_creator
|
||||
)
|
||||
with patch.object(
|
||||
agent,
|
||||
"_calculate_current_context_tokens",
|
||||
side_effect=RuntimeError("boom"),
|
||||
):
|
||||
result = agent._check_context_limit([])
|
||||
assert result is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _validate_context_size
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestValidateContextSize:
|
||||
|
||||
def test_at_limit_logs_warning(
|
||||
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
|
||||
):
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
with patch.object(agent, "_calculate_current_context_tokens", return_value=10000):
|
||||
with patch(
|
||||
"application.core.model_utils.get_token_limit", return_value=10000
|
||||
):
|
||||
# Should not raise
|
||||
agent._validate_context_size([{"role": "user", "content": "x"}])
|
||||
assert agent.current_token_count == 10000
|
||||
|
||||
def test_below_threshold_no_warning(
|
||||
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
|
||||
):
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
with patch.object(agent, "_calculate_current_context_tokens", return_value=100):
|
||||
with patch(
|
||||
"application.core.model_utils.get_token_limit", return_value=10000
|
||||
):
|
||||
agent._validate_context_size([])
|
||||
assert agent.current_token_count == 100
|
||||
|
||||
def test_approaching_threshold(
|
||||
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
|
||||
):
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
# 8500 / 10000 = 85% → above 80% threshold but below 100%
|
||||
with patch.object(agent, "_calculate_current_context_tokens", return_value=8500):
|
||||
with patch(
|
||||
"application.core.model_utils.get_token_limit", return_value=10000
|
||||
):
|
||||
agent._validate_context_size([])
|
||||
assert agent.current_token_count == 8500
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _truncate_text_middle
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestTruncateTextMiddle:
|
||||
|
||||
def test_short_text_unchanged(
|
||||
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
|
||||
):
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
with patch("application.utils.num_tokens_from_string", return_value=5):
|
||||
result = agent._truncate_text_middle("short", max_tokens=100)
|
||||
assert result == "short"
|
||||
|
||||
def test_long_text_truncated(
|
||||
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
|
||||
):
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
long_text = "A" * 1000
|
||||
|
||||
def fake_tokens(text):
|
||||
return len(text) // 4
|
||||
|
||||
with patch("application.utils.num_tokens_from_string", side_effect=fake_tokens):
|
||||
result = agent._truncate_text_middle(long_text, max_tokens=50)
|
||||
assert "[... content truncated to fit context limit ...]" in result
|
||||
assert len(result) < len(long_text)
|
||||
|
||||
def test_zero_target_returns_empty(
|
||||
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
|
||||
):
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
with patch("application.utils.num_tokens_from_string", return_value=100):
|
||||
result = agent._truncate_text_middle("some text", max_tokens=0)
|
||||
assert result == ""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _truncate_history_to_fit
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestTruncateHistoryToFit:
|
||||
|
||||
def _make_agent(self, agent_base_params, mock_llm_creator, mock_llm_handler_creator):
|
||||
return ClassicAgent(**agent_base_params)
|
||||
|
||||
def test_empty_history(
|
||||
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
|
||||
):
|
||||
agent = self._make_agent(
|
||||
agent_base_params, mock_llm_creator, mock_llm_handler_creator
|
||||
)
|
||||
assert agent._truncate_history_to_fit([], 100) == []
|
||||
|
||||
def test_zero_budget(
|
||||
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
|
||||
):
|
||||
agent = self._make_agent(
|
||||
agent_base_params, mock_llm_creator, mock_llm_handler_creator
|
||||
)
|
||||
history = [{"prompt": "a", "response": "b"}]
|
||||
assert agent._truncate_history_to_fit(history, 0) == []
|
||||
|
||||
def test_fits_all(
|
||||
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
|
||||
):
|
||||
agent = self._make_agent(
|
||||
agent_base_params, mock_llm_creator, mock_llm_handler_creator
|
||||
)
|
||||
history = [
|
||||
{"prompt": "q1", "response": "a1"},
|
||||
{"prompt": "q2", "response": "a2"},
|
||||
]
|
||||
with patch("application.utils.num_tokens_from_string", return_value=5):
|
||||
result = agent._truncate_history_to_fit(history, 10000)
|
||||
assert len(result) == 2
|
||||
|
||||
def test_partial_fit_keeps_most_recent(
|
||||
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
|
||||
):
|
||||
agent = self._make_agent(
|
||||
agent_base_params, mock_llm_creator, mock_llm_handler_creator
|
||||
)
|
||||
history = [
|
||||
{"prompt": "old", "response": "old_ans"},
|
||||
{"prompt": "new", "response": "new_ans"},
|
||||
]
|
||||
# Each message = 10 tokens (prompt + response), budget = 15 → only 1 fits
|
||||
with patch("application.utils.num_tokens_from_string", return_value=5):
|
||||
result = agent._truncate_history_to_fit(history, 15)
|
||||
assert len(result) == 1
|
||||
assert result[0]["prompt"] == "new"
|
||||
|
||||
def test_history_with_tool_calls(
|
||||
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
|
||||
):
|
||||
agent = self._make_agent(
|
||||
agent_base_params, mock_llm_creator, mock_llm_handler_creator
|
||||
)
|
||||
history = [
|
||||
{
|
||||
"prompt": "q",
|
||||
"response": "a",
|
||||
"tool_calls": [
|
||||
{
|
||||
"tool_name": "t",
|
||||
"action_name": "act",
|
||||
"arguments": "{}",
|
||||
"result": "ok",
|
||||
}
|
||||
],
|
||||
}
|
||||
]
|
||||
with patch("application.utils.num_tokens_from_string", return_value=3):
|
||||
result = agent._truncate_history_to_fit(history, 100)
|
||||
assert len(result) == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _build_messages — compressed_summary and query truncation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestBuildMessagesAdvanced:
|
||||
|
||||
def test_compressed_summary_appended(
|
||||
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
|
||||
):
|
||||
agent_base_params["compressed_summary"] = "Previous conversation summary"
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
|
||||
with patch(
|
||||
"application.core.model_utils.get_token_limit", return_value=100000
|
||||
), patch("application.utils.num_tokens_from_string", return_value=10):
|
||||
messages = agent._build_messages("System prompt", "query")
|
||||
|
||||
system_content = messages[0]["content"]
|
||||
assert "Previous conversation summary" in system_content
|
||||
|
||||
def test_query_truncated_when_too_large(
|
||||
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
|
||||
):
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
|
||||
call_count = {"n": 0}
|
||||
|
||||
def fake_tokens(text):
|
||||
call_count["n"] += 1
|
||||
return len(text)
|
||||
|
||||
with patch(
|
||||
"application.core.model_utils.get_token_limit", return_value=200
|
||||
), patch("application.utils.num_tokens_from_string", side_effect=fake_tokens):
|
||||
with patch.object(agent, "_truncate_text_middle", return_value="truncated"):
|
||||
with patch.object(agent, "_truncate_history_to_fit", return_value=[]):
|
||||
messages = agent._build_messages("sys", "A" * 500)
|
||||
|
||||
# The method should have been called for truncation
|
||||
assert messages[-1]["role"] == "user"
|
||||
|
||||
def test_build_messages_with_tool_call_missing_call_id(
|
||||
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
|
||||
):
|
||||
"""Tool calls without call_id get a generated UUID."""
|
||||
history = [
|
||||
{
|
||||
"tool_calls": [
|
||||
{
|
||||
"action_name": "search",
|
||||
"arguments": "{}",
|
||||
"result": "found",
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
agent_base_params["chat_history"] = history
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
|
||||
with patch(
|
||||
"application.core.model_utils.get_token_limit", return_value=100000
|
||||
), patch("application.utils.num_tokens_from_string", return_value=5):
|
||||
messages = agent._build_messages("sys", "q")
|
||||
|
||||
tool_msgs = [m for m in messages if m["role"] == "tool"]
|
||||
assert len(tool_msgs) == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _llm_gen — edge cases
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestLLMGenAdvanced:
|
||||
|
||||
def test_llm_gen_with_attachments(
|
||||
self,
|
||||
agent_base_params,
|
||||
mock_llm,
|
||||
mock_llm_creator,
|
||||
mock_llm_handler_creator,
|
||||
):
|
||||
agent_base_params["attachments"] = [{"id": "att1", "mime_type": "image/png"}]
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
|
||||
messages = [{"role": "user", "content": "test"}]
|
||||
agent._llm_gen(messages)
|
||||
|
||||
call_kwargs = mock_llm.gen_stream.call_args[1]
|
||||
assert "_usage_attachments" in call_kwargs
|
||||
|
||||
def test_llm_gen_without_log_context(
|
||||
self,
|
||||
agent_base_params,
|
||||
mock_llm,
|
||||
mock_llm_creator,
|
||||
mock_llm_handler_creator,
|
||||
):
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
messages = [{"role": "user", "content": "test"}]
|
||||
|
||||
# Should not raise even without log_context
|
||||
agent._llm_gen(messages, log_context=None)
|
||||
mock_llm.gen_stream.assert_called_once()
|
||||
|
||||
def test_llm_gen_google_structured_output(
|
||||
self,
|
||||
agent_base_params,
|
||||
mock_llm,
|
||||
mock_llm_creator,
|
||||
mock_llm_handler_creator,
|
||||
log_context,
|
||||
):
|
||||
mock_llm._supports_structured_output = Mock(return_value=True)
|
||||
mock_llm.prepare_structured_output_format = Mock(
|
||||
return_value={"schema": "test"}
|
||||
)
|
||||
|
||||
agent_base_params["json_schema"] = {"type": "object"}
|
||||
agent_base_params["llm_name"] = "google"
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
|
||||
messages = [{"role": "user", "content": "test"}]
|
||||
agent._llm_gen(messages, log_context)
|
||||
|
||||
call_kwargs = mock_llm.gen_stream.call_args[1]
|
||||
assert "response_schema" in call_kwargs
|
||||
|
||||
def test_llm_gen_no_tools_when_unsupported(
|
||||
self,
|
||||
agent_base_params,
|
||||
mock_llm,
|
||||
mock_llm_creator,
|
||||
mock_llm_handler_creator,
|
||||
):
|
||||
mock_llm._supports_tools = False
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
agent.tools = [{"type": "function", "function": {"name": "test"}}]
|
||||
|
||||
messages = [{"role": "user", "content": "test"}]
|
||||
agent._llm_gen(messages)
|
||||
|
||||
call_kwargs = mock_llm.gen_stream.call_args[1]
|
||||
assert "tools" not in call_kwargs
|
||||
|
||||
def test_llm_gen_no_structured_output_when_unsupported(
|
||||
self,
|
||||
agent_base_params,
|
||||
mock_llm,
|
||||
mock_llm_creator,
|
||||
mock_llm_handler_creator,
|
||||
):
|
||||
mock_llm._supports_structured_output = Mock(return_value=False)
|
||||
agent_base_params["json_schema"] = {"type": "object"}
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
|
||||
messages = [{"role": "user", "content": "test"}]
|
||||
agent._llm_gen(messages)
|
||||
|
||||
call_kwargs = mock_llm.gen_stream.call_args[1]
|
||||
assert "response_format" not in call_kwargs
|
||||
assert "response_schema" not in call_kwargs
|
||||
|
||||
def test_llm_gen_no_format_when_prepare_returns_none(
|
||||
self,
|
||||
agent_base_params,
|
||||
mock_llm,
|
||||
mock_llm_creator,
|
||||
mock_llm_handler_creator,
|
||||
):
|
||||
mock_llm._supports_structured_output = Mock(return_value=True)
|
||||
mock_llm.prepare_structured_output_format = Mock(return_value=None)
|
||||
|
||||
agent_base_params["json_schema"] = {"type": "object"}
|
||||
agent_base_params["llm_name"] = "openai"
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
|
||||
messages = [{"role": "user", "content": "test"}]
|
||||
agent._llm_gen(messages)
|
||||
|
||||
call_kwargs = mock_llm.gen_stream.call_args[1]
|
||||
assert "response_format" not in call_kwargs
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _llm_handler
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestLLMHandlerMethod:
|
||||
|
||||
def test_delegates_to_handler(
|
||||
self,
|
||||
agent_base_params,
|
||||
mock_llm_handler,
|
||||
mock_llm_creator,
|
||||
mock_llm_handler_creator,
|
||||
log_context,
|
||||
):
|
||||
mock_llm_handler.process_message_flow = Mock(return_value="result")
|
||||
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
resp = Mock()
|
||||
result = agent._llm_handler(resp, {}, [], log_context)
|
||||
|
||||
mock_llm_handler.process_message_flow.assert_called_once()
|
||||
assert result == "result"
|
||||
assert len(log_context.stacks) == 1
|
||||
assert log_context.stacks[0]["component"] == "llm_handler"
|
||||
|
||||
def test_without_log_context(
|
||||
self,
|
||||
agent_base_params,
|
||||
mock_llm_handler,
|
||||
mock_llm_creator,
|
||||
mock_llm_handler_creator,
|
||||
):
|
||||
mock_llm_handler.process_message_flow = Mock(return_value="r")
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
result = agent._llm_handler(Mock(), {}, [], log_context=None)
|
||||
assert result == "r"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _handle_response — structured output on all code paths
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestHandleResponseStructuredAllPaths:
|
||||
|
||||
def test_message_response_with_structured_output(
|
||||
self,
|
||||
agent_base_params,
|
||||
mock_llm,
|
||||
mock_llm_creator,
|
||||
mock_llm_handler_creator,
|
||||
log_context,
|
||||
):
|
||||
"""Structured output on the message.content early-return path."""
|
||||
mock_llm._supports_structured_output = Mock(return_value=True)
|
||||
agent_base_params["json_schema"] = {"type": "object"}
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
|
||||
response = Mock()
|
||||
response.message = Mock()
|
||||
response.message.content = "structured msg"
|
||||
|
||||
results = list(agent._handle_response(response, {}, [], log_context))
|
||||
assert results[0]["structured"] is True
|
||||
assert results[0]["schema"] == {"type": "object"}
|
||||
assert results[0]["answer"] == "structured msg"
|
||||
|
||||
def test_handler_string_event_with_structured_output(
|
||||
self,
|
||||
agent_base_params,
|
||||
mock_llm,
|
||||
mock_llm_handler,
|
||||
mock_llm_creator,
|
||||
mock_llm_handler_creator,
|
||||
log_context,
|
||||
):
|
||||
"""Structured output on string events from the handler."""
|
||||
mock_llm._supports_structured_output = Mock(return_value=True)
|
||||
agent_base_params["json_schema"] = {"type": "array"}
|
||||
|
||||
def mock_process(*args):
|
||||
yield "handler string"
|
||||
|
||||
mock_llm_handler.process_message_flow = Mock(side_effect=mock_process)
|
||||
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
response = Mock()
|
||||
response.message = None
|
||||
|
||||
results = list(agent._handle_response(response, {}, [], log_context))
|
||||
assert results[0]["structured"] is True
|
||||
assert results[0]["schema"] == {"type": "array"}
|
||||
|
||||
def test_handler_message_event_with_structured_output(
|
||||
self,
|
||||
agent_base_params,
|
||||
mock_llm,
|
||||
mock_llm_handler,
|
||||
mock_llm_creator,
|
||||
mock_llm_handler_creator,
|
||||
log_context,
|
||||
):
|
||||
"""Structured output on message-object events from the handler."""
|
||||
mock_llm._supports_structured_output = Mock(return_value=True)
|
||||
agent_base_params["json_schema"] = {"type": "number"}
|
||||
|
||||
event = Mock()
|
||||
event.message = Mock()
|
||||
event.message.content = "from handler msg"
|
||||
|
||||
def mock_process(*args):
|
||||
yield event
|
||||
|
||||
mock_llm_handler.process_message_flow = Mock(side_effect=mock_process)
|
||||
|
||||
agent = ClassicAgent(**agent_base_params)
|
||||
response = Mock()
|
||||
response.message = None
|
||||
|
||||
results = list(agent._handle_response(response, {}, [], log_context))
|
||||
assert results[0]["structured"] is True
|
||||
assert results[0]["schema"] == {"type": "number"}
|
||||
assert results[0]["answer"] == "from handler msg"
|
||||
|
||||
556
tests/agents/test_workflow_schemas.py
Normal file
556
tests/agents/test_workflow_schemas.py
Normal file
@@ -0,0 +1,556 @@
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import pytest
|
||||
from bson import ObjectId
|
||||
from pydantic import ValidationError
|
||||
|
||||
from application.agents.workflows.schemas import (
|
||||
AgentNodeConfig,
|
||||
AgentType,
|
||||
ConditionCase,
|
||||
ConditionNodeConfig,
|
||||
ExecutionStatus,
|
||||
NodeExecutionLog,
|
||||
NodeType,
|
||||
Position,
|
||||
StateOperation,
|
||||
Workflow,
|
||||
WorkflowCreate,
|
||||
WorkflowEdge,
|
||||
WorkflowEdgeCreate,
|
||||
WorkflowGraph,
|
||||
WorkflowNode,
|
||||
WorkflowNodeCreate,
|
||||
WorkflowRun,
|
||||
WorkflowRunCreate,
|
||||
)
|
||||
|
||||
|
||||
# ── Enum tests ───────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestNodeType:
|
||||
@pytest.mark.unit
|
||||
def test_values(self):
|
||||
assert NodeType.START == "start"
|
||||
assert NodeType.END == "end"
|
||||
assert NodeType.AGENT == "agent"
|
||||
assert NodeType.NOTE == "note"
|
||||
assert NodeType.STATE == "state"
|
||||
assert NodeType.CONDITION == "condition"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_all_members(self):
|
||||
assert set(NodeType) == {
|
||||
NodeType.START,
|
||||
NodeType.END,
|
||||
NodeType.AGENT,
|
||||
NodeType.NOTE,
|
||||
NodeType.STATE,
|
||||
NodeType.CONDITION,
|
||||
}
|
||||
|
||||
|
||||
class TestAgentType:
|
||||
@pytest.mark.unit
|
||||
def test_values(self):
|
||||
assert AgentType.CLASSIC == "classic"
|
||||
assert AgentType.REACT == "react"
|
||||
assert AgentType.AGENTIC == "agentic"
|
||||
assert AgentType.RESEARCH == "research"
|
||||
|
||||
|
||||
class TestExecutionStatus:
|
||||
@pytest.mark.unit
|
||||
def test_values(self):
|
||||
assert ExecutionStatus.PENDING == "pending"
|
||||
assert ExecutionStatus.RUNNING == "running"
|
||||
assert ExecutionStatus.COMPLETED == "completed"
|
||||
assert ExecutionStatus.FAILED == "failed"
|
||||
|
||||
|
||||
# ── Position ─────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestPosition:
|
||||
@pytest.mark.unit
|
||||
def test_defaults(self):
|
||||
p = Position()
|
||||
assert p.x == 0.0
|
||||
assert p.y == 0.0
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_custom_values(self):
|
||||
p = Position(x=10.5, y=-3.2)
|
||||
assert p.x == 10.5
|
||||
assert p.y == -3.2
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_extra_fields_forbidden(self):
|
||||
with pytest.raises(ValidationError):
|
||||
Position(x=0, y=0, z=1)
|
||||
|
||||
|
||||
# ── AgentNodeConfig ──────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestAgentNodeConfig:
|
||||
@pytest.mark.unit
|
||||
def test_defaults(self):
|
||||
c = AgentNodeConfig()
|
||||
assert c.agent_type == AgentType.CLASSIC
|
||||
assert c.llm_name is None
|
||||
assert c.system_prompt == "You are a helpful assistant."
|
||||
assert c.prompt_template == ""
|
||||
assert c.output_variable is None
|
||||
assert c.stream_to_user is True
|
||||
assert c.tools == []
|
||||
assert c.sources == []
|
||||
assert c.chunks == "2"
|
||||
assert c.retriever == ""
|
||||
assert c.model_id is None
|
||||
assert c.json_schema is None
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_custom_values(self):
|
||||
c = AgentNodeConfig(
|
||||
agent_type=AgentType.REACT,
|
||||
llm_name="gpt-4",
|
||||
tools=["search"],
|
||||
sources=["src1"],
|
||||
chunks="5",
|
||||
model_id="m1",
|
||||
json_schema={"type": "object"},
|
||||
)
|
||||
assert c.agent_type == AgentType.REACT
|
||||
assert c.llm_name == "gpt-4"
|
||||
assert c.tools == ["search"]
|
||||
assert c.sources == ["src1"]
|
||||
assert c.chunks == "5"
|
||||
assert c.model_id == "m1"
|
||||
assert c.json_schema == {"type": "object"}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_extra_fields_allowed(self):
|
||||
c = AgentNodeConfig(custom_field="value")
|
||||
assert c.custom_field == "value"
|
||||
|
||||
|
||||
# ── ConditionCase / ConditionNodeConfig ──────────────────────────────────────
|
||||
|
||||
|
||||
class TestConditionCase:
|
||||
@pytest.mark.unit
|
||||
def test_alias(self):
|
||||
c = ConditionCase(expression="x > 1", sourceHandle="handle-1")
|
||||
assert c.source_handle == "handle-1"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_defaults(self):
|
||||
c = ConditionCase(sourceHandle="h")
|
||||
assert c.name is None
|
||||
assert c.expression == ""
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_extra_forbidden(self):
|
||||
with pytest.raises(ValidationError):
|
||||
ConditionCase(sourceHandle="h", extra="nope")
|
||||
|
||||
|
||||
class TestConditionNodeConfig:
|
||||
@pytest.mark.unit
|
||||
def test_defaults(self):
|
||||
c = ConditionNodeConfig()
|
||||
assert c.mode == "simple"
|
||||
assert c.cases == []
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_with_cases(self):
|
||||
c = ConditionNodeConfig(
|
||||
mode="advanced",
|
||||
cases=[{"expression": "x > 1", "sourceHandle": "h1"}],
|
||||
)
|
||||
assert c.mode == "advanced"
|
||||
assert len(c.cases) == 1
|
||||
assert c.cases[0].source_handle == "h1"
|
||||
|
||||
|
||||
# ── StateOperation ───────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestStateOperation:
|
||||
@pytest.mark.unit
|
||||
def test_defaults(self):
|
||||
s = StateOperation()
|
||||
assert s.expression == ""
|
||||
assert s.target_variable == ""
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_extra_forbidden(self):
|
||||
with pytest.raises(ValidationError):
|
||||
StateOperation(expression="a", target_variable="b", extra="no")
|
||||
|
||||
|
||||
# ── WorkflowEdgeCreate / WorkflowEdge ───────────────────────────────────────
|
||||
|
||||
|
||||
class TestWorkflowEdgeCreate:
|
||||
@pytest.mark.unit
|
||||
def test_aliases(self):
|
||||
e = WorkflowEdgeCreate(
|
||||
id="e1",
|
||||
workflow_id="w1",
|
||||
source="n1",
|
||||
target="n2",
|
||||
sourceHandle="sh",
|
||||
targetHandle="th",
|
||||
)
|
||||
assert e.source_id == "n1"
|
||||
assert e.target_id == "n2"
|
||||
assert e.source_handle == "sh"
|
||||
assert e.target_handle == "th"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_optional_handles_default_none(self):
|
||||
e = WorkflowEdgeCreate(id="e1", workflow_id="w1", source="n1", target="n2")
|
||||
assert e.source_handle is None
|
||||
assert e.target_handle is None
|
||||
|
||||
|
||||
class TestWorkflowEdge:
|
||||
@pytest.mark.unit
|
||||
def test_objectid_conversion(self):
|
||||
oid = ObjectId()
|
||||
e = WorkflowEdge(
|
||||
_id=oid,
|
||||
id="e1",
|
||||
workflow_id="w1",
|
||||
source="n1",
|
||||
target="n2",
|
||||
)
|
||||
assert e.mongo_id == str(oid)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_string_id_passthrough(self):
|
||||
e = WorkflowEdge(
|
||||
_id="string-id",
|
||||
id="e1",
|
||||
workflow_id="w1",
|
||||
source="n1",
|
||||
target="n2",
|
||||
)
|
||||
assert e.mongo_id == "string-id"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_none_id(self):
|
||||
e = WorkflowEdge(id="e1", workflow_id="w1", source="n1", target="n2")
|
||||
assert e.mongo_id is None
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_to_mongo_doc(self):
|
||||
e = WorkflowEdge(
|
||||
id="e1",
|
||||
workflow_id="w1",
|
||||
source="n1",
|
||||
target="n2",
|
||||
sourceHandle="sh",
|
||||
targetHandle="th",
|
||||
)
|
||||
doc = e.to_mongo_doc()
|
||||
assert doc == {
|
||||
"id": "e1",
|
||||
"workflow_id": "w1",
|
||||
"source_id": "n1",
|
||||
"target_id": "n2",
|
||||
"source_handle": "sh",
|
||||
"target_handle": "th",
|
||||
}
|
||||
|
||||
|
||||
# ── WorkflowNodeCreate / WorkflowNode ───────────────────────────────────────
|
||||
|
||||
|
||||
class TestWorkflowNodeCreate:
|
||||
@pytest.mark.unit
|
||||
def test_defaults(self):
|
||||
n = WorkflowNodeCreate(id="n1", workflow_id="w1", type=NodeType.AGENT)
|
||||
assert n.title == "Node"
|
||||
assert n.description is None
|
||||
assert n.position.x == 0.0
|
||||
assert n.position.y == 0.0
|
||||
assert n.config == {}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_position_from_dict(self):
|
||||
n = WorkflowNodeCreate(
|
||||
id="n1",
|
||||
workflow_id="w1",
|
||||
type=NodeType.START,
|
||||
position={"x": 100, "y": 200},
|
||||
)
|
||||
assert isinstance(n.position, Position)
|
||||
assert n.position.x == 100
|
||||
assert n.position.y == 200
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_position_from_position_object(self):
|
||||
pos = Position(x=5, y=10)
|
||||
n = WorkflowNodeCreate(
|
||||
id="n1", workflow_id="w1", type=NodeType.END, position=pos
|
||||
)
|
||||
assert n.position is pos
|
||||
|
||||
|
||||
class TestWorkflowNode:
|
||||
@pytest.mark.unit
|
||||
def test_objectid_conversion(self):
|
||||
oid = ObjectId()
|
||||
n = WorkflowNode(
|
||||
_id=oid, id="n1", workflow_id="w1", type=NodeType.AGENT
|
||||
)
|
||||
assert n.mongo_id == str(oid)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_to_mongo_doc(self):
|
||||
n = WorkflowNode(
|
||||
id="n1",
|
||||
workflow_id="w1",
|
||||
type=NodeType.AGENT,
|
||||
title="My Node",
|
||||
description="desc",
|
||||
position={"x": 10, "y": 20},
|
||||
config={"key": "val"},
|
||||
)
|
||||
doc = n.to_mongo_doc()
|
||||
assert doc == {
|
||||
"id": "n1",
|
||||
"workflow_id": "w1",
|
||||
"type": "agent",
|
||||
"title": "My Node",
|
||||
"description": "desc",
|
||||
"position": {"x": 10.0, "y": 20.0},
|
||||
"config": {"key": "val"},
|
||||
}
|
||||
|
||||
|
||||
# ── WorkflowCreate / Workflow ───────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestWorkflowCreate:
|
||||
@pytest.mark.unit
|
||||
def test_defaults(self):
|
||||
w = WorkflowCreate()
|
||||
assert w.name == "New Workflow"
|
||||
assert w.description is None
|
||||
assert w.user is None
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_custom_values(self):
|
||||
w = WorkflowCreate(name="Test", description="d", user="u1")
|
||||
assert w.name == "Test"
|
||||
assert w.description == "d"
|
||||
assert w.user == "u1"
|
||||
|
||||
|
||||
class TestWorkflow:
|
||||
@pytest.mark.unit
|
||||
def test_objectid_conversion(self):
|
||||
oid = ObjectId()
|
||||
w = Workflow(_id=oid)
|
||||
assert w.id == str(oid)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_string_id(self):
|
||||
w = Workflow(_id="abc")
|
||||
assert w.id == "abc"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_none_id(self):
|
||||
w = Workflow()
|
||||
assert w.id is None
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_datetime_defaults(self):
|
||||
before = datetime.now(timezone.utc)
|
||||
w = Workflow()
|
||||
after = datetime.now(timezone.utc)
|
||||
assert before <= w.created_at <= after
|
||||
assert before <= w.updated_at <= after
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_to_mongo_doc(self):
|
||||
w = Workflow(name="W", description="d", user="u1")
|
||||
doc = w.to_mongo_doc()
|
||||
assert doc["name"] == "W"
|
||||
assert doc["description"] == "d"
|
||||
assert doc["user"] == "u1"
|
||||
assert "created_at" in doc
|
||||
assert "updated_at" in doc
|
||||
|
||||
|
||||
# ── WorkflowGraph ───────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestWorkflowGraph:
|
||||
@pytest.fixture
|
||||
def graph(self):
|
||||
workflow = Workflow(name="test")
|
||||
nodes = [
|
||||
WorkflowNode(id="start", workflow_id="w1", type=NodeType.START),
|
||||
WorkflowNode(id="agent1", workflow_id="w1", type=NodeType.AGENT),
|
||||
WorkflowNode(id="end", workflow_id="w1", type=NodeType.END),
|
||||
]
|
||||
edges = [
|
||||
WorkflowEdge(
|
||||
id="e1", workflow_id="w1", source="start", target="agent1"
|
||||
),
|
||||
WorkflowEdge(
|
||||
id="e2", workflow_id="w1", source="agent1", target="end"
|
||||
),
|
||||
]
|
||||
return WorkflowGraph(workflow=workflow, nodes=nodes, edges=edges)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_get_node_by_id_found(self, graph):
|
||||
node = graph.get_node_by_id("agent1")
|
||||
assert node is not None
|
||||
assert node.id == "agent1"
|
||||
assert node.type == NodeType.AGENT
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_get_node_by_id_not_found(self, graph):
|
||||
assert graph.get_node_by_id("nonexistent") is None
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_get_start_node(self, graph):
|
||||
start = graph.get_start_node()
|
||||
assert start is not None
|
||||
assert start.id == "start"
|
||||
assert start.type == NodeType.START
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_get_start_node_missing(self):
|
||||
g = WorkflowGraph(
|
||||
workflow=Workflow(),
|
||||
nodes=[
|
||||
WorkflowNode(id="a", workflow_id="w", type=NodeType.AGENT),
|
||||
],
|
||||
)
|
||||
assert g.get_start_node() is None
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_get_outgoing_edges(self, graph):
|
||||
edges = graph.get_outgoing_edges("start")
|
||||
assert len(edges) == 1
|
||||
assert edges[0].target_id == "agent1"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_get_outgoing_edges_none(self, graph):
|
||||
edges = graph.get_outgoing_edges("end")
|
||||
assert edges == []
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_empty_graph(self):
|
||||
g = WorkflowGraph(workflow=Workflow())
|
||||
assert g.nodes == []
|
||||
assert g.edges == []
|
||||
assert g.get_start_node() is None
|
||||
|
||||
|
||||
# ── NodeExecutionLog ─────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestNodeExecutionLog:
|
||||
@pytest.mark.unit
|
||||
def test_required_fields(self):
|
||||
now = datetime.now(timezone.utc)
|
||||
log = NodeExecutionLog(
|
||||
node_id="n1",
|
||||
node_type="agent",
|
||||
status=ExecutionStatus.RUNNING,
|
||||
started_at=now,
|
||||
)
|
||||
assert log.node_id == "n1"
|
||||
assert log.completed_at is None
|
||||
assert log.error is None
|
||||
assert log.state_snapshot == {}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_full_log(self):
|
||||
started = datetime.now(timezone.utc)
|
||||
completed = datetime.now(timezone.utc)
|
||||
log = NodeExecutionLog(
|
||||
node_id="n1",
|
||||
node_type="agent",
|
||||
status=ExecutionStatus.COMPLETED,
|
||||
started_at=started,
|
||||
completed_at=completed,
|
||||
error=None,
|
||||
state_snapshot={"key": "value"},
|
||||
)
|
||||
assert log.completed_at == completed
|
||||
assert log.state_snapshot == {"key": "value"}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_extra_forbidden(self):
|
||||
with pytest.raises(ValidationError):
|
||||
NodeExecutionLog(
|
||||
node_id="n",
|
||||
node_type="agent",
|
||||
status=ExecutionStatus.PENDING,
|
||||
started_at=datetime.now(timezone.utc),
|
||||
extra="no",
|
||||
)
|
||||
|
||||
|
||||
# ── WorkflowRunCreate / WorkflowRun ─────────────────────────────────────────
|
||||
|
||||
|
||||
class TestWorkflowRunCreate:
|
||||
@pytest.mark.unit
|
||||
def test_defaults(self):
|
||||
r = WorkflowRunCreate(workflow_id="w1")
|
||||
assert r.workflow_id == "w1"
|
||||
assert r.inputs == {}
|
||||
|
||||
|
||||
class TestWorkflowRun:
|
||||
@pytest.mark.unit
|
||||
def test_defaults(self):
|
||||
r = WorkflowRun(workflow_id="w1")
|
||||
assert r.status == ExecutionStatus.PENDING
|
||||
assert r.inputs == {}
|
||||
assert r.outputs == {}
|
||||
assert r.steps == []
|
||||
assert r.completed_at is None
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_objectid_conversion(self):
|
||||
oid = ObjectId()
|
||||
r = WorkflowRun(_id=oid, workflow_id="w1")
|
||||
assert r.id == str(oid)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_to_mongo_doc(self):
|
||||
now = datetime.now(timezone.utc)
|
||||
log = NodeExecutionLog(
|
||||
node_id="n1",
|
||||
node_type="agent",
|
||||
status=ExecutionStatus.COMPLETED,
|
||||
started_at=now,
|
||||
)
|
||||
r = WorkflowRun(
|
||||
workflow_id="w1",
|
||||
status=ExecutionStatus.RUNNING,
|
||||
inputs={"q": "hello"},
|
||||
outputs={"a": "world"},
|
||||
steps=[log],
|
||||
)
|
||||
doc = r.to_mongo_doc()
|
||||
assert doc["workflow_id"] == "w1"
|
||||
assert doc["status"] == "running"
|
||||
assert doc["inputs"] == {"q": "hello"}
|
||||
assert doc["outputs"] == {"a": "world"}
|
||||
assert len(doc["steps"]) == 1
|
||||
assert doc["steps"][0]["node_id"] == "n1"
|
||||
assert doc["completed_at"] is None
|
||||
240
tests/api/user/test_tasks.py
Normal file
240
tests/api/user/test_tasks.py
Normal file
@@ -0,0 +1,240 @@
|
||||
from datetime import timedelta
|
||||
from unittest.mock import ANY, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class TestIngestTask:
|
||||
@pytest.mark.unit
|
||||
@patch("application.api.user.tasks.ingest_worker")
|
||||
def test_calls_ingest_worker(self, mock_worker):
|
||||
from application.api.user.tasks import ingest
|
||||
|
||||
mock_worker.return_value = {"status": "ok"}
|
||||
|
||||
result = ingest("dir", ["pdf"], "job1", "user1", "/path", "file.pdf")
|
||||
|
||||
mock_worker.assert_called_once_with(
|
||||
ANY, "dir", ["pdf"], "job1", "/path", "file.pdf", "user1",
|
||||
file_name_map=None,
|
||||
)
|
||||
assert result == {"status": "ok"}
|
||||
|
||||
@pytest.mark.unit
|
||||
@patch("application.api.user.tasks.ingest_worker")
|
||||
def test_passes_file_name_map(self, mock_worker):
|
||||
from application.api.user.tasks import ingest
|
||||
|
||||
mock_worker.return_value = {"status": "ok"}
|
||||
name_map = {"a.pdf": "b.pdf"}
|
||||
|
||||
ingest("dir", ["pdf"], "job1", "user1", "/path", "file.pdf",
|
||||
file_name_map=name_map)
|
||||
|
||||
mock_worker.assert_called_once_with(
|
||||
ANY, "dir", ["pdf"], "job1", "/path", "file.pdf", "user1",
|
||||
file_name_map=name_map,
|
||||
)
|
||||
|
||||
|
||||
class TestIngestRemoteTask:
|
||||
@pytest.mark.unit
|
||||
@patch("application.api.user.tasks.remote_worker")
|
||||
def test_calls_remote_worker(self, mock_worker):
|
||||
from application.api.user.tasks import ingest_remote
|
||||
|
||||
mock_worker.return_value = {"status": "ok"}
|
||||
|
||||
result = ingest_remote({"url": "http://x"}, "job1", "user1", "web")
|
||||
|
||||
mock_worker.assert_called_once_with(
|
||||
ANY, {"url": "http://x"}, "job1", "user1", "web"
|
||||
)
|
||||
assert result == {"status": "ok"}
|
||||
|
||||
|
||||
class TestReingestSourceTask:
|
||||
@pytest.mark.unit
|
||||
@patch("application.worker.reingest_source_worker")
|
||||
def test_calls_reingest_worker(self, mock_worker):
|
||||
from application.api.user.tasks import reingest_source_task
|
||||
|
||||
mock_worker.return_value = {"status": "ok"}
|
||||
|
||||
result = reingest_source_task("source123", "user1")
|
||||
|
||||
mock_worker.assert_called_once_with(ANY, "source123", "user1")
|
||||
assert result == {"status": "ok"}
|
||||
|
||||
|
||||
class TestScheduleSyncsTask:
|
||||
@pytest.mark.unit
|
||||
@patch("application.api.user.tasks.sync_worker")
|
||||
def test_calls_sync_worker(self, mock_worker):
|
||||
from application.api.user.tasks import schedule_syncs
|
||||
|
||||
mock_worker.return_value = {"status": "ok"}
|
||||
|
||||
result = schedule_syncs("daily")
|
||||
|
||||
mock_worker.assert_called_once_with(ANY, "daily")
|
||||
assert result == {"status": "ok"}
|
||||
|
||||
|
||||
class TestSyncSourceTask:
|
||||
@pytest.mark.unit
|
||||
@patch("application.api.user.tasks.sync")
|
||||
def test_calls_sync(self, mock_sync):
|
||||
from application.api.user.tasks import sync_source
|
||||
|
||||
mock_sync.return_value = {"status": "ok"}
|
||||
|
||||
result = sync_source(
|
||||
{"data": 1}, "job1", "user1", "web", "daily", "classic", "doc1"
|
||||
)
|
||||
|
||||
mock_sync.assert_called_once_with(
|
||||
ANY, {"data": 1}, "job1", "user1", "web", "daily", "classic", "doc1"
|
||||
)
|
||||
assert result == {"status": "ok"}
|
||||
|
||||
|
||||
class TestStoreAttachmentTask:
|
||||
@pytest.mark.unit
|
||||
@patch("application.api.user.tasks.attachment_worker")
|
||||
def test_calls_attachment_worker(self, mock_worker):
|
||||
from application.api.user.tasks import store_attachment
|
||||
|
||||
mock_worker.return_value = {"status": "ok"}
|
||||
|
||||
result = store_attachment({"file": "info"}, "user1")
|
||||
|
||||
mock_worker.assert_called_once_with(ANY, {"file": "info"}, "user1")
|
||||
assert result == {"status": "ok"}
|
||||
|
||||
|
||||
class TestProcessAgentWebhookTask:
|
||||
@pytest.mark.unit
|
||||
@patch("application.api.user.tasks.agent_webhook_worker")
|
||||
def test_calls_agent_webhook_worker(self, mock_worker):
|
||||
from application.api.user.tasks import process_agent_webhook
|
||||
|
||||
mock_worker.return_value = {"status": "ok"}
|
||||
|
||||
result = process_agent_webhook("agent123", {"event": "test"})
|
||||
|
||||
mock_worker.assert_called_once_with(ANY, "agent123", {"event": "test"})
|
||||
assert result == {"status": "ok"}
|
||||
|
||||
|
||||
class TestIngestConnectorTask:
|
||||
@pytest.mark.unit
|
||||
@patch("application.worker.ingest_connector")
|
||||
def test_calls_ingest_connector_defaults(self, mock_worker):
|
||||
from application.api.user.tasks import ingest_connector_task
|
||||
|
||||
mock_worker.return_value = {"status": "ok"}
|
||||
|
||||
result = ingest_connector_task("job1", "user1", "gdrive")
|
||||
|
||||
mock_worker.assert_called_once_with(
|
||||
ANY,
|
||||
"job1",
|
||||
"user1",
|
||||
"gdrive",
|
||||
session_token=None,
|
||||
file_ids=None,
|
||||
folder_ids=None,
|
||||
recursive=True,
|
||||
retriever="classic",
|
||||
operation_mode="upload",
|
||||
doc_id=None,
|
||||
sync_frequency="never",
|
||||
)
|
||||
assert result == {"status": "ok"}
|
||||
|
||||
@pytest.mark.unit
|
||||
@patch("application.worker.ingest_connector")
|
||||
def test_calls_ingest_connector_custom(self, mock_worker):
|
||||
from application.api.user.tasks import ingest_connector_task
|
||||
|
||||
mock_worker.return_value = {"status": "ok"}
|
||||
|
||||
result = ingest_connector_task(
|
||||
"job1",
|
||||
"user1",
|
||||
"sharepoint",
|
||||
session_token="tok",
|
||||
file_ids=["f1"],
|
||||
folder_ids=["d1"],
|
||||
recursive=False,
|
||||
retriever="duckdb",
|
||||
operation_mode="sync",
|
||||
doc_id="doc1",
|
||||
sync_frequency="daily",
|
||||
)
|
||||
|
||||
mock_worker.assert_called_once_with(
|
||||
ANY,
|
||||
"job1",
|
||||
"user1",
|
||||
"sharepoint",
|
||||
session_token="tok",
|
||||
file_ids=["f1"],
|
||||
folder_ids=["d1"],
|
||||
recursive=False,
|
||||
retriever="duckdb",
|
||||
operation_mode="sync",
|
||||
doc_id="doc1",
|
||||
sync_frequency="daily",
|
||||
)
|
||||
assert result == {"status": "ok"}
|
||||
|
||||
|
||||
class TestSetupPeriodicTasks:
|
||||
@pytest.mark.unit
|
||||
def test_registers_periodic_tasks(self):
|
||||
from application.api.user.tasks import setup_periodic_tasks
|
||||
|
||||
sender = MagicMock()
|
||||
|
||||
setup_periodic_tasks(sender)
|
||||
|
||||
assert sender.add_periodic_task.call_count == 3
|
||||
|
||||
calls = sender.add_periodic_task.call_args_list
|
||||
|
||||
# daily
|
||||
assert calls[0][0][0] == timedelta(days=1)
|
||||
# weekly
|
||||
assert calls[1][0][0] == timedelta(weeks=1)
|
||||
# monthly
|
||||
assert calls[2][0][0] == timedelta(days=30)
|
||||
|
||||
|
||||
class TestMcpOauthTask:
|
||||
@pytest.mark.unit
|
||||
@patch("application.api.user.tasks.mcp_oauth")
|
||||
def test_calls_mcp_oauth(self, mock_worker):
|
||||
from application.api.user.tasks import mcp_oauth_task
|
||||
|
||||
mock_worker.return_value = {"url": "http://auth"}
|
||||
|
||||
result = mcp_oauth_task({"server": "mcp"}, "user1")
|
||||
|
||||
mock_worker.assert_called_once_with(ANY, {"server": "mcp"}, "user1")
|
||||
assert result == {"url": "http://auth"}
|
||||
|
||||
|
||||
class TestMcpOauthStatusTask:
|
||||
@pytest.mark.unit
|
||||
@patch("application.api.user.tasks.mcp_oauth_status")
|
||||
def test_calls_mcp_oauth_status(self, mock_worker):
|
||||
from application.api.user.tasks import mcp_oauth_status_task
|
||||
|
||||
mock_worker.return_value = {"status": "authorized"}
|
||||
|
||||
result = mcp_oauth_status_task("task123")
|
||||
|
||||
mock_worker.assert_called_once_with(ANY, "task123")
|
||||
assert result == {"status": "authorized"}
|
||||
382
tests/core/test_model_utils.py
Normal file
382
tests/core/test_model_utils.py
Normal file
@@ -0,0 +1,382 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from application.core.model_settings import (
|
||||
AvailableModel,
|
||||
ModelCapabilities,
|
||||
ModelProvider,
|
||||
ModelRegistry,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _reset_registry():
|
||||
"""Reset ModelRegistry singleton between tests."""
|
||||
ModelRegistry._instance = None
|
||||
ModelRegistry._initialized = False
|
||||
yield
|
||||
ModelRegistry._instance = None
|
||||
ModelRegistry._initialized = False
|
||||
|
||||
|
||||
def _make_model(
|
||||
model_id="test-model",
|
||||
provider=ModelProvider.OPENAI,
|
||||
display_name="Test Model",
|
||||
context_window=128000,
|
||||
supports_tools=True,
|
||||
supports_structured_output=False,
|
||||
supported_attachment_types=None,
|
||||
enabled=True,
|
||||
base_url=None,
|
||||
):
|
||||
return AvailableModel(
|
||||
id=model_id,
|
||||
provider=provider,
|
||||
display_name=display_name,
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=supports_tools,
|
||||
supports_structured_output=supports_structured_output,
|
||||
supported_attachment_types=supported_attachment_types or [],
|
||||
context_window=context_window,
|
||||
),
|
||||
enabled=enabled,
|
||||
base_url=base_url,
|
||||
)
|
||||
|
||||
|
||||
# ── get_api_key_for_provider ─────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestGetApiKeyForProvider:
|
||||
"""settings is lazily imported inside the function body, so we patch
|
||||
at application.core.settings.settings (the actual module attribute)."""
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_openai_key(self):
|
||||
with patch("application.core.settings.settings") as mock_settings:
|
||||
mock_settings.OPENAI_API_KEY = "sk-openai"
|
||||
mock_settings.API_KEY = "sk-fallback"
|
||||
mock_settings.OPEN_ROUTER_API_KEY = None
|
||||
mock_settings.NOVITA_API_KEY = None
|
||||
mock_settings.ANTHROPIC_API_KEY = None
|
||||
mock_settings.GOOGLE_API_KEY = None
|
||||
mock_settings.GROQ_API_KEY = None
|
||||
mock_settings.HUGGINGFACE_API_KEY = None
|
||||
|
||||
from application.core.model_utils import get_api_key_for_provider
|
||||
|
||||
assert get_api_key_for_provider("openai") == "sk-openai"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_anthropic_key(self):
|
||||
with patch("application.core.settings.settings") as mock_settings:
|
||||
mock_settings.ANTHROPIC_API_KEY = "sk-anthropic"
|
||||
mock_settings.API_KEY = "sk-fallback"
|
||||
|
||||
from application.core.model_utils import get_api_key_for_provider
|
||||
|
||||
assert get_api_key_for_provider("anthropic") == "sk-anthropic"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_google_key(self):
|
||||
with patch("application.core.settings.settings") as mock_settings:
|
||||
mock_settings.GOOGLE_API_KEY = "sk-google"
|
||||
mock_settings.API_KEY = "sk-fallback"
|
||||
|
||||
from application.core.model_utils import get_api_key_for_provider
|
||||
|
||||
assert get_api_key_for_provider("google") == "sk-google"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_groq_key(self):
|
||||
with patch("application.core.settings.settings") as mock_settings:
|
||||
mock_settings.GROQ_API_KEY = "sk-groq"
|
||||
mock_settings.API_KEY = "sk-fallback"
|
||||
|
||||
from application.core.model_utils import get_api_key_for_provider
|
||||
|
||||
assert get_api_key_for_provider("groq") == "sk-groq"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_openrouter_key(self):
|
||||
with patch("application.core.settings.settings") as mock_settings:
|
||||
mock_settings.OPEN_ROUTER_API_KEY = "sk-or"
|
||||
mock_settings.API_KEY = "sk-fallback"
|
||||
|
||||
from application.core.model_utils import get_api_key_for_provider
|
||||
|
||||
assert get_api_key_for_provider("openrouter") == "sk-or"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_novita_key(self):
|
||||
with patch("application.core.settings.settings") as mock_settings:
|
||||
mock_settings.NOVITA_API_KEY = "sk-novita"
|
||||
mock_settings.API_KEY = "sk-fallback"
|
||||
|
||||
from application.core.model_utils import get_api_key_for_provider
|
||||
|
||||
assert get_api_key_for_provider("novita") == "sk-novita"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_huggingface_key(self):
|
||||
with patch("application.core.settings.settings") as mock_settings:
|
||||
mock_settings.HUGGINGFACE_API_KEY = "hf-key"
|
||||
mock_settings.API_KEY = "sk-fallback"
|
||||
|
||||
from application.core.model_utils import get_api_key_for_provider
|
||||
|
||||
assert get_api_key_for_provider("huggingface") == "hf-key"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_docsgpt_returns_fallback(self):
|
||||
with patch("application.core.settings.settings") as mock_settings:
|
||||
mock_settings.API_KEY = "sk-fallback"
|
||||
|
||||
from application.core.model_utils import get_api_key_for_provider
|
||||
|
||||
assert get_api_key_for_provider("docsgpt") == "sk-fallback"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_llama_cpp_returns_fallback(self):
|
||||
with patch("application.core.settings.settings") as mock_settings:
|
||||
mock_settings.API_KEY = "sk-fallback"
|
||||
|
||||
from application.core.model_utils import get_api_key_for_provider
|
||||
|
||||
assert get_api_key_for_provider("llama.cpp") == "sk-fallback"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_unknown_provider_returns_fallback(self):
|
||||
with patch("application.core.settings.settings") as mock_settings:
|
||||
mock_settings.API_KEY = "sk-fallback"
|
||||
|
||||
from application.core.model_utils import get_api_key_for_provider
|
||||
|
||||
assert get_api_key_for_provider("unknown_provider") == "sk-fallback"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_azure_openai_key(self):
|
||||
with patch("application.core.settings.settings") as mock_settings:
|
||||
mock_settings.API_KEY = "sk-azure"
|
||||
|
||||
from application.core.model_utils import get_api_key_for_provider
|
||||
|
||||
assert get_api_key_for_provider("azure_openai") == "sk-azure"
|
||||
|
||||
|
||||
# ── get_all_available_models ─────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestGetAllAvailableModels:
|
||||
@pytest.mark.unit
|
||||
@patch("application.core.model_utils.ModelRegistry.get_instance")
|
||||
def test_returns_enabled_models_as_dict(self, mock_get_instance):
|
||||
model_a = _make_model("model-a", display_name="Model A")
|
||||
model_b = _make_model("model-b", display_name="Model B")
|
||||
mock_registry = MagicMock()
|
||||
mock_registry.get_enabled_models.return_value = [model_a, model_b]
|
||||
mock_get_instance.return_value = mock_registry
|
||||
|
||||
from application.core.model_utils import get_all_available_models
|
||||
|
||||
result = get_all_available_models()
|
||||
|
||||
assert "model-a" in result
|
||||
assert "model-b" in result
|
||||
assert result["model-a"]["display_name"] == "Model A"
|
||||
assert result["model-b"]["display_name"] == "Model B"
|
||||
|
||||
@pytest.mark.unit
|
||||
@patch("application.core.model_utils.ModelRegistry.get_instance")
|
||||
def test_empty_registry(self, mock_get_instance):
|
||||
mock_registry = MagicMock()
|
||||
mock_registry.get_enabled_models.return_value = []
|
||||
mock_get_instance.return_value = mock_registry
|
||||
|
||||
from application.core.model_utils import get_all_available_models
|
||||
|
||||
assert get_all_available_models() == {}
|
||||
|
||||
|
||||
# ── validate_model_id ────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestValidateModelId:
|
||||
@pytest.mark.unit
|
||||
@patch("application.core.model_utils.ModelRegistry.get_instance")
|
||||
def test_exists(self, mock_get_instance):
|
||||
mock_registry = MagicMock()
|
||||
mock_registry.model_exists.return_value = True
|
||||
mock_get_instance.return_value = mock_registry
|
||||
|
||||
from application.core.model_utils import validate_model_id
|
||||
|
||||
assert validate_model_id("gpt-4") is True
|
||||
|
||||
@pytest.mark.unit
|
||||
@patch("application.core.model_utils.ModelRegistry.get_instance")
|
||||
def test_not_exists(self, mock_get_instance):
|
||||
mock_registry = MagicMock()
|
||||
mock_registry.model_exists.return_value = False
|
||||
mock_get_instance.return_value = mock_registry
|
||||
|
||||
from application.core.model_utils import validate_model_id
|
||||
|
||||
assert validate_model_id("nonexistent") is False
|
||||
|
||||
|
||||
# ── get_model_capabilities ───────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestGetModelCapabilities:
|
||||
@pytest.mark.unit
|
||||
@patch("application.core.model_utils.ModelRegistry.get_instance")
|
||||
def test_model_found(self, mock_get_instance):
|
||||
model = _make_model(
|
||||
"gpt-4",
|
||||
context_window=8192,
|
||||
supports_tools=True,
|
||||
supports_structured_output=True,
|
||||
supported_attachment_types=["image/png"],
|
||||
)
|
||||
mock_registry = MagicMock()
|
||||
mock_registry.get_model.return_value = model
|
||||
mock_get_instance.return_value = mock_registry
|
||||
|
||||
from application.core.model_utils import get_model_capabilities
|
||||
|
||||
caps = get_model_capabilities("gpt-4")
|
||||
|
||||
assert caps is not None
|
||||
assert caps["supported_attachment_types"] == ["image/png"]
|
||||
assert caps["supports_tools"] is True
|
||||
assert caps["supports_structured_output"] is True
|
||||
assert caps["context_window"] == 8192
|
||||
|
||||
@pytest.mark.unit
|
||||
@patch("application.core.model_utils.ModelRegistry.get_instance")
|
||||
def test_model_not_found(self, mock_get_instance):
|
||||
mock_registry = MagicMock()
|
||||
mock_registry.get_model.return_value = None
|
||||
mock_get_instance.return_value = mock_registry
|
||||
|
||||
from application.core.model_utils import get_model_capabilities
|
||||
|
||||
assert get_model_capabilities("nonexistent") is None
|
||||
|
||||
|
||||
# ── get_default_model_id ─────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestGetDefaultModelId:
|
||||
@pytest.mark.unit
|
||||
@patch("application.core.model_utils.ModelRegistry.get_instance")
|
||||
def test_returns_default(self, mock_get_instance):
|
||||
mock_registry = MagicMock()
|
||||
mock_registry.default_model_id = "gpt-4"
|
||||
mock_get_instance.return_value = mock_registry
|
||||
|
||||
from application.core.model_utils import get_default_model_id
|
||||
|
||||
assert get_default_model_id() == "gpt-4"
|
||||
|
||||
|
||||
# ── get_provider_from_model_id ───────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestGetProviderFromModelId:
|
||||
@pytest.mark.unit
|
||||
@patch("application.core.model_utils.ModelRegistry.get_instance")
|
||||
def test_model_found(self, mock_get_instance):
|
||||
model = _make_model("gpt-4", provider=ModelProvider.OPENAI)
|
||||
mock_registry = MagicMock()
|
||||
mock_registry.get_model.return_value = model
|
||||
mock_get_instance.return_value = mock_registry
|
||||
|
||||
from application.core.model_utils import get_provider_from_model_id
|
||||
|
||||
assert get_provider_from_model_id("gpt-4") == "openai"
|
||||
|
||||
@pytest.mark.unit
|
||||
@patch("application.core.model_utils.ModelRegistry.get_instance")
|
||||
def test_model_not_found(self, mock_get_instance):
|
||||
mock_registry = MagicMock()
|
||||
mock_registry.get_model.return_value = None
|
||||
mock_get_instance.return_value = mock_registry
|
||||
|
||||
from application.core.model_utils import get_provider_from_model_id
|
||||
|
||||
assert get_provider_from_model_id("nonexistent") is None
|
||||
|
||||
|
||||
# ── get_token_limit ──────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestGetTokenLimit:
|
||||
@pytest.mark.unit
|
||||
@patch("application.core.model_utils.ModelRegistry.get_instance")
|
||||
def test_model_found(self, mock_get_instance):
|
||||
model = _make_model("gpt-4", context_window=8192)
|
||||
mock_registry = MagicMock()
|
||||
mock_registry.get_model.return_value = model
|
||||
mock_get_instance.return_value = mock_registry
|
||||
|
||||
from application.core.model_utils import get_token_limit
|
||||
|
||||
assert get_token_limit("gpt-4") == 8192
|
||||
|
||||
@pytest.mark.unit
|
||||
@patch("application.core.model_utils.ModelRegistry.get_instance")
|
||||
def test_model_not_found_returns_default(self, mock_get_instance):
|
||||
mock_registry = MagicMock()
|
||||
mock_registry.get_model.return_value = None
|
||||
mock_get_instance.return_value = mock_registry
|
||||
|
||||
with patch("application.core.settings.settings") as mock_settings:
|
||||
mock_settings.DEFAULT_LLM_TOKEN_LIMIT = 128000
|
||||
|
||||
from application.core.model_utils import get_token_limit
|
||||
|
||||
assert get_token_limit("nonexistent") == 128000
|
||||
|
||||
|
||||
# ── get_base_url_for_model ───────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestGetBaseUrlForModel:
|
||||
@pytest.mark.unit
|
||||
@patch("application.core.model_utils.ModelRegistry.get_instance")
|
||||
def test_model_with_base_url(self, mock_get_instance):
|
||||
model = _make_model("custom-model", base_url="http://localhost:8080")
|
||||
mock_registry = MagicMock()
|
||||
mock_registry.get_model.return_value = model
|
||||
mock_get_instance.return_value = mock_registry
|
||||
|
||||
from application.core.model_utils import get_base_url_for_model
|
||||
|
||||
assert get_base_url_for_model("custom-model") == "http://localhost:8080"
|
||||
|
||||
@pytest.mark.unit
|
||||
@patch("application.core.model_utils.ModelRegistry.get_instance")
|
||||
def test_model_without_base_url(self, mock_get_instance):
|
||||
model = _make_model("gpt-4", base_url=None)
|
||||
mock_registry = MagicMock()
|
||||
mock_registry.get_model.return_value = model
|
||||
mock_get_instance.return_value = mock_registry
|
||||
|
||||
from application.core.model_utils import get_base_url_for_model
|
||||
|
||||
assert get_base_url_for_model("gpt-4") is None
|
||||
|
||||
@pytest.mark.unit
|
||||
@patch("application.core.model_utils.ModelRegistry.get_instance")
|
||||
def test_model_not_found(self, mock_get_instance):
|
||||
mock_registry = MagicMock()
|
||||
mock_registry.get_model.return_value = None
|
||||
mock_get_instance.return_value = mock_registry
|
||||
|
||||
from application.core.model_utils import get_base_url_for_model
|
||||
|
||||
assert get_base_url_for_model("nonexistent") is None
|
||||
File diff suppressed because it is too large
Load Diff
204
tests/llm/test_base_llm.py
Normal file
204
tests/llm/test_base_llm.py
Normal file
@@ -0,0 +1,204 @@
|
||||
"""Unit tests for application/llm/base.py — BaseLLM.
|
||||
|
||||
Covers initialisation, static helpers, supports_* introspection,
|
||||
structured-output defaults, and attachment-type defaults.
|
||||
Fallback behaviour is covered separately in test_fallback.py.
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock, Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from application.llm.base import BaseLLM
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Concrete stub so we can instantiate the abstract base
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class StubLLM(BaseLLM):
|
||||
"""Minimal concrete BaseLLM for unit-testing non-abstract members."""
|
||||
|
||||
def _raw_gen(self, baseself, model, messages, stream, tools=None, **kw):
|
||||
return "raw_gen_result"
|
||||
|
||||
def _raw_gen_stream(self, baseself, model, messages, stream, tools=None, **kw):
|
||||
yield "chunk"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# __init__
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestBaseLLMInit:
|
||||
|
||||
def test_defaults(self):
|
||||
llm = StubLLM()
|
||||
assert llm.decoded_token is None
|
||||
assert llm.agent_id is None
|
||||
assert llm.model_id is None
|
||||
assert llm.base_url is None
|
||||
assert llm.token_usage == {"prompt_tokens": 0, "generated_tokens": 0}
|
||||
assert llm._backup_models == []
|
||||
assert llm._fallback_llm is None
|
||||
|
||||
def test_agent_id_cast_to_str(self):
|
||||
llm = StubLLM(agent_id=42)
|
||||
assert llm.agent_id == "42"
|
||||
|
||||
def test_agent_id_none_stays_none(self):
|
||||
llm = StubLLM(agent_id=None)
|
||||
assert llm.agent_id is None
|
||||
|
||||
def test_custom_params(self):
|
||||
token = {"sub": "u1"}
|
||||
llm = StubLLM(
|
||||
decoded_token=token,
|
||||
agent_id="abc",
|
||||
model_id="gpt-4",
|
||||
base_url="http://x",
|
||||
backup_models=["m1", "m2"],
|
||||
)
|
||||
assert llm.decoded_token is token
|
||||
assert llm.agent_id == "abc"
|
||||
assert llm.model_id == "gpt-4"
|
||||
assert llm.base_url == "http://x"
|
||||
assert llm._backup_models == ["m1", "m2"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _remove_null_values
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestRemoveNullValues:
|
||||
|
||||
def test_removes_none_values(self):
|
||||
result = BaseLLM._remove_null_values({"a": 1, "b": None, "c": "x"})
|
||||
assert result == {"a": 1, "c": "x"}
|
||||
|
||||
def test_keeps_falsy_non_none(self):
|
||||
result = BaseLLM._remove_null_values({"a": 0, "b": "", "c": False, "d": []})
|
||||
assert result == {"a": 0, "b": "", "c": False, "d": []}
|
||||
|
||||
def test_non_dict_passthrough(self):
|
||||
assert BaseLLM._remove_null_values("hello") == "hello"
|
||||
assert BaseLLM._remove_null_values(42) == 42
|
||||
assert BaseLLM._remove_null_values([1, 2]) == [1, 2]
|
||||
|
||||
def test_empty_dict(self):
|
||||
assert BaseLLM._remove_null_values({}) == {}
|
||||
|
||||
def test_all_none(self):
|
||||
assert BaseLLM._remove_null_values({"a": None, "b": None}) == {}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# supports_tools / _supports_tools
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestSupportsTools:
|
||||
|
||||
def test_supports_tools_true_when_callable(self):
|
||||
llm = StubLLM()
|
||||
assert llm.supports_tools() is True
|
||||
|
||||
def test_supports_tools_false_when_not_callable(self):
|
||||
llm = StubLLM()
|
||||
llm._supports_tools = "not_callable"
|
||||
assert llm.supports_tools() is False
|
||||
|
||||
def test_default_supports_tools_raises(self):
|
||||
llm = StubLLM()
|
||||
with pytest.raises(NotImplementedError):
|
||||
llm._supports_tools()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# supports_structured_output / _supports_structured_output
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestSupportsStructuredOutput:
|
||||
|
||||
def test_supports_structured_output_true(self):
|
||||
llm = StubLLM()
|
||||
assert llm.supports_structured_output() is True
|
||||
|
||||
def test_default_supports_structured_output_returns_false(self):
|
||||
llm = StubLLM()
|
||||
assert llm._supports_structured_output() is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# prepare_structured_output_format
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestPrepareStructuredOutputFormat:
|
||||
|
||||
def test_returns_none_by_default(self):
|
||||
llm = StubLLM()
|
||||
assert llm.prepare_structured_output_format({"type": "object"}) is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# get_supported_attachment_types
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestGetSupportedAttachmentTypes:
|
||||
|
||||
def test_returns_empty_list(self):
|
||||
llm = StubLLM()
|
||||
assert llm.get_supported_attachment_types() == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# fallback_llm — caching
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestFallbackLLMCaching:
|
||||
|
||||
def test_returns_cached_instance(self, monkeypatch):
|
||||
"""Once resolved, the same fallback instance is returned."""
|
||||
sentinel = StubLLM()
|
||||
llm = StubLLM()
|
||||
llm._fallback_llm = sentinel
|
||||
assert llm.fallback_llm is sentinel
|
||||
|
||||
def test_none_when_no_backup_and_no_global(self, monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
"application.llm.base.settings",
|
||||
MagicMock(FALLBACK_LLM_PROVIDER=None),
|
||||
)
|
||||
llm = StubLLM(backup_models=[])
|
||||
assert llm.fallback_llm is None
|
||||
|
||||
def test_global_fallback_init_failure_returns_none(self, monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
"application.llm.base.settings",
|
||||
MagicMock(
|
||||
FALLBACK_LLM_PROVIDER="openai",
|
||||
FALLBACK_LLM_NAME="gpt-4",
|
||||
FALLBACK_LLM_API_KEY="k",
|
||||
API_KEY="k",
|
||||
),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"application.llm.llm_creator.LLMCreator.create_llm",
|
||||
Mock(side_effect=RuntimeError("boom")),
|
||||
)
|
||||
llm = StubLLM(backup_models=[])
|
||||
assert llm.fallback_llm is None
|
||||
Reference in New Issue
Block a user