diff --git a/tests/agents/test_base_agent.py b/tests/agents/test_base_agent.py index 7510c563..05e831f0 100644 --- a/tests/agents/test_base_agent.py +++ b/tests/agents/test_base_agent.py @@ -1,4 +1,4 @@ -from unittest.mock import Mock +from unittest.mock import Mock, patch import pytest from application.agents.classic_agent import ClassicAgent @@ -56,6 +56,73 @@ class TestBaseAgentInitialization: agent = ClassicAgent(**agent_base_params) assert agent.user == "user123" + def test_dependency_injection_llm(self, agent_base_params, mock_llm_handler_creator): + """When llm is provided, LLMCreator.create_llm is NOT called.""" + injected_llm = Mock() + agent_base_params["llm"] = injected_llm + agent = ClassicAgent(**agent_base_params) + assert agent.llm is injected_llm + + def test_dependency_injection_llm_handler(self, agent_base_params, mock_llm_creator): + """When llm_handler is provided, LLMHandlerCreator is NOT called.""" + injected_handler = Mock() + agent_base_params["llm_handler"] = injected_handler + agent = ClassicAgent(**agent_base_params) + assert agent.llm_handler is injected_handler + + def test_dependency_injection_tool_executor( + self, agent_base_params, mock_llm_creator, mock_llm_handler_creator + ): + """When tool_executor is provided, a new one is NOT created.""" + injected_executor = Mock() + injected_executor.tool_calls = [] + agent_base_params["tool_executor"] = injected_executor + agent = ClassicAgent(**agent_base_params) + assert agent.tool_executor is injected_executor + + def test_json_schema_normalized( + self, agent_base_params, mock_llm_creator, mock_llm_handler_creator + ): + agent_base_params["json_schema"] = {"type": "object"} + agent = ClassicAgent(**agent_base_params) + assert agent.json_schema == {"type": "object"} + + def test_json_schema_wrapped( + self, agent_base_params, mock_llm_creator, mock_llm_handler_creator + ): + agent_base_params["json_schema"] = {"schema": {"type": "string"}} + agent = ClassicAgent(**agent_base_params) + assert agent.json_schema == {"type": "string"} + + def test_json_schema_invalid_ignored( + self, agent_base_params, mock_llm_creator, mock_llm_handler_creator + ): + agent_base_params["json_schema"] = {"bad": "no type"} + agent = ClassicAgent(**agent_base_params) + assert agent.json_schema is None + + def test_retrieved_docs_defaults_to_empty( + self, agent_base_params, mock_llm_creator, mock_llm_handler_creator + ): + agent = ClassicAgent(**agent_base_params) + assert agent.retrieved_docs == [] + + def test_attachments_defaults_to_empty( + self, agent_base_params, mock_llm_creator, mock_llm_handler_creator + ): + agent_base_params["attachments"] = None + agent = ClassicAgent(**agent_base_params) + assert agent.attachments == [] + + def test_limited_token_mode_defaults( + self, agent_base_params, mock_llm_creator, mock_llm_handler_creator + ): + agent = ClassicAgent(**agent_base_params) + assert agent.limited_token_mode is False + assert agent.limited_request_mode is False + assert agent.current_token_count == 0 + assert agent.context_limit_reached is False + @pytest.mark.unit class TestBaseAgentBuildMessages: @@ -602,3 +669,656 @@ class TestBaseAgentHandleResponse: assert len(results) == 2 assert results[0]["type"] == "tool_call" assert results[1]["answer"] == "Final answer" + + def test_handle_response_dict_event_passthrough( + self, + agent_base_params, + mock_llm_handler, + mock_llm_creator, + mock_llm_handler_creator, + log_context, + ): + """Dict events with 'type' key pass through without wrapping.""" + + def mock_process(*args): + yield {"type": "info", "data": {"message": "processing"}} + + mock_llm_handler.process_message_flow = Mock(side_effect=mock_process) + + agent = ClassicAgent(**agent_base_params) + response = Mock() + response.message = None + + results = list(agent._handle_response(response, {}, [], log_context)) + assert results == [{"type": "info", "data": {"message": "processing"}}] + + def test_handle_response_message_object_from_handler( + self, + agent_base_params, + mock_llm_handler, + mock_llm_creator, + mock_llm_handler_creator, + log_context, + ): + """Response objects with .message.content from handler are unwrapped.""" + event = Mock() + event.message = Mock() + event.message.content = "from handler" + + def mock_process(*args): + yield event + + mock_llm_handler.process_message_flow = Mock(side_effect=mock_process) + + agent = ClassicAgent(**agent_base_params) + response = Mock() + response.message = None + + results = list(agent._handle_response(response, {}, [], log_context)) + assert results[0]["answer"] == "from handler" + + +# --------------------------------------------------------------------------- +# gen() — the @log_activity decorated entry point +# --------------------------------------------------------------------------- + + +@pytest.mark.unit +class TestBaseAgentGen: + + def test_gen_delegates_to_gen_inner( + self, agent_base_params, mock_llm_creator, mock_llm_handler_creator + ): + agent = ClassicAgent(**agent_base_params) + + # ClassicAgent._gen_inner is abstract — we patch it + with patch.object(agent, "_gen_inner") as mock_inner: + mock_inner.return_value = iter([{"answer": "ok"}]) + results = list(agent.gen("hello")) + + assert any(r.get("answer") == "ok" for r in results) + + +# --------------------------------------------------------------------------- +# tool_calls property +# --------------------------------------------------------------------------- + + +@pytest.mark.unit +class TestBaseAgentToolCallsProperty: + + def test_getter( + self, agent_base_params, mock_llm_creator, mock_llm_handler_creator + ): + agent = ClassicAgent(**agent_base_params) + agent.tool_executor.tool_calls = ["a", "b"] + assert agent.tool_calls == ["a", "b"] + + def test_setter( + self, agent_base_params, mock_llm_creator, mock_llm_handler_creator + ): + agent = ClassicAgent(**agent_base_params) + agent.tool_calls = ["x"] + assert agent.tool_executor.tool_calls == ["x"] + + +# --------------------------------------------------------------------------- +# _calculate_current_context_tokens +# --------------------------------------------------------------------------- + + +@pytest.mark.unit +class TestCalculateContextTokens: + + def test_delegates_to_token_counter( + self, agent_base_params, mock_llm_creator, mock_llm_handler_creator + ): + agent = ClassicAgent(**agent_base_params) + messages = [{"role": "user", "content": "hello"}] + + with patch( + "application.api.answer.services.compression.token_counter.TokenCounter" + ) as MockTC: + MockTC.count_message_tokens.return_value = 42 + result = agent._calculate_current_context_tokens(messages) + assert result == 42 + MockTC.count_message_tokens.assert_called_once_with(messages) + + +# --------------------------------------------------------------------------- +# _check_context_limit +# --------------------------------------------------------------------------- + + +@pytest.mark.unit +class TestCheckContextLimit: + + def _make_agent(self, agent_base_params, mock_llm_creator, mock_llm_handler_creator): + return ClassicAgent(**agent_base_params) + + def test_below_threshold_returns_false( + self, agent_base_params, mock_llm_creator, mock_llm_handler_creator + ): + agent = self._make_agent( + agent_base_params, mock_llm_creator, mock_llm_handler_creator + ) + messages = [{"role": "user", "content": "hi"}] + + with patch.object(agent, "_calculate_current_context_tokens", return_value=100): + with patch( + "application.core.model_utils.get_token_limit", return_value=10000 + ): + result = agent._check_context_limit(messages) + assert result is False + assert agent.current_token_count == 100 + + def test_at_threshold_returns_true( + self, agent_base_params, mock_llm_creator, mock_llm_handler_creator + ): + agent = self._make_agent( + agent_base_params, mock_llm_creator, mock_llm_handler_creator + ) + messages = [{"role": "user", "content": "hi"}] + + # threshold = 10000 * 0.8 = 8000; tokens = 8001 → True + with patch.object(agent, "_calculate_current_context_tokens", return_value=8001): + with patch( + "application.core.model_utils.get_token_limit", return_value=10000 + ): + result = agent._check_context_limit(messages) + assert result is True + + def test_error_returns_false( + self, agent_base_params, mock_llm_creator, mock_llm_handler_creator + ): + agent = self._make_agent( + agent_base_params, mock_llm_creator, mock_llm_handler_creator + ) + with patch.object( + agent, + "_calculate_current_context_tokens", + side_effect=RuntimeError("boom"), + ): + result = agent._check_context_limit([]) + assert result is False + + +# --------------------------------------------------------------------------- +# _validate_context_size +# --------------------------------------------------------------------------- + + +@pytest.mark.unit +class TestValidateContextSize: + + def test_at_limit_logs_warning( + self, agent_base_params, mock_llm_creator, mock_llm_handler_creator + ): + agent = ClassicAgent(**agent_base_params) + with patch.object(agent, "_calculate_current_context_tokens", return_value=10000): + with patch( + "application.core.model_utils.get_token_limit", return_value=10000 + ): + # Should not raise + agent._validate_context_size([{"role": "user", "content": "x"}]) + assert agent.current_token_count == 10000 + + def test_below_threshold_no_warning( + self, agent_base_params, mock_llm_creator, mock_llm_handler_creator + ): + agent = ClassicAgent(**agent_base_params) + with patch.object(agent, "_calculate_current_context_tokens", return_value=100): + with patch( + "application.core.model_utils.get_token_limit", return_value=10000 + ): + agent._validate_context_size([]) + assert agent.current_token_count == 100 + + def test_approaching_threshold( + self, agent_base_params, mock_llm_creator, mock_llm_handler_creator + ): + agent = ClassicAgent(**agent_base_params) + # 8500 / 10000 = 85% → above 80% threshold but below 100% + with patch.object(agent, "_calculate_current_context_tokens", return_value=8500): + with patch( + "application.core.model_utils.get_token_limit", return_value=10000 + ): + agent._validate_context_size([]) + assert agent.current_token_count == 8500 + + +# --------------------------------------------------------------------------- +# _truncate_text_middle +# --------------------------------------------------------------------------- + + +@pytest.mark.unit +class TestTruncateTextMiddle: + + def test_short_text_unchanged( + self, agent_base_params, mock_llm_creator, mock_llm_handler_creator + ): + agent = ClassicAgent(**agent_base_params) + with patch("application.utils.num_tokens_from_string", return_value=5): + result = agent._truncate_text_middle("short", max_tokens=100) + assert result == "short" + + def test_long_text_truncated( + self, agent_base_params, mock_llm_creator, mock_llm_handler_creator + ): + agent = ClassicAgent(**agent_base_params) + long_text = "A" * 1000 + + def fake_tokens(text): + return len(text) // 4 + + with patch("application.utils.num_tokens_from_string", side_effect=fake_tokens): + result = agent._truncate_text_middle(long_text, max_tokens=50) + assert "[... content truncated to fit context limit ...]" in result + assert len(result) < len(long_text) + + def test_zero_target_returns_empty( + self, agent_base_params, mock_llm_creator, mock_llm_handler_creator + ): + agent = ClassicAgent(**agent_base_params) + with patch("application.utils.num_tokens_from_string", return_value=100): + result = agent._truncate_text_middle("some text", max_tokens=0) + assert result == "" + + +# --------------------------------------------------------------------------- +# _truncate_history_to_fit +# --------------------------------------------------------------------------- + + +@pytest.mark.unit +class TestTruncateHistoryToFit: + + def _make_agent(self, agent_base_params, mock_llm_creator, mock_llm_handler_creator): + return ClassicAgent(**agent_base_params) + + def test_empty_history( + self, agent_base_params, mock_llm_creator, mock_llm_handler_creator + ): + agent = self._make_agent( + agent_base_params, mock_llm_creator, mock_llm_handler_creator + ) + assert agent._truncate_history_to_fit([], 100) == [] + + def test_zero_budget( + self, agent_base_params, mock_llm_creator, mock_llm_handler_creator + ): + agent = self._make_agent( + agent_base_params, mock_llm_creator, mock_llm_handler_creator + ) + history = [{"prompt": "a", "response": "b"}] + assert agent._truncate_history_to_fit(history, 0) == [] + + def test_fits_all( + self, agent_base_params, mock_llm_creator, mock_llm_handler_creator + ): + agent = self._make_agent( + agent_base_params, mock_llm_creator, mock_llm_handler_creator + ) + history = [ + {"prompt": "q1", "response": "a1"}, + {"prompt": "q2", "response": "a2"}, + ] + with patch("application.utils.num_tokens_from_string", return_value=5): + result = agent._truncate_history_to_fit(history, 10000) + assert len(result) == 2 + + def test_partial_fit_keeps_most_recent( + self, agent_base_params, mock_llm_creator, mock_llm_handler_creator + ): + agent = self._make_agent( + agent_base_params, mock_llm_creator, mock_llm_handler_creator + ) + history = [ + {"prompt": "old", "response": "old_ans"}, + {"prompt": "new", "response": "new_ans"}, + ] + # Each message = 10 tokens (prompt + response), budget = 15 → only 1 fits + with patch("application.utils.num_tokens_from_string", return_value=5): + result = agent._truncate_history_to_fit(history, 15) + assert len(result) == 1 + assert result[0]["prompt"] == "new" + + def test_history_with_tool_calls( + self, agent_base_params, mock_llm_creator, mock_llm_handler_creator + ): + agent = self._make_agent( + agent_base_params, mock_llm_creator, mock_llm_handler_creator + ) + history = [ + { + "prompt": "q", + "response": "a", + "tool_calls": [ + { + "tool_name": "t", + "action_name": "act", + "arguments": "{}", + "result": "ok", + } + ], + } + ] + with patch("application.utils.num_tokens_from_string", return_value=3): + result = agent._truncate_history_to_fit(history, 100) + assert len(result) == 1 + + +# --------------------------------------------------------------------------- +# _build_messages — compressed_summary and query truncation +# --------------------------------------------------------------------------- + + +@pytest.mark.unit +class TestBuildMessagesAdvanced: + + def test_compressed_summary_appended( + self, agent_base_params, mock_llm_creator, mock_llm_handler_creator + ): + agent_base_params["compressed_summary"] = "Previous conversation summary" + agent = ClassicAgent(**agent_base_params) + + with patch( + "application.core.model_utils.get_token_limit", return_value=100000 + ), patch("application.utils.num_tokens_from_string", return_value=10): + messages = agent._build_messages("System prompt", "query") + + system_content = messages[0]["content"] + assert "Previous conversation summary" in system_content + + def test_query_truncated_when_too_large( + self, agent_base_params, mock_llm_creator, mock_llm_handler_creator + ): + agent = ClassicAgent(**agent_base_params) + + call_count = {"n": 0} + + def fake_tokens(text): + call_count["n"] += 1 + return len(text) + + with patch( + "application.core.model_utils.get_token_limit", return_value=200 + ), patch("application.utils.num_tokens_from_string", side_effect=fake_tokens): + with patch.object(agent, "_truncate_text_middle", return_value="truncated"): + with patch.object(agent, "_truncate_history_to_fit", return_value=[]): + messages = agent._build_messages("sys", "A" * 500) + + # The method should have been called for truncation + assert messages[-1]["role"] == "user" + + def test_build_messages_with_tool_call_missing_call_id( + self, agent_base_params, mock_llm_creator, mock_llm_handler_creator + ): + """Tool calls without call_id get a generated UUID.""" + history = [ + { + "tool_calls": [ + { + "action_name": "search", + "arguments": "{}", + "result": "found", + } + ] + } + ] + agent_base_params["chat_history"] = history + agent = ClassicAgent(**agent_base_params) + + with patch( + "application.core.model_utils.get_token_limit", return_value=100000 + ), patch("application.utils.num_tokens_from_string", return_value=5): + messages = agent._build_messages("sys", "q") + + tool_msgs = [m for m in messages if m["role"] == "tool"] + assert len(tool_msgs) == 1 + + +# --------------------------------------------------------------------------- +# _llm_gen — edge cases +# --------------------------------------------------------------------------- + + +@pytest.mark.unit +class TestLLMGenAdvanced: + + def test_llm_gen_with_attachments( + self, + agent_base_params, + mock_llm, + mock_llm_creator, + mock_llm_handler_creator, + ): + agent_base_params["attachments"] = [{"id": "att1", "mime_type": "image/png"}] + agent = ClassicAgent(**agent_base_params) + + messages = [{"role": "user", "content": "test"}] + agent._llm_gen(messages) + + call_kwargs = mock_llm.gen_stream.call_args[1] + assert "_usage_attachments" in call_kwargs + + def test_llm_gen_without_log_context( + self, + agent_base_params, + mock_llm, + mock_llm_creator, + mock_llm_handler_creator, + ): + agent = ClassicAgent(**agent_base_params) + messages = [{"role": "user", "content": "test"}] + + # Should not raise even without log_context + agent._llm_gen(messages, log_context=None) + mock_llm.gen_stream.assert_called_once() + + def test_llm_gen_google_structured_output( + self, + agent_base_params, + mock_llm, + mock_llm_creator, + mock_llm_handler_creator, + log_context, + ): + mock_llm._supports_structured_output = Mock(return_value=True) + mock_llm.prepare_structured_output_format = Mock( + return_value={"schema": "test"} + ) + + agent_base_params["json_schema"] = {"type": "object"} + agent_base_params["llm_name"] = "google" + agent = ClassicAgent(**agent_base_params) + + messages = [{"role": "user", "content": "test"}] + agent._llm_gen(messages, log_context) + + call_kwargs = mock_llm.gen_stream.call_args[1] + assert "response_schema" in call_kwargs + + def test_llm_gen_no_tools_when_unsupported( + self, + agent_base_params, + mock_llm, + mock_llm_creator, + mock_llm_handler_creator, + ): + mock_llm._supports_tools = False + agent = ClassicAgent(**agent_base_params) + agent.tools = [{"type": "function", "function": {"name": "test"}}] + + messages = [{"role": "user", "content": "test"}] + agent._llm_gen(messages) + + call_kwargs = mock_llm.gen_stream.call_args[1] + assert "tools" not in call_kwargs + + def test_llm_gen_no_structured_output_when_unsupported( + self, + agent_base_params, + mock_llm, + mock_llm_creator, + mock_llm_handler_creator, + ): + mock_llm._supports_structured_output = Mock(return_value=False) + agent_base_params["json_schema"] = {"type": "object"} + agent = ClassicAgent(**agent_base_params) + + messages = [{"role": "user", "content": "test"}] + agent._llm_gen(messages) + + call_kwargs = mock_llm.gen_stream.call_args[1] + assert "response_format" not in call_kwargs + assert "response_schema" not in call_kwargs + + def test_llm_gen_no_format_when_prepare_returns_none( + self, + agent_base_params, + mock_llm, + mock_llm_creator, + mock_llm_handler_creator, + ): + mock_llm._supports_structured_output = Mock(return_value=True) + mock_llm.prepare_structured_output_format = Mock(return_value=None) + + agent_base_params["json_schema"] = {"type": "object"} + agent_base_params["llm_name"] = "openai" + agent = ClassicAgent(**agent_base_params) + + messages = [{"role": "user", "content": "test"}] + agent._llm_gen(messages) + + call_kwargs = mock_llm.gen_stream.call_args[1] + assert "response_format" not in call_kwargs + + +# --------------------------------------------------------------------------- +# _llm_handler +# --------------------------------------------------------------------------- + + +@pytest.mark.unit +class TestLLMHandlerMethod: + + def test_delegates_to_handler( + self, + agent_base_params, + mock_llm_handler, + mock_llm_creator, + mock_llm_handler_creator, + log_context, + ): + mock_llm_handler.process_message_flow = Mock(return_value="result") + + agent = ClassicAgent(**agent_base_params) + resp = Mock() + result = agent._llm_handler(resp, {}, [], log_context) + + mock_llm_handler.process_message_flow.assert_called_once() + assert result == "result" + assert len(log_context.stacks) == 1 + assert log_context.stacks[0]["component"] == "llm_handler" + + def test_without_log_context( + self, + agent_base_params, + mock_llm_handler, + mock_llm_creator, + mock_llm_handler_creator, + ): + mock_llm_handler.process_message_flow = Mock(return_value="r") + agent = ClassicAgent(**agent_base_params) + result = agent._llm_handler(Mock(), {}, [], log_context=None) + assert result == "r" + + +# --------------------------------------------------------------------------- +# _handle_response — structured output on all code paths +# --------------------------------------------------------------------------- + + +@pytest.mark.unit +class TestHandleResponseStructuredAllPaths: + + def test_message_response_with_structured_output( + self, + agent_base_params, + mock_llm, + mock_llm_creator, + mock_llm_handler_creator, + log_context, + ): + """Structured output on the message.content early-return path.""" + mock_llm._supports_structured_output = Mock(return_value=True) + agent_base_params["json_schema"] = {"type": "object"} + agent = ClassicAgent(**agent_base_params) + + response = Mock() + response.message = Mock() + response.message.content = "structured msg" + + results = list(agent._handle_response(response, {}, [], log_context)) + assert results[0]["structured"] is True + assert results[0]["schema"] == {"type": "object"} + assert results[0]["answer"] == "structured msg" + + def test_handler_string_event_with_structured_output( + self, + agent_base_params, + mock_llm, + mock_llm_handler, + mock_llm_creator, + mock_llm_handler_creator, + log_context, + ): + """Structured output on string events from the handler.""" + mock_llm._supports_structured_output = Mock(return_value=True) + agent_base_params["json_schema"] = {"type": "array"} + + def mock_process(*args): + yield "handler string" + + mock_llm_handler.process_message_flow = Mock(side_effect=mock_process) + + agent = ClassicAgent(**agent_base_params) + response = Mock() + response.message = None + + results = list(agent._handle_response(response, {}, [], log_context)) + assert results[0]["structured"] is True + assert results[0]["schema"] == {"type": "array"} + + def test_handler_message_event_with_structured_output( + self, + agent_base_params, + mock_llm, + mock_llm_handler, + mock_llm_creator, + mock_llm_handler_creator, + log_context, + ): + """Structured output on message-object events from the handler.""" + mock_llm._supports_structured_output = Mock(return_value=True) + agent_base_params["json_schema"] = {"type": "number"} + + event = Mock() + event.message = Mock() + event.message.content = "from handler msg" + + def mock_process(*args): + yield event + + mock_llm_handler.process_message_flow = Mock(side_effect=mock_process) + + agent = ClassicAgent(**agent_base_params) + response = Mock() + response.message = None + + results = list(agent._handle_response(response, {}, [], log_context)) + assert results[0]["structured"] is True + assert results[0]["schema"] == {"type": "number"} + assert results[0]["answer"] == "from handler msg" diff --git a/tests/agents/test_workflow_schemas.py b/tests/agents/test_workflow_schemas.py new file mode 100644 index 00000000..21b3aedd --- /dev/null +++ b/tests/agents/test_workflow_schemas.py @@ -0,0 +1,556 @@ +from datetime import datetime, timezone + +import pytest +from bson import ObjectId +from pydantic import ValidationError + +from application.agents.workflows.schemas import ( + AgentNodeConfig, + AgentType, + ConditionCase, + ConditionNodeConfig, + ExecutionStatus, + NodeExecutionLog, + NodeType, + Position, + StateOperation, + Workflow, + WorkflowCreate, + WorkflowEdge, + WorkflowEdgeCreate, + WorkflowGraph, + WorkflowNode, + WorkflowNodeCreate, + WorkflowRun, + WorkflowRunCreate, +) + + +# ── Enum tests ─────────────────────────────────────────────────────────────── + + +class TestNodeType: + @pytest.mark.unit + def test_values(self): + assert NodeType.START == "start" + assert NodeType.END == "end" + assert NodeType.AGENT == "agent" + assert NodeType.NOTE == "note" + assert NodeType.STATE == "state" + assert NodeType.CONDITION == "condition" + + @pytest.mark.unit + def test_all_members(self): + assert set(NodeType) == { + NodeType.START, + NodeType.END, + NodeType.AGENT, + NodeType.NOTE, + NodeType.STATE, + NodeType.CONDITION, + } + + +class TestAgentType: + @pytest.mark.unit + def test_values(self): + assert AgentType.CLASSIC == "classic" + assert AgentType.REACT == "react" + assert AgentType.AGENTIC == "agentic" + assert AgentType.RESEARCH == "research" + + +class TestExecutionStatus: + @pytest.mark.unit + def test_values(self): + assert ExecutionStatus.PENDING == "pending" + assert ExecutionStatus.RUNNING == "running" + assert ExecutionStatus.COMPLETED == "completed" + assert ExecutionStatus.FAILED == "failed" + + +# ── Position ───────────────────────────────────────────────────────────────── + + +class TestPosition: + @pytest.mark.unit + def test_defaults(self): + p = Position() + assert p.x == 0.0 + assert p.y == 0.0 + + @pytest.mark.unit + def test_custom_values(self): + p = Position(x=10.5, y=-3.2) + assert p.x == 10.5 + assert p.y == -3.2 + + @pytest.mark.unit + def test_extra_fields_forbidden(self): + with pytest.raises(ValidationError): + Position(x=0, y=0, z=1) + + +# ── AgentNodeConfig ────────────────────────────────────────────────────────── + + +class TestAgentNodeConfig: + @pytest.mark.unit + def test_defaults(self): + c = AgentNodeConfig() + assert c.agent_type == AgentType.CLASSIC + assert c.llm_name is None + assert c.system_prompt == "You are a helpful assistant." + assert c.prompt_template == "" + assert c.output_variable is None + assert c.stream_to_user is True + assert c.tools == [] + assert c.sources == [] + assert c.chunks == "2" + assert c.retriever == "" + assert c.model_id is None + assert c.json_schema is None + + @pytest.mark.unit + def test_custom_values(self): + c = AgentNodeConfig( + agent_type=AgentType.REACT, + llm_name="gpt-4", + tools=["search"], + sources=["src1"], + chunks="5", + model_id="m1", + json_schema={"type": "object"}, + ) + assert c.agent_type == AgentType.REACT + assert c.llm_name == "gpt-4" + assert c.tools == ["search"] + assert c.sources == ["src1"] + assert c.chunks == "5" + assert c.model_id == "m1" + assert c.json_schema == {"type": "object"} + + @pytest.mark.unit + def test_extra_fields_allowed(self): + c = AgentNodeConfig(custom_field="value") + assert c.custom_field == "value" + + +# ── ConditionCase / ConditionNodeConfig ────────────────────────────────────── + + +class TestConditionCase: + @pytest.mark.unit + def test_alias(self): + c = ConditionCase(expression="x > 1", sourceHandle="handle-1") + assert c.source_handle == "handle-1" + + @pytest.mark.unit + def test_defaults(self): + c = ConditionCase(sourceHandle="h") + assert c.name is None + assert c.expression == "" + + @pytest.mark.unit + def test_extra_forbidden(self): + with pytest.raises(ValidationError): + ConditionCase(sourceHandle="h", extra="nope") + + +class TestConditionNodeConfig: + @pytest.mark.unit + def test_defaults(self): + c = ConditionNodeConfig() + assert c.mode == "simple" + assert c.cases == [] + + @pytest.mark.unit + def test_with_cases(self): + c = ConditionNodeConfig( + mode="advanced", + cases=[{"expression": "x > 1", "sourceHandle": "h1"}], + ) + assert c.mode == "advanced" + assert len(c.cases) == 1 + assert c.cases[0].source_handle == "h1" + + +# ── StateOperation ─────────────────────────────────────────────────────────── + + +class TestStateOperation: + @pytest.mark.unit + def test_defaults(self): + s = StateOperation() + assert s.expression == "" + assert s.target_variable == "" + + @pytest.mark.unit + def test_extra_forbidden(self): + with pytest.raises(ValidationError): + StateOperation(expression="a", target_variable="b", extra="no") + + +# ── WorkflowEdgeCreate / WorkflowEdge ─────────────────────────────────────── + + +class TestWorkflowEdgeCreate: + @pytest.mark.unit + def test_aliases(self): + e = WorkflowEdgeCreate( + id="e1", + workflow_id="w1", + source="n1", + target="n2", + sourceHandle="sh", + targetHandle="th", + ) + assert e.source_id == "n1" + assert e.target_id == "n2" + assert e.source_handle == "sh" + assert e.target_handle == "th" + + @pytest.mark.unit + def test_optional_handles_default_none(self): + e = WorkflowEdgeCreate(id="e1", workflow_id="w1", source="n1", target="n2") + assert e.source_handle is None + assert e.target_handle is None + + +class TestWorkflowEdge: + @pytest.mark.unit + def test_objectid_conversion(self): + oid = ObjectId() + e = WorkflowEdge( + _id=oid, + id="e1", + workflow_id="w1", + source="n1", + target="n2", + ) + assert e.mongo_id == str(oid) + + @pytest.mark.unit + def test_string_id_passthrough(self): + e = WorkflowEdge( + _id="string-id", + id="e1", + workflow_id="w1", + source="n1", + target="n2", + ) + assert e.mongo_id == "string-id" + + @pytest.mark.unit + def test_none_id(self): + e = WorkflowEdge(id="e1", workflow_id="w1", source="n1", target="n2") + assert e.mongo_id is None + + @pytest.mark.unit + def test_to_mongo_doc(self): + e = WorkflowEdge( + id="e1", + workflow_id="w1", + source="n1", + target="n2", + sourceHandle="sh", + targetHandle="th", + ) + doc = e.to_mongo_doc() + assert doc == { + "id": "e1", + "workflow_id": "w1", + "source_id": "n1", + "target_id": "n2", + "source_handle": "sh", + "target_handle": "th", + } + + +# ── WorkflowNodeCreate / WorkflowNode ─────────────────────────────────────── + + +class TestWorkflowNodeCreate: + @pytest.mark.unit + def test_defaults(self): + n = WorkflowNodeCreate(id="n1", workflow_id="w1", type=NodeType.AGENT) + assert n.title == "Node" + assert n.description is None + assert n.position.x == 0.0 + assert n.position.y == 0.0 + assert n.config == {} + + @pytest.mark.unit + def test_position_from_dict(self): + n = WorkflowNodeCreate( + id="n1", + workflow_id="w1", + type=NodeType.START, + position={"x": 100, "y": 200}, + ) + assert isinstance(n.position, Position) + assert n.position.x == 100 + assert n.position.y == 200 + + @pytest.mark.unit + def test_position_from_position_object(self): + pos = Position(x=5, y=10) + n = WorkflowNodeCreate( + id="n1", workflow_id="w1", type=NodeType.END, position=pos + ) + assert n.position is pos + + +class TestWorkflowNode: + @pytest.mark.unit + def test_objectid_conversion(self): + oid = ObjectId() + n = WorkflowNode( + _id=oid, id="n1", workflow_id="w1", type=NodeType.AGENT + ) + assert n.mongo_id == str(oid) + + @pytest.mark.unit + def test_to_mongo_doc(self): + n = WorkflowNode( + id="n1", + workflow_id="w1", + type=NodeType.AGENT, + title="My Node", + description="desc", + position={"x": 10, "y": 20}, + config={"key": "val"}, + ) + doc = n.to_mongo_doc() + assert doc == { + "id": "n1", + "workflow_id": "w1", + "type": "agent", + "title": "My Node", + "description": "desc", + "position": {"x": 10.0, "y": 20.0}, + "config": {"key": "val"}, + } + + +# ── WorkflowCreate / Workflow ─────────────────────────────────────────────── + + +class TestWorkflowCreate: + @pytest.mark.unit + def test_defaults(self): + w = WorkflowCreate() + assert w.name == "New Workflow" + assert w.description is None + assert w.user is None + + @pytest.mark.unit + def test_custom_values(self): + w = WorkflowCreate(name="Test", description="d", user="u1") + assert w.name == "Test" + assert w.description == "d" + assert w.user == "u1" + + +class TestWorkflow: + @pytest.mark.unit + def test_objectid_conversion(self): + oid = ObjectId() + w = Workflow(_id=oid) + assert w.id == str(oid) + + @pytest.mark.unit + def test_string_id(self): + w = Workflow(_id="abc") + assert w.id == "abc" + + @pytest.mark.unit + def test_none_id(self): + w = Workflow() + assert w.id is None + + @pytest.mark.unit + def test_datetime_defaults(self): + before = datetime.now(timezone.utc) + w = Workflow() + after = datetime.now(timezone.utc) + assert before <= w.created_at <= after + assert before <= w.updated_at <= after + + @pytest.mark.unit + def test_to_mongo_doc(self): + w = Workflow(name="W", description="d", user="u1") + doc = w.to_mongo_doc() + assert doc["name"] == "W" + assert doc["description"] == "d" + assert doc["user"] == "u1" + assert "created_at" in doc + assert "updated_at" in doc + + +# ── WorkflowGraph ─────────────────────────────────────────────────────────── + + +class TestWorkflowGraph: + @pytest.fixture + def graph(self): + workflow = Workflow(name="test") + nodes = [ + WorkflowNode(id="start", workflow_id="w1", type=NodeType.START), + WorkflowNode(id="agent1", workflow_id="w1", type=NodeType.AGENT), + WorkflowNode(id="end", workflow_id="w1", type=NodeType.END), + ] + edges = [ + WorkflowEdge( + id="e1", workflow_id="w1", source="start", target="agent1" + ), + WorkflowEdge( + id="e2", workflow_id="w1", source="agent1", target="end" + ), + ] + return WorkflowGraph(workflow=workflow, nodes=nodes, edges=edges) + + @pytest.mark.unit + def test_get_node_by_id_found(self, graph): + node = graph.get_node_by_id("agent1") + assert node is not None + assert node.id == "agent1" + assert node.type == NodeType.AGENT + + @pytest.mark.unit + def test_get_node_by_id_not_found(self, graph): + assert graph.get_node_by_id("nonexistent") is None + + @pytest.mark.unit + def test_get_start_node(self, graph): + start = graph.get_start_node() + assert start is not None + assert start.id == "start" + assert start.type == NodeType.START + + @pytest.mark.unit + def test_get_start_node_missing(self): + g = WorkflowGraph( + workflow=Workflow(), + nodes=[ + WorkflowNode(id="a", workflow_id="w", type=NodeType.AGENT), + ], + ) + assert g.get_start_node() is None + + @pytest.mark.unit + def test_get_outgoing_edges(self, graph): + edges = graph.get_outgoing_edges("start") + assert len(edges) == 1 + assert edges[0].target_id == "agent1" + + @pytest.mark.unit + def test_get_outgoing_edges_none(self, graph): + edges = graph.get_outgoing_edges("end") + assert edges == [] + + @pytest.mark.unit + def test_empty_graph(self): + g = WorkflowGraph(workflow=Workflow()) + assert g.nodes == [] + assert g.edges == [] + assert g.get_start_node() is None + + +# ── NodeExecutionLog ───────────────────────────────────────────────────────── + + +class TestNodeExecutionLog: + @pytest.mark.unit + def test_required_fields(self): + now = datetime.now(timezone.utc) + log = NodeExecutionLog( + node_id="n1", + node_type="agent", + status=ExecutionStatus.RUNNING, + started_at=now, + ) + assert log.node_id == "n1" + assert log.completed_at is None + assert log.error is None + assert log.state_snapshot == {} + + @pytest.mark.unit + def test_full_log(self): + started = datetime.now(timezone.utc) + completed = datetime.now(timezone.utc) + log = NodeExecutionLog( + node_id="n1", + node_type="agent", + status=ExecutionStatus.COMPLETED, + started_at=started, + completed_at=completed, + error=None, + state_snapshot={"key": "value"}, + ) + assert log.completed_at == completed + assert log.state_snapshot == {"key": "value"} + + @pytest.mark.unit + def test_extra_forbidden(self): + with pytest.raises(ValidationError): + NodeExecutionLog( + node_id="n", + node_type="agent", + status=ExecutionStatus.PENDING, + started_at=datetime.now(timezone.utc), + extra="no", + ) + + +# ── WorkflowRunCreate / WorkflowRun ───────────────────────────────────────── + + +class TestWorkflowRunCreate: + @pytest.mark.unit + def test_defaults(self): + r = WorkflowRunCreate(workflow_id="w1") + assert r.workflow_id == "w1" + assert r.inputs == {} + + +class TestWorkflowRun: + @pytest.mark.unit + def test_defaults(self): + r = WorkflowRun(workflow_id="w1") + assert r.status == ExecutionStatus.PENDING + assert r.inputs == {} + assert r.outputs == {} + assert r.steps == [] + assert r.completed_at is None + + @pytest.mark.unit + def test_objectid_conversion(self): + oid = ObjectId() + r = WorkflowRun(_id=oid, workflow_id="w1") + assert r.id == str(oid) + + @pytest.mark.unit + def test_to_mongo_doc(self): + now = datetime.now(timezone.utc) + log = NodeExecutionLog( + node_id="n1", + node_type="agent", + status=ExecutionStatus.COMPLETED, + started_at=now, + ) + r = WorkflowRun( + workflow_id="w1", + status=ExecutionStatus.RUNNING, + inputs={"q": "hello"}, + outputs={"a": "world"}, + steps=[log], + ) + doc = r.to_mongo_doc() + assert doc["workflow_id"] == "w1" + assert doc["status"] == "running" + assert doc["inputs"] == {"q": "hello"} + assert doc["outputs"] == {"a": "world"} + assert len(doc["steps"]) == 1 + assert doc["steps"][0]["node_id"] == "n1" + assert doc["completed_at"] is None diff --git a/tests/api/user/test_tasks.py b/tests/api/user/test_tasks.py new file mode 100644 index 00000000..a628af3b --- /dev/null +++ b/tests/api/user/test_tasks.py @@ -0,0 +1,240 @@ +from datetime import timedelta +from unittest.mock import ANY, MagicMock, patch + +import pytest + + +class TestIngestTask: + @pytest.mark.unit + @patch("application.api.user.tasks.ingest_worker") + def test_calls_ingest_worker(self, mock_worker): + from application.api.user.tasks import ingest + + mock_worker.return_value = {"status": "ok"} + + result = ingest("dir", ["pdf"], "job1", "user1", "/path", "file.pdf") + + mock_worker.assert_called_once_with( + ANY, "dir", ["pdf"], "job1", "/path", "file.pdf", "user1", + file_name_map=None, + ) + assert result == {"status": "ok"} + + @pytest.mark.unit + @patch("application.api.user.tasks.ingest_worker") + def test_passes_file_name_map(self, mock_worker): + from application.api.user.tasks import ingest + + mock_worker.return_value = {"status": "ok"} + name_map = {"a.pdf": "b.pdf"} + + ingest("dir", ["pdf"], "job1", "user1", "/path", "file.pdf", + file_name_map=name_map) + + mock_worker.assert_called_once_with( + ANY, "dir", ["pdf"], "job1", "/path", "file.pdf", "user1", + file_name_map=name_map, + ) + + +class TestIngestRemoteTask: + @pytest.mark.unit + @patch("application.api.user.tasks.remote_worker") + def test_calls_remote_worker(self, mock_worker): + from application.api.user.tasks import ingest_remote + + mock_worker.return_value = {"status": "ok"} + + result = ingest_remote({"url": "http://x"}, "job1", "user1", "web") + + mock_worker.assert_called_once_with( + ANY, {"url": "http://x"}, "job1", "user1", "web" + ) + assert result == {"status": "ok"} + + +class TestReingestSourceTask: + @pytest.mark.unit + @patch("application.worker.reingest_source_worker") + def test_calls_reingest_worker(self, mock_worker): + from application.api.user.tasks import reingest_source_task + + mock_worker.return_value = {"status": "ok"} + + result = reingest_source_task("source123", "user1") + + mock_worker.assert_called_once_with(ANY, "source123", "user1") + assert result == {"status": "ok"} + + +class TestScheduleSyncsTask: + @pytest.mark.unit + @patch("application.api.user.tasks.sync_worker") + def test_calls_sync_worker(self, mock_worker): + from application.api.user.tasks import schedule_syncs + + mock_worker.return_value = {"status": "ok"} + + result = schedule_syncs("daily") + + mock_worker.assert_called_once_with(ANY, "daily") + assert result == {"status": "ok"} + + +class TestSyncSourceTask: + @pytest.mark.unit + @patch("application.api.user.tasks.sync") + def test_calls_sync(self, mock_sync): + from application.api.user.tasks import sync_source + + mock_sync.return_value = {"status": "ok"} + + result = sync_source( + {"data": 1}, "job1", "user1", "web", "daily", "classic", "doc1" + ) + + mock_sync.assert_called_once_with( + ANY, {"data": 1}, "job1", "user1", "web", "daily", "classic", "doc1" + ) + assert result == {"status": "ok"} + + +class TestStoreAttachmentTask: + @pytest.mark.unit + @patch("application.api.user.tasks.attachment_worker") + def test_calls_attachment_worker(self, mock_worker): + from application.api.user.tasks import store_attachment + + mock_worker.return_value = {"status": "ok"} + + result = store_attachment({"file": "info"}, "user1") + + mock_worker.assert_called_once_with(ANY, {"file": "info"}, "user1") + assert result == {"status": "ok"} + + +class TestProcessAgentWebhookTask: + @pytest.mark.unit + @patch("application.api.user.tasks.agent_webhook_worker") + def test_calls_agent_webhook_worker(self, mock_worker): + from application.api.user.tasks import process_agent_webhook + + mock_worker.return_value = {"status": "ok"} + + result = process_agent_webhook("agent123", {"event": "test"}) + + mock_worker.assert_called_once_with(ANY, "agent123", {"event": "test"}) + assert result == {"status": "ok"} + + +class TestIngestConnectorTask: + @pytest.mark.unit + @patch("application.worker.ingest_connector") + def test_calls_ingest_connector_defaults(self, mock_worker): + from application.api.user.tasks import ingest_connector_task + + mock_worker.return_value = {"status": "ok"} + + result = ingest_connector_task("job1", "user1", "gdrive") + + mock_worker.assert_called_once_with( + ANY, + "job1", + "user1", + "gdrive", + session_token=None, + file_ids=None, + folder_ids=None, + recursive=True, + retriever="classic", + operation_mode="upload", + doc_id=None, + sync_frequency="never", + ) + assert result == {"status": "ok"} + + @pytest.mark.unit + @patch("application.worker.ingest_connector") + def test_calls_ingest_connector_custom(self, mock_worker): + from application.api.user.tasks import ingest_connector_task + + mock_worker.return_value = {"status": "ok"} + + result = ingest_connector_task( + "job1", + "user1", + "sharepoint", + session_token="tok", + file_ids=["f1"], + folder_ids=["d1"], + recursive=False, + retriever="duckdb", + operation_mode="sync", + doc_id="doc1", + sync_frequency="daily", + ) + + mock_worker.assert_called_once_with( + ANY, + "job1", + "user1", + "sharepoint", + session_token="tok", + file_ids=["f1"], + folder_ids=["d1"], + recursive=False, + retriever="duckdb", + operation_mode="sync", + doc_id="doc1", + sync_frequency="daily", + ) + assert result == {"status": "ok"} + + +class TestSetupPeriodicTasks: + @pytest.mark.unit + def test_registers_periodic_tasks(self): + from application.api.user.tasks import setup_periodic_tasks + + sender = MagicMock() + + setup_periodic_tasks(sender) + + assert sender.add_periodic_task.call_count == 3 + + calls = sender.add_periodic_task.call_args_list + + # daily + assert calls[0][0][0] == timedelta(days=1) + # weekly + assert calls[1][0][0] == timedelta(weeks=1) + # monthly + assert calls[2][0][0] == timedelta(days=30) + + +class TestMcpOauthTask: + @pytest.mark.unit + @patch("application.api.user.tasks.mcp_oauth") + def test_calls_mcp_oauth(self, mock_worker): + from application.api.user.tasks import mcp_oauth_task + + mock_worker.return_value = {"url": "http://auth"} + + result = mcp_oauth_task({"server": "mcp"}, "user1") + + mock_worker.assert_called_once_with(ANY, {"server": "mcp"}, "user1") + assert result == {"url": "http://auth"} + + +class TestMcpOauthStatusTask: + @pytest.mark.unit + @patch("application.api.user.tasks.mcp_oauth_status") + def test_calls_mcp_oauth_status(self, mock_worker): + from application.api.user.tasks import mcp_oauth_status_task + + mock_worker.return_value = {"status": "authorized"} + + result = mcp_oauth_status_task("task123") + + mock_worker.assert_called_once_with(ANY, "task123") + assert result == {"status": "authorized"} diff --git a/tests/core/test_model_utils.py b/tests/core/test_model_utils.py new file mode 100644 index 00000000..7aefeeb5 --- /dev/null +++ b/tests/core/test_model_utils.py @@ -0,0 +1,382 @@ +from unittest.mock import MagicMock, patch + +import pytest + +from application.core.model_settings import ( + AvailableModel, + ModelCapabilities, + ModelProvider, + ModelRegistry, +) + + +@pytest.fixture(autouse=True) +def _reset_registry(): + """Reset ModelRegistry singleton between tests.""" + ModelRegistry._instance = None + ModelRegistry._initialized = False + yield + ModelRegistry._instance = None + ModelRegistry._initialized = False + + +def _make_model( + model_id="test-model", + provider=ModelProvider.OPENAI, + display_name="Test Model", + context_window=128000, + supports_tools=True, + supports_structured_output=False, + supported_attachment_types=None, + enabled=True, + base_url=None, +): + return AvailableModel( + id=model_id, + provider=provider, + display_name=display_name, + capabilities=ModelCapabilities( + supports_tools=supports_tools, + supports_structured_output=supports_structured_output, + supported_attachment_types=supported_attachment_types or [], + context_window=context_window, + ), + enabled=enabled, + base_url=base_url, + ) + + +# ── get_api_key_for_provider ───────────────────────────────────────────────── + + +class TestGetApiKeyForProvider: + """settings is lazily imported inside the function body, so we patch + at application.core.settings.settings (the actual module attribute).""" + + @pytest.mark.unit + def test_openai_key(self): + with patch("application.core.settings.settings") as mock_settings: + mock_settings.OPENAI_API_KEY = "sk-openai" + mock_settings.API_KEY = "sk-fallback" + mock_settings.OPEN_ROUTER_API_KEY = None + mock_settings.NOVITA_API_KEY = None + mock_settings.ANTHROPIC_API_KEY = None + mock_settings.GOOGLE_API_KEY = None + mock_settings.GROQ_API_KEY = None + mock_settings.HUGGINGFACE_API_KEY = None + + from application.core.model_utils import get_api_key_for_provider + + assert get_api_key_for_provider("openai") == "sk-openai" + + @pytest.mark.unit + def test_anthropic_key(self): + with patch("application.core.settings.settings") as mock_settings: + mock_settings.ANTHROPIC_API_KEY = "sk-anthropic" + mock_settings.API_KEY = "sk-fallback" + + from application.core.model_utils import get_api_key_for_provider + + assert get_api_key_for_provider("anthropic") == "sk-anthropic" + + @pytest.mark.unit + def test_google_key(self): + with patch("application.core.settings.settings") as mock_settings: + mock_settings.GOOGLE_API_KEY = "sk-google" + mock_settings.API_KEY = "sk-fallback" + + from application.core.model_utils import get_api_key_for_provider + + assert get_api_key_for_provider("google") == "sk-google" + + @pytest.mark.unit + def test_groq_key(self): + with patch("application.core.settings.settings") as mock_settings: + mock_settings.GROQ_API_KEY = "sk-groq" + mock_settings.API_KEY = "sk-fallback" + + from application.core.model_utils import get_api_key_for_provider + + assert get_api_key_for_provider("groq") == "sk-groq" + + @pytest.mark.unit + def test_openrouter_key(self): + with patch("application.core.settings.settings") as mock_settings: + mock_settings.OPEN_ROUTER_API_KEY = "sk-or" + mock_settings.API_KEY = "sk-fallback" + + from application.core.model_utils import get_api_key_for_provider + + assert get_api_key_for_provider("openrouter") == "sk-or" + + @pytest.mark.unit + def test_novita_key(self): + with patch("application.core.settings.settings") as mock_settings: + mock_settings.NOVITA_API_KEY = "sk-novita" + mock_settings.API_KEY = "sk-fallback" + + from application.core.model_utils import get_api_key_for_provider + + assert get_api_key_for_provider("novita") == "sk-novita" + + @pytest.mark.unit + def test_huggingface_key(self): + with patch("application.core.settings.settings") as mock_settings: + mock_settings.HUGGINGFACE_API_KEY = "hf-key" + mock_settings.API_KEY = "sk-fallback" + + from application.core.model_utils import get_api_key_for_provider + + assert get_api_key_for_provider("huggingface") == "hf-key" + + @pytest.mark.unit + def test_docsgpt_returns_fallback(self): + with patch("application.core.settings.settings") as mock_settings: + mock_settings.API_KEY = "sk-fallback" + + from application.core.model_utils import get_api_key_for_provider + + assert get_api_key_for_provider("docsgpt") == "sk-fallback" + + @pytest.mark.unit + def test_llama_cpp_returns_fallback(self): + with patch("application.core.settings.settings") as mock_settings: + mock_settings.API_KEY = "sk-fallback" + + from application.core.model_utils import get_api_key_for_provider + + assert get_api_key_for_provider("llama.cpp") == "sk-fallback" + + @pytest.mark.unit + def test_unknown_provider_returns_fallback(self): + with patch("application.core.settings.settings") as mock_settings: + mock_settings.API_KEY = "sk-fallback" + + from application.core.model_utils import get_api_key_for_provider + + assert get_api_key_for_provider("unknown_provider") == "sk-fallback" + + @pytest.mark.unit + def test_azure_openai_key(self): + with patch("application.core.settings.settings") as mock_settings: + mock_settings.API_KEY = "sk-azure" + + from application.core.model_utils import get_api_key_for_provider + + assert get_api_key_for_provider("azure_openai") == "sk-azure" + + +# ── get_all_available_models ───────────────────────────────────────────────── + + +class TestGetAllAvailableModels: + @pytest.mark.unit + @patch("application.core.model_utils.ModelRegistry.get_instance") + def test_returns_enabled_models_as_dict(self, mock_get_instance): + model_a = _make_model("model-a", display_name="Model A") + model_b = _make_model("model-b", display_name="Model B") + mock_registry = MagicMock() + mock_registry.get_enabled_models.return_value = [model_a, model_b] + mock_get_instance.return_value = mock_registry + + from application.core.model_utils import get_all_available_models + + result = get_all_available_models() + + assert "model-a" in result + assert "model-b" in result + assert result["model-a"]["display_name"] == "Model A" + assert result["model-b"]["display_name"] == "Model B" + + @pytest.mark.unit + @patch("application.core.model_utils.ModelRegistry.get_instance") + def test_empty_registry(self, mock_get_instance): + mock_registry = MagicMock() + mock_registry.get_enabled_models.return_value = [] + mock_get_instance.return_value = mock_registry + + from application.core.model_utils import get_all_available_models + + assert get_all_available_models() == {} + + +# ── validate_model_id ──────────────────────────────────────────────────────── + + +class TestValidateModelId: + @pytest.mark.unit + @patch("application.core.model_utils.ModelRegistry.get_instance") + def test_exists(self, mock_get_instance): + mock_registry = MagicMock() + mock_registry.model_exists.return_value = True + mock_get_instance.return_value = mock_registry + + from application.core.model_utils import validate_model_id + + assert validate_model_id("gpt-4") is True + + @pytest.mark.unit + @patch("application.core.model_utils.ModelRegistry.get_instance") + def test_not_exists(self, mock_get_instance): + mock_registry = MagicMock() + mock_registry.model_exists.return_value = False + mock_get_instance.return_value = mock_registry + + from application.core.model_utils import validate_model_id + + assert validate_model_id("nonexistent") is False + + +# ── get_model_capabilities ─────────────────────────────────────────────────── + + +class TestGetModelCapabilities: + @pytest.mark.unit + @patch("application.core.model_utils.ModelRegistry.get_instance") + def test_model_found(self, mock_get_instance): + model = _make_model( + "gpt-4", + context_window=8192, + supports_tools=True, + supports_structured_output=True, + supported_attachment_types=["image/png"], + ) + mock_registry = MagicMock() + mock_registry.get_model.return_value = model + mock_get_instance.return_value = mock_registry + + from application.core.model_utils import get_model_capabilities + + caps = get_model_capabilities("gpt-4") + + assert caps is not None + assert caps["supported_attachment_types"] == ["image/png"] + assert caps["supports_tools"] is True + assert caps["supports_structured_output"] is True + assert caps["context_window"] == 8192 + + @pytest.mark.unit + @patch("application.core.model_utils.ModelRegistry.get_instance") + def test_model_not_found(self, mock_get_instance): + mock_registry = MagicMock() + mock_registry.get_model.return_value = None + mock_get_instance.return_value = mock_registry + + from application.core.model_utils import get_model_capabilities + + assert get_model_capabilities("nonexistent") is None + + +# ── get_default_model_id ───────────────────────────────────────────────────── + + +class TestGetDefaultModelId: + @pytest.mark.unit + @patch("application.core.model_utils.ModelRegistry.get_instance") + def test_returns_default(self, mock_get_instance): + mock_registry = MagicMock() + mock_registry.default_model_id = "gpt-4" + mock_get_instance.return_value = mock_registry + + from application.core.model_utils import get_default_model_id + + assert get_default_model_id() == "gpt-4" + + +# ── get_provider_from_model_id ─────────────────────────────────────────────── + + +class TestGetProviderFromModelId: + @pytest.mark.unit + @patch("application.core.model_utils.ModelRegistry.get_instance") + def test_model_found(self, mock_get_instance): + model = _make_model("gpt-4", provider=ModelProvider.OPENAI) + mock_registry = MagicMock() + mock_registry.get_model.return_value = model + mock_get_instance.return_value = mock_registry + + from application.core.model_utils import get_provider_from_model_id + + assert get_provider_from_model_id("gpt-4") == "openai" + + @pytest.mark.unit + @patch("application.core.model_utils.ModelRegistry.get_instance") + def test_model_not_found(self, mock_get_instance): + mock_registry = MagicMock() + mock_registry.get_model.return_value = None + mock_get_instance.return_value = mock_registry + + from application.core.model_utils import get_provider_from_model_id + + assert get_provider_from_model_id("nonexistent") is None + + +# ── get_token_limit ────────────────────────────────────────────────────────── + + +class TestGetTokenLimit: + @pytest.mark.unit + @patch("application.core.model_utils.ModelRegistry.get_instance") + def test_model_found(self, mock_get_instance): + model = _make_model("gpt-4", context_window=8192) + mock_registry = MagicMock() + mock_registry.get_model.return_value = model + mock_get_instance.return_value = mock_registry + + from application.core.model_utils import get_token_limit + + assert get_token_limit("gpt-4") == 8192 + + @pytest.mark.unit + @patch("application.core.model_utils.ModelRegistry.get_instance") + def test_model_not_found_returns_default(self, mock_get_instance): + mock_registry = MagicMock() + mock_registry.get_model.return_value = None + mock_get_instance.return_value = mock_registry + + with patch("application.core.settings.settings") as mock_settings: + mock_settings.DEFAULT_LLM_TOKEN_LIMIT = 128000 + + from application.core.model_utils import get_token_limit + + assert get_token_limit("nonexistent") == 128000 + + +# ── get_base_url_for_model ─────────────────────────────────────────────────── + + +class TestGetBaseUrlForModel: + @pytest.mark.unit + @patch("application.core.model_utils.ModelRegistry.get_instance") + def test_model_with_base_url(self, mock_get_instance): + model = _make_model("custom-model", base_url="http://localhost:8080") + mock_registry = MagicMock() + mock_registry.get_model.return_value = model + mock_get_instance.return_value = mock_registry + + from application.core.model_utils import get_base_url_for_model + + assert get_base_url_for_model("custom-model") == "http://localhost:8080" + + @pytest.mark.unit + @patch("application.core.model_utils.ModelRegistry.get_instance") + def test_model_without_base_url(self, mock_get_instance): + model = _make_model("gpt-4", base_url=None) + mock_registry = MagicMock() + mock_registry.get_model.return_value = model + mock_get_instance.return_value = mock_registry + + from application.core.model_utils import get_base_url_for_model + + assert get_base_url_for_model("gpt-4") is None + + @pytest.mark.unit + @patch("application.core.model_utils.ModelRegistry.get_instance") + def test_model_not_found(self, mock_get_instance): + mock_registry = MagicMock() + mock_registry.get_model.return_value = None + mock_get_instance.return_value = mock_registry + + from application.core.model_utils import get_base_url_for_model + + assert get_base_url_for_model("nonexistent") is None diff --git a/tests/llm/handlers/test_llm_handlers.py b/tests/llm/handlers/test_llm_handlers.py index a44b2de5..304d89fb 100644 --- a/tests/llm/handlers/test_llm_handlers.py +++ b/tests/llm/handlers/test_llm_handlers.py @@ -1,5 +1,7 @@ from typing import Any, Dict, Generator -from unittest.mock import Mock, patch +from unittest.mock import Mock, MagicMock, patch + +import pytest from application.llm.handlers.base import LLMHandler, LLMResponse, ToolCall @@ -35,6 +37,16 @@ class TestToolCall: assert tool_call.arguments == {} assert tool_call.index is None + def test_tool_call_thought_signature(self): + tc = ToolCall( + id="1", name="fn", arguments={}, thought_signature="sig123" + ) + assert tc.thought_signature == "sig123" + + def test_tool_call_thought_signature_default_none(self): + tc = ToolCall(id="1", name="fn", arguments={}) + assert tc.thought_signature is None + class TestLLMResponse: def test_llm_response_creation(self): @@ -211,3 +223,1360 @@ class TestLLMHandler: chunks = list(result) assert chunks == ["chunk1", "chunk2"] + + +# --------------------------------------------------------------------------- +# _append_unsupported_attachments +# --------------------------------------------------------------------------- + + +class TestAppendUnsupportedAttachments: + + def test_with_content(self): + handler = ConcreteHandler() + messages = [{"role": "system", "content": "You are helpful."}] + attachments = [{"id": "a1", "content": "File contents here"}] + + result = handler._append_unsupported_attachments(messages, attachments) + assert "File contents here" in result[0]["content"] + + def test_without_content(self): + handler = ConcreteHandler() + messages = [{"role": "system", "content": "sys"}] + attachments = [{"id": "a1", "mime_type": "text/plain"}] + + result = handler._append_unsupported_attachments(messages, attachments) + # No content key → no change to system prompt + assert result[0]["content"] == "sys" + + def test_no_system_message_creates_one(self): + handler = ConcreteHandler() + messages = [{"role": "user", "content": "hello"}] + attachments = [{"id": "a1", "content": "data"}] + + result = handler._append_unsupported_attachments(messages, attachments) + system_msgs = [m for m in result if m["role"] == "system"] + assert len(system_msgs) == 1 + assert "data" in system_msgs[0]["content"] + + def test_multiple_attachments(self): + handler = ConcreteHandler() + messages = [{"role": "system", "content": "base"}] + attachments = [ + {"id": "a1", "content": "content1"}, + {"id": "a2", "content": "content2"}, + ] + + result = handler._append_unsupported_attachments(messages, attachments) + assert "content1" in result[0]["content"] + assert "content2" in result[0]["content"] + + def test_original_messages_not_mutated(self): + handler = ConcreteHandler() + original = [{"role": "system", "content": "sys"}] + handler._append_unsupported_attachments( + original, [{"id": "a", "content": "x"}] + ) + # The shallow copy means the dict inside IS mutated, but the list is not + assert len(original) == 1 + + +# --------------------------------------------------------------------------- +# _prune_messages_minimal +# --------------------------------------------------------------------------- + + +class TestPruneMessagesMinimal: + + def test_normal_case(self): + handler = ConcreteHandler() + messages = [ + {"role": "system", "content": "sys prompt"}, + {"role": "user", "content": "first question"}, + {"role": "assistant", "content": "first answer"}, + {"role": "user", "content": "second question"}, + ] + result = handler._prune_messages_minimal(messages) + assert result is not None + assert len(result) == 2 + assert result[0]["role"] == "system" + assert result[1]["role"] == "user" + assert result[1]["content"] == "second question" + + def test_no_system_message(self): + handler = ConcreteHandler() + messages = [{"role": "user", "content": "hi"}] + result = handler._prune_messages_minimal(messages) + assert result is None + + def test_no_user_message(self): + handler = ConcreteHandler() + messages = [{"role": "system", "content": "sys"}] + result = handler._prune_messages_minimal(messages) + assert result is None + + def test_falls_back_to_non_user_role(self): + handler = ConcreteHandler() + messages = [ + {"role": "system", "content": "sys"}, + {"role": "assistant", "content": "response"}, + ] + result = handler._prune_messages_minimal(messages) + assert result is not None + assert result[1]["role"] == "assistant" + + +# --------------------------------------------------------------------------- +# _extract_text_from_content +# --------------------------------------------------------------------------- + + +class TestExtractTextFromContent: + + def test_string(self): + handler = ConcreteHandler() + assert handler._extract_text_from_content("hello") == "hello" + + def test_list_with_text(self): + handler = ConcreteHandler() + content = [{"text": "part1"}, {"text": "part2"}] + result = handler._extract_text_from_content(content) + assert "part1" in result + assert "part2" in result + + def test_list_with_function_call(self): + handler = ConcreteHandler() + content = [{"function_call": {"name": "search", "args": {}}}] + result = handler._extract_text_from_content(content) + assert "function_call" in result + + def test_list_with_function_response(self): + handler = ConcreteHandler() + content = [{"function_response": {"name": "search", "response": "ok"}}] + result = handler._extract_text_from_content(content) + assert "function_response" in result + + def test_list_with_files(self): + handler = ConcreteHandler() + content = [{"files": ["/tmp/a.txt"]}] + result = handler._extract_text_from_content(content) + assert "files" in result + + def test_list_with_none_text(self): + handler = ConcreteHandler() + content = [{"text": None}] + result = handler._extract_text_from_content(content) + assert result == "" + + def test_empty_list(self): + handler = ConcreteHandler() + assert handler._extract_text_from_content([]) == "" + + def test_none_returns_empty(self): + handler = ConcreteHandler() + assert handler._extract_text_from_content(None) == "" + + def test_integer_returns_empty(self): + handler = ConcreteHandler() + assert handler._extract_text_from_content(42) == "" + + +# --------------------------------------------------------------------------- +# _build_conversation_from_messages +# --------------------------------------------------------------------------- + + +class TestBuildConversationFromMessages: + + def test_basic_conversation(self): + handler = ConcreteHandler() + messages = [ + {"role": "system", "content": "sys"}, + {"role": "user", "content": "hello"}, + {"role": "assistant", "content": "hi there"}, + ] + result = handler._build_conversation_from_messages(messages) + assert result is not None + assert len(result["queries"]) == 1 + assert result["queries"][0]["prompt"] == "hello" + assert result["queries"][0]["response"] == "hi there" + + def test_with_tool_calls(self): + handler = ConcreteHandler() + messages = [ + {"role": "user", "content": "search for X"}, + { + "role": "assistant", + "content": [ + { + "function_call": { + "name": "search", + "args": {"q": "X"}, + "call_id": "c1", + } + } + ], + }, + { + "role": "tool", + "content": [ + { + "function_response": { + "name": "search", + "response": {"result": "found"}, + "call_id": "c1", + } + } + ], + }, + {"role": "assistant", "content": "I found X"}, + ] + result = handler._build_conversation_from_messages(messages) + assert result is not None + queries = result["queries"] + assert len(queries) == 1 + assert queries[0]["prompt"] == "search for X" + assert queries[0]["response"] == "I found X" + + def test_empty_messages(self): + handler = ConcreteHandler() + result = handler._build_conversation_from_messages([]) + assert result is None + + def test_system_only(self): + handler = ConcreteHandler() + result = handler._build_conversation_from_messages( + [{"role": "system", "content": "sys"}] + ) + assert result is None + + def test_unfinished_prompt_committed(self): + handler = ConcreteHandler() + messages = [ + {"role": "user", "content": "question"}, + ] + result = handler._build_conversation_from_messages(messages) + assert result is not None + assert result["queries"][0]["prompt"] == "question" + assert result["queries"][0]["response"] == "" + + def test_tool_response_without_matching_call(self): + handler = ConcreteHandler() + messages = [ + {"role": "user", "content": "q"}, + {"role": "assistant", "content": "a"}, + {"role": "tool", "content": "tool output"}, + ] + result = handler._build_conversation_from_messages(messages) + assert result is not None + # Tool output appended to last query + assert len(result["queries"]) >= 1 + + def test_compression_metadata_present(self): + handler = ConcreteHandler() + messages = [ + {"role": "user", "content": "q"}, + {"role": "assistant", "content": "a"}, + ] + result = handler._build_conversation_from_messages(messages) + assert "compression_metadata" in result + assert result["compression_metadata"]["is_compressed"] is False + + def test_model_role_treated_as_assistant(self): + handler = ConcreteHandler() + messages = [ + {"role": "user", "content": "hello"}, + {"role": "model", "content": "hi from model"}, + ] + result = handler._build_conversation_from_messages(messages) + assert result is not None + assert result["queries"][0]["response"] == "hi from model" + + +# --------------------------------------------------------------------------- +# _rebuild_messages_after_compression +# --------------------------------------------------------------------------- + + +class TestRebuildMessagesAfterCompression: + + def test_delegates_to_message_builder(self): + handler = ConcreteHandler() + messages = [{"role": "system", "content": "sys"}] + + with patch( + "application.api.answer.services.compression.message_builder.MessageBuilder.rebuild_messages_after_compression", + return_value=[{"role": "system", "content": "rebuilt"}], + ) as mock_rebuild: + result = handler._rebuild_messages_after_compression( + messages, "summary", [{"prompt": "q", "response": "a"}] + ) + mock_rebuild.assert_called_once() + assert result == [{"role": "system", "content": "rebuilt"}] + + +# --------------------------------------------------------------------------- +# _convert_pdf_to_images +# --------------------------------------------------------------------------- + + +class TestConvertPdfToImages: + + def test_no_path_raises(self): + handler = ConcreteHandler() + with pytest.raises(ValueError, match="No file path"): + handler._convert_pdf_to_images({"mime_type": "application/pdf"}) + + def test_delegates_to_utils(self): + handler = ConcreteHandler() + mock_storage = Mock() + expected = [{"data": "img1", "mime_type": "image/png", "page": 1}] + + with patch( + "application.storage.storage_creator.StorageCreator.get_storage", + return_value=mock_storage, + ), patch( + "application.utils.convert_pdf_to_images", + return_value=expected, + ) as mock_convert: + result = handler._convert_pdf_to_images( + {"path": "/tmp/doc.pdf", "mime_type": "application/pdf"} + ) + mock_convert.assert_called_once_with( + file_path="/tmp/doc.pdf", + storage=mock_storage, + max_pages=20, + dpi=150, + ) + assert result == expected + + +# --------------------------------------------------------------------------- +# prepare_messages — synthetic PDF support +# --------------------------------------------------------------------------- + + +class TestPrepareMessagesSyntheticPDF: + + def test_pdf_converted_when_images_supported_not_pdf(self): + handler = ConcreteHandler() + messages = [{"role": "user", "content": "analyse"}] + attachments = [{"mime_type": "application/pdf", "path": "/tmp/doc.pdf"}] + + mock_agent = Mock() + mock_agent.llm.get_supported_attachment_types.return_value = ["image/png"] + mock_agent.llm.prepare_messages_with_attachments.return_value = messages + + converted = [{"data": "b64", "mime_type": "image/png", "page": 1}] + with patch.object(handler, "_convert_pdf_to_images", return_value=converted): + handler.prepare_messages(mock_agent, messages, attachments) + mock_agent.llm.prepare_messages_with_attachments.assert_called_once_with( + messages, converted + ) + + def test_pdf_conversion_failure_falls_back(self): + handler = ConcreteHandler() + messages = [{"role": "user", "content": "analyse"}] + attachments = [{"mime_type": "application/pdf", "path": "/tmp/doc.pdf"}] + + mock_agent = Mock() + mock_agent.llm.get_supported_attachment_types.return_value = ["image/png"] + + with patch.object( + handler, + "_convert_pdf_to_images", + side_effect=RuntimeError("conversion failed"), + ), patch.object( + handler, "_append_unsupported_attachments", return_value=messages + ) as mock_append: + handler.prepare_messages(mock_agent, messages, attachments) + mock_append.assert_called_once() + + def test_pdf_not_converted_when_natively_supported(self): + handler = ConcreteHandler() + messages = [{"role": "user", "content": "analyse"}] + attachments = [{"mime_type": "application/pdf", "path": "/tmp/doc.pdf"}] + + mock_agent = Mock() + mock_agent.llm.get_supported_attachment_types.return_value = [ + "image/png", + "application/pdf", + ] + mock_agent.llm.prepare_messages_with_attachments.return_value = messages + + with patch.object(handler, "_convert_pdf_to_images") as mock_convert: + handler.prepare_messages(mock_agent, messages, attachments) + mock_convert.assert_not_called() + + +# --------------------------------------------------------------------------- +# handle_tool_calls +# --------------------------------------------------------------------------- + + +class TestHandleToolCalls: + + def _make_agent(self): + agent = Mock() + agent._check_context_limit = Mock(return_value=False) + agent.context_limit_reached = False + agent.llm.__class__.__name__ = "MockLLM" + + def fake_execute(tools_dict, call): + yield {"type": "tool_call", "data": {"status": "pending"}} + return ("tool result", call.id) + + agent._execute_tool_action = Mock(side_effect=fake_execute) + return agent + + def test_single_tool_call(self): + handler = ConcreteHandler() + agent = self._make_agent() + call = ToolCall(id="c1", name="action_1", arguments="{}") + tools_dict = {"1": {"name": "tool"}} + + gen = handler.handle_tool_calls(agent, [call], tools_dict, []) + events = [] + try: + while True: + events.append(next(gen)) + except StopIteration as e: + messages = e.value + + assert any(e.get("type") == "tool_call" for e in events) + assert len(messages) >= 2 # function_call + tool_message + + def test_context_limit_skips_remaining(self): + handler = ConcreteHandler() + agent = self._make_agent() + agent._check_context_limit = Mock(return_value=True) + + with patch("application.core.settings.settings") as mock_settings: + mock_settings.ENABLE_CONVERSATION_COMPRESSION = False + + calls = [ + ToolCall(id="c1", name="a_1", arguments="{}"), + ToolCall(id="c2", name="b_1", arguments="{}"), + ] + gen = handler.handle_tool_calls(agent, calls, {}, []) + events = list(gen) + + skip_events = [ + e for e in events + if isinstance(e, dict) + and e.get("type") == "tool_call" + and e.get("data", {}).get("status") == "skipped" + ] + assert len(skip_events) == 2 + assert agent.context_limit_reached is True + + def test_tool_execution_error(self): + handler = ConcreteHandler() + agent = Mock() + agent._check_context_limit = Mock(return_value=False) + agent.context_limit_reached = False + agent._execute_tool_action = Mock(side_effect=RuntimeError("exec error")) + + call = ToolCall(id="c1", name="action_1", arguments="{}") + gen = handler.handle_tool_calls(agent, [call], {"1": {"name": "t"}}, []) + events = [] + try: + while True: + events.append(next(gen)) + except StopIteration: + pass + + error_events = [ + e for e in events + if isinstance(e, dict) and e.get("data", {}).get("status") == "error" + ] + assert len(error_events) == 1 + + def test_thought_signature_preserved(self): + handler = ConcreteHandler() + agent = self._make_agent() + call = ToolCall( + id="c1", name="action_1", arguments="{}", thought_signature="sig" + ) + + gen = handler.handle_tool_calls(agent, [call], {"1": {"name": "t"}}, []) + try: + while True: + next(gen) + except StopIteration as e: + messages = e.value + + assistant_msgs = [ + m for m in messages + if m.get("role") == "assistant" + and isinstance(m.get("content"), list) + ] + assert any( + "thought_signature" in item + for m in assistant_msgs + for item in m["content"] + if isinstance(item, dict) + ) + + +# --------------------------------------------------------------------------- +# handle_non_streaming +# --------------------------------------------------------------------------- + + +class TestHandleNonStreaming: + + def test_no_tool_calls(self): + handler = ConcreteHandler() + agent = Mock() + agent.llm = Mock() + response = "simple answer" + + gen = handler.handle_non_streaming(agent, response, {}, []) + events = [] + try: + while True: + events.append(next(gen)) + except StopIteration as e: + final = e.value + + assert final == "simple answer" + + def test_with_tool_calls_loop(self): + handler = ConcreteHandler() + agent = Mock() + agent.llm = Mock() + agent.model_id = "test" + agent.tools = [] + agent._check_context_limit = Mock(return_value=False) + agent.context_limit_reached = False + agent.llm.__class__.__name__ = "MockLLM" + + # First response requires tool call, second is final + call_count = {"n": 0} + + def fake_parse(response): + call_count["n"] += 1 + if call_count["n"] == 1: + return LLMResponse( + content="", + tool_calls=[ToolCall(id="c1", name="fn_1", arguments="{}")], + finish_reason="tool_calls", + raw_response=response, + ) + return LLMResponse( + content="final", + tool_calls=[], + finish_reason="stop", + raw_response=response, + ) + + handler.parse_response = fake_parse + + def fake_execute(tools_dict, call): + yield {"type": "tool_call", "data": {"status": "pending"}} + return ("result", call.id) + + agent._execute_tool_action = Mock(side_effect=fake_execute) + agent.llm.gen = Mock(return_value="second_response") + + gen = handler.handle_non_streaming(agent, "first_response", {"1": {"name": "t"}}, []) + events = [] + try: + while True: + events.append(next(gen)) + except StopIteration as e: + final = e.value + + assert final == "final" + assert agent.llm.gen.called + + +# --------------------------------------------------------------------------- +# handle_streaming +# --------------------------------------------------------------------------- + + +class TestHandleStreaming: + + def test_text_chunks(self): + handler = ConcreteHandler() + agent = Mock() + agent.llm = Mock() + + # Stream yields parsed responses with content + chunks = [ + LLMResponse(content="hello ", tool_calls=[], finish_reason="", raw_response={}), + LLMResponse(content="world", tool_calls=[], finish_reason="stop", raw_response={}), + ] + handler.parse_response = lambda c: c + + def fake_iterate(response): + yield from response + + handler._iterate_stream = fake_iterate + + gen = handler.handle_streaming(agent, chunks, {}, []) + results = list(gen) + assert "hello " in results + assert "world" in results + + def test_thought_chunks_passed_through(self): + handler = ConcreteHandler() + agent = Mock() + + def fake_iterate(response): + yield {"type": "thought", "content": "thinking..."} + + handler._iterate_stream = fake_iterate + + gen = handler.handle_streaming(agent, "response", {}, []) + results = list(gen) + assert results[0] == {"type": "thought", "content": "thinking..."} + + def test_string_chunks_passed_through(self): + handler = ConcreteHandler() + agent = Mock() + + def fake_iterate(response): + yield "raw string" + + handler._iterate_stream = fake_iterate + + gen = handler.handle_streaming(agent, "response", {}, []) + results = list(gen) + assert results[0] == "raw string" + + def test_tool_calls_accumulated_across_chunks(self): + handler = ConcreteHandler() + agent = Mock() + agent.llm = Mock() + agent.model_id = "test" + agent.tools = [] + agent._check_context_limit = Mock(return_value=False) + agent.context_limit_reached = False + agent.llm.__class__.__name__ = "MockLLM" + + # First chunk has partial tool call, second completes it + chunk1 = LLMResponse( + content="", + tool_calls=[ToolCall(id="c1", name="search", arguments='{"q":', index=0)], + finish_reason="", + raw_response={}, + ) + chunk2 = LLMResponse( + content="", + tool_calls=[ToolCall(id="", name="", arguments='"test"}', index=0)], + finish_reason="tool_calls", + raw_response={}, + ) + + handler.parse_response = lambda c: c + + def fake_iterate(response): + yield from response + + handler._iterate_stream = fake_iterate + + def fake_execute(tools_dict, call): + yield {"type": "tool_call", "data": {"status": "pending"}} + return ("result", call.id) + + agent._execute_tool_action = Mock(side_effect=fake_execute) + + # After tool calls, return final streaming response + final_chunk = LLMResponse( + content="done", tool_calls=[], finish_reason="stop", raw_response={} + ) + agent.llm.gen_stream = Mock(return_value=[final_chunk]) + + gen = handler.handle_streaming(agent, [chunk1, chunk2], {"1": {"name": "t"}}, []) + list(gen) + + # The accumulated arguments should be concatenated + tool_call_args = agent._execute_tool_action.call_args + executed_call = tool_call_args[0][1] + assert executed_call.arguments == '{"q":"test"}' + + def test_context_limit_adds_system_message(self): + handler = ConcreteHandler() + agent = Mock() + agent.llm = Mock() + agent.model_id = "test" + agent.tools = [{"type": "function"}] + agent.context_limit_reached = True + agent._check_context_limit = Mock(return_value=True) + agent.llm.__class__.__name__ = "MockLLM" + + # Chunk finishes with tool_calls + chunk = LLMResponse( + content="", + tool_calls=[ToolCall(id="c1", name="fn_1", arguments="{}", index=0)], + finish_reason="tool_calls", + raw_response={}, + ) + handler.parse_response = lambda c: c + + def fake_iterate(response): + yield from response + + handler._iterate_stream = fake_iterate + + with patch("application.core.settings.settings") as mock_settings: + mock_settings.ENABLE_CONVERSATION_COMPRESSION = False + + # handle_tool_calls yields skip events and sets context_limit_reached + def fake_handle_tool_calls(agent, calls, tools_dict, messages): + agent.context_limit_reached = True + yield {"type": "tool_call", "data": {"status": "skipped"}} + return messages + + handler.handle_tool_calls = fake_handle_tool_calls + + final_chunk = LLMResponse( + content="wrapping up", tool_calls=[], finish_reason="stop", raw_response={} + ) + agent.llm.gen_stream = Mock(return_value=[final_chunk]) + + gen = handler.handle_streaming(agent, [chunk], {"1": {"name": "t"}}, []) + list(gen) + + # Should have called gen_stream with tools=None (context limit) + gen_stream_kwargs = agent.llm.gen_stream.call_args[1] + assert gen_stream_kwargs.get("tools") is None + + +# --------------------------------------------------------------------------- +# _perform_mid_execution_compression +# --------------------------------------------------------------------------- + + +class TestPerformMidExecutionCompression: + + def test_success_path(self): + handler = ConcreteHandler() + agent = Mock() + agent.conversation_id = "conv1" + agent.initial_user_id = "user1" + agent.model_id = "gpt-4" + agent.decoded_token = {} + agent.context_limit_reached = False + agent.current_token_count = 0 + + mock_metadata = Mock() + mock_metadata.compressed_token_count = 100 + mock_metadata.original_token_count = 1000 + mock_metadata.compression_ratio = 10.0 + mock_metadata.to_dict.return_value = {"ratio": 10.0} + + mock_result = Mock() + mock_result.success = True + mock_result.compression_performed = True + mock_result.compressed_summary = "summary" + mock_result.recent_queries = [] + mock_result.metadata = mock_metadata + mock_result.error = None + + mock_conversation = {"queries": []} + mock_conv_service = Mock() + mock_conv_service.get_conversation.return_value = mock_conversation + + mock_orchestrator = Mock() + mock_orchestrator.compress_mid_execution.return_value = mock_result + + with patch( + "application.api.answer.services.compression.CompressionOrchestrator", + return_value=mock_orchestrator, + ), patch( + "application.api.answer.services.conversation_service.ConversationService", + return_value=mock_conv_service, + ), patch.object( + handler, "_build_conversation_from_messages", return_value={"queries": []} + ), patch.object( + handler, + "_rebuild_messages_after_compression", + return_value=[{"role": "system", "content": "rebuilt"}], + ): + success, messages = handler._perform_mid_execution_compression( + agent, [{"role": "user", "content": "hi"}] + ) + + assert success is True + assert messages is not None + assert agent.compressed_summary == "summary" + + def test_failure_falls_back_to_pruning(self): + handler = ConcreteHandler() + agent = Mock() + agent.conversation_id = "conv1" + agent.initial_user_id = "user1" + agent.model_id = "gpt-4" + agent.decoded_token = {} + agent.context_limit_reached = False + agent.current_token_count = 0 + + mock_result = Mock() + mock_result.success = False + mock_result.error = "failed" + + mock_conv_service = Mock() + mock_conv_service.get_conversation.return_value = {"queries": []} + + mock_orchestrator = Mock() + mock_orchestrator.compress_mid_execution.return_value = mock_result + + with patch( + "application.api.answer.services.compression.CompressionOrchestrator", + return_value=mock_orchestrator, + ), patch( + "application.api.answer.services.conversation_service.ConversationService", + return_value=mock_conv_service, + ), patch.object( + handler, "_build_conversation_from_messages", return_value={"queries": []} + ), patch.object( + handler, + "_prune_messages_minimal", + return_value=[{"role": "system", "content": "pruned"}], + ): + success, messages = handler._perform_mid_execution_compression( + agent, [{"role": "system", "content": "sys"}, {"role": "user", "content": "hi"}] + ) + + assert success is True + assert messages is not None + + def test_exception_returns_false(self): + handler = ConcreteHandler() + agent = Mock() + agent.conversation_id = "conv1" + agent.initial_user_id = "user1" + agent.model_id = "gpt-4" + agent.decoded_token = {} + + with patch( + "application.api.answer.services.compression.CompressionOrchestrator", + side_effect=RuntimeError("import error"), + ), patch( + "application.api.answer.services.conversation_service.ConversationService", + return_value=Mock(), + ): + success, messages = handler._perform_mid_execution_compression(agent, []) + + assert success is False + assert messages is None + + +# --------------------------------------------------------------------------- +# _perform_in_memory_compression +# --------------------------------------------------------------------------- + + +class TestPerformInMemoryCompression: + + def test_no_conversation_returns_false(self): + handler = ConcreteHandler() + agent = Mock() + + with patch.object( + handler, "_build_conversation_from_messages", return_value=None + ): + success, messages = handler._perform_in_memory_compression(agent, []) + + assert success is False + assert messages is None + + def test_compression_doesnt_reduce_falls_back_to_prune(self): + handler = ConcreteHandler() + agent = Mock() + agent.model_id = "gpt-4" + agent.user_api_key = None + agent.decoded_token = {} + agent.agent_id = None + agent.context_limit_reached = False + agent.current_token_count = 0 + + mock_metadata = Mock() + mock_metadata.compressed_token_count = 1000 + mock_metadata.original_token_count = 900 # worse! + + mock_service = Mock() + mock_service.compress_conversation.return_value = mock_metadata + + with patch.object( + handler, + "_build_conversation_from_messages", + return_value={"queries": [{"prompt": "q", "response": "a"}]}, + ), patch( + "application.core.model_utils.get_provider_from_model_id", + return_value="openai", + ), patch( + "application.core.model_utils.get_api_key_for_provider", + return_value="key", + ), patch( + "application.llm.llm_creator.LLMCreator.create_llm", + return_value=Mock(), + ), patch( + "application.api.answer.services.compression.service.CompressionService", + return_value=mock_service, + ), patch.object( + handler, + "_prune_messages_minimal", + return_value=[{"role": "system", "content": "pruned"}], + ), patch( + "application.core.settings.settings", + MagicMock(COMPRESSION_MODEL_OVERRIDE=None), + ): + success, messages = handler._perform_in_memory_compression( + agent, [{"role": "user", "content": "hi"}] + ) + + assert success is True + + def test_exception_returns_false(self): + handler = ConcreteHandler() + agent = Mock() + + with patch.object( + handler, + "_build_conversation_from_messages", + side_effect=RuntimeError("boom"), + ): + success, messages = handler._perform_in_memory_compression(agent, []) + + assert success is False + assert messages is None + + def test_not_enough_queries(self): + handler = ConcreteHandler() + agent = Mock() + agent.model_id = "gpt-4" + agent.user_api_key = None + agent.decoded_token = {} + agent.agent_id = None + + with patch.object( + handler, + "_build_conversation_from_messages", + return_value={"queries": []}, + ), patch( + "application.core.model_utils.get_provider_from_model_id", + return_value="openai", + ), patch( + "application.core.model_utils.get_api_key_for_provider", + return_value="key", + ), patch( + "application.llm.llm_creator.LLMCreator.create_llm", + return_value=Mock(), + ), patch( + "application.api.answer.services.compression.service.CompressionService", + return_value=Mock(), + ), patch( + "application.core.settings.settings", + MagicMock(COMPRESSION_MODEL_OVERRIDE=None), + ): + success, messages = handler._perform_in_memory_compression( + agent, [{"role": "user", "content": "hi"}] + ) + + assert success is False + assert messages is None + + def test_success_path(self): + handler = ConcreteHandler() + agent = Mock() + agent.model_id = "gpt-4" + agent.user_api_key = None + agent.decoded_token = {} + agent.agent_id = None + agent.context_limit_reached = False + agent.current_token_count = 0 + + mock_metadata = Mock() + mock_metadata.compressed_token_count = 100 + mock_metadata.original_token_count = 1000 + mock_metadata.compression_ratio = 10.0 + mock_metadata.to_dict.return_value = {"ratio": 10.0} + + mock_service = Mock() + mock_service.compress_conversation.return_value = mock_metadata + mock_service.get_compressed_context.return_value = ("summary", [{"prompt": "q", "response": "a"}]) + + with patch.object( + handler, + "_build_conversation_from_messages", + return_value={"queries": [{"prompt": "q", "response": "a"}]}, + ), patch( + "application.core.model_utils.get_provider_from_model_id", + return_value="openai", + ), patch( + "application.core.model_utils.get_api_key_for_provider", + return_value="key", + ), patch( + "application.llm.llm_creator.LLMCreator.create_llm", + return_value=Mock(), + ), patch( + "application.api.answer.services.compression.service.CompressionService", + return_value=mock_service, + ), patch.object( + handler, + "_rebuild_messages_after_compression", + return_value=[{"role": "system", "content": "rebuilt"}], + ), patch( + "application.core.settings.settings", + MagicMock(COMPRESSION_MODEL_OVERRIDE=None), + ): + success, messages = handler._perform_in_memory_compression( + agent, [{"role": "user", "content": "hi"}] + ) + + assert success is True + assert messages is not None + assert agent.compressed_summary == "summary" + + +# --------------------------------------------------------------------------- +# _perform_mid_execution_compression — additional edge cases +# --------------------------------------------------------------------------- + + +class TestPerformMidExecutionCompressionEdgeCases: + + def test_no_conversation_falls_back_to_in_memory(self): + handler = ConcreteHandler() + agent = Mock() + agent.conversation_id = "conv1" + agent.initial_user_id = "user1" + agent.model_id = "gpt-4" + agent.decoded_token = {} + + mock_conv_service = Mock() + mock_conv_service.get_conversation.return_value = None + + with patch( + "application.api.answer.services.compression.CompressionOrchestrator", + return_value=Mock(), + ), patch( + "application.api.answer.services.conversation_service.ConversationService", + return_value=mock_conv_service, + ), patch.object( + handler, + "_perform_in_memory_compression", + return_value=(True, [{"role": "system", "content": "ok"}]), + ) as mock_in_memory: + success, messages = handler._perform_mid_execution_compression( + agent, [{"role": "user", "content": "hi"}] + ) + + mock_in_memory.assert_called_once() + assert success is True + + def test_compression_not_performed(self): + handler = ConcreteHandler() + agent = Mock() + agent.conversation_id = "conv1" + agent.initial_user_id = "user1" + agent.model_id = "gpt-4" + agent.decoded_token = {} + + mock_result = Mock() + mock_result.success = True + mock_result.compression_performed = False + + mock_conv_service = Mock() + mock_conv_service.get_conversation.return_value = {"queries": []} + + mock_orchestrator = Mock() + mock_orchestrator.compress_mid_execution.return_value = mock_result + + with patch( + "application.api.answer.services.compression.CompressionOrchestrator", + return_value=mock_orchestrator, + ), patch( + "application.api.answer.services.conversation_service.ConversationService", + return_value=mock_conv_service, + ), patch.object( + handler, "_build_conversation_from_messages", return_value={"queries": []} + ): + success, messages = handler._perform_mid_execution_compression( + agent, [{"role": "user", "content": "hi"}] + ) + + assert success is False + assert messages is None + + def test_compression_didnt_reduce_tokens(self): + handler = ConcreteHandler() + agent = Mock() + agent.conversation_id = "conv1" + agent.initial_user_id = "user1" + agent.model_id = "gpt-4" + agent.decoded_token = {} + agent.context_limit_reached = False + agent.current_token_count = 0 + + mock_metadata = Mock() + mock_metadata.compressed_token_count = 1000 + mock_metadata.original_token_count = 900 + + mock_result = Mock() + mock_result.success = True + mock_result.compression_performed = True + mock_result.metadata = mock_metadata + + mock_conv_service = Mock() + mock_conv_service.get_conversation.return_value = {"queries": []} + + mock_orchestrator = Mock() + mock_orchestrator.compress_mid_execution.return_value = mock_result + + with patch( + "application.api.answer.services.compression.CompressionOrchestrator", + return_value=mock_orchestrator, + ), patch( + "application.api.answer.services.conversation_service.ConversationService", + return_value=mock_conv_service, + ), patch.object( + handler, "_build_conversation_from_messages", return_value={"queries": []} + ), patch.object( + handler, + "_prune_messages_minimal", + return_value=[{"role": "system", "content": "pruned"}], + ): + success, messages = handler._perform_mid_execution_compression( + agent, [{"role": "system", "content": "sys"}, {"role": "user", "content": "hi"}] + ) + + assert success is True + + def test_rebuild_returns_none(self): + handler = ConcreteHandler() + agent = Mock() + agent.conversation_id = "conv1" + agent.initial_user_id = "user1" + agent.model_id = "gpt-4" + agent.decoded_token = {} + agent.context_limit_reached = False + agent.current_token_count = 0 + + mock_metadata = Mock() + mock_metadata.compressed_token_count = 100 + mock_metadata.original_token_count = 1000 + mock_metadata.compression_ratio = 10.0 + mock_metadata.to_dict.return_value = {} + + mock_result = Mock() + mock_result.success = True + mock_result.compression_performed = True + mock_result.compressed_summary = "summary" + mock_result.recent_queries = [] + mock_result.metadata = mock_metadata + + mock_conv_service = Mock() + mock_conv_service.get_conversation.return_value = {"queries": []} + + mock_orchestrator = Mock() + mock_orchestrator.compress_mid_execution.return_value = mock_result + + with patch( + "application.api.answer.services.compression.CompressionOrchestrator", + return_value=mock_orchestrator, + ), patch( + "application.api.answer.services.conversation_service.ConversationService", + return_value=mock_conv_service, + ), patch.object( + handler, "_build_conversation_from_messages", return_value={"queries": []} + ), patch.object( + handler, "_rebuild_messages_after_compression", return_value=None + ): + success, messages = handler._perform_mid_execution_compression( + agent, [{"role": "user", "content": "hi"}] + ) + + assert success is False + assert messages is None + + +# --------------------------------------------------------------------------- +# _build_conversation_from_messages — additional edge cases +# --------------------------------------------------------------------------- + + +class TestBuildConversationEdgeCases: + + def test_function_call_without_call_id(self): + handler = ConcreteHandler() + messages = [ + {"role": "user", "content": "search"}, + { + "role": "assistant", + "content": [ + { + "function_call": { + "name": "search", + "args": {"q": "X"}, + } + } + ], + }, + {"role": "assistant", "content": "done"}, + ] + result = handler._build_conversation_from_messages(messages) + assert result is not None + queries = result["queries"] + # The tool call should still be tracked + assert len(queries) >= 1 + + def test_function_response_in_assistant_content(self): + handler = ConcreteHandler() + messages = [ + {"role": "user", "content": "q"}, + { + "role": "assistant", + "content": [ + { + "function_response": { + "name": "search", + "response": {"result": "found"}, + "call_id": "c1", + } + } + ], + }, + {"role": "assistant", "content": "final"}, + ] + result = handler._build_conversation_from_messages(messages) + assert result is not None + + def test_tool_role_without_function_response_format(self): + """Tool message with plain string content, no matching call_id.""" + handler = ConcreteHandler() + messages = [ + {"role": "user", "content": "q"}, + {"role": "assistant", "content": "calling tool"}, + {"role": "tool", "content": "tool output text"}, + {"role": "assistant", "content": "done"}, + ] + result = handler._build_conversation_from_messages(messages) + assert result is not None + + def test_pending_tool_calls_committed(self): + handler = ConcreteHandler() + messages = [ + {"role": "user", "content": "q"}, + { + "role": "assistant", + "content": [ + { + "function_call": { + "name": "search", + "args": {}, + "call_id": "c1", + } + } + ], + }, + ] + result = handler._build_conversation_from_messages(messages) + assert result is not None + assert len(result["queries"]) == 1 + assert len(result["queries"][0]["tool_calls"]) == 1 + + +# --------------------------------------------------------------------------- +# handle_tool_calls — compression success path +# --------------------------------------------------------------------------- + + +class TestHandleToolCallsCompressionSuccess: + + def test_compression_success_continues(self): + handler = ConcreteHandler() + agent = Mock() + call_count = {"n": 0} + + def check_limit(messages): + call_count["n"] += 1 + return call_count["n"] == 1 # Only trigger on first call + + agent._check_context_limit = Mock(side_effect=check_limit) + agent.context_limit_reached = False + agent.llm.__class__.__name__ = "MockLLM" + + def fake_execute(tools_dict, call): + yield {"type": "tool_call", "data": {"status": "pending"}} + return ("tool result", call.id) + + agent._execute_tool_action = Mock(side_effect=fake_execute) + + with patch( + "application.core.settings.settings" + ) as mock_settings: + mock_settings.ENABLE_CONVERSATION_COMPRESSION = True + + with patch.object( + handler, + "_perform_mid_execution_compression", + return_value=(True, [{"role": "system", "content": "compressed"}]), + ): + calls = [ToolCall(id="c1", name="a_1", arguments="{}")] + gen = handler.handle_tool_calls(agent, calls, {"1": {"name": "t"}}, []) + events = [] + try: + while True: + events.append(next(gen)) + except StopIteration: + pass + + info_events = [ + e for e in events + if isinstance(e, dict) and e.get("type") == "info" + ] + assert len(info_events) == 1 + + def test_compression_failure_after_some_tools(self): + handler = ConcreteHandler() + agent = Mock() + agent.context_limit_reached = False + agent.llm.__class__.__name__ = "MockLLM" + + exec_count = {"n": 0} + + def check_limit(messages): + return exec_count["n"] >= 1 + + agent._check_context_limit = Mock(side_effect=check_limit) + + def fake_execute(tools_dict, call): + exec_count["n"] += 1 + yield {"type": "tool_call", "data": {"status": "pending"}} + return ("tool result", call.id) + + agent._execute_tool_action = Mock(side_effect=fake_execute) + + with patch( + "application.core.settings.settings" + ) as mock_settings: + mock_settings.ENABLE_CONVERSATION_COMPRESSION = True + + with patch.object( + handler, + "_perform_mid_execution_compression", + return_value=(False, None), + ): + calls = [ + ToolCall(id="c1", name="a_1", arguments="{}"), + ToolCall(id="c2", name="b_1", arguments="{}"), + ] + gen = handler.handle_tool_calls(agent, calls, {"1": {"name": "t"}}, []) + events = [] + try: + while True: + events.append(next(gen)) + except StopIteration: + pass + + skip_events = [ + e for e in events + if isinstance(e, dict) and e.get("data", {}).get("status") == "skipped" + ] + assert len(skip_events) == 1 # Only second call skipped diff --git a/tests/llm/test_base_llm.py b/tests/llm/test_base_llm.py new file mode 100644 index 00000000..7176ab2a --- /dev/null +++ b/tests/llm/test_base_llm.py @@ -0,0 +1,204 @@ +"""Unit tests for application/llm/base.py — BaseLLM. + +Covers initialisation, static helpers, supports_* introspection, +structured-output defaults, and attachment-type defaults. +Fallback behaviour is covered separately in test_fallback.py. +""" + +from unittest.mock import MagicMock, Mock + +import pytest + +from application.llm.base import BaseLLM + + +# --------------------------------------------------------------------------- +# Concrete stub so we can instantiate the abstract base +# --------------------------------------------------------------------------- + + +class StubLLM(BaseLLM): + """Minimal concrete BaseLLM for unit-testing non-abstract members.""" + + def _raw_gen(self, baseself, model, messages, stream, tools=None, **kw): + return "raw_gen_result" + + def _raw_gen_stream(self, baseself, model, messages, stream, tools=None, **kw): + yield "chunk" + + +# --------------------------------------------------------------------------- +# __init__ +# --------------------------------------------------------------------------- + + +@pytest.mark.unit +class TestBaseLLMInit: + + def test_defaults(self): + llm = StubLLM() + assert llm.decoded_token is None + assert llm.agent_id is None + assert llm.model_id is None + assert llm.base_url is None + assert llm.token_usage == {"prompt_tokens": 0, "generated_tokens": 0} + assert llm._backup_models == [] + assert llm._fallback_llm is None + + def test_agent_id_cast_to_str(self): + llm = StubLLM(agent_id=42) + assert llm.agent_id == "42" + + def test_agent_id_none_stays_none(self): + llm = StubLLM(agent_id=None) + assert llm.agent_id is None + + def test_custom_params(self): + token = {"sub": "u1"} + llm = StubLLM( + decoded_token=token, + agent_id="abc", + model_id="gpt-4", + base_url="http://x", + backup_models=["m1", "m2"], + ) + assert llm.decoded_token is token + assert llm.agent_id == "abc" + assert llm.model_id == "gpt-4" + assert llm.base_url == "http://x" + assert llm._backup_models == ["m1", "m2"] + + +# --------------------------------------------------------------------------- +# _remove_null_values +# --------------------------------------------------------------------------- + + +@pytest.mark.unit +class TestRemoveNullValues: + + def test_removes_none_values(self): + result = BaseLLM._remove_null_values({"a": 1, "b": None, "c": "x"}) + assert result == {"a": 1, "c": "x"} + + def test_keeps_falsy_non_none(self): + result = BaseLLM._remove_null_values({"a": 0, "b": "", "c": False, "d": []}) + assert result == {"a": 0, "b": "", "c": False, "d": []} + + def test_non_dict_passthrough(self): + assert BaseLLM._remove_null_values("hello") == "hello" + assert BaseLLM._remove_null_values(42) == 42 + assert BaseLLM._remove_null_values([1, 2]) == [1, 2] + + def test_empty_dict(self): + assert BaseLLM._remove_null_values({}) == {} + + def test_all_none(self): + assert BaseLLM._remove_null_values({"a": None, "b": None}) == {} + + +# --------------------------------------------------------------------------- +# supports_tools / _supports_tools +# --------------------------------------------------------------------------- + + +@pytest.mark.unit +class TestSupportsTools: + + def test_supports_tools_true_when_callable(self): + llm = StubLLM() + assert llm.supports_tools() is True + + def test_supports_tools_false_when_not_callable(self): + llm = StubLLM() + llm._supports_tools = "not_callable" + assert llm.supports_tools() is False + + def test_default_supports_tools_raises(self): + llm = StubLLM() + with pytest.raises(NotImplementedError): + llm._supports_tools() + + +# --------------------------------------------------------------------------- +# supports_structured_output / _supports_structured_output +# --------------------------------------------------------------------------- + + +@pytest.mark.unit +class TestSupportsStructuredOutput: + + def test_supports_structured_output_true(self): + llm = StubLLM() + assert llm.supports_structured_output() is True + + def test_default_supports_structured_output_returns_false(self): + llm = StubLLM() + assert llm._supports_structured_output() is False + + +# --------------------------------------------------------------------------- +# prepare_structured_output_format +# --------------------------------------------------------------------------- + + +@pytest.mark.unit +class TestPrepareStructuredOutputFormat: + + def test_returns_none_by_default(self): + llm = StubLLM() + assert llm.prepare_structured_output_format({"type": "object"}) is None + + +# --------------------------------------------------------------------------- +# get_supported_attachment_types +# --------------------------------------------------------------------------- + + +@pytest.mark.unit +class TestGetSupportedAttachmentTypes: + + def test_returns_empty_list(self): + llm = StubLLM() + assert llm.get_supported_attachment_types() == [] + + +# --------------------------------------------------------------------------- +# fallback_llm — caching +# --------------------------------------------------------------------------- + + +@pytest.mark.unit +class TestFallbackLLMCaching: + + def test_returns_cached_instance(self, monkeypatch): + """Once resolved, the same fallback instance is returned.""" + sentinel = StubLLM() + llm = StubLLM() + llm._fallback_llm = sentinel + assert llm.fallback_llm is sentinel + + def test_none_when_no_backup_and_no_global(self, monkeypatch): + monkeypatch.setattr( + "application.llm.base.settings", + MagicMock(FALLBACK_LLM_PROVIDER=None), + ) + llm = StubLLM(backup_models=[]) + assert llm.fallback_llm is None + + def test_global_fallback_init_failure_returns_none(self, monkeypatch): + monkeypatch.setattr( + "application.llm.base.settings", + MagicMock( + FALLBACK_LLM_PROVIDER="openai", + FALLBACK_LLM_NAME="gpt-4", + FALLBACK_LLM_API_KEY="k", + API_KEY="k", + ), + ) + monkeypatch.setattr( + "application.llm.llm_creator.LLMCreator.create_llm", + Mock(side_effect=RuntimeError("boom")), + ) + llm = StubLLM(backup_models=[]) + assert llm.fallback_llm is None