chore: handlers tests

This commit is contained in:
Alex
2026-03-30 12:53:50 +01:00
parent 3f6d6f15ea
commit ed0063aada
6 changed files with 3473 additions and 2 deletions

View File

@@ -1,4 +1,4 @@
from unittest.mock import Mock
from unittest.mock import Mock, patch
import pytest
from application.agents.classic_agent import ClassicAgent
@@ -56,6 +56,73 @@ class TestBaseAgentInitialization:
agent = ClassicAgent(**agent_base_params)
assert agent.user == "user123"
def test_dependency_injection_llm(self, agent_base_params, mock_llm_handler_creator):
"""When llm is provided, LLMCreator.create_llm is NOT called."""
injected_llm = Mock()
agent_base_params["llm"] = injected_llm
agent = ClassicAgent(**agent_base_params)
assert agent.llm is injected_llm
def test_dependency_injection_llm_handler(self, agent_base_params, mock_llm_creator):
"""When llm_handler is provided, LLMHandlerCreator is NOT called."""
injected_handler = Mock()
agent_base_params["llm_handler"] = injected_handler
agent = ClassicAgent(**agent_base_params)
assert agent.llm_handler is injected_handler
def test_dependency_injection_tool_executor(
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
):
"""When tool_executor is provided, a new one is NOT created."""
injected_executor = Mock()
injected_executor.tool_calls = []
agent_base_params["tool_executor"] = injected_executor
agent = ClassicAgent(**agent_base_params)
assert agent.tool_executor is injected_executor
def test_json_schema_normalized(
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
):
agent_base_params["json_schema"] = {"type": "object"}
agent = ClassicAgent(**agent_base_params)
assert agent.json_schema == {"type": "object"}
def test_json_schema_wrapped(
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
):
agent_base_params["json_schema"] = {"schema": {"type": "string"}}
agent = ClassicAgent(**agent_base_params)
assert agent.json_schema == {"type": "string"}
def test_json_schema_invalid_ignored(
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
):
agent_base_params["json_schema"] = {"bad": "no type"}
agent = ClassicAgent(**agent_base_params)
assert agent.json_schema is None
def test_retrieved_docs_defaults_to_empty(
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
):
agent = ClassicAgent(**agent_base_params)
assert agent.retrieved_docs == []
def test_attachments_defaults_to_empty(
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
):
agent_base_params["attachments"] = None
agent = ClassicAgent(**agent_base_params)
assert agent.attachments == []
def test_limited_token_mode_defaults(
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
):
agent = ClassicAgent(**agent_base_params)
assert agent.limited_token_mode is False
assert agent.limited_request_mode is False
assert agent.current_token_count == 0
assert agent.context_limit_reached is False
@pytest.mark.unit
class TestBaseAgentBuildMessages:
@@ -602,3 +669,656 @@ class TestBaseAgentHandleResponse:
assert len(results) == 2
assert results[0]["type"] == "tool_call"
assert results[1]["answer"] == "Final answer"
def test_handle_response_dict_event_passthrough(
self,
agent_base_params,
mock_llm_handler,
mock_llm_creator,
mock_llm_handler_creator,
log_context,
):
"""Dict events with 'type' key pass through without wrapping."""
def mock_process(*args):
yield {"type": "info", "data": {"message": "processing"}}
mock_llm_handler.process_message_flow = Mock(side_effect=mock_process)
agent = ClassicAgent(**agent_base_params)
response = Mock()
response.message = None
results = list(agent._handle_response(response, {}, [], log_context))
assert results == [{"type": "info", "data": {"message": "processing"}}]
def test_handle_response_message_object_from_handler(
self,
agent_base_params,
mock_llm_handler,
mock_llm_creator,
mock_llm_handler_creator,
log_context,
):
"""Response objects with .message.content from handler are unwrapped."""
event = Mock()
event.message = Mock()
event.message.content = "from handler"
def mock_process(*args):
yield event
mock_llm_handler.process_message_flow = Mock(side_effect=mock_process)
agent = ClassicAgent(**agent_base_params)
response = Mock()
response.message = None
results = list(agent._handle_response(response, {}, [], log_context))
assert results[0]["answer"] == "from handler"
# ---------------------------------------------------------------------------
# gen() — the @log_activity decorated entry point
# ---------------------------------------------------------------------------
@pytest.mark.unit
class TestBaseAgentGen:
def test_gen_delegates_to_gen_inner(
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
):
agent = ClassicAgent(**agent_base_params)
# ClassicAgent._gen_inner is abstract — we patch it
with patch.object(agent, "_gen_inner") as mock_inner:
mock_inner.return_value = iter([{"answer": "ok"}])
results = list(agent.gen("hello"))
assert any(r.get("answer") == "ok" for r in results)
# ---------------------------------------------------------------------------
# tool_calls property
# ---------------------------------------------------------------------------
@pytest.mark.unit
class TestBaseAgentToolCallsProperty:
def test_getter(
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
):
agent = ClassicAgent(**agent_base_params)
agent.tool_executor.tool_calls = ["a", "b"]
assert agent.tool_calls == ["a", "b"]
def test_setter(
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
):
agent = ClassicAgent(**agent_base_params)
agent.tool_calls = ["x"]
assert agent.tool_executor.tool_calls == ["x"]
# ---------------------------------------------------------------------------
# _calculate_current_context_tokens
# ---------------------------------------------------------------------------
@pytest.mark.unit
class TestCalculateContextTokens:
def test_delegates_to_token_counter(
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
):
agent = ClassicAgent(**agent_base_params)
messages = [{"role": "user", "content": "hello"}]
with patch(
"application.api.answer.services.compression.token_counter.TokenCounter"
) as MockTC:
MockTC.count_message_tokens.return_value = 42
result = agent._calculate_current_context_tokens(messages)
assert result == 42
MockTC.count_message_tokens.assert_called_once_with(messages)
# ---------------------------------------------------------------------------
# _check_context_limit
# ---------------------------------------------------------------------------
@pytest.mark.unit
class TestCheckContextLimit:
def _make_agent(self, agent_base_params, mock_llm_creator, mock_llm_handler_creator):
return ClassicAgent(**agent_base_params)
def test_below_threshold_returns_false(
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
):
agent = self._make_agent(
agent_base_params, mock_llm_creator, mock_llm_handler_creator
)
messages = [{"role": "user", "content": "hi"}]
with patch.object(agent, "_calculate_current_context_tokens", return_value=100):
with patch(
"application.core.model_utils.get_token_limit", return_value=10000
):
result = agent._check_context_limit(messages)
assert result is False
assert agent.current_token_count == 100
def test_at_threshold_returns_true(
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
):
agent = self._make_agent(
agent_base_params, mock_llm_creator, mock_llm_handler_creator
)
messages = [{"role": "user", "content": "hi"}]
# threshold = 10000 * 0.8 = 8000; tokens = 8001 → True
with patch.object(agent, "_calculate_current_context_tokens", return_value=8001):
with patch(
"application.core.model_utils.get_token_limit", return_value=10000
):
result = agent._check_context_limit(messages)
assert result is True
def test_error_returns_false(
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
):
agent = self._make_agent(
agent_base_params, mock_llm_creator, mock_llm_handler_creator
)
with patch.object(
agent,
"_calculate_current_context_tokens",
side_effect=RuntimeError("boom"),
):
result = agent._check_context_limit([])
assert result is False
# ---------------------------------------------------------------------------
# _validate_context_size
# ---------------------------------------------------------------------------
@pytest.mark.unit
class TestValidateContextSize:
def test_at_limit_logs_warning(
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
):
agent = ClassicAgent(**agent_base_params)
with patch.object(agent, "_calculate_current_context_tokens", return_value=10000):
with patch(
"application.core.model_utils.get_token_limit", return_value=10000
):
# Should not raise
agent._validate_context_size([{"role": "user", "content": "x"}])
assert agent.current_token_count == 10000
def test_below_threshold_no_warning(
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
):
agent = ClassicAgent(**agent_base_params)
with patch.object(agent, "_calculate_current_context_tokens", return_value=100):
with patch(
"application.core.model_utils.get_token_limit", return_value=10000
):
agent._validate_context_size([])
assert agent.current_token_count == 100
def test_approaching_threshold(
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
):
agent = ClassicAgent(**agent_base_params)
# 8500 / 10000 = 85% → above 80% threshold but below 100%
with patch.object(agent, "_calculate_current_context_tokens", return_value=8500):
with patch(
"application.core.model_utils.get_token_limit", return_value=10000
):
agent._validate_context_size([])
assert agent.current_token_count == 8500
# ---------------------------------------------------------------------------
# _truncate_text_middle
# ---------------------------------------------------------------------------
@pytest.mark.unit
class TestTruncateTextMiddle:
def test_short_text_unchanged(
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
):
agent = ClassicAgent(**agent_base_params)
with patch("application.utils.num_tokens_from_string", return_value=5):
result = agent._truncate_text_middle("short", max_tokens=100)
assert result == "short"
def test_long_text_truncated(
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
):
agent = ClassicAgent(**agent_base_params)
long_text = "A" * 1000
def fake_tokens(text):
return len(text) // 4
with patch("application.utils.num_tokens_from_string", side_effect=fake_tokens):
result = agent._truncate_text_middle(long_text, max_tokens=50)
assert "[... content truncated to fit context limit ...]" in result
assert len(result) < len(long_text)
def test_zero_target_returns_empty(
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
):
agent = ClassicAgent(**agent_base_params)
with patch("application.utils.num_tokens_from_string", return_value=100):
result = agent._truncate_text_middle("some text", max_tokens=0)
assert result == ""
# ---------------------------------------------------------------------------
# _truncate_history_to_fit
# ---------------------------------------------------------------------------
@pytest.mark.unit
class TestTruncateHistoryToFit:
def _make_agent(self, agent_base_params, mock_llm_creator, mock_llm_handler_creator):
return ClassicAgent(**agent_base_params)
def test_empty_history(
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
):
agent = self._make_agent(
agent_base_params, mock_llm_creator, mock_llm_handler_creator
)
assert agent._truncate_history_to_fit([], 100) == []
def test_zero_budget(
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
):
agent = self._make_agent(
agent_base_params, mock_llm_creator, mock_llm_handler_creator
)
history = [{"prompt": "a", "response": "b"}]
assert agent._truncate_history_to_fit(history, 0) == []
def test_fits_all(
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
):
agent = self._make_agent(
agent_base_params, mock_llm_creator, mock_llm_handler_creator
)
history = [
{"prompt": "q1", "response": "a1"},
{"prompt": "q2", "response": "a2"},
]
with patch("application.utils.num_tokens_from_string", return_value=5):
result = agent._truncate_history_to_fit(history, 10000)
assert len(result) == 2
def test_partial_fit_keeps_most_recent(
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
):
agent = self._make_agent(
agent_base_params, mock_llm_creator, mock_llm_handler_creator
)
history = [
{"prompt": "old", "response": "old_ans"},
{"prompt": "new", "response": "new_ans"},
]
# Each message = 10 tokens (prompt + response), budget = 15 → only 1 fits
with patch("application.utils.num_tokens_from_string", return_value=5):
result = agent._truncate_history_to_fit(history, 15)
assert len(result) == 1
assert result[0]["prompt"] == "new"
def test_history_with_tool_calls(
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
):
agent = self._make_agent(
agent_base_params, mock_llm_creator, mock_llm_handler_creator
)
history = [
{
"prompt": "q",
"response": "a",
"tool_calls": [
{
"tool_name": "t",
"action_name": "act",
"arguments": "{}",
"result": "ok",
}
],
}
]
with patch("application.utils.num_tokens_from_string", return_value=3):
result = agent._truncate_history_to_fit(history, 100)
assert len(result) == 1
# ---------------------------------------------------------------------------
# _build_messages — compressed_summary and query truncation
# ---------------------------------------------------------------------------
@pytest.mark.unit
class TestBuildMessagesAdvanced:
def test_compressed_summary_appended(
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
):
agent_base_params["compressed_summary"] = "Previous conversation summary"
agent = ClassicAgent(**agent_base_params)
with patch(
"application.core.model_utils.get_token_limit", return_value=100000
), patch("application.utils.num_tokens_from_string", return_value=10):
messages = agent._build_messages("System prompt", "query")
system_content = messages[0]["content"]
assert "Previous conversation summary" in system_content
def test_query_truncated_when_too_large(
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
):
agent = ClassicAgent(**agent_base_params)
call_count = {"n": 0}
def fake_tokens(text):
call_count["n"] += 1
return len(text)
with patch(
"application.core.model_utils.get_token_limit", return_value=200
), patch("application.utils.num_tokens_from_string", side_effect=fake_tokens):
with patch.object(agent, "_truncate_text_middle", return_value="truncated"):
with patch.object(agent, "_truncate_history_to_fit", return_value=[]):
messages = agent._build_messages("sys", "A" * 500)
# The method should have been called for truncation
assert messages[-1]["role"] == "user"
def test_build_messages_with_tool_call_missing_call_id(
self, agent_base_params, mock_llm_creator, mock_llm_handler_creator
):
"""Tool calls without call_id get a generated UUID."""
history = [
{
"tool_calls": [
{
"action_name": "search",
"arguments": "{}",
"result": "found",
}
]
}
]
agent_base_params["chat_history"] = history
agent = ClassicAgent(**agent_base_params)
with patch(
"application.core.model_utils.get_token_limit", return_value=100000
), patch("application.utils.num_tokens_from_string", return_value=5):
messages = agent._build_messages("sys", "q")
tool_msgs = [m for m in messages if m["role"] == "tool"]
assert len(tool_msgs) == 1
# ---------------------------------------------------------------------------
# _llm_gen — edge cases
# ---------------------------------------------------------------------------
@pytest.mark.unit
class TestLLMGenAdvanced:
def test_llm_gen_with_attachments(
self,
agent_base_params,
mock_llm,
mock_llm_creator,
mock_llm_handler_creator,
):
agent_base_params["attachments"] = [{"id": "att1", "mime_type": "image/png"}]
agent = ClassicAgent(**agent_base_params)
messages = [{"role": "user", "content": "test"}]
agent._llm_gen(messages)
call_kwargs = mock_llm.gen_stream.call_args[1]
assert "_usage_attachments" in call_kwargs
def test_llm_gen_without_log_context(
self,
agent_base_params,
mock_llm,
mock_llm_creator,
mock_llm_handler_creator,
):
agent = ClassicAgent(**agent_base_params)
messages = [{"role": "user", "content": "test"}]
# Should not raise even without log_context
agent._llm_gen(messages, log_context=None)
mock_llm.gen_stream.assert_called_once()
def test_llm_gen_google_structured_output(
self,
agent_base_params,
mock_llm,
mock_llm_creator,
mock_llm_handler_creator,
log_context,
):
mock_llm._supports_structured_output = Mock(return_value=True)
mock_llm.prepare_structured_output_format = Mock(
return_value={"schema": "test"}
)
agent_base_params["json_schema"] = {"type": "object"}
agent_base_params["llm_name"] = "google"
agent = ClassicAgent(**agent_base_params)
messages = [{"role": "user", "content": "test"}]
agent._llm_gen(messages, log_context)
call_kwargs = mock_llm.gen_stream.call_args[1]
assert "response_schema" in call_kwargs
def test_llm_gen_no_tools_when_unsupported(
self,
agent_base_params,
mock_llm,
mock_llm_creator,
mock_llm_handler_creator,
):
mock_llm._supports_tools = False
agent = ClassicAgent(**agent_base_params)
agent.tools = [{"type": "function", "function": {"name": "test"}}]
messages = [{"role": "user", "content": "test"}]
agent._llm_gen(messages)
call_kwargs = mock_llm.gen_stream.call_args[1]
assert "tools" not in call_kwargs
def test_llm_gen_no_structured_output_when_unsupported(
self,
agent_base_params,
mock_llm,
mock_llm_creator,
mock_llm_handler_creator,
):
mock_llm._supports_structured_output = Mock(return_value=False)
agent_base_params["json_schema"] = {"type": "object"}
agent = ClassicAgent(**agent_base_params)
messages = [{"role": "user", "content": "test"}]
agent._llm_gen(messages)
call_kwargs = mock_llm.gen_stream.call_args[1]
assert "response_format" not in call_kwargs
assert "response_schema" not in call_kwargs
def test_llm_gen_no_format_when_prepare_returns_none(
self,
agent_base_params,
mock_llm,
mock_llm_creator,
mock_llm_handler_creator,
):
mock_llm._supports_structured_output = Mock(return_value=True)
mock_llm.prepare_structured_output_format = Mock(return_value=None)
agent_base_params["json_schema"] = {"type": "object"}
agent_base_params["llm_name"] = "openai"
agent = ClassicAgent(**agent_base_params)
messages = [{"role": "user", "content": "test"}]
agent._llm_gen(messages)
call_kwargs = mock_llm.gen_stream.call_args[1]
assert "response_format" not in call_kwargs
# ---------------------------------------------------------------------------
# _llm_handler
# ---------------------------------------------------------------------------
@pytest.mark.unit
class TestLLMHandlerMethod:
def test_delegates_to_handler(
self,
agent_base_params,
mock_llm_handler,
mock_llm_creator,
mock_llm_handler_creator,
log_context,
):
mock_llm_handler.process_message_flow = Mock(return_value="result")
agent = ClassicAgent(**agent_base_params)
resp = Mock()
result = agent._llm_handler(resp, {}, [], log_context)
mock_llm_handler.process_message_flow.assert_called_once()
assert result == "result"
assert len(log_context.stacks) == 1
assert log_context.stacks[0]["component"] == "llm_handler"
def test_without_log_context(
self,
agent_base_params,
mock_llm_handler,
mock_llm_creator,
mock_llm_handler_creator,
):
mock_llm_handler.process_message_flow = Mock(return_value="r")
agent = ClassicAgent(**agent_base_params)
result = agent._llm_handler(Mock(), {}, [], log_context=None)
assert result == "r"
# ---------------------------------------------------------------------------
# _handle_response — structured output on all code paths
# ---------------------------------------------------------------------------
@pytest.mark.unit
class TestHandleResponseStructuredAllPaths:
def test_message_response_with_structured_output(
self,
agent_base_params,
mock_llm,
mock_llm_creator,
mock_llm_handler_creator,
log_context,
):
"""Structured output on the message.content early-return path."""
mock_llm._supports_structured_output = Mock(return_value=True)
agent_base_params["json_schema"] = {"type": "object"}
agent = ClassicAgent(**agent_base_params)
response = Mock()
response.message = Mock()
response.message.content = "structured msg"
results = list(agent._handle_response(response, {}, [], log_context))
assert results[0]["structured"] is True
assert results[0]["schema"] == {"type": "object"}
assert results[0]["answer"] == "structured msg"
def test_handler_string_event_with_structured_output(
self,
agent_base_params,
mock_llm,
mock_llm_handler,
mock_llm_creator,
mock_llm_handler_creator,
log_context,
):
"""Structured output on string events from the handler."""
mock_llm._supports_structured_output = Mock(return_value=True)
agent_base_params["json_schema"] = {"type": "array"}
def mock_process(*args):
yield "handler string"
mock_llm_handler.process_message_flow = Mock(side_effect=mock_process)
agent = ClassicAgent(**agent_base_params)
response = Mock()
response.message = None
results = list(agent._handle_response(response, {}, [], log_context))
assert results[0]["structured"] is True
assert results[0]["schema"] == {"type": "array"}
def test_handler_message_event_with_structured_output(
self,
agent_base_params,
mock_llm,
mock_llm_handler,
mock_llm_creator,
mock_llm_handler_creator,
log_context,
):
"""Structured output on message-object events from the handler."""
mock_llm._supports_structured_output = Mock(return_value=True)
agent_base_params["json_schema"] = {"type": "number"}
event = Mock()
event.message = Mock()
event.message.content = "from handler msg"
def mock_process(*args):
yield event
mock_llm_handler.process_message_flow = Mock(side_effect=mock_process)
agent = ClassicAgent(**agent_base_params)
response = Mock()
response.message = None
results = list(agent._handle_response(response, {}, [], log_context))
assert results[0]["structured"] is True
assert results[0]["schema"] == {"type": "number"}
assert results[0]["answer"] == "from handler msg"

View File

@@ -0,0 +1,556 @@
from datetime import datetime, timezone
import pytest
from bson import ObjectId
from pydantic import ValidationError
from application.agents.workflows.schemas import (
AgentNodeConfig,
AgentType,
ConditionCase,
ConditionNodeConfig,
ExecutionStatus,
NodeExecutionLog,
NodeType,
Position,
StateOperation,
Workflow,
WorkflowCreate,
WorkflowEdge,
WorkflowEdgeCreate,
WorkflowGraph,
WorkflowNode,
WorkflowNodeCreate,
WorkflowRun,
WorkflowRunCreate,
)
# ── Enum tests ───────────────────────────────────────────────────────────────
class TestNodeType:
@pytest.mark.unit
def test_values(self):
assert NodeType.START == "start"
assert NodeType.END == "end"
assert NodeType.AGENT == "agent"
assert NodeType.NOTE == "note"
assert NodeType.STATE == "state"
assert NodeType.CONDITION == "condition"
@pytest.mark.unit
def test_all_members(self):
assert set(NodeType) == {
NodeType.START,
NodeType.END,
NodeType.AGENT,
NodeType.NOTE,
NodeType.STATE,
NodeType.CONDITION,
}
class TestAgentType:
@pytest.mark.unit
def test_values(self):
assert AgentType.CLASSIC == "classic"
assert AgentType.REACT == "react"
assert AgentType.AGENTIC == "agentic"
assert AgentType.RESEARCH == "research"
class TestExecutionStatus:
@pytest.mark.unit
def test_values(self):
assert ExecutionStatus.PENDING == "pending"
assert ExecutionStatus.RUNNING == "running"
assert ExecutionStatus.COMPLETED == "completed"
assert ExecutionStatus.FAILED == "failed"
# ── Position ─────────────────────────────────────────────────────────────────
class TestPosition:
@pytest.mark.unit
def test_defaults(self):
p = Position()
assert p.x == 0.0
assert p.y == 0.0
@pytest.mark.unit
def test_custom_values(self):
p = Position(x=10.5, y=-3.2)
assert p.x == 10.5
assert p.y == -3.2
@pytest.mark.unit
def test_extra_fields_forbidden(self):
with pytest.raises(ValidationError):
Position(x=0, y=0, z=1)
# ── AgentNodeConfig ──────────────────────────────────────────────────────────
class TestAgentNodeConfig:
@pytest.mark.unit
def test_defaults(self):
c = AgentNodeConfig()
assert c.agent_type == AgentType.CLASSIC
assert c.llm_name is None
assert c.system_prompt == "You are a helpful assistant."
assert c.prompt_template == ""
assert c.output_variable is None
assert c.stream_to_user is True
assert c.tools == []
assert c.sources == []
assert c.chunks == "2"
assert c.retriever == ""
assert c.model_id is None
assert c.json_schema is None
@pytest.mark.unit
def test_custom_values(self):
c = AgentNodeConfig(
agent_type=AgentType.REACT,
llm_name="gpt-4",
tools=["search"],
sources=["src1"],
chunks="5",
model_id="m1",
json_schema={"type": "object"},
)
assert c.agent_type == AgentType.REACT
assert c.llm_name == "gpt-4"
assert c.tools == ["search"]
assert c.sources == ["src1"]
assert c.chunks == "5"
assert c.model_id == "m1"
assert c.json_schema == {"type": "object"}
@pytest.mark.unit
def test_extra_fields_allowed(self):
c = AgentNodeConfig(custom_field="value")
assert c.custom_field == "value"
# ── ConditionCase / ConditionNodeConfig ──────────────────────────────────────
class TestConditionCase:
@pytest.mark.unit
def test_alias(self):
c = ConditionCase(expression="x > 1", sourceHandle="handle-1")
assert c.source_handle == "handle-1"
@pytest.mark.unit
def test_defaults(self):
c = ConditionCase(sourceHandle="h")
assert c.name is None
assert c.expression == ""
@pytest.mark.unit
def test_extra_forbidden(self):
with pytest.raises(ValidationError):
ConditionCase(sourceHandle="h", extra="nope")
class TestConditionNodeConfig:
@pytest.mark.unit
def test_defaults(self):
c = ConditionNodeConfig()
assert c.mode == "simple"
assert c.cases == []
@pytest.mark.unit
def test_with_cases(self):
c = ConditionNodeConfig(
mode="advanced",
cases=[{"expression": "x > 1", "sourceHandle": "h1"}],
)
assert c.mode == "advanced"
assert len(c.cases) == 1
assert c.cases[0].source_handle == "h1"
# ── StateOperation ───────────────────────────────────────────────────────────
class TestStateOperation:
@pytest.mark.unit
def test_defaults(self):
s = StateOperation()
assert s.expression == ""
assert s.target_variable == ""
@pytest.mark.unit
def test_extra_forbidden(self):
with pytest.raises(ValidationError):
StateOperation(expression="a", target_variable="b", extra="no")
# ── WorkflowEdgeCreate / WorkflowEdge ───────────────────────────────────────
class TestWorkflowEdgeCreate:
@pytest.mark.unit
def test_aliases(self):
e = WorkflowEdgeCreate(
id="e1",
workflow_id="w1",
source="n1",
target="n2",
sourceHandle="sh",
targetHandle="th",
)
assert e.source_id == "n1"
assert e.target_id == "n2"
assert e.source_handle == "sh"
assert e.target_handle == "th"
@pytest.mark.unit
def test_optional_handles_default_none(self):
e = WorkflowEdgeCreate(id="e1", workflow_id="w1", source="n1", target="n2")
assert e.source_handle is None
assert e.target_handle is None
class TestWorkflowEdge:
@pytest.mark.unit
def test_objectid_conversion(self):
oid = ObjectId()
e = WorkflowEdge(
_id=oid,
id="e1",
workflow_id="w1",
source="n1",
target="n2",
)
assert e.mongo_id == str(oid)
@pytest.mark.unit
def test_string_id_passthrough(self):
e = WorkflowEdge(
_id="string-id",
id="e1",
workflow_id="w1",
source="n1",
target="n2",
)
assert e.mongo_id == "string-id"
@pytest.mark.unit
def test_none_id(self):
e = WorkflowEdge(id="e1", workflow_id="w1", source="n1", target="n2")
assert e.mongo_id is None
@pytest.mark.unit
def test_to_mongo_doc(self):
e = WorkflowEdge(
id="e1",
workflow_id="w1",
source="n1",
target="n2",
sourceHandle="sh",
targetHandle="th",
)
doc = e.to_mongo_doc()
assert doc == {
"id": "e1",
"workflow_id": "w1",
"source_id": "n1",
"target_id": "n2",
"source_handle": "sh",
"target_handle": "th",
}
# ── WorkflowNodeCreate / WorkflowNode ───────────────────────────────────────
class TestWorkflowNodeCreate:
@pytest.mark.unit
def test_defaults(self):
n = WorkflowNodeCreate(id="n1", workflow_id="w1", type=NodeType.AGENT)
assert n.title == "Node"
assert n.description is None
assert n.position.x == 0.0
assert n.position.y == 0.0
assert n.config == {}
@pytest.mark.unit
def test_position_from_dict(self):
n = WorkflowNodeCreate(
id="n1",
workflow_id="w1",
type=NodeType.START,
position={"x": 100, "y": 200},
)
assert isinstance(n.position, Position)
assert n.position.x == 100
assert n.position.y == 200
@pytest.mark.unit
def test_position_from_position_object(self):
pos = Position(x=5, y=10)
n = WorkflowNodeCreate(
id="n1", workflow_id="w1", type=NodeType.END, position=pos
)
assert n.position is pos
class TestWorkflowNode:
@pytest.mark.unit
def test_objectid_conversion(self):
oid = ObjectId()
n = WorkflowNode(
_id=oid, id="n1", workflow_id="w1", type=NodeType.AGENT
)
assert n.mongo_id == str(oid)
@pytest.mark.unit
def test_to_mongo_doc(self):
n = WorkflowNode(
id="n1",
workflow_id="w1",
type=NodeType.AGENT,
title="My Node",
description="desc",
position={"x": 10, "y": 20},
config={"key": "val"},
)
doc = n.to_mongo_doc()
assert doc == {
"id": "n1",
"workflow_id": "w1",
"type": "agent",
"title": "My Node",
"description": "desc",
"position": {"x": 10.0, "y": 20.0},
"config": {"key": "val"},
}
# ── WorkflowCreate / Workflow ───────────────────────────────────────────────
class TestWorkflowCreate:
@pytest.mark.unit
def test_defaults(self):
w = WorkflowCreate()
assert w.name == "New Workflow"
assert w.description is None
assert w.user is None
@pytest.mark.unit
def test_custom_values(self):
w = WorkflowCreate(name="Test", description="d", user="u1")
assert w.name == "Test"
assert w.description == "d"
assert w.user == "u1"
class TestWorkflow:
@pytest.mark.unit
def test_objectid_conversion(self):
oid = ObjectId()
w = Workflow(_id=oid)
assert w.id == str(oid)
@pytest.mark.unit
def test_string_id(self):
w = Workflow(_id="abc")
assert w.id == "abc"
@pytest.mark.unit
def test_none_id(self):
w = Workflow()
assert w.id is None
@pytest.mark.unit
def test_datetime_defaults(self):
before = datetime.now(timezone.utc)
w = Workflow()
after = datetime.now(timezone.utc)
assert before <= w.created_at <= after
assert before <= w.updated_at <= after
@pytest.mark.unit
def test_to_mongo_doc(self):
w = Workflow(name="W", description="d", user="u1")
doc = w.to_mongo_doc()
assert doc["name"] == "W"
assert doc["description"] == "d"
assert doc["user"] == "u1"
assert "created_at" in doc
assert "updated_at" in doc
# ── WorkflowGraph ───────────────────────────────────────────────────────────
class TestWorkflowGraph:
@pytest.fixture
def graph(self):
workflow = Workflow(name="test")
nodes = [
WorkflowNode(id="start", workflow_id="w1", type=NodeType.START),
WorkflowNode(id="agent1", workflow_id="w1", type=NodeType.AGENT),
WorkflowNode(id="end", workflow_id="w1", type=NodeType.END),
]
edges = [
WorkflowEdge(
id="e1", workflow_id="w1", source="start", target="agent1"
),
WorkflowEdge(
id="e2", workflow_id="w1", source="agent1", target="end"
),
]
return WorkflowGraph(workflow=workflow, nodes=nodes, edges=edges)
@pytest.mark.unit
def test_get_node_by_id_found(self, graph):
node = graph.get_node_by_id("agent1")
assert node is not None
assert node.id == "agent1"
assert node.type == NodeType.AGENT
@pytest.mark.unit
def test_get_node_by_id_not_found(self, graph):
assert graph.get_node_by_id("nonexistent") is None
@pytest.mark.unit
def test_get_start_node(self, graph):
start = graph.get_start_node()
assert start is not None
assert start.id == "start"
assert start.type == NodeType.START
@pytest.mark.unit
def test_get_start_node_missing(self):
g = WorkflowGraph(
workflow=Workflow(),
nodes=[
WorkflowNode(id="a", workflow_id="w", type=NodeType.AGENT),
],
)
assert g.get_start_node() is None
@pytest.mark.unit
def test_get_outgoing_edges(self, graph):
edges = graph.get_outgoing_edges("start")
assert len(edges) == 1
assert edges[0].target_id == "agent1"
@pytest.mark.unit
def test_get_outgoing_edges_none(self, graph):
edges = graph.get_outgoing_edges("end")
assert edges == []
@pytest.mark.unit
def test_empty_graph(self):
g = WorkflowGraph(workflow=Workflow())
assert g.nodes == []
assert g.edges == []
assert g.get_start_node() is None
# ── NodeExecutionLog ─────────────────────────────────────────────────────────
class TestNodeExecutionLog:
@pytest.mark.unit
def test_required_fields(self):
now = datetime.now(timezone.utc)
log = NodeExecutionLog(
node_id="n1",
node_type="agent",
status=ExecutionStatus.RUNNING,
started_at=now,
)
assert log.node_id == "n1"
assert log.completed_at is None
assert log.error is None
assert log.state_snapshot == {}
@pytest.mark.unit
def test_full_log(self):
started = datetime.now(timezone.utc)
completed = datetime.now(timezone.utc)
log = NodeExecutionLog(
node_id="n1",
node_type="agent",
status=ExecutionStatus.COMPLETED,
started_at=started,
completed_at=completed,
error=None,
state_snapshot={"key": "value"},
)
assert log.completed_at == completed
assert log.state_snapshot == {"key": "value"}
@pytest.mark.unit
def test_extra_forbidden(self):
with pytest.raises(ValidationError):
NodeExecutionLog(
node_id="n",
node_type="agent",
status=ExecutionStatus.PENDING,
started_at=datetime.now(timezone.utc),
extra="no",
)
# ── WorkflowRunCreate / WorkflowRun ─────────────────────────────────────────
class TestWorkflowRunCreate:
@pytest.mark.unit
def test_defaults(self):
r = WorkflowRunCreate(workflow_id="w1")
assert r.workflow_id == "w1"
assert r.inputs == {}
class TestWorkflowRun:
@pytest.mark.unit
def test_defaults(self):
r = WorkflowRun(workflow_id="w1")
assert r.status == ExecutionStatus.PENDING
assert r.inputs == {}
assert r.outputs == {}
assert r.steps == []
assert r.completed_at is None
@pytest.mark.unit
def test_objectid_conversion(self):
oid = ObjectId()
r = WorkflowRun(_id=oid, workflow_id="w1")
assert r.id == str(oid)
@pytest.mark.unit
def test_to_mongo_doc(self):
now = datetime.now(timezone.utc)
log = NodeExecutionLog(
node_id="n1",
node_type="agent",
status=ExecutionStatus.COMPLETED,
started_at=now,
)
r = WorkflowRun(
workflow_id="w1",
status=ExecutionStatus.RUNNING,
inputs={"q": "hello"},
outputs={"a": "world"},
steps=[log],
)
doc = r.to_mongo_doc()
assert doc["workflow_id"] == "w1"
assert doc["status"] == "running"
assert doc["inputs"] == {"q": "hello"}
assert doc["outputs"] == {"a": "world"}
assert len(doc["steps"]) == 1
assert doc["steps"][0]["node_id"] == "n1"
assert doc["completed_at"] is None

View File

@@ -0,0 +1,240 @@
from datetime import timedelta
from unittest.mock import ANY, MagicMock, patch
import pytest
class TestIngestTask:
@pytest.mark.unit
@patch("application.api.user.tasks.ingest_worker")
def test_calls_ingest_worker(self, mock_worker):
from application.api.user.tasks import ingest
mock_worker.return_value = {"status": "ok"}
result = ingest("dir", ["pdf"], "job1", "user1", "/path", "file.pdf")
mock_worker.assert_called_once_with(
ANY, "dir", ["pdf"], "job1", "/path", "file.pdf", "user1",
file_name_map=None,
)
assert result == {"status": "ok"}
@pytest.mark.unit
@patch("application.api.user.tasks.ingest_worker")
def test_passes_file_name_map(self, mock_worker):
from application.api.user.tasks import ingest
mock_worker.return_value = {"status": "ok"}
name_map = {"a.pdf": "b.pdf"}
ingest("dir", ["pdf"], "job1", "user1", "/path", "file.pdf",
file_name_map=name_map)
mock_worker.assert_called_once_with(
ANY, "dir", ["pdf"], "job1", "/path", "file.pdf", "user1",
file_name_map=name_map,
)
class TestIngestRemoteTask:
@pytest.mark.unit
@patch("application.api.user.tasks.remote_worker")
def test_calls_remote_worker(self, mock_worker):
from application.api.user.tasks import ingest_remote
mock_worker.return_value = {"status": "ok"}
result = ingest_remote({"url": "http://x"}, "job1", "user1", "web")
mock_worker.assert_called_once_with(
ANY, {"url": "http://x"}, "job1", "user1", "web"
)
assert result == {"status": "ok"}
class TestReingestSourceTask:
@pytest.mark.unit
@patch("application.worker.reingest_source_worker")
def test_calls_reingest_worker(self, mock_worker):
from application.api.user.tasks import reingest_source_task
mock_worker.return_value = {"status": "ok"}
result = reingest_source_task("source123", "user1")
mock_worker.assert_called_once_with(ANY, "source123", "user1")
assert result == {"status": "ok"}
class TestScheduleSyncsTask:
@pytest.mark.unit
@patch("application.api.user.tasks.sync_worker")
def test_calls_sync_worker(self, mock_worker):
from application.api.user.tasks import schedule_syncs
mock_worker.return_value = {"status": "ok"}
result = schedule_syncs("daily")
mock_worker.assert_called_once_with(ANY, "daily")
assert result == {"status": "ok"}
class TestSyncSourceTask:
@pytest.mark.unit
@patch("application.api.user.tasks.sync")
def test_calls_sync(self, mock_sync):
from application.api.user.tasks import sync_source
mock_sync.return_value = {"status": "ok"}
result = sync_source(
{"data": 1}, "job1", "user1", "web", "daily", "classic", "doc1"
)
mock_sync.assert_called_once_with(
ANY, {"data": 1}, "job1", "user1", "web", "daily", "classic", "doc1"
)
assert result == {"status": "ok"}
class TestStoreAttachmentTask:
@pytest.mark.unit
@patch("application.api.user.tasks.attachment_worker")
def test_calls_attachment_worker(self, mock_worker):
from application.api.user.tasks import store_attachment
mock_worker.return_value = {"status": "ok"}
result = store_attachment({"file": "info"}, "user1")
mock_worker.assert_called_once_with(ANY, {"file": "info"}, "user1")
assert result == {"status": "ok"}
class TestProcessAgentWebhookTask:
@pytest.mark.unit
@patch("application.api.user.tasks.agent_webhook_worker")
def test_calls_agent_webhook_worker(self, mock_worker):
from application.api.user.tasks import process_agent_webhook
mock_worker.return_value = {"status": "ok"}
result = process_agent_webhook("agent123", {"event": "test"})
mock_worker.assert_called_once_with(ANY, "agent123", {"event": "test"})
assert result == {"status": "ok"}
class TestIngestConnectorTask:
@pytest.mark.unit
@patch("application.worker.ingest_connector")
def test_calls_ingest_connector_defaults(self, mock_worker):
from application.api.user.tasks import ingest_connector_task
mock_worker.return_value = {"status": "ok"}
result = ingest_connector_task("job1", "user1", "gdrive")
mock_worker.assert_called_once_with(
ANY,
"job1",
"user1",
"gdrive",
session_token=None,
file_ids=None,
folder_ids=None,
recursive=True,
retriever="classic",
operation_mode="upload",
doc_id=None,
sync_frequency="never",
)
assert result == {"status": "ok"}
@pytest.mark.unit
@patch("application.worker.ingest_connector")
def test_calls_ingest_connector_custom(self, mock_worker):
from application.api.user.tasks import ingest_connector_task
mock_worker.return_value = {"status": "ok"}
result = ingest_connector_task(
"job1",
"user1",
"sharepoint",
session_token="tok",
file_ids=["f1"],
folder_ids=["d1"],
recursive=False,
retriever="duckdb",
operation_mode="sync",
doc_id="doc1",
sync_frequency="daily",
)
mock_worker.assert_called_once_with(
ANY,
"job1",
"user1",
"sharepoint",
session_token="tok",
file_ids=["f1"],
folder_ids=["d1"],
recursive=False,
retriever="duckdb",
operation_mode="sync",
doc_id="doc1",
sync_frequency="daily",
)
assert result == {"status": "ok"}
class TestSetupPeriodicTasks:
@pytest.mark.unit
def test_registers_periodic_tasks(self):
from application.api.user.tasks import setup_periodic_tasks
sender = MagicMock()
setup_periodic_tasks(sender)
assert sender.add_periodic_task.call_count == 3
calls = sender.add_periodic_task.call_args_list
# daily
assert calls[0][0][0] == timedelta(days=1)
# weekly
assert calls[1][0][0] == timedelta(weeks=1)
# monthly
assert calls[2][0][0] == timedelta(days=30)
class TestMcpOauthTask:
@pytest.mark.unit
@patch("application.api.user.tasks.mcp_oauth")
def test_calls_mcp_oauth(self, mock_worker):
from application.api.user.tasks import mcp_oauth_task
mock_worker.return_value = {"url": "http://auth"}
result = mcp_oauth_task({"server": "mcp"}, "user1")
mock_worker.assert_called_once_with(ANY, {"server": "mcp"}, "user1")
assert result == {"url": "http://auth"}
class TestMcpOauthStatusTask:
@pytest.mark.unit
@patch("application.api.user.tasks.mcp_oauth_status")
def test_calls_mcp_oauth_status(self, mock_worker):
from application.api.user.tasks import mcp_oauth_status_task
mock_worker.return_value = {"status": "authorized"}
result = mcp_oauth_status_task("task123")
mock_worker.assert_called_once_with(ANY, "task123")
assert result == {"status": "authorized"}

View File

@@ -0,0 +1,382 @@
from unittest.mock import MagicMock, patch
import pytest
from application.core.model_settings import (
AvailableModel,
ModelCapabilities,
ModelProvider,
ModelRegistry,
)
@pytest.fixture(autouse=True)
def _reset_registry():
"""Reset ModelRegistry singleton between tests."""
ModelRegistry._instance = None
ModelRegistry._initialized = False
yield
ModelRegistry._instance = None
ModelRegistry._initialized = False
def _make_model(
model_id="test-model",
provider=ModelProvider.OPENAI,
display_name="Test Model",
context_window=128000,
supports_tools=True,
supports_structured_output=False,
supported_attachment_types=None,
enabled=True,
base_url=None,
):
return AvailableModel(
id=model_id,
provider=provider,
display_name=display_name,
capabilities=ModelCapabilities(
supports_tools=supports_tools,
supports_structured_output=supports_structured_output,
supported_attachment_types=supported_attachment_types or [],
context_window=context_window,
),
enabled=enabled,
base_url=base_url,
)
# ── get_api_key_for_provider ─────────────────────────────────────────────────
class TestGetApiKeyForProvider:
"""settings is lazily imported inside the function body, so we patch
at application.core.settings.settings (the actual module attribute)."""
@pytest.mark.unit
def test_openai_key(self):
with patch("application.core.settings.settings") as mock_settings:
mock_settings.OPENAI_API_KEY = "sk-openai"
mock_settings.API_KEY = "sk-fallback"
mock_settings.OPEN_ROUTER_API_KEY = None
mock_settings.NOVITA_API_KEY = None
mock_settings.ANTHROPIC_API_KEY = None
mock_settings.GOOGLE_API_KEY = None
mock_settings.GROQ_API_KEY = None
mock_settings.HUGGINGFACE_API_KEY = None
from application.core.model_utils import get_api_key_for_provider
assert get_api_key_for_provider("openai") == "sk-openai"
@pytest.mark.unit
def test_anthropic_key(self):
with patch("application.core.settings.settings") as mock_settings:
mock_settings.ANTHROPIC_API_KEY = "sk-anthropic"
mock_settings.API_KEY = "sk-fallback"
from application.core.model_utils import get_api_key_for_provider
assert get_api_key_for_provider("anthropic") == "sk-anthropic"
@pytest.mark.unit
def test_google_key(self):
with patch("application.core.settings.settings") as mock_settings:
mock_settings.GOOGLE_API_KEY = "sk-google"
mock_settings.API_KEY = "sk-fallback"
from application.core.model_utils import get_api_key_for_provider
assert get_api_key_for_provider("google") == "sk-google"
@pytest.mark.unit
def test_groq_key(self):
with patch("application.core.settings.settings") as mock_settings:
mock_settings.GROQ_API_KEY = "sk-groq"
mock_settings.API_KEY = "sk-fallback"
from application.core.model_utils import get_api_key_for_provider
assert get_api_key_for_provider("groq") == "sk-groq"
@pytest.mark.unit
def test_openrouter_key(self):
with patch("application.core.settings.settings") as mock_settings:
mock_settings.OPEN_ROUTER_API_KEY = "sk-or"
mock_settings.API_KEY = "sk-fallback"
from application.core.model_utils import get_api_key_for_provider
assert get_api_key_for_provider("openrouter") == "sk-or"
@pytest.mark.unit
def test_novita_key(self):
with patch("application.core.settings.settings") as mock_settings:
mock_settings.NOVITA_API_KEY = "sk-novita"
mock_settings.API_KEY = "sk-fallback"
from application.core.model_utils import get_api_key_for_provider
assert get_api_key_for_provider("novita") == "sk-novita"
@pytest.mark.unit
def test_huggingface_key(self):
with patch("application.core.settings.settings") as mock_settings:
mock_settings.HUGGINGFACE_API_KEY = "hf-key"
mock_settings.API_KEY = "sk-fallback"
from application.core.model_utils import get_api_key_for_provider
assert get_api_key_for_provider("huggingface") == "hf-key"
@pytest.mark.unit
def test_docsgpt_returns_fallback(self):
with patch("application.core.settings.settings") as mock_settings:
mock_settings.API_KEY = "sk-fallback"
from application.core.model_utils import get_api_key_for_provider
assert get_api_key_for_provider("docsgpt") == "sk-fallback"
@pytest.mark.unit
def test_llama_cpp_returns_fallback(self):
with patch("application.core.settings.settings") as mock_settings:
mock_settings.API_KEY = "sk-fallback"
from application.core.model_utils import get_api_key_for_provider
assert get_api_key_for_provider("llama.cpp") == "sk-fallback"
@pytest.mark.unit
def test_unknown_provider_returns_fallback(self):
with patch("application.core.settings.settings") as mock_settings:
mock_settings.API_KEY = "sk-fallback"
from application.core.model_utils import get_api_key_for_provider
assert get_api_key_for_provider("unknown_provider") == "sk-fallback"
@pytest.mark.unit
def test_azure_openai_key(self):
with patch("application.core.settings.settings") as mock_settings:
mock_settings.API_KEY = "sk-azure"
from application.core.model_utils import get_api_key_for_provider
assert get_api_key_for_provider("azure_openai") == "sk-azure"
# ── get_all_available_models ─────────────────────────────────────────────────
class TestGetAllAvailableModels:
@pytest.mark.unit
@patch("application.core.model_utils.ModelRegistry.get_instance")
def test_returns_enabled_models_as_dict(self, mock_get_instance):
model_a = _make_model("model-a", display_name="Model A")
model_b = _make_model("model-b", display_name="Model B")
mock_registry = MagicMock()
mock_registry.get_enabled_models.return_value = [model_a, model_b]
mock_get_instance.return_value = mock_registry
from application.core.model_utils import get_all_available_models
result = get_all_available_models()
assert "model-a" in result
assert "model-b" in result
assert result["model-a"]["display_name"] == "Model A"
assert result["model-b"]["display_name"] == "Model B"
@pytest.mark.unit
@patch("application.core.model_utils.ModelRegistry.get_instance")
def test_empty_registry(self, mock_get_instance):
mock_registry = MagicMock()
mock_registry.get_enabled_models.return_value = []
mock_get_instance.return_value = mock_registry
from application.core.model_utils import get_all_available_models
assert get_all_available_models() == {}
# ── validate_model_id ────────────────────────────────────────────────────────
class TestValidateModelId:
@pytest.mark.unit
@patch("application.core.model_utils.ModelRegistry.get_instance")
def test_exists(self, mock_get_instance):
mock_registry = MagicMock()
mock_registry.model_exists.return_value = True
mock_get_instance.return_value = mock_registry
from application.core.model_utils import validate_model_id
assert validate_model_id("gpt-4") is True
@pytest.mark.unit
@patch("application.core.model_utils.ModelRegistry.get_instance")
def test_not_exists(self, mock_get_instance):
mock_registry = MagicMock()
mock_registry.model_exists.return_value = False
mock_get_instance.return_value = mock_registry
from application.core.model_utils import validate_model_id
assert validate_model_id("nonexistent") is False
# ── get_model_capabilities ───────────────────────────────────────────────────
class TestGetModelCapabilities:
@pytest.mark.unit
@patch("application.core.model_utils.ModelRegistry.get_instance")
def test_model_found(self, mock_get_instance):
model = _make_model(
"gpt-4",
context_window=8192,
supports_tools=True,
supports_structured_output=True,
supported_attachment_types=["image/png"],
)
mock_registry = MagicMock()
mock_registry.get_model.return_value = model
mock_get_instance.return_value = mock_registry
from application.core.model_utils import get_model_capabilities
caps = get_model_capabilities("gpt-4")
assert caps is not None
assert caps["supported_attachment_types"] == ["image/png"]
assert caps["supports_tools"] is True
assert caps["supports_structured_output"] is True
assert caps["context_window"] == 8192
@pytest.mark.unit
@patch("application.core.model_utils.ModelRegistry.get_instance")
def test_model_not_found(self, mock_get_instance):
mock_registry = MagicMock()
mock_registry.get_model.return_value = None
mock_get_instance.return_value = mock_registry
from application.core.model_utils import get_model_capabilities
assert get_model_capabilities("nonexistent") is None
# ── get_default_model_id ─────────────────────────────────────────────────────
class TestGetDefaultModelId:
@pytest.mark.unit
@patch("application.core.model_utils.ModelRegistry.get_instance")
def test_returns_default(self, mock_get_instance):
mock_registry = MagicMock()
mock_registry.default_model_id = "gpt-4"
mock_get_instance.return_value = mock_registry
from application.core.model_utils import get_default_model_id
assert get_default_model_id() == "gpt-4"
# ── get_provider_from_model_id ───────────────────────────────────────────────
class TestGetProviderFromModelId:
@pytest.mark.unit
@patch("application.core.model_utils.ModelRegistry.get_instance")
def test_model_found(self, mock_get_instance):
model = _make_model("gpt-4", provider=ModelProvider.OPENAI)
mock_registry = MagicMock()
mock_registry.get_model.return_value = model
mock_get_instance.return_value = mock_registry
from application.core.model_utils import get_provider_from_model_id
assert get_provider_from_model_id("gpt-4") == "openai"
@pytest.mark.unit
@patch("application.core.model_utils.ModelRegistry.get_instance")
def test_model_not_found(self, mock_get_instance):
mock_registry = MagicMock()
mock_registry.get_model.return_value = None
mock_get_instance.return_value = mock_registry
from application.core.model_utils import get_provider_from_model_id
assert get_provider_from_model_id("nonexistent") is None
# ── get_token_limit ──────────────────────────────────────────────────────────
class TestGetTokenLimit:
@pytest.mark.unit
@patch("application.core.model_utils.ModelRegistry.get_instance")
def test_model_found(self, mock_get_instance):
model = _make_model("gpt-4", context_window=8192)
mock_registry = MagicMock()
mock_registry.get_model.return_value = model
mock_get_instance.return_value = mock_registry
from application.core.model_utils import get_token_limit
assert get_token_limit("gpt-4") == 8192
@pytest.mark.unit
@patch("application.core.model_utils.ModelRegistry.get_instance")
def test_model_not_found_returns_default(self, mock_get_instance):
mock_registry = MagicMock()
mock_registry.get_model.return_value = None
mock_get_instance.return_value = mock_registry
with patch("application.core.settings.settings") as mock_settings:
mock_settings.DEFAULT_LLM_TOKEN_LIMIT = 128000
from application.core.model_utils import get_token_limit
assert get_token_limit("nonexistent") == 128000
# ── get_base_url_for_model ───────────────────────────────────────────────────
class TestGetBaseUrlForModel:
@pytest.mark.unit
@patch("application.core.model_utils.ModelRegistry.get_instance")
def test_model_with_base_url(self, mock_get_instance):
model = _make_model("custom-model", base_url="http://localhost:8080")
mock_registry = MagicMock()
mock_registry.get_model.return_value = model
mock_get_instance.return_value = mock_registry
from application.core.model_utils import get_base_url_for_model
assert get_base_url_for_model("custom-model") == "http://localhost:8080"
@pytest.mark.unit
@patch("application.core.model_utils.ModelRegistry.get_instance")
def test_model_without_base_url(self, mock_get_instance):
model = _make_model("gpt-4", base_url=None)
mock_registry = MagicMock()
mock_registry.get_model.return_value = model
mock_get_instance.return_value = mock_registry
from application.core.model_utils import get_base_url_for_model
assert get_base_url_for_model("gpt-4") is None
@pytest.mark.unit
@patch("application.core.model_utils.ModelRegistry.get_instance")
def test_model_not_found(self, mock_get_instance):
mock_registry = MagicMock()
mock_registry.get_model.return_value = None
mock_get_instance.return_value = mock_registry
from application.core.model_utils import get_base_url_for_model
assert get_base_url_for_model("nonexistent") is None

File diff suppressed because it is too large Load Diff

204
tests/llm/test_base_llm.py Normal file
View File

@@ -0,0 +1,204 @@
"""Unit tests for application/llm/base.py — BaseLLM.
Covers initialisation, static helpers, supports_* introspection,
structured-output defaults, and attachment-type defaults.
Fallback behaviour is covered separately in test_fallback.py.
"""
from unittest.mock import MagicMock, Mock
import pytest
from application.llm.base import BaseLLM
# ---------------------------------------------------------------------------
# Concrete stub so we can instantiate the abstract base
# ---------------------------------------------------------------------------
class StubLLM(BaseLLM):
"""Minimal concrete BaseLLM for unit-testing non-abstract members."""
def _raw_gen(self, baseself, model, messages, stream, tools=None, **kw):
return "raw_gen_result"
def _raw_gen_stream(self, baseself, model, messages, stream, tools=None, **kw):
yield "chunk"
# ---------------------------------------------------------------------------
# __init__
# ---------------------------------------------------------------------------
@pytest.mark.unit
class TestBaseLLMInit:
def test_defaults(self):
llm = StubLLM()
assert llm.decoded_token is None
assert llm.agent_id is None
assert llm.model_id is None
assert llm.base_url is None
assert llm.token_usage == {"prompt_tokens": 0, "generated_tokens": 0}
assert llm._backup_models == []
assert llm._fallback_llm is None
def test_agent_id_cast_to_str(self):
llm = StubLLM(agent_id=42)
assert llm.agent_id == "42"
def test_agent_id_none_stays_none(self):
llm = StubLLM(agent_id=None)
assert llm.agent_id is None
def test_custom_params(self):
token = {"sub": "u1"}
llm = StubLLM(
decoded_token=token,
agent_id="abc",
model_id="gpt-4",
base_url="http://x",
backup_models=["m1", "m2"],
)
assert llm.decoded_token is token
assert llm.agent_id == "abc"
assert llm.model_id == "gpt-4"
assert llm.base_url == "http://x"
assert llm._backup_models == ["m1", "m2"]
# ---------------------------------------------------------------------------
# _remove_null_values
# ---------------------------------------------------------------------------
@pytest.mark.unit
class TestRemoveNullValues:
def test_removes_none_values(self):
result = BaseLLM._remove_null_values({"a": 1, "b": None, "c": "x"})
assert result == {"a": 1, "c": "x"}
def test_keeps_falsy_non_none(self):
result = BaseLLM._remove_null_values({"a": 0, "b": "", "c": False, "d": []})
assert result == {"a": 0, "b": "", "c": False, "d": []}
def test_non_dict_passthrough(self):
assert BaseLLM._remove_null_values("hello") == "hello"
assert BaseLLM._remove_null_values(42) == 42
assert BaseLLM._remove_null_values([1, 2]) == [1, 2]
def test_empty_dict(self):
assert BaseLLM._remove_null_values({}) == {}
def test_all_none(self):
assert BaseLLM._remove_null_values({"a": None, "b": None}) == {}
# ---------------------------------------------------------------------------
# supports_tools / _supports_tools
# ---------------------------------------------------------------------------
@pytest.mark.unit
class TestSupportsTools:
def test_supports_tools_true_when_callable(self):
llm = StubLLM()
assert llm.supports_tools() is True
def test_supports_tools_false_when_not_callable(self):
llm = StubLLM()
llm._supports_tools = "not_callable"
assert llm.supports_tools() is False
def test_default_supports_tools_raises(self):
llm = StubLLM()
with pytest.raises(NotImplementedError):
llm._supports_tools()
# ---------------------------------------------------------------------------
# supports_structured_output / _supports_structured_output
# ---------------------------------------------------------------------------
@pytest.mark.unit
class TestSupportsStructuredOutput:
def test_supports_structured_output_true(self):
llm = StubLLM()
assert llm.supports_structured_output() is True
def test_default_supports_structured_output_returns_false(self):
llm = StubLLM()
assert llm._supports_structured_output() is False
# ---------------------------------------------------------------------------
# prepare_structured_output_format
# ---------------------------------------------------------------------------
@pytest.mark.unit
class TestPrepareStructuredOutputFormat:
def test_returns_none_by_default(self):
llm = StubLLM()
assert llm.prepare_structured_output_format({"type": "object"}) is None
# ---------------------------------------------------------------------------
# get_supported_attachment_types
# ---------------------------------------------------------------------------
@pytest.mark.unit
class TestGetSupportedAttachmentTypes:
def test_returns_empty_list(self):
llm = StubLLM()
assert llm.get_supported_attachment_types() == []
# ---------------------------------------------------------------------------
# fallback_llm — caching
# ---------------------------------------------------------------------------
@pytest.mark.unit
class TestFallbackLLMCaching:
def test_returns_cached_instance(self, monkeypatch):
"""Once resolved, the same fallback instance is returned."""
sentinel = StubLLM()
llm = StubLLM()
llm._fallback_llm = sentinel
assert llm.fallback_llm is sentinel
def test_none_when_no_backup_and_no_global(self, monkeypatch):
monkeypatch.setattr(
"application.llm.base.settings",
MagicMock(FALLBACK_LLM_PROVIDER=None),
)
llm = StubLLM(backup_models=[])
assert llm.fallback_llm is None
def test_global_fallback_init_failure_returns_none(self, monkeypatch):
monkeypatch.setattr(
"application.llm.base.settings",
MagicMock(
FALLBACK_LLM_PROVIDER="openai",
FALLBACK_LLM_NAME="gpt-4",
FALLBACK_LLM_API_KEY="k",
API_KEY="k",
),
)
monkeypatch.setattr(
"application.llm.llm_creator.LLMCreator.create_llm",
Mock(side_effect=RuntimeError("boom")),
)
llm = StubLLM(backup_models=[])
assert llm.fallback_llm is None