diff --git a/tests/agents/test_cel_evaluator.py b/tests/agents/test_cel_evaluator.py new file mode 100644 index 00000000..70209b56 --- /dev/null +++ b/tests/agents/test_cel_evaluator.py @@ -0,0 +1,169 @@ +"""Tests for application/agents/workflows/cel_evaluator.py""" + +import pytest + +from application.agents.workflows.cel_evaluator import ( + CelEvaluationError, + _convert_value, + build_activation, + cel_to_python, + evaluate_cel, +) +import celpy.celtypes + + +class TestConvertValue: + + @pytest.mark.unit + def test_bool_true(self): + result = _convert_value(True) + assert isinstance(result, celpy.celtypes.BoolType) + assert bool(result) is True + + @pytest.mark.unit + def test_bool_false(self): + result = _convert_value(False) + assert isinstance(result, celpy.celtypes.BoolType) + assert bool(result) is False + + @pytest.mark.unit + def test_int(self): + result = _convert_value(42) + assert isinstance(result, celpy.celtypes.IntType) + assert int(result) == 42 + + @pytest.mark.unit + def test_float(self): + result = _convert_value(3.14) + assert isinstance(result, celpy.celtypes.DoubleType) + assert float(result) == pytest.approx(3.14) + + @pytest.mark.unit + def test_string(self): + result = _convert_value("hello") + assert isinstance(result, celpy.celtypes.StringType) + assert str(result) == "hello" + + @pytest.mark.unit + def test_list(self): + result = _convert_value([1, "two", 3.0]) + assert isinstance(result, celpy.celtypes.ListType) + + @pytest.mark.unit + def test_dict(self): + result = _convert_value({"key": "value"}) + assert isinstance(result, celpy.celtypes.MapType) + + @pytest.mark.unit + def test_none(self): + result = _convert_value(None) + assert isinstance(result, celpy.celtypes.BoolType) + assert bool(result) is False + + @pytest.mark.unit + def test_other_type_converts_to_string(self): + result = _convert_value(object()) + assert isinstance(result, celpy.celtypes.StringType) + + +class TestBuildActivation: + + @pytest.mark.unit + def test_converts_dict_values(self): + state = {"name": "Alice", "age": 30, "active": True} + result = build_activation(state) + assert "name" in result + assert "age" in result + assert "active" in result + + @pytest.mark.unit + def test_empty_state(self): + assert build_activation({}) == {} + + +class TestEvaluateCel: + + @pytest.mark.unit + def test_simple_comparison(self): + assert evaluate_cel("x > 5", {"x": 10}) is True + assert evaluate_cel("x > 5", {"x": 3}) is False + + @pytest.mark.unit + def test_string_comparison(self): + assert evaluate_cel('name == "Alice"', {"name": "Alice"}) is True + assert evaluate_cel('name == "Alice"', {"name": "Bob"}) is False + + @pytest.mark.unit + def test_arithmetic(self): + assert evaluate_cel("x + y", {"x": 3, "y": 4}) == 7 + + @pytest.mark.unit + def test_boolean_logic(self): + assert evaluate_cel("a && b", {"a": True, "b": True}) is True + assert evaluate_cel("a && b", {"a": True, "b": False}) is False + assert evaluate_cel("a || b", {"a": False, "b": True}) is True + + @pytest.mark.unit + def test_empty_expression_raises(self): + with pytest.raises(CelEvaluationError, match="Empty expression"): + evaluate_cel("", {}) + + @pytest.mark.unit + def test_whitespace_expression_raises(self): + with pytest.raises(CelEvaluationError, match="Empty expression"): + evaluate_cel(" ", {}) + + @pytest.mark.unit + def test_invalid_expression_raises(self): + with pytest.raises(CelEvaluationError): + evaluate_cel("invalid!!!", {}) + + @pytest.mark.unit + def test_missing_variable_raises(self): + with pytest.raises(CelEvaluationError): + evaluate_cel("undefined_var > 5", {}) + + +class TestCelToPython: + + @pytest.mark.unit + def test_bool(self): + result = cel_to_python(celpy.celtypes.BoolType(True)) + assert result is True + + @pytest.mark.unit + def test_int(self): + result = cel_to_python(celpy.celtypes.IntType(42)) + assert result == 42 + + @pytest.mark.unit + def test_double(self): + result = cel_to_python(celpy.celtypes.DoubleType(3.14)) + assert result == pytest.approx(3.14) + + @pytest.mark.unit + def test_string(self): + result = cel_to_python(celpy.celtypes.StringType("hello")) + assert result == "hello" + + @pytest.mark.unit + def test_list(self): + cel_list = celpy.celtypes.ListType([ + celpy.celtypes.IntType(1), + celpy.celtypes.IntType(2), + ]) + result = cel_to_python(cel_list) + assert result == [1, 2] + + @pytest.mark.unit + def test_map(self): + cel_map = celpy.celtypes.MapType({ + celpy.celtypes.StringType("key"): celpy.celtypes.StringType("value"), + }) + result = cel_to_python(cel_map) + assert result == {"key": "value"} + + @pytest.mark.unit + def test_unknown_type_passthrough(self): + result = cel_to_python("raw_value") + assert result == "raw_value" diff --git a/tests/agents/test_workflow_agent_coverage.py b/tests/agents/test_workflow_agent_coverage.py new file mode 100644 index 00000000..590f6b5b --- /dev/null +++ b/tests/agents/test_workflow_agent_coverage.py @@ -0,0 +1,433 @@ +"""Tests for WorkflowAgent - covering _parse_embedded_workflow, _load_from_database, +_save_workflow_run, _determine_run_status, _serialize_state, and gen flow.""" + +from datetime import datetime, timezone +from unittest.mock import MagicMock, patch + +import pytest + +from application.agents.workflows.schemas import ( + ExecutionStatus, + WorkflowGraph, +) + + +def _make_agent(**overrides): + """Create a WorkflowAgent with mocked base class dependencies.""" + defaults = { + "endpoint": "https://api.example.com", + "llm_name": "openai", + "model_id": "gpt-4", + "api_key": "test_key", + "user_api_key": None, + "prompt": "You are helpful.", + "chat_history": [], + "decoded_token": {"sub": "user1"}, + "attachments": [], + "json_schema": None, + } + defaults.update(overrides) + + with patch("application.agents.workflow_agent.log_activity", lambda **kw: lambda f: f): + from application.agents.workflow_agent import WorkflowAgent + agent = WorkflowAgent(**defaults) + return agent + + +class TestWorkflowAgentInit: + + @pytest.mark.unit + def test_sets_attributes(self): + agent = _make_agent(workflow_id="wf1", workflow_owner="owner1") + assert agent.workflow_id == "wf1" + assert agent.workflow_owner == "owner1" + assert agent._engine is None + + @pytest.mark.unit + def test_embedded_workflow(self): + wf_data = {"nodes": [], "edges": [], "name": "Test"} + agent = _make_agent(workflow=wf_data) + assert agent._workflow_data == wf_data + + +class TestParseEmbeddedWorkflow: + + @pytest.mark.unit + def test_parses_valid_workflow(self): + wf_data = { + "name": "Test Workflow", + "description": "A test", + "nodes": [ + {"id": "n1", "type": "start", "title": "Start", "data": {}, "position": {"x": 0, "y": 0}}, + {"id": "n2", "type": "end", "title": "End", "data": {}, "position": {"x": 100, "y": 0}}, + ], + "edges": [ + {"id": "e1", "source": "n1", "target": "n2", "sourceHandle": "out", "targetHandle": "in"}, + ], + } + agent = _make_agent(workflow=wf_data, workflow_id="wf1") + graph = agent._parse_embedded_workflow() + assert graph is not None + assert len(graph.nodes) == 2 + assert len(graph.edges) == 1 + assert graph.workflow.name == "Test Workflow" + + @pytest.mark.unit + def test_edge_source_id_alias(self): + wf_data = { + "nodes": [{"id": "n1", "type": "start", "data": {}}], + "edges": [{"id": "e1", "source_id": "n1", "target_id": "n2", "source_handle": "out", "target_handle": "in"}], + } + agent = _make_agent(workflow=wf_data) + graph = agent._parse_embedded_workflow() + assert graph is not None + assert graph.edges[0].source_id == "n1" + + @pytest.mark.unit + def test_invalid_data_returns_none(self): + agent = _make_agent(workflow={"nodes": [{"bad": "data"}], "edges": []}) + graph = agent._parse_embedded_workflow() + assert graph is None + + +class TestLoadWorkflowGraph: + + @pytest.mark.unit + def test_uses_embedded_when_available(self): + agent = _make_agent(workflow={"nodes": [], "edges": [], "name": "E"}) + agent._parse_embedded_workflow = MagicMock(return_value="parsed_graph") + result = agent._load_workflow_graph() + assert result == "parsed_graph" + + @pytest.mark.unit + def test_uses_database_when_workflow_id(self): + agent = _make_agent(workflow_id="wf1") + agent._load_from_database = MagicMock(return_value="db_graph") + result = agent._load_workflow_graph() + assert result == "db_graph" + + @pytest.mark.unit + def test_returns_none_when_nothing(self): + agent = _make_agent() + result = agent._load_workflow_graph() + assert result is None + + +class TestLoadFromDatabase: + + @pytest.mark.unit + def test_invalid_workflow_id_returns_none(self): + agent = _make_agent(workflow_id="invalid!") + result = agent._load_from_database() + assert result is None + + @pytest.mark.unit + def test_no_owner_returns_none(self): + agent = _make_agent(workflow_id="507f1f77bcf86cd799439011", decoded_token={}) + agent.workflow_owner = None + result = agent._load_from_database() + assert result is None + + @pytest.mark.unit + def test_uses_decoded_token_sub(self): + agent = _make_agent( + workflow_id="507f1f77bcf86cd799439011", + decoded_token={"sub": "user1"}, + ) + agent.workflow_owner = None + + mock_collection = MagicMock() + mock_collection.find_one.return_value = None + mock_db = MagicMock() + mock_db.__getitem__ = MagicMock(return_value=mock_collection) + + with patch("application.agents.workflow_agent.MongoDB") as MockMongo, \ + patch("application.agents.workflow_agent.settings") as mock_settings: + mock_settings.MONGO_DB_NAME = "test_db" + MockMongo.get_client.return_value = {"test_db": mock_db} + result = agent._load_from_database() + assert result is None # workflow_doc not found + + @pytest.mark.unit + def test_successful_load(self): + agent = _make_agent( + workflow_id="507f1f77bcf86cd799439011", + workflow_owner="owner1", + ) + + mock_wf_coll = MagicMock() + mock_wf_coll.find_one.return_value = { + "_id": "507f1f77bcf86cd799439011", + "name": "Test WF", + "user": "owner1", + "current_graph_version": 1, + } + + mock_nodes_coll = MagicMock() + mock_nodes_coll.find.return_value = [ + {"id": "n1", "workflow_id": "507f1f77bcf86cd799439011", "type": "start", + "title": "Start", "position": {"x": 0, "y": 0}, "config": {}}, + ] + + mock_edges_coll = MagicMock() + mock_edges_coll.find.return_value = [] + + def getitem(name): + return {"workflows": mock_wf_coll, "workflow_nodes": mock_nodes_coll, "workflow_edges": mock_edges_coll}[name] + + mock_db = MagicMock() + mock_db.__getitem__ = MagicMock(side_effect=getitem) + + with patch("application.agents.workflow_agent.MongoDB") as MockMongo, \ + patch("application.agents.workflow_agent.settings") as mock_settings: + mock_settings.MONGO_DB_NAME = "test_db" + MockMongo.get_client.return_value = {"test_db": mock_db} + result = agent._load_from_database() + + assert result is not None + assert len(result.nodes) == 1 + + @pytest.mark.unit + def test_invalid_graph_version(self): + agent = _make_agent( + workflow_id="507f1f77bcf86cd799439011", + workflow_owner="owner1", + ) + + mock_wf_coll = MagicMock() + mock_wf_coll.find_one.return_value = { + "_id": "507f1f77bcf86cd799439011", + "name": "WF", + "user": "owner1", + "current_graph_version": "bad", + } + + mock_nodes_coll = MagicMock() + mock_nodes_coll.find.return_value = [] + mock_edges_coll = MagicMock() + mock_edges_coll.find.return_value = [] + + def getitem(name): + return {"workflows": mock_wf_coll, "workflow_nodes": mock_nodes_coll, "workflow_edges": mock_edges_coll}[name] + + mock_db = MagicMock() + mock_db.__getitem__ = MagicMock(side_effect=getitem) + + with patch("application.agents.workflow_agent.MongoDB") as MockMongo, \ + patch("application.agents.workflow_agent.settings") as mock_settings: + mock_settings.MONGO_DB_NAME = "test_db" + MockMongo.get_client.return_value = {"test_db": mock_db} + result = agent._load_from_database() + assert result is not None # Defaults to version 1 + + @pytest.mark.unit + def test_fallback_nodes_without_version(self): + """When graph_version=1 finds no nodes, falls back to nodes without version field.""" + agent = _make_agent( + workflow_id="507f1f77bcf86cd799439011", + workflow_owner="owner1", + ) + + mock_wf_coll = MagicMock() + mock_wf_coll.find_one.return_value = { + "_id": "507f1f77bcf86cd799439011", + "name": "WF", + "user": "owner1", + "current_graph_version": 1, + } + + call_count = [0] + def nodes_find(query): + call_count[0] += 1 + if call_count[0] == 1: + return [] # No versioned nodes + return [{"id": "n1", "workflow_id": "wf", "type": "start", + "title": "S", "position": {"x": 0, "y": 0}, "config": {}}] + + mock_nodes_coll = MagicMock() + mock_nodes_coll.find.side_effect = nodes_find + + edge_call = [0] + def edges_find(query): + edge_call[0] += 1 + if edge_call[0] == 1: + return [] + return [] + + mock_edges_coll = MagicMock() + mock_edges_coll.find.side_effect = edges_find + + def getitem(name): + return {"workflows": mock_wf_coll, "workflow_nodes": mock_nodes_coll, "workflow_edges": mock_edges_coll}[name] + + mock_db = MagicMock() + mock_db.__getitem__ = MagicMock(side_effect=getitem) + + with patch("application.agents.workflow_agent.MongoDB") as MockMongo, \ + patch("application.agents.workflow_agent.settings") as mock_settings: + mock_settings.MONGO_DB_NAME = "test_db" + MockMongo.get_client.return_value = {"test_db": mock_db} + result = agent._load_from_database() + assert result is not None + assert len(result.nodes) == 1 + + @pytest.mark.unit + def test_exception_returns_none(self): + agent = _make_agent( + workflow_id="507f1f77bcf86cd799439011", + workflow_owner="owner1", + ) + with patch("application.agents.workflow_agent.MongoDB") as MockMongo: + MockMongo.get_client.side_effect = Exception("db error") + result = agent._load_from_database() + assert result is None + + +class TestGenInner: + + @pytest.mark.unit + def test_no_graph_yields_error(self): + agent = _make_agent() + agent._load_workflow_graph = MagicMock(return_value=None) + events = list(agent._gen_inner("query", None)) + assert any(e.get("type") == "error" for e in events) + + @pytest.mark.unit + def test_successful_execution(self): + agent = _make_agent(workflow_id="wf1") + mock_graph = MagicMock(spec=WorkflowGraph) + agent._load_workflow_graph = MagicMock(return_value=mock_graph) + agent._save_workflow_run = MagicMock() + + mock_engine = MagicMock() + mock_engine.execute.return_value = iter([{"answer": "result"}]) + + with patch("application.agents.workflow_agent.WorkflowEngine", return_value=mock_engine): + events = list(agent._gen_inner("query", None)) + assert len(events) == 1 + agent._save_workflow_run.assert_called_once_with("query") + + +class TestSaveWorkflowRun: + + @pytest.mark.unit + def test_no_engine_returns_early(self): + agent = _make_agent() + agent._engine = None + agent._save_workflow_run("query") # Should not raise + + @pytest.mark.unit + def test_saves_to_mongo(self): + agent = _make_agent(workflow_id="wf1") + mock_engine = MagicMock() + mock_engine.state = {"query": "test"} + mock_engine.execution_log = [] + mock_engine.get_execution_summary.return_value = [] + agent._engine = mock_engine + + mock_collection = MagicMock() + mock_db = MagicMock() + mock_db.__getitem__ = MagicMock(return_value=mock_collection) + + with patch("application.agents.workflow_agent.MongoDB") as MockMongo, \ + patch("application.agents.workflow_agent.settings") as mock_settings: + mock_settings.MONGO_DB_NAME = "test_db" + MockMongo.get_client.return_value = {"test_db": mock_db} + agent._save_workflow_run("query") + + mock_collection.insert_one.assert_called_once() + + @pytest.mark.unit + def test_exception_does_not_propagate(self): + agent = _make_agent(workflow_id="wf1") + mock_engine = MagicMock() + mock_engine.state = {} + mock_engine.execution_log = [] + mock_engine.get_execution_summary.return_value = [] + agent._engine = mock_engine + + with patch("application.agents.workflow_agent.MongoDB") as MockMongo: + MockMongo.get_client.side_effect = Exception("db fail") + agent._save_workflow_run("query") # Should not raise + + +class TestDetermineRunStatus: + + @pytest.mark.unit + def test_no_engine_returns_completed(self): + agent = _make_agent() + agent._engine = None + assert agent._determine_run_status() == ExecutionStatus.COMPLETED + + @pytest.mark.unit + def test_empty_log_returns_completed(self): + agent = _make_agent() + agent._engine = MagicMock() + agent._engine.execution_log = [] + assert agent._determine_run_status() == ExecutionStatus.COMPLETED + + @pytest.mark.unit + def test_failed_log_returns_failed(self): + agent = _make_agent() + agent._engine = MagicMock() + agent._engine.execution_log = [ + {"status": "completed"}, + {"status": "failed"}, + ] + assert agent._determine_run_status() == ExecutionStatus.FAILED + + @pytest.mark.unit + def test_all_completed_returns_completed(self): + agent = _make_agent() + agent._engine = MagicMock() + agent._engine.execution_log = [ + {"status": "completed"}, + {"status": "completed"}, + ] + assert agent._determine_run_status() == ExecutionStatus.COMPLETED + + +class TestSerializeState: + + @pytest.mark.unit + def test_serializes_primitives(self): + agent = _make_agent() + state = {"str": "hello", "int": 42, "float": 3.14, "bool": True, "none": None} + result = agent._serialize_state(state) + assert result == state + + @pytest.mark.unit + def test_serializes_nested_dict(self): + agent = _make_agent() + state = {"nested": {"key": "value"}} + result = agent._serialize_state(state) + assert result["nested"]["key"] == "value" + + @pytest.mark.unit + def test_serializes_list(self): + agent = _make_agent() + state = {"items": [1, 2, "three"]} + result = agent._serialize_state(state) + assert result["items"] == [1, 2, "three"] + + @pytest.mark.unit + def test_serializes_tuple(self): + agent = _make_agent() + state = {"tup": (1, 2)} + result = agent._serialize_state(state) + assert result["tup"] == [1, 2] + + @pytest.mark.unit + def test_serializes_datetime(self): + agent = _make_agent() + dt = datetime(2025, 1, 1, 12, 0, 0, tzinfo=timezone.utc) + state = {"time": dt} + result = agent._serialize_state(state) + assert "2025-01-01" in result["time"] + + @pytest.mark.unit + def test_serializes_unknown_to_str(self): + agent = _make_agent() + state = {"obj": object()} + result = agent._serialize_state(state) + assert isinstance(result["obj"], str) diff --git a/tests/agents/test_workflow_engine_coverage.py b/tests/agents/test_workflow_engine_coverage.py new file mode 100644 index 00000000..2a5b8773 --- /dev/null +++ b/tests/agents/test_workflow_engine_coverage.py @@ -0,0 +1,573 @@ +"""Tests covering gaps in WorkflowEngine: execute loop, state/condition/end nodes, +template context, source data, structured output parsing, get_execution_summary.""" + +from datetime import datetime, timezone +from unittest.mock import MagicMock, patch + +import pytest + +from application.agents.workflows.schemas import ( + ExecutionStatus, + NodeType, + WorkflowEdge, + WorkflowGraph, + WorkflowNode, + Workflow, +) +from application.agents.workflows.workflow_engine import WorkflowEngine + + +def _make_graph(nodes, edges): + wf = Workflow(name="Test", description="test workflow") + return WorkflowGraph(workflow=wf, nodes=nodes, edges=edges) + + +def _make_node(id, type, title="Node", config=None, position=None): + return WorkflowNode( + id=id, + workflow_id="wf1", + type=type, + title=title, + position=position or {"x": 0, "y": 0}, + config=config or {}, + ) + + +def _make_edge(id, source, target, source_handle=None, target_handle=None): + return WorkflowEdge( + id=id, + workflow_id="wf1", + source=source, + target=target, + sourceHandle=source_handle, + targetHandle=target_handle, + ) + + +def _make_agent(): + agent = MagicMock() + agent.chat_history = [] + agent.endpoint = "https://api.example.com" + agent.llm_name = "openai" + agent.model_id = "gpt-4" + agent.api_key = "key" + agent.decoded_token = {"sub": "user1"} + agent.retrieved_docs = None + return agent + + +class TestExecuteLoop: + + @pytest.mark.unit + def test_no_start_node_yields_error(self): + graph = _make_graph([], []) + engine = WorkflowEngine(graph, _make_agent()) + events = list(engine.execute({}, "query")) + assert any(e.get("type") == "error" and "start node" in e.get("error", "") for e in events) + + @pytest.mark.unit + def test_start_to_end(self): + nodes = [ + _make_node("n1", NodeType.START, "Start"), + _make_node("n2", NodeType.END, "End", config={"config": {}}), + ] + edges = [_make_edge("e1", "n1", "n2")] + graph = _make_graph(nodes, edges) + engine = WorkflowEngine(graph, _make_agent()) + events = list(engine.execute({}, "hello")) + step_events = [e for e in events if e.get("type") == "workflow_step"] + assert len(step_events) >= 2 # At least start + end + + @pytest.mark.unit + def test_node_not_found_yields_error(self): + nodes = [_make_node("n1", NodeType.START)] + edges = [_make_edge("e1", "n1", "nonexistent")] + graph = _make_graph(nodes, edges) + engine = WorkflowEngine(graph, _make_agent()) + events = list(engine.execute({}, "q")) + assert any("not found" in e.get("error", "") for e in events) + + @pytest.mark.unit + def test_node_execution_error_yields_error(self): + nodes = [ + _make_node("n1", NodeType.START), + _make_node("n2", NodeType.STATE, "State", config={"config": {"operations": [{"expression": "bad!!!", "target_variable": "x"}]}}), + ] + edges = [_make_edge("e1", "n1", "n2")] + graph = _make_graph(nodes, edges) + engine = WorkflowEngine(graph, _make_agent()) + events = list(engine.execute({}, "q")) + failed_events = [e for e in events if e.get("status") == "failed"] + assert len(failed_events) >= 1 + + @pytest.mark.unit + def test_max_steps_limit(self): + # Create a cycle: start -> state -> state (loop) + nodes = [ + _make_node("n1", NodeType.START), + _make_node("n2", NodeType.NOTE, "Note"), + ] + edges = [_make_edge("e1", "n1", "n2"), _make_edge("e2", "n2", "n2")] + graph = _make_graph(nodes, edges) + engine = WorkflowEngine(graph, _make_agent()) + engine.MAX_EXECUTION_STEPS = 5 + events = list(engine.execute({}, "q")) + # Should not run forever + assert len(events) > 0 + + @pytest.mark.unit + def test_branch_ends_without_end_node(self): + nodes = [ + _make_node("n1", NodeType.START), + _make_node("n2", NodeType.NOTE, "Note"), + ] + edges = [_make_edge("e1", "n1", "n2")] # n2 has no outgoing edges + graph = _make_graph(nodes, edges) + engine = WorkflowEngine(graph, _make_agent()) + events = list(engine.execute({}, "q")) + assert len(events) > 0 + + +class TestInitializeState: + + @pytest.mark.unit + def test_sets_query_and_history(self): + graph = _make_graph([], []) + agent = _make_agent() + agent.chat_history = [{"prompt": "hi", "response": "hey"}] + engine = WorkflowEngine(graph, agent) + engine._initialize_state({"custom": "value"}, "test query") + assert engine.state["query"] == "test query" + assert "custom" in engine.state + assert engine.state["chat_history"] is not None + + +class TestGetNextNodeId: + + @pytest.mark.unit + def test_no_edges_returns_none(self): + nodes = [_make_node("n1", NodeType.START)] + graph = _make_graph(nodes, []) + engine = WorkflowEngine(graph, _make_agent()) + assert engine._get_next_node_id("n1") is None + + @pytest.mark.unit + def test_returns_first_edge_target(self): + nodes = [_make_node("n1", NodeType.START), _make_node("n2", NodeType.END)] + edges = [_make_edge("e1", "n1", "n2")] + graph = _make_graph(nodes, edges) + engine = WorkflowEngine(graph, _make_agent()) + assert engine._get_next_node_id("n1") == "n2" + + @pytest.mark.unit + def test_condition_uses_matched_handle(self): + nodes = [ + _make_node("n1", NodeType.CONDITION), + _make_node("n2", NodeType.END, "Yes End"), + _make_node("n3", NodeType.END, "No End"), + ] + edges = [ + _make_edge("e1", "n1", "n2", source_handle="yes"), + _make_edge("e2", "n1", "n3", source_handle="no"), + ] + graph = _make_graph(nodes, edges) + engine = WorkflowEngine(graph, _make_agent()) + engine._condition_result = "no" + assert engine._get_next_node_id("n1") == "n3" + assert engine._condition_result is None # Cleared after use + + @pytest.mark.unit + def test_condition_no_matching_handle_returns_none(self): + nodes = [_make_node("n1", NodeType.CONDITION)] + edges = [_make_edge("e1", "n1", "n2", source_handle="yes")] + graph = _make_graph(nodes, edges) + engine = WorkflowEngine(graph, _make_agent()) + engine._condition_result = "nonexistent" + assert engine._get_next_node_id("n1") is None + + +class TestExecuteStateNode: + + @pytest.mark.unit + def test_evaluates_operations(self): + node = _make_node("n1", NodeType.STATE, config={ + "config": { + "operations": [ + {"expression": "x + 1", "target_variable": "result"}, + ] + } + }) + graph = _make_graph([node], []) + engine = WorkflowEngine(graph, _make_agent()) + engine.state = {"x": 5} + list(engine._execute_state_node(node)) + assert engine.state["result"] == 6 + + @pytest.mark.unit + def test_skips_empty_expression(self): + node = _make_node("n1", NodeType.STATE, config={ + "config": { + "operations": [ + {"expression": "", "target_variable": "result"}, + ] + } + }) + graph = _make_graph([node], []) + engine = WorkflowEngine(graph, _make_agent()) + engine.state = {} + list(engine._execute_state_node(node)) + assert "result" not in engine.state + + +class TestExecuteConditionNode: + + @pytest.mark.unit + def test_matches_first_true_case(self): + node = _make_node("n1", NodeType.CONDITION, config={ + "config": { + "cases": [ + {"expression": "x > 10", "source_handle": "high"}, + {"expression": "x > 5", "source_handle": "medium"}, + ] + } + }) + graph = _make_graph([node], []) + engine = WorkflowEngine(graph, _make_agent()) + engine.state = {"x": 7} + list(engine._execute_condition_node(node)) + assert engine._condition_result == "medium" + + @pytest.mark.unit + def test_falls_through_to_else(self): + node = _make_node("n1", NodeType.CONDITION, config={ + "config": { + "cases": [ + {"expression": "x > 100", "source_handle": "high"}, + ] + } + }) + graph = _make_graph([node], []) + engine = WorkflowEngine(graph, _make_agent()) + engine.state = {"x": 1} + list(engine._execute_condition_node(node)) + assert engine._condition_result == "else" + + @pytest.mark.unit + def test_skips_empty_expression(self): + node = _make_node("n1", NodeType.CONDITION, config={ + "config": { + "cases": [ + {"expression": " ", "source_handle": "a"}, + {"expression": "true", "source_handle": "b"}, + ] + } + }) + graph = _make_graph([node], []) + engine = WorkflowEngine(graph, _make_agent()) + engine.state = {} + list(engine._execute_condition_node(node)) + assert engine._condition_result == "b" + + @pytest.mark.unit + def test_cel_error_continues(self): + node = _make_node("n1", NodeType.CONDITION, config={ + "config": { + "cases": [ + {"expression": "bad!!!", "source_handle": "a"}, + {"expression": "true", "source_handle": "b"}, + ] + } + }) + graph = _make_graph([node], []) + engine = WorkflowEngine(graph, _make_agent()) + engine.state = {} + list(engine._execute_condition_node(node)) + assert engine._condition_result == "b" + + +class TestExecuteEndNode: + + @pytest.mark.unit + def test_with_output_template(self): + node = _make_node("n1", NodeType.END, config={ + "config": {"output_template": "Result: {{ query }}"} + }) + graph = _make_graph([node], []) + engine = WorkflowEngine(graph, _make_agent()) + engine.state = {"query": "hello"} + engine._format_template = MagicMock(return_value="Result: hello") + events = list(engine._execute_end_node(node)) + assert len(events) == 1 + assert events[0]["answer"] == "Result: hello" + + @pytest.mark.unit + def test_without_output_template(self): + node = _make_node("n1", NodeType.END, config={"config": {}}) + graph = _make_graph([node], []) + engine = WorkflowEngine(graph, _make_agent()) + events = list(engine._execute_end_node(node)) + assert len(events) == 0 + + +class TestParseStructuredOutput: + + @pytest.mark.unit + def test_valid_json(self): + graph = _make_graph([], []) + engine = WorkflowEngine(graph, _make_agent()) + success, data = engine._parse_structured_output('{"key": "value"}') + assert success is True + assert data == {"key": "value"} + + @pytest.mark.unit + def test_invalid_json(self): + graph = _make_graph([], []) + engine = WorkflowEngine(graph, _make_agent()) + success, data = engine._parse_structured_output("not json") + assert success is False + assert data is None + + @pytest.mark.unit + def test_empty_string(self): + graph = _make_graph([], []) + engine = WorkflowEngine(graph, _make_agent()) + success, data = engine._parse_structured_output("") + assert success is False + assert data is None + + @pytest.mark.unit + def test_whitespace_only(self): + graph = _make_graph([], []) + engine = WorkflowEngine(graph, _make_agent()) + success, data = engine._parse_structured_output(" ") + assert success is False + assert data is None + + +class TestNormalizeNodeJsonSchema: + + @pytest.mark.unit + def test_none_returns_none(self): + graph = _make_graph([], []) + engine = WorkflowEngine(graph, _make_agent()) + assert engine._normalize_node_json_schema(None, "Node") is None + + @pytest.mark.unit + def test_valid_schema(self): + graph = _make_graph([], []) + engine = WorkflowEngine(graph, _make_agent()) + schema = {"type": "object", "properties": {"name": {"type": "string"}}} + result = engine._normalize_node_json_schema(schema, "Node") + assert result is not None + + @pytest.mark.unit + def test_invalid_schema_raises(self): + graph = _make_graph([], []) + engine = WorkflowEngine(graph, _make_agent()) + with patch("application.agents.workflows.workflow_engine.normalize_json_schema_payload") as mock_norm: + from application.core.json_schema_utils import JsonSchemaValidationError + mock_norm.side_effect = JsonSchemaValidationError("bad schema") + with pytest.raises(ValueError, match="Invalid JSON schema"): + engine._normalize_node_json_schema({"bad": True}, "TestNode") + + +class TestValidateStructuredOutput: + + @pytest.mark.unit + def test_valid_output_passes(self): + graph = _make_graph([], []) + engine = WorkflowEngine(graph, _make_agent()) + schema = {"type": "object", "properties": {"name": {"type": "string"}}} + engine._validate_structured_output(schema, {"name": "Alice"}) # Should not raise + + @pytest.mark.unit + def test_invalid_output_raises(self): + graph = _make_graph([], []) + engine = WorkflowEngine(graph, _make_agent()) + schema = {"type": "object", "properties": {"name": {"type": "string"}}, "required": ["name"]} + with pytest.raises(ValueError, match="did not match schema"): + engine._validate_structured_output(schema, {}) + + @pytest.mark.unit + def test_no_jsonschema_module(self): + graph = _make_graph([], []) + engine = WorkflowEngine(graph, _make_agent()) + with patch("application.agents.workflows.workflow_engine.jsonschema", None): + engine._validate_structured_output({"type": "object"}, {}) # Should not raise + + +class TestFormatTemplate: + + @pytest.mark.unit + def test_renders_template(self): + graph = _make_graph([], []) + engine = WorkflowEngine(graph, _make_agent()) + engine.state = {"query": "hello"} + engine._build_template_context = MagicMock(return_value={"query": "hello"}) + engine._template_engine = MagicMock() + engine._template_engine.render.return_value = "hello world" + result = engine._format_template("{{ query }} world") + assert result == "hello world" + + @pytest.mark.unit + def test_render_error_returns_raw(self): + from application.templates.template_engine import TemplateRenderError + graph = _make_graph([], []) + engine = WorkflowEngine(graph, _make_agent()) + engine._build_template_context = MagicMock(return_value={}) + engine._template_engine = MagicMock() + engine._template_engine.render.side_effect = TemplateRenderError("fail") + result = engine._format_template("{{ bad }}") + assert result == "{{ bad }}" + + +class TestBuildTemplateContext: + + @pytest.mark.unit + def test_includes_state_variables(self): + graph = _make_graph([], []) + agent = _make_agent() + agent.retrieved_docs = None + engine = WorkflowEngine(graph, agent) + engine.state = {"query": "hello", "custom_var": "value"} + context = engine._build_template_context() + assert context["agent"]["query"] == "hello" + assert "custom_var" in context + + @pytest.mark.unit + def test_reserved_namespace_gets_prefixed(self): + graph = _make_graph([], []) + agent = _make_agent() + agent.retrieved_docs = None + engine = WorkflowEngine(graph, agent) + engine.state = {"source": "my_source_val"} + context = engine._build_template_context() + assert context.get("agent_source") == "my_source_val" + + @pytest.mark.unit + def test_passthrough_data(self): + graph = _make_graph([], []) + agent = _make_agent() + agent.retrieved_docs = None + engine = WorkflowEngine(graph, agent) + engine.state = {"passthrough": {"key": "val"}} + context = engine._build_template_context() + assert "passthrough" in context or "agent_passthrough" in context + + @pytest.mark.unit + def test_tools_data(self): + graph = _make_graph([], []) + agent = _make_agent() + agent.retrieved_docs = None + engine = WorkflowEngine(graph, agent) + engine.state = {"tools": {"tool1": "result"}} + context = engine._build_template_context() + assert "agent" in context + + +class TestGetSourceTemplateData: + + @pytest.mark.unit + def test_no_docs_returns_none(self): + graph = _make_graph([], []) + agent = _make_agent() + agent.retrieved_docs = None + engine = WorkflowEngine(graph, agent) + docs, together = engine._get_source_template_data() + assert docs is None + assert together is None + + @pytest.mark.unit + def test_empty_docs_returns_none(self): + graph = _make_graph([], []) + agent = _make_agent() + agent.retrieved_docs = [] + engine = WorkflowEngine(graph, agent) + docs, together = engine._get_source_template_data() + assert docs is None + + @pytest.mark.unit + def test_docs_with_filename(self): + graph = _make_graph([], []) + agent = _make_agent() + agent.retrieved_docs = [{"text": "content", "filename": "doc.txt"}] + engine = WorkflowEngine(graph, agent) + docs, together = engine._get_source_template_data() + assert docs is not None + assert "doc.txt" in together + assert "content" in together + + @pytest.mark.unit + def test_docs_without_filename(self): + graph = _make_graph([], []) + agent = _make_agent() + agent.retrieved_docs = [{"text": "content only"}] + engine = WorkflowEngine(graph, agent) + docs, together = engine._get_source_template_data() + assert together == "content only" + + @pytest.mark.unit + def test_skips_non_dict_docs(self): + graph = _make_graph([], []) + agent = _make_agent() + agent.retrieved_docs = ["not a dict", {"text": "ok"}] + engine = WorkflowEngine(graph, agent) + docs, together = engine._get_source_template_data() + assert together == "ok" + + @pytest.mark.unit + def test_skips_non_string_text(self): + graph = _make_graph([], []) + agent = _make_agent() + agent.retrieved_docs = [{"text": 123}] + engine = WorkflowEngine(graph, agent) + docs, together = engine._get_source_template_data() + assert together is None + + @pytest.mark.unit + def test_doc_with_title_fallback(self): + graph = _make_graph([], []) + agent = _make_agent() + agent.retrieved_docs = [{"text": "content", "title": "doc_title"}] + engine = WorkflowEngine(graph, agent) + docs, together = engine._get_source_template_data() + assert "doc_title" in together + + @pytest.mark.unit + def test_doc_with_source_fallback(self): + graph = _make_graph([], []) + agent = _make_agent() + agent.retrieved_docs = [{"text": "content", "source": "src"}] + engine = WorkflowEngine(graph, agent) + docs, together = engine._get_source_template_data() + assert "src" in together + + +class TestGetExecutionSummary: + + @pytest.mark.unit + def test_returns_log_entries(self): + graph = _make_graph([], []) + engine = WorkflowEngine(graph, _make_agent()) + now = datetime.now(timezone.utc) + engine.execution_log = [ + { + "node_id": "n1", + "node_type": "start", + "status": "completed", + "started_at": now, + "completed_at": now, + "error": None, + "state_snapshot": {}, + } + ] + summary = engine.get_execution_summary() + assert len(summary) == 1 + assert summary[0].node_id == "n1" + assert summary[0].status == ExecutionStatus.COMPLETED + + @pytest.mark.unit + def test_empty_log(self): + graph = _make_graph([], []) + engine = WorkflowEngine(graph, _make_agent()) + assert engine.get_execution_summary() == [] diff --git a/tests/api/answer/test_stream_processor.py b/tests/api/answer/test_stream_processor.py new file mode 100644 index 00000000..1ba4a130 --- /dev/null +++ b/tests/api/answer/test_stream_processor.py @@ -0,0 +1,331 @@ +"""Tests for application/api/answer/services/stream_processor.py — get_prompt and helpers.""" + +from unittest.mock import MagicMock, patch + +import pytest + +from application.api.answer.services.stream_processor import get_prompt + + +class TestGetPrompt: + + @pytest.mark.unit + def test_default_preset(self): + prompt = get_prompt("default") + assert isinstance(prompt, str) + assert len(prompt) > 0 + + @pytest.mark.unit + def test_creative_preset(self): + prompt = get_prompt("creative") + assert isinstance(prompt, str) + + @pytest.mark.unit + def test_strict_preset(self): + prompt = get_prompt("strict") + assert isinstance(prompt, str) + + @pytest.mark.unit + def test_reduce_preset(self): + prompt = get_prompt("reduce") + assert isinstance(prompt, str) + + @pytest.mark.unit + def test_agentic_default_preset(self): + prompt = get_prompt("agentic_default") + assert isinstance(prompt, str) + + @pytest.mark.unit + def test_agentic_creative_preset(self): + prompt = get_prompt("agentic_creative") + assert isinstance(prompt, str) + + @pytest.mark.unit + def test_agentic_strict_preset(self): + prompt = get_prompt("agentic_strict") + assert isinstance(prompt, str) + + @pytest.mark.unit + def test_mongo_prompt_by_id(self): + mock_collection = MagicMock() + mock_collection.find_one.return_value = {"_id": "abc", "content": "Custom prompt"} + prompt = get_prompt("507f1f77bcf86cd799439011", prompts_collection=mock_collection) + assert prompt == "Custom prompt" + + @pytest.mark.unit + def test_mongo_prompt_not_found_raises(self): + mock_collection = MagicMock() + mock_collection.find_one.return_value = None + with pytest.raises(ValueError, match="Invalid prompt ID"): + get_prompt("507f1f77bcf86cd799439011", prompts_collection=mock_collection) + + @pytest.mark.unit + def test_invalid_id_raises(self): + mock_collection = MagicMock() + mock_collection.find_one.side_effect = Exception("bad id") + with pytest.raises(ValueError, match="Invalid prompt ID"): + get_prompt("not-an-objectid", prompts_collection=mock_collection) + + @pytest.mark.unit + def test_mongo_fallback_when_no_collection(self): + """When no collection passed, it reads from MongoDB.""" + mock_collection = MagicMock() + mock_collection.find_one.return_value = {"content": "From DB"} + mock_db = MagicMock() + mock_db.__getitem__ = MagicMock(return_value=mock_collection) + + with patch("application.api.answer.services.stream_processor.MongoDB") as MockMongo, \ + patch("application.api.answer.services.stream_processor.settings") as mock_settings: + mock_settings.MONGO_DB_NAME = "test_db" + MockMongo.get_client.return_value = {"test_db": mock_db} + prompt = get_prompt("507f1f77bcf86cd799439011") + assert prompt == "From DB" + + +class TestStreamProcessorInit: + + @pytest.mark.unit + def test_init_sets_attributes(self): + mock_db = MagicMock() + mock_client = {"docsgpt": mock_db} + + with patch("application.api.answer.services.stream_processor.MongoDB") as MockMongo, \ + patch("application.api.answer.services.stream_processor.settings") as mock_settings: + mock_settings.MONGO_DB_NAME = "docsgpt" + MockMongo.get_client.return_value = mock_client + + from application.api.answer.services.stream_processor import StreamProcessor + sp = StreamProcessor( + request_data={"conversation_id": "conv1", "agent_id": "a1"}, + decoded_token={"sub": "user1"}, + ) + assert sp.conversation_id == "conv1" + assert sp.initial_user_id == "user1" + assert sp.agent_id == "a1" + assert sp.history == [] + assert sp.attachments == [] + + @pytest.mark.unit + def test_init_no_token(self): + mock_db = MagicMock() + mock_client = {"docsgpt": mock_db} + + with patch("application.api.answer.services.stream_processor.MongoDB") as MockMongo, \ + patch("application.api.answer.services.stream_processor.settings") as mock_settings: + mock_settings.MONGO_DB_NAME = "docsgpt" + MockMongo.get_client.return_value = mock_client + + from application.api.answer.services.stream_processor import StreamProcessor + sp = StreamProcessor(request_data={}, decoded_token=None) + assert sp.initial_user_id is None + + +class TestGetAttachmentsContent: + + @pytest.mark.unit + def test_empty_ids_returns_empty(self): + mock_db = MagicMock() + with patch("application.api.answer.services.stream_processor.MongoDB") as MockMongo, \ + patch("application.api.answer.services.stream_processor.settings") as mock_settings: + mock_settings.MONGO_DB_NAME = "docsgpt" + MockMongo.get_client.return_value = {"docsgpt": mock_db} + + from application.api.answer.services.stream_processor import StreamProcessor + sp = StreamProcessor(request_data={}, decoded_token={"sub": "u"}) + result = sp._get_attachments_content([], "u") + assert result == [] + + @pytest.mark.unit + def test_returns_matching_attachments(self): + mock_db = MagicMock() + mock_attachments = MagicMock() + mock_attachments.find_one.return_value = {"_id": "att1", "content": "data"} + mock_db.__getitem__ = MagicMock(return_value=mock_attachments) + + with patch("application.api.answer.services.stream_processor.MongoDB") as MockMongo, \ + patch("application.api.answer.services.stream_processor.settings") as mock_settings: + mock_settings.MONGO_DB_NAME = "docsgpt" + MockMongo.get_client.return_value = {"docsgpt": mock_db} + + from application.api.answer.services.stream_processor import StreamProcessor + sp = StreamProcessor(request_data={}, decoded_token={"sub": "u"}) + result = sp._get_attachments_content(["507f1f77bcf86cd799439011"], "u") + assert len(result) == 1 + + @pytest.mark.unit + def test_invalid_attachment_id_continues(self): + mock_db = MagicMock() + mock_attachments = MagicMock() + mock_attachments.find_one.side_effect = Exception("bad id") + mock_db.__getitem__ = MagicMock(return_value=mock_attachments) + + with patch("application.api.answer.services.stream_processor.MongoDB") as MockMongo, \ + patch("application.api.answer.services.stream_processor.settings") as mock_settings: + mock_settings.MONGO_DB_NAME = "docsgpt" + MockMongo.get_client.return_value = {"docsgpt": mock_db} + + from application.api.answer.services.stream_processor import StreamProcessor + sp = StreamProcessor(request_data={}, decoded_token={"sub": "u"}) + result = sp._get_attachments_content(["bad"], "u") + assert result == [] + + +class TestResolveAgentId: + + @pytest.mark.unit + def test_from_request_data(self): + mock_db = MagicMock() + with patch("application.api.answer.services.stream_processor.MongoDB") as MockMongo, \ + patch("application.api.answer.services.stream_processor.settings") as mock_settings: + mock_settings.MONGO_DB_NAME = "docsgpt" + MockMongo.get_client.return_value = {"docsgpt": mock_db} + + from application.api.answer.services.stream_processor import StreamProcessor + sp = StreamProcessor( + request_data={"agent_id": "agent_123"}, + decoded_token={"sub": "u"}, + ) + assert sp._resolve_agent_id() == "agent_123" + + @pytest.mark.unit + def test_no_agent_no_conversation(self): + mock_db = MagicMock() + with patch("application.api.answer.services.stream_processor.MongoDB") as MockMongo, \ + patch("application.api.answer.services.stream_processor.settings") as mock_settings: + mock_settings.MONGO_DB_NAME = "docsgpt" + MockMongo.get_client.return_value = {"docsgpt": mock_db} + + from application.api.answer.services.stream_processor import StreamProcessor + sp = StreamProcessor(request_data={}, decoded_token={"sub": "u"}) + assert sp._resolve_agent_id() is None + + @pytest.mark.unit + def test_from_conversation(self): + mock_db = MagicMock() + with patch("application.api.answer.services.stream_processor.MongoDB") as MockMongo, \ + patch("application.api.answer.services.stream_processor.settings") as mock_settings: + mock_settings.MONGO_DB_NAME = "docsgpt" + MockMongo.get_client.return_value = {"docsgpt": mock_db} + + from application.api.answer.services.stream_processor import StreamProcessor + sp = StreamProcessor( + request_data={"conversation_id": "conv1"}, + decoded_token={"sub": "u"}, + ) + sp.conversation_service = MagicMock() + sp.conversation_service.get_conversation.return_value = {"agent_id": "from_conv"} + assert sp._resolve_agent_id() == "from_conv" + + @pytest.mark.unit + def test_conversation_not_found(self): + mock_db = MagicMock() + with patch("application.api.answer.services.stream_processor.MongoDB") as MockMongo, \ + patch("application.api.answer.services.stream_processor.settings") as mock_settings: + mock_settings.MONGO_DB_NAME = "docsgpt" + MockMongo.get_client.return_value = {"docsgpt": mock_db} + + from application.api.answer.services.stream_processor import StreamProcessor + sp = StreamProcessor( + request_data={"conversation_id": "conv1"}, + decoded_token={"sub": "u"}, + ) + sp.conversation_service = MagicMock() + sp.conversation_service.get_conversation.return_value = None + assert sp._resolve_agent_id() is None + + @pytest.mark.unit + def test_conversation_lookup_exception(self): + mock_db = MagicMock() + with patch("application.api.answer.services.stream_processor.MongoDB") as MockMongo, \ + patch("application.api.answer.services.stream_processor.settings") as mock_settings: + mock_settings.MONGO_DB_NAME = "docsgpt" + MockMongo.get_client.return_value = {"docsgpt": mock_db} + + from application.api.answer.services.stream_processor import StreamProcessor + sp = StreamProcessor( + request_data={"conversation_id": "conv1"}, + decoded_token={"sub": "u"}, + ) + sp.conversation_service = MagicMock() + sp.conversation_service.get_conversation.side_effect = Exception("db error") + assert sp._resolve_agent_id() is None + + +class TestGetPromptContent: + + @pytest.mark.unit + def test_caches_result(self): + mock_db = MagicMock() + with patch("application.api.answer.services.stream_processor.MongoDB") as MockMongo, \ + patch("application.api.answer.services.stream_processor.settings") as mock_settings: + mock_settings.MONGO_DB_NAME = "docsgpt" + MockMongo.get_client.return_value = {"docsgpt": mock_db} + + from application.api.answer.services.stream_processor import StreamProcessor + sp = StreamProcessor(request_data={}, decoded_token={"sub": "u"}) + sp.agent_config = {"prompt_id": "default"} + result1 = sp._get_prompt_content() + result2 = sp._get_prompt_content() + assert result1 == result2 + assert result1 is not None + + @pytest.mark.unit + def test_no_prompt_id(self): + mock_db = MagicMock() + with patch("application.api.answer.services.stream_processor.MongoDB") as MockMongo, \ + patch("application.api.answer.services.stream_processor.settings") as mock_settings: + mock_settings.MONGO_DB_NAME = "docsgpt" + MockMongo.get_client.return_value = {"docsgpt": mock_db} + + from application.api.answer.services.stream_processor import StreamProcessor + sp = StreamProcessor(request_data={}, decoded_token={"sub": "u"}) + sp.agent_config = {} + assert sp._get_prompt_content() is None + + @pytest.mark.unit + def test_invalid_prompt_id_returns_none(self): + mock_db = MagicMock() + mock_prompts = MagicMock() + mock_prompts.find_one.side_effect = Exception("bad") + mock_db.__getitem__ = MagicMock(return_value=mock_prompts) + + with patch("application.api.answer.services.stream_processor.MongoDB") as MockMongo, \ + patch("application.api.answer.services.stream_processor.settings") as mock_settings: + mock_settings.MONGO_DB_NAME = "docsgpt" + MockMongo.get_client.return_value = {"docsgpt": mock_db} + + from application.api.answer.services.stream_processor import StreamProcessor + sp = StreamProcessor(request_data={}, decoded_token={"sub": "u"}) + sp.agent_config = {"prompt_id": "bad_id"} + assert sp._get_prompt_content() is None + + +class TestGetRequiredToolActions: + + @pytest.mark.unit + def test_no_prompt_returns_none(self): + mock_db = MagicMock() + with patch("application.api.answer.services.stream_processor.MongoDB") as MockMongo, \ + patch("application.api.answer.services.stream_processor.settings") as mock_settings: + mock_settings.MONGO_DB_NAME = "docsgpt" + MockMongo.get_client.return_value = {"docsgpt": mock_db} + + from application.api.answer.services.stream_processor import StreamProcessor + sp = StreamProcessor(request_data={}, decoded_token={"sub": "u"}) + sp.agent_config = {} + assert sp._get_required_tool_actions() is None + + @pytest.mark.unit + def test_no_template_syntax_returns_empty(self): + mock_db = MagicMock() + with patch("application.api.answer.services.stream_processor.MongoDB") as MockMongo, \ + patch("application.api.answer.services.stream_processor.settings") as mock_settings: + mock_settings.MONGO_DB_NAME = "docsgpt" + MockMongo.get_client.return_value = {"docsgpt": mock_db} + + from application.api.answer.services.stream_processor import StreamProcessor + sp = StreamProcessor(request_data={}, decoded_token={"sub": "u"}) + sp.agent_config = {"prompt_id": "default"} + sp._prompt_content = "No template syntax here" + result = sp._get_required_tool_actions() + assert result == {} diff --git a/tests/api/test_connector_routes.py b/tests/api/test_connector_routes.py new file mode 100644 index 00000000..9a3ccbbb --- /dev/null +++ b/tests/api/test_connector_routes.py @@ -0,0 +1,333 @@ +"""Tests for application/api/connector/routes.py""" + +import json +from unittest.mock import MagicMock, patch + +import mongomock +import pytest + + +@pytest.fixture +def app(): + with patch("application.app.handle_auth", return_value={"sub": "test_user"}): + from application.app import app as flask_app + flask_app.config["TESTING"] = True + yield flask_app + + +@pytest.fixture +def client(app): + return app.test_client() + + +@pytest.fixture +def mock_sessions(monkeypatch): + mock_client = mongomock.MongoClient() + mock_db = mock_client["docsgpt"] + sessions = mock_db["connector_sessions"] + sources = mock_db["sources"] + monkeypatch.setattr("application.api.connector.routes.sessions_collection", sessions) + monkeypatch.setattr("application.api.connector.routes.sources_collection", sources) + return {"sessions": sessions, "sources": sources} + + +class TestConnectorAuth: + + @pytest.mark.unit + def test_missing_provider(self, client): + resp = client.get("/api/connectors/auth") + assert resp.status_code == 400 + + @pytest.mark.unit + def test_unsupported_provider(self, client): + resp = client.get("/api/connectors/auth?provider=dropbox") + assert resp.status_code == 400 + + @pytest.mark.unit + def test_unauthorized(self, client, app): + with patch("application.app.handle_auth", return_value=None): + resp = client.get("/api/connectors/auth?provider=google_drive") + data = json.loads(resp.data) + # decoded_token is None -> 401 + assert resp.status_code == 401 or data.get("error") == "Unauthorized" + + @pytest.mark.unit + def test_success(self, client, mock_sessions): + with patch("application.api.connector.routes.ConnectorCreator") as MockCC: + MockCC.is_supported.return_value = True + mock_auth = MagicMock() + mock_auth.get_authorization_url.return_value = "https://oauth.example.com/auth" + MockCC.create_auth.return_value = mock_auth + + resp = client.get("/api/connectors/auth?provider=google_drive") + assert resp.status_code == 200 + data = json.loads(resp.data) + assert data["success"] is True + assert "authorization_url" in data + + @pytest.mark.unit + def test_exception_returns_500(self, client, mock_sessions): + with patch("application.api.connector.routes.ConnectorCreator") as MockCC: + MockCC.is_supported.return_value = True + MockCC.create_auth.side_effect = Exception("oauth fail") + resp = client.get("/api/connectors/auth?provider=google_drive") + assert resp.status_code == 500 + + +class TestConnectorFiles: + + @pytest.mark.unit + def test_missing_params(self, client): + resp = client.post("/api/connectors/files", json={"provider": "google_drive"}) + assert resp.status_code == 400 + + @pytest.mark.unit + def test_invalid_session(self, client, mock_sessions): + resp = client.post("/api/connectors/files", json={ + "provider": "google_drive", + "session_token": "bad_token", + }) + assert resp.status_code == 401 + + @pytest.mark.unit + def test_success(self, client, mock_sessions): + mock_sessions["sessions"].insert_one({ + "session_token": "valid_tok", + "user": "test_user", + "provider": "google_drive", + }) + + mock_doc = MagicMock() + mock_doc.doc_id = "f1" + mock_doc.extra_info = { + "file_name": "test.pdf", + "mime_type": "application/pdf", + "size": 1024, + "modified_time": "2025-01-01T12:00:00.000Z", + "is_folder": False, + } + mock_loader = MagicMock() + mock_loader.load_data.return_value = [mock_doc] + mock_loader.next_page_token = None + + with patch("application.api.connector.routes.ConnectorCreator") as MockCC: + MockCC.create_connector.return_value = mock_loader + resp = client.post("/api/connectors/files", json={ + "provider": "google_drive", + "session_token": "valid_tok", + }) + assert resp.status_code == 200 + data = json.loads(resp.data) + assert data["success"] is True + assert len(data["files"]) == 1 + + @pytest.mark.unit + def test_no_modified_time(self, client, mock_sessions): + mock_sessions["sessions"].insert_one({ + "session_token": "tok2", + "user": "test_user", + "provider": "google_drive", + }) + mock_doc = MagicMock() + mock_doc.doc_id = "f1" + mock_doc.extra_info = {"file_name": "test.pdf", "mime_type": "application/pdf"} + mock_loader = MagicMock() + mock_loader.load_data.return_value = [mock_doc] + mock_loader.next_page_token = None + + with patch("application.api.connector.routes.ConnectorCreator") as MockCC: + MockCC.create_connector.return_value = mock_loader + resp = client.post("/api/connectors/files", json={ + "provider": "google_drive", "session_token": "tok2", + }) + assert resp.status_code == 200 + + +class TestConnectorValidateSession: + + @pytest.mark.unit + def test_missing_params(self, client): + resp = client.post("/api/connectors/validate-session", json={"provider": "google_drive"}) + assert resp.status_code == 400 + + @pytest.mark.unit + def test_invalid_session(self, client, mock_sessions): + resp = client.post("/api/connectors/validate-session", json={ + "provider": "google_drive", "session_token": "bad", + }) + assert resp.status_code == 401 + + @pytest.mark.unit + def test_valid_non_expired(self, client, mock_sessions): + mock_sessions["sessions"].insert_one({ + "session_token": "valid", + "user": "test_user", + "provider": "google_drive", + "token_info": {"access_token": "at", "refresh_token": "rt", "expiry": None}, + "user_email": "user@example.com", + }) + with patch("application.api.connector.routes.ConnectorCreator") as MockCC: + mock_auth = MagicMock() + mock_auth.is_token_expired.return_value = False + MockCC.create_auth.return_value = mock_auth + resp = client.post("/api/connectors/validate-session", json={ + "provider": "google_drive", "session_token": "valid", + }) + assert resp.status_code == 200 + data = json.loads(resp.data) + assert data["success"] is True + assert data["expired"] is False + + @pytest.mark.unit + def test_expired_with_refresh(self, client, mock_sessions): + mock_sessions["sessions"].insert_one({ + "session_token": "expired_tok", + "user": "test_user", + "provider": "google_drive", + "token_info": {"access_token": "old_at", "refresh_token": "rt", "expiry": 100}, + }) + with patch("application.api.connector.routes.ConnectorCreator") as MockCC: + mock_auth = MagicMock() + mock_auth.is_token_expired.return_value = True + mock_auth.refresh_access_token.return_value = {"access_token": "new_at", "refresh_token": "rt"} + mock_auth.sanitize_token_info.return_value = {"access_token": "new_at", "refresh_token": "rt"} + MockCC.create_auth.return_value = mock_auth + resp = client.post("/api/connectors/validate-session", json={ + "provider": "google_drive", "session_token": "expired_tok", + }) + assert resp.status_code == 200 + + @pytest.mark.unit + def test_expired_no_refresh(self, client, mock_sessions): + mock_sessions["sessions"].insert_one({ + "session_token": "exp_no_ref", + "user": "test_user", + "token_info": {"access_token": "at", "expiry": 100}, + }) + with patch("application.api.connector.routes.ConnectorCreator") as MockCC: + mock_auth = MagicMock() + mock_auth.is_token_expired.return_value = True + MockCC.create_auth.return_value = mock_auth + resp = client.post("/api/connectors/validate-session", json={ + "provider": "google_drive", "session_token": "exp_no_ref", + }) + assert resp.status_code == 401 + + +class TestConnectorDisconnect: + + @pytest.mark.unit + def test_missing_provider(self, client): + resp = client.post("/api/connectors/disconnect", json={}) + assert resp.status_code == 400 + + @pytest.mark.unit + def test_success_with_session(self, client, mock_sessions): + mock_sessions["sessions"].insert_one({"session_token": "del_me", "provider": "google_drive"}) + resp = client.post("/api/connectors/disconnect", json={ + "provider": "google_drive", "session_token": "del_me", + }) + assert resp.status_code == 200 + data = json.loads(resp.data) + assert data["success"] is True + + @pytest.mark.unit + def test_success_without_session(self, client, mock_sessions): + resp = client.post("/api/connectors/disconnect", json={"provider": "google_drive"}) + assert resp.status_code == 200 + + +class TestConnectorSync: + + @pytest.mark.unit + def test_missing_params(self, client, mock_sessions): + resp = client.post("/api/connectors/sync", json={"source_id": "abc"}) + assert resp.status_code == 400 + + @pytest.mark.unit + def test_source_not_found(self, client, mock_sessions): + from bson.objectid import ObjectId + resp = client.post("/api/connectors/sync", json={ + "source_id": str(ObjectId()), "session_token": "tok", + }) + assert resp.status_code == 404 + + @pytest.mark.unit + def test_unauthorized_source(self, client, mock_sessions): + sid = mock_sessions["sources"].insert_one({"user": "other_user", "name": "src"}).inserted_id + resp = client.post("/api/connectors/sync", json={ + "source_id": str(sid), "session_token": "tok", + }) + assert resp.status_code == 403 + + @pytest.mark.unit + def test_missing_provider_in_remote_data(self, client, mock_sessions): + sid = mock_sessions["sources"].insert_one({ + "user": "test_user", "name": "src", "remote_data": json.dumps({}), + }).inserted_id + resp = client.post("/api/connectors/sync", json={ + "source_id": str(sid), "session_token": "tok", + }) + assert resp.status_code == 400 + + @pytest.mark.unit + def test_success(self, client, mock_sessions): + sid = mock_sessions["sources"].insert_one({ + "user": "test_user", + "name": "src", + "remote_data": json.dumps({"provider": "google_drive", "file_ids": ["f1"]}), + }).inserted_id + mock_task = MagicMock() + mock_task.id = "task_123" + with patch("application.api.connector.routes.ingest_connector_task") as mock_ingest: + mock_ingest.delay.return_value = mock_task + resp = client.post("/api/connectors/sync", json={ + "source_id": str(sid), "session_token": "tok", + }) + assert resp.status_code == 200 + data = json.loads(resp.data) + assert data["task_id"] == "task_123" + + +class TestConnectorCallbackStatus: + + @pytest.mark.unit + def test_success_status(self, client): + resp = client.get("/api/connectors/callback-status?status=success&message=OK&provider=google_drive&session_token=tok&user_email=u@e.com") + assert resp.status_code == 200 + assert b"success" in resp.data + + @pytest.mark.unit + def test_error_status(self, client): + resp = client.get("/api/connectors/callback-status?status=error&message=Failed") + assert resp.status_code == 200 + assert b"error" in resp.data + + @pytest.mark.unit + def test_cancelled_status(self, client): + resp = client.get("/api/connectors/callback-status?status=cancelled&message=Cancelled&provider=google_drive") + assert resp.status_code == 200 + assert b"cancelled" in resp.data + + @pytest.mark.unit + def test_unknown_status_defaults_to_error(self, client): + resp = client.get("/api/connectors/callback-status?status=badvalue") + assert resp.status_code == 200 + assert b"error" in resp.data + + @pytest.mark.unit + def test_html_escaping(self, client): + resp = client.get('/api/connectors/callback-status?status=error&message=') + assert resp.status_code == 200 + # The raw " not in resp.data + + +class TestBuildCallbackRedirect: + + @pytest.mark.unit + def test_builds_url(self): + from application.api.connector.routes import build_callback_redirect + url = build_callback_redirect({"status": "success", "message": "OK"}) + assert url.startswith("/api/connectors/callback-status?") + assert "status=success" in url diff --git a/tests/core/test_model_settings.py b/tests/core/test_model_settings.py new file mode 100644 index 00000000..7257df4a --- /dev/null +++ b/tests/core/test_model_settings.py @@ -0,0 +1,339 @@ +"""Tests for application/core/model_settings.py""" + +from unittest.mock import MagicMock, patch + +import pytest + +from application.core.model_settings import ( + AvailableModel, + ModelCapabilities, + ModelProvider, + ModelRegistry, +) + + +class TestModelProvider: + + @pytest.mark.unit + def test_all_providers_exist(self): + assert ModelProvider.OPENAI == "openai" + assert ModelProvider.ANTHROPIC == "anthropic" + assert ModelProvider.GOOGLE == "google" + assert ModelProvider.GROQ == "groq" + assert ModelProvider.DOCSGPT == "docsgpt" + assert ModelProvider.HUGGINGFACE == "huggingface" + assert ModelProvider.NOVITA == "novita" + assert ModelProvider.OPENROUTER == "openrouter" + assert ModelProvider.SAGEMAKER == "sagemaker" + assert ModelProvider.PREMAI == "premai" + assert ModelProvider.LLAMA_CPP == "llama.cpp" + assert ModelProvider.AZURE_OPENAI == "azure_openai" + + +class TestModelCapabilities: + + @pytest.mark.unit + def test_defaults(self): + caps = ModelCapabilities() + assert caps.supports_tools is False + assert caps.supports_structured_output is False + assert caps.supports_streaming is True + assert caps.supported_attachment_types == [] + assert caps.context_window == 128000 + assert caps.input_cost_per_token is None + assert caps.output_cost_per_token is None + + @pytest.mark.unit + def test_custom_values(self): + caps = ModelCapabilities( + supports_tools=True, + supports_structured_output=True, + context_window=32000, + input_cost_per_token=0.001, + ) + assert caps.supports_tools is True + assert caps.context_window == 32000 + + +class TestAvailableModel: + + @pytest.mark.unit + def test_to_dict_basic(self): + model = AvailableModel( + id="gpt-4", + provider=ModelProvider.OPENAI, + display_name="GPT-4", + description="OpenAI GPT-4", + ) + d = model.to_dict() + assert d["id"] == "gpt-4" + assert d["provider"] == "openai" + assert d["display_name"] == "GPT-4" + assert d["enabled"] is True + assert "base_url" not in d + + @pytest.mark.unit + def test_to_dict_with_base_url(self): + model = AvailableModel( + id="local-model", + provider=ModelProvider.OPENAI, + display_name="Local", + base_url="http://localhost:11434", + ) + d = model.to_dict() + assert d["base_url"] == "http://localhost:11434" + + @pytest.mark.unit + def test_to_dict_includes_capabilities(self): + caps = ModelCapabilities(supports_tools=True, context_window=64000) + model = AvailableModel( + id="m1", + provider=ModelProvider.ANTHROPIC, + display_name="M1", + capabilities=caps, + ) + d = model.to_dict() + assert d["supports_tools"] is True + assert d["context_window"] == 64000 + + +class TestModelRegistry: + + @pytest.fixture(autouse=True) + def _reset_singleton(self): + """Reset singleton between tests.""" + ModelRegistry._instance = None + ModelRegistry._initialized = False + yield + ModelRegistry._instance = None + ModelRegistry._initialized = False + + @pytest.mark.unit + def test_singleton(self): + with patch.object(ModelRegistry, "_load_models"): + r1 = ModelRegistry() + r2 = ModelRegistry() + assert r1 is r2 + + @pytest.mark.unit + def test_get_instance(self): + with patch.object(ModelRegistry, "_load_models"): + r = ModelRegistry.get_instance() + assert isinstance(r, ModelRegistry) + + @pytest.mark.unit + def test_get_model(self): + with patch.object(ModelRegistry, "_load_models"): + reg = ModelRegistry() + model = AvailableModel(id="test", provider=ModelProvider.OPENAI, display_name="Test") + reg.models["test"] = model + assert reg.get_model("test") is model + assert reg.get_model("nonexistent") is None + + @pytest.mark.unit + def test_get_all_models(self): + with patch.object(ModelRegistry, "_load_models"): + reg = ModelRegistry() + reg.models["m1"] = AvailableModel(id="m1", provider=ModelProvider.OPENAI, display_name="M1") + reg.models["m2"] = AvailableModel(id="m2", provider=ModelProvider.ANTHROPIC, display_name="M2") + assert len(reg.get_all_models()) == 2 + + @pytest.mark.unit + def test_get_enabled_models(self): + with patch.object(ModelRegistry, "_load_models"): + reg = ModelRegistry() + reg.models["m1"] = AvailableModel(id="m1", provider=ModelProvider.OPENAI, display_name="M1", enabled=True) + reg.models["m2"] = AvailableModel(id="m2", provider=ModelProvider.OPENAI, display_name="M2", enabled=False) + enabled = reg.get_enabled_models() + assert len(enabled) == 1 + assert enabled[0].id == "m1" + + @pytest.mark.unit + def test_model_exists(self): + with patch.object(ModelRegistry, "_load_models"): + reg = ModelRegistry() + reg.models["m1"] = AvailableModel(id="m1", provider=ModelProvider.OPENAI, display_name="M1") + assert reg.model_exists("m1") is True + assert reg.model_exists("m2") is False + + @pytest.mark.unit + def test_parse_model_names(self): + with patch.object(ModelRegistry, "_load_models"): + reg = ModelRegistry() + assert reg._parse_model_names("model1,model2") == ["model1", "model2"] + assert reg._parse_model_names("model1 , model2 ") == ["model1", "model2"] + assert reg._parse_model_names("single") == ["single"] + assert reg._parse_model_names("") == [] + assert reg._parse_model_names(None) == [] + + @pytest.mark.unit + def test_add_docsgpt_models(self): + with patch.object(ModelRegistry, "_load_models"): + reg = ModelRegistry() + reg.models = {} + mock_settings = MagicMock() + reg._add_docsgpt_models(mock_settings) + assert "docsgpt-local" in reg.models + + @pytest.mark.unit + def test_add_huggingface_models(self): + with patch.object(ModelRegistry, "_load_models"): + reg = ModelRegistry() + reg.models = {} + mock_settings = MagicMock() + reg._add_huggingface_models(mock_settings) + assert "huggingface-local" in reg.models + + @pytest.mark.unit + def test_load_models_with_openai_key(self): + mock_settings = MagicMock() + mock_settings.OPENAI_BASE_URL = None + mock_settings.OPENAI_API_KEY = "sk-test" + mock_settings.OPENAI_API_BASE = None + mock_settings.ANTHROPIC_API_KEY = None + mock_settings.GOOGLE_API_KEY = None + mock_settings.GROQ_API_KEY = None + mock_settings.OPEN_ROUTER_API_KEY = None + mock_settings.NOVITA_API_KEY = None + mock_settings.HUGGINGFACE_API_KEY = None + mock_settings.LLM_PROVIDER = "openai" + mock_settings.LLM_NAME = "" + mock_settings.API_KEY = None + + with patch("application.core.settings.settings", mock_settings): + reg = ModelRegistry() + assert len(reg.models) > 0 + + @pytest.mark.unit + def test_load_models_custom_openai_base_url(self): + mock_settings = MagicMock() + mock_settings.OPENAI_BASE_URL = "http://localhost:11434/v1" + mock_settings.OPENAI_API_KEY = "sk-test" + mock_settings.OPENAI_API_BASE = None + mock_settings.ANTHROPIC_API_KEY = None + mock_settings.GOOGLE_API_KEY = None + mock_settings.GROQ_API_KEY = None + mock_settings.OPEN_ROUTER_API_KEY = None + mock_settings.NOVITA_API_KEY = None + mock_settings.HUGGINGFACE_API_KEY = None + mock_settings.LLM_PROVIDER = "openai" + mock_settings.LLM_NAME = "llama3,gemma" + mock_settings.API_KEY = None + + with patch("application.core.settings.settings", mock_settings): + reg = ModelRegistry() + assert "llama3" in reg.models + assert "gemma" in reg.models + + @pytest.mark.unit + def test_default_model_selection_from_llm_name(self): + with patch.object(ModelRegistry, "_load_models"): + reg = ModelRegistry() + reg.models = {"gpt-4": AvailableModel(id="gpt-4", provider=ModelProvider.OPENAI, display_name="GPT-4")} + reg.default_model_id = "gpt-4" + assert reg.default_model_id == "gpt-4" + + @pytest.mark.unit + def test_add_anthropic_models_with_key(self): + with patch.object(ModelRegistry, "_load_models"): + reg = ModelRegistry() + reg.models = {} + mock_settings = MagicMock() + mock_settings.ANTHROPIC_API_KEY = "sk-ant-test" + mock_settings.LLM_PROVIDER = "" + mock_settings.LLM_NAME = "" + reg._add_anthropic_models(mock_settings) + assert len(reg.models) > 0 + + @pytest.mark.unit + def test_add_google_models_with_key(self): + with patch.object(ModelRegistry, "_load_models"): + reg = ModelRegistry() + reg.models = {} + mock_settings = MagicMock() + mock_settings.GOOGLE_API_KEY = "google-test" + mock_settings.LLM_PROVIDER = "" + mock_settings.LLM_NAME = "" + reg._add_google_models(mock_settings) + assert len(reg.models) > 0 + + @pytest.mark.unit + def test_add_groq_models_with_key(self): + with patch.object(ModelRegistry, "_load_models"): + reg = ModelRegistry() + reg.models = {} + mock_settings = MagicMock() + mock_settings.GROQ_API_KEY = "groq-test" + mock_settings.LLM_PROVIDER = "" + mock_settings.LLM_NAME = "" + reg._add_groq_models(mock_settings) + assert len(reg.models) > 0 + + @pytest.mark.unit + def test_add_openrouter_models_with_key(self): + with patch.object(ModelRegistry, "_load_models"): + reg = ModelRegistry() + reg.models = {} + mock_settings = MagicMock() + mock_settings.OPEN_ROUTER_API_KEY = "or-test" + mock_settings.LLM_PROVIDER = "" + mock_settings.LLM_NAME = "" + reg._add_openrouter_models(mock_settings) + assert len(reg.models) > 0 + + @pytest.mark.unit + def test_add_novita_models_with_key(self): + with patch.object(ModelRegistry, "_load_models"): + reg = ModelRegistry() + reg.models = {} + mock_settings = MagicMock() + mock_settings.NOVITA_API_KEY = "novita-test" + mock_settings.LLM_PROVIDER = "" + mock_settings.LLM_NAME = "" + reg._add_novita_models(mock_settings) + assert len(reg.models) > 0 + + @pytest.mark.unit + def test_add_azure_openai_models_specific(self): + with patch.object(ModelRegistry, "_load_models"): + reg = ModelRegistry() + reg.models = {} + mock_settings = MagicMock() + mock_settings.LLM_PROVIDER = "azure_openai" + mock_settings.LLM_NAME = "nonexistent-model" + reg._add_azure_openai_models(mock_settings) + # Falls through to adding all azure models + assert len(reg.models) > 0 + + @pytest.mark.unit + def test_add_anthropic_models_no_key_with_provider(self): + with patch.object(ModelRegistry, "_load_models"): + reg = ModelRegistry() + reg.models = {} + mock_settings = MagicMock() + mock_settings.ANTHROPIC_API_KEY = None + mock_settings.LLM_PROVIDER = "anthropic" + mock_settings.LLM_NAME = "nonexistent" + reg._add_anthropic_models(mock_settings) + assert len(reg.models) > 0 + + @pytest.mark.unit + def test_default_model_fallback_to_first(self): + mock_settings = MagicMock() + mock_settings.OPENAI_BASE_URL = None + mock_settings.OPENAI_API_KEY = None + mock_settings.OPENAI_API_BASE = None + mock_settings.ANTHROPIC_API_KEY = None + mock_settings.GOOGLE_API_KEY = None + mock_settings.GROQ_API_KEY = None + mock_settings.OPEN_ROUTER_API_KEY = None + mock_settings.NOVITA_API_KEY = None + mock_settings.HUGGINGFACE_API_KEY = None + mock_settings.LLM_PROVIDER = "" + mock_settings.LLM_NAME = "" + mock_settings.API_KEY = None + + with patch("application.core.settings.settings", mock_settings): + reg = ModelRegistry() + # Should have at least docsgpt-local + assert reg.default_model_id is not None diff --git a/tests/test_app_routes.py b/tests/test_app_routes.py new file mode 100644 index 00000000..eb5fe26f --- /dev/null +++ b/tests/test_app_routes.py @@ -0,0 +1,115 @@ +"""Tests for application/app.py route handlers.""" + +import json +from unittest.mock import patch + +import pytest + + +@pytest.fixture +def app(): + """Import the Flask app with auth mocked to avoid JWT setup issues.""" + with patch("application.app.handle_auth", return_value={"sub": "test_user"}): + from application.app import app as flask_app + flask_app.config["TESTING"] = True + yield flask_app + + +@pytest.fixture +def client(app): + return app.test_client() + + +class TestHomeRoute: + + @pytest.mark.unit + def test_root_returns_200(self, client): + """Root serves Swagger UI via Flask-RESTX.""" + response = client.get("/") + assert response.status_code == 200 + + +class TestHealthRoute: + + @pytest.mark.unit + def test_returns_ok(self, client): + response = client.get("/api/health") + assert response.status_code == 200 + data = json.loads(response.data) + assert data["status"] == "ok" + + +class TestConfigRoute: + + @pytest.mark.unit + def test_returns_auth_config(self, client): + response = client.get("/api/config") + assert response.status_code == 200 + data = json.loads(response.data) + assert "auth_type" in data + assert "requires_auth" in data + + +class TestGenerateTokenRoute: + + @pytest.mark.unit + def test_session_jwt_generates_token(self, client, app): + with patch("application.app.settings") as mock_settings: + mock_settings.AUTH_TYPE = "session_jwt" + mock_settings.JWT_SECRET_KEY = "test_secret" + response = client.get("/api/generate_token") + assert response.status_code == 200 + data = json.loads(response.data) + assert "token" in data + + @pytest.mark.unit + def test_non_session_jwt_returns_error(self, client, app): + with patch("application.app.settings") as mock_settings: + mock_settings.AUTH_TYPE = "none" + response = client.get("/api/generate_token") + assert response.status_code == 400 + + +class TestSttRequestSizeLimits: + + @pytest.mark.unit + def test_non_stt_request_passes(self, client): + response = client.get("/api/health") + assert response.status_code == 200 + + @pytest.mark.unit + def test_oversized_stt_request_rejected(self, client): + with patch("application.app.should_reject_stt_request", return_value=True), \ + patch("application.app.build_stt_file_size_limit_message", return_value="Too large"): + response = client.post("/api/stt/upload", data=b"x" * 100) + assert response.status_code == 413 + + +class TestAuthenticateRequest: + + @pytest.mark.unit + def test_options_returns_200(self, client): + response = client.options("/api/health") + assert response.status_code == 200 + + @pytest.mark.unit + def test_auth_error_returns_401(self, client, app): + with patch("application.app.handle_auth", return_value={"error": "Invalid token"}): + response = client.get("/api/health") + assert response.status_code == 401 + + @pytest.mark.unit + def test_no_token_sets_none(self, client, app): + with patch("application.app.handle_auth", return_value=None): + response = client.get("/api/health") + assert response.status_code == 200 + + +class TestAfterRequest: + + @pytest.mark.unit + def test_cors_headers(self, client): + response = client.get("/api/health") + assert response.headers.get("Access-Control-Allow-Origin") == "*" + assert "Content-Type" in response.headers.get("Access-Control-Allow-Headers", "") + assert "GET" in response.headers.get("Access-Control-Allow-Methods", "") diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 00000000..ed1a4977 --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,533 @@ +"""Tests for application/utils.py""" + +from unittest.mock import MagicMock, patch + +import pytest + +from application.utils import ( + calculate_compression_threshold, + calculate_doc_token_budget, + check_required_fields, + clean_text_for_tts, + convert_pdf_to_images, + get_encoding, + get_field_validation_errors, + get_gpt_model, + get_hash, + get_missing_fields, + generate_image_url, + limit_chat_history, + num_tokens_from_object_or_list, + num_tokens_from_string, + safe_filename, + validate_function_name, + validate_required_fields, +) + + +class TestGetEncoding: + + @pytest.mark.unit + def test_returns_encoding(self): + enc = get_encoding() + assert enc is not None + + @pytest.mark.unit + def test_returns_same_instance(self): + enc1 = get_encoding() + enc2 = get_encoding() + assert enc1 is enc2 + + +class TestGetGptModel: + + @pytest.mark.unit + def test_returns_llm_name_when_set(self): + with patch("application.utils.settings") as s: + s.LLM_NAME = "my-model" + s.LLM_PROVIDER = "openai" + assert get_gpt_model() == "my-model" + + @pytest.mark.unit + def test_falls_back_to_provider_map(self): + with patch("application.utils.settings") as s: + s.LLM_NAME = "" + s.LLM_PROVIDER = "openai" + assert get_gpt_model() == "gpt-4o-mini" + + @pytest.mark.unit + def test_unknown_provider_returns_empty(self): + with patch("application.utils.settings") as s: + s.LLM_NAME = "" + s.LLM_PROVIDER = "unknown" + assert get_gpt_model() == "" + + +class TestSafeFilename: + + @pytest.mark.unit + def test_normal_filename(self): + assert safe_filename("test.pdf") == "test.pdf" + + @pytest.mark.unit + def test_empty_filename_returns_uuid(self): + result = safe_filename("") + assert len(result) > 10 # UUID + + @pytest.mark.unit + def test_none_filename_returns_uuid(self): + result = safe_filename(None) + assert len(result) > 10 + + @pytest.mark.unit + def test_non_latin_filename(self): + result = safe_filename("документ.pdf") + assert result.endswith(".pdf") + + +class TestNumTokens: + + @pytest.mark.unit + def test_string_token_count(self): + count = num_tokens_from_string("hello world") + assert count > 0 + + @pytest.mark.unit + def test_non_string_returns_zero(self): + assert num_tokens_from_string(123) == 0 + + @pytest.mark.unit + def test_empty_string(self): + assert num_tokens_from_string("") == 0 + + +class TestNumTokensFromObjectOrList: + + @pytest.mark.unit + def test_list(self): + result = num_tokens_from_object_or_list(["hello", "world"]) + assert result > 0 + + @pytest.mark.unit + def test_dict(self): + result = num_tokens_from_object_or_list({"key": "value"}) + assert result > 0 + + @pytest.mark.unit + def test_string(self): + result = num_tokens_from_object_or_list("hello") + assert result > 0 + + @pytest.mark.unit + def test_number_returns_zero(self): + assert num_tokens_from_object_or_list(42) == 0 + + @pytest.mark.unit + def test_nested(self): + result = num_tokens_from_object_or_list({"a": ["b", "c"]}) + assert result > 0 + + +class TestCountTokensDocs: + + @pytest.mark.unit + def test_counts_doc_tokens(self): + from application.utils import count_tokens_docs + doc1 = MagicMock() + doc1.page_content = "hello world" + doc2 = MagicMock() + doc2.page_content = " foo bar" + result = count_tokens_docs([doc1, doc2]) + assert result > 0 + + +class TestCalculateDocTokenBudget: + + @pytest.mark.unit + def test_returns_budget(self): + with patch("application.utils.get_token_limit", return_value=128000), \ + patch("application.utils.settings") as s: + s.RESERVED_TOKENS = {"system": 500, "history": 500} + result = calculate_doc_token_budget("gpt-4o") + assert result == 127000 + + @pytest.mark.unit + def test_minimum_budget(self): + with patch("application.utils.get_token_limit", return_value=1000), \ + patch("application.utils.settings") as s: + s.RESERVED_TOKENS = {"system": 500, "history": 500} + result = calculate_doc_token_budget("small-model") + assert result == 1000 + + +class TestFieldValidation: + + @pytest.mark.unit + def test_get_missing_fields(self): + assert get_missing_fields({"a": 1}, ["a", "b"]) == ["b"] + assert get_missing_fields({"a": 1, "b": 2}, ["a", "b"]) == [] + + @pytest.mark.unit + def test_check_required_fields_pass(self): + from flask import Flask + app = Flask(__name__) + with app.app_context(): + result = check_required_fields({"a": 1, "b": 2}, ["a", "b"]) + assert result is None + + @pytest.mark.unit + def test_check_required_fields_fail(self): + from flask import Flask + app = Flask(__name__) + with app.app_context(): + result = check_required_fields({"a": 1}, ["a", "b"]) + assert result is not None + assert result.status_code == 400 + + @pytest.mark.unit + def test_get_field_validation_errors_none_when_valid(self): + assert get_field_validation_errors({"a": 1}, ["a"]) is None + + @pytest.mark.unit + def test_get_field_validation_errors_missing(self): + result = get_field_validation_errors({}, ["a"]) + assert result["missing_fields"] == ["a"] + + @pytest.mark.unit + def test_get_field_validation_errors_empty(self): + result = get_field_validation_errors({"a": ""}, ["a"]) + assert result["empty_fields"] == ["a"] + + @pytest.mark.unit + def test_validate_required_fields_pass(self): + from flask import Flask + app = Flask(__name__) + with app.app_context(): + result = validate_required_fields({"a": "v"}, ["a"]) + assert result is None + + @pytest.mark.unit + def test_validate_required_fields_missing(self): + from flask import Flask + app = Flask(__name__) + with app.app_context(): + result = validate_required_fields({}, ["a"]) + assert result is not None + assert result.status_code == 400 + + @pytest.mark.unit + def test_validate_required_fields_empty(self): + from flask import Flask + app = Flask(__name__) + with app.app_context(): + result = validate_required_fields({"a": ""}, ["a"]) + assert result is not None + + @pytest.mark.unit + def test_validate_required_fields_both_missing_and_empty(self): + from flask import Flask + app = Flask(__name__) + with app.app_context(): + result = validate_required_fields({"a": ""}, ["a", "b"]) + assert result is not None + + +class TestGetHash: + + @pytest.mark.unit + def test_returns_hex_string(self): + h = get_hash("test") + assert len(h) == 32 + assert all(c in "0123456789abcdef" for c in h) + + @pytest.mark.unit + def test_deterministic(self): + assert get_hash("hello") == get_hash("hello") + + @pytest.mark.unit + def test_different_inputs(self): + assert get_hash("a") != get_hash("b") + + +class TestLimitChatHistory: + + @pytest.mark.unit + def test_empty_history(self): + assert limit_chat_history([]) == [] + + @pytest.mark.unit + def test_none_history(self): + assert limit_chat_history(None) == [] + + @pytest.mark.unit + def test_keeps_recent_messages(self): + history = [ + {"prompt": "q1", "response": "a1"}, + {"prompt": "q2", "response": "a2"}, + ] + result = limit_chat_history(history, max_token_limit=10000) + assert len(result) == 2 + + @pytest.mark.unit + def test_trims_old_messages(self): + history = [ + {"prompt": "x" * 5000, "response": "y" * 5000}, + {"prompt": "q", "response": "a"}, + ] + result = limit_chat_history(history, max_token_limit=100) + assert len(result) <= 2 + + @pytest.mark.unit + def test_handles_tool_calls(self): + history = [ + { + "prompt": "q", + "response": "a", + "tool_calls": [ + {"tool_name": "t", "action_name": "a", "arguments": "{}", "result": "r"} + ], + } + ] + result = limit_chat_history(history, max_token_limit=10000) + assert len(result) == 1 + + +class TestValidateFunctionName: + + @pytest.mark.unit + def test_valid_names(self): + assert validate_function_name("hello") is True + assert validate_function_name("hello_world") is True + assert validate_function_name("hello-world") is True + assert validate_function_name("test123") is True + + @pytest.mark.unit + def test_invalid_names(self): + assert validate_function_name("hello world") is False + assert validate_function_name("hello!") is False + assert validate_function_name("") is False + + +class TestGenerateImageUrl: + + @pytest.mark.unit + def test_http_url_passthrough(self): + assert generate_image_url("https://example.com/img.png") == "https://example.com/img.png" + assert generate_image_url("http://example.com/img.png") == "http://example.com/img.png" + + @pytest.mark.unit + def test_s3_strategy(self): + with patch("application.utils.settings") as s: + s.URL_STRATEGY = "s3" + s.S3_BUCKET_NAME = "my-bucket" + s.SAGEMAKER_REGION = "us-west-2" + result = generate_image_url("path/to/img.png") + assert "my-bucket.s3.us-west-2" in result + + @pytest.mark.unit + def test_backend_strategy(self): + with patch("application.utils.settings") as s: + s.URL_STRATEGY = "backend" + s.API_URL = "http://localhost:7091" + result = generate_image_url("path/to/img.png") + assert result == "http://localhost:7091/api/images/path/to/img.png" + + +class TestCalculateCompressionThreshold: + + @pytest.mark.unit + def test_default_threshold(self): + with patch("application.utils.get_token_limit", return_value=100000): + result = calculate_compression_threshold("gpt-4o") + assert result == 80000 + + @pytest.mark.unit + def test_custom_percentage(self): + with patch("application.utils.get_token_limit", return_value=100000): + result = calculate_compression_threshold("gpt-4o", 0.5) + assert result == 50000 + + +class TestConvertPdfToImages: + + @pytest.mark.unit + def test_missing_pdf2image_raises(self): + with patch.dict("sys.modules", {"pdf2image": None}): + # Force re-import to trigger ImportError + # The function handles the import internally + with pytest.raises(ImportError, match="pdf2image"): + convert_pdf_to_images("test.pdf") + + @pytest.mark.unit + def test_converts_from_path(self): + mock_image = MagicMock() + mock_image.save = MagicMock(side_effect=lambda buf, format: buf.write(b"PNG_DATA")) + + mock_module = MagicMock() + mock_module.convert_from_path.return_value = [mock_image] + mock_module.convert_from_bytes.return_value = [mock_image] + + original_import = __import__ + + def patched_import(name, *args, **kwargs): + if name == "pdf2image": + return mock_module + return original_import(name, *args, **kwargs) + + with patch("builtins.__import__", side_effect=patched_import): + result = convert_pdf_to_images("/some/file.pdf") + assert len(result) == 1 + assert result[0]["mime_type"] == "image/png" + assert result[0]["page"] == 1 + + @pytest.mark.unit + def test_with_storage(self): + mock_image = MagicMock() + mock_image.save = MagicMock(side_effect=lambda buf, format: buf.write(b"IMG")) + + mock_storage = MagicMock() + mock_file = MagicMock() + mock_file.read.return_value = b"pdf_bytes" + mock_file.__enter__ = MagicMock(return_value=mock_file) + mock_file.__exit__ = MagicMock(return_value=False) + mock_storage.get_file.return_value = mock_file + + mock_module = MagicMock() + mock_module.convert_from_bytes.return_value = [mock_image] + + original_import = __import__ + + def patched_import(name, *args, **kwargs): + if name == "pdf2image": + return mock_module + return original_import(name, *args, **kwargs) + + with patch("builtins.__import__", side_effect=patched_import): + result = convert_pdf_to_images("test.pdf", storage=mock_storage) + assert len(result) == 1 + mock_module.convert_from_bytes.assert_called_once() + + @pytest.mark.unit + def test_file_not_found_raises(self): + mock_module = MagicMock() + mock_module.convert_from_path.side_effect = FileNotFoundError("not found") + + # Patch the import inside the function + original_import = __builtins__.__import__ if hasattr(__builtins__, '__import__') else __import__ + + def patched_import(name, *args, **kwargs): + if name == "pdf2image": + return mock_module + return original_import(name, *args, **kwargs) + + with patch("builtins.__import__", side_effect=patched_import): + with pytest.raises(FileNotFoundError): + convert_pdf_to_images("/nonexistent.pdf") + + @pytest.mark.unit + def test_generic_error_raises(self): + mock_module = MagicMock() + mock_module.convert_from_path.side_effect = RuntimeError("conversion failed") + + original_import = __builtins__.__import__ if hasattr(__builtins__, '__import__') else __import__ + + def patched_import(name, *args, **kwargs): + if name == "pdf2image": + return mock_module + return original_import(name, *args, **kwargs) + + with patch("builtins.__import__", side_effect=patched_import): + with pytest.raises(RuntimeError, match="conversion failed"): + convert_pdf_to_images("/some.pdf") + + +class TestCleanTextForTts: + + @pytest.mark.unit + def test_removes_code_blocks(self): + result = clean_text_for_tts("before ```python\ncode\n``` after") + assert "code block" in result + assert "python" not in result + + @pytest.mark.unit + def test_removes_mermaid_blocks(self): + result = clean_text_for_tts("```mermaid\ngraph TD\n```") + assert "flowchart" in result + + @pytest.mark.unit + def test_removes_markdown_links(self): + result = clean_text_for_tts("[click here](https://example.com)") + assert "click here" in result + assert "https" not in result + + @pytest.mark.unit + def test_removes_images(self): + result = clean_text_for_tts("![alt text](image.png)") + assert "image.png" not in result + + @pytest.mark.unit + def test_removes_inline_code(self): + result = clean_text_for_tts("use `foo()` here") + assert "foo()" in result + assert "`" not in result + + @pytest.mark.unit + def test_removes_bold_italic(self): + result = clean_text_for_tts("**bold** and *italic*") + assert "bold" in result + assert "italic" in result + assert "*" not in result + + @pytest.mark.unit + def test_removes_headers(self): + result = clean_text_for_tts("# Header\ntext") + assert "Header" in result + assert "#" not in result + + @pytest.mark.unit + def test_removes_blockquotes(self): + result = clean_text_for_tts("> quoted text") + assert "quoted text" in result + assert ">" not in result + + @pytest.mark.unit + def test_removes_html_tags(self): + result = clean_text_for_tts("
content
") + assert "content" in result + assert "<" not in result + + @pytest.mark.unit + def test_removes_arrows(self): + result = clean_text_for_tts("a --> b <-- c => d") + assert "-->" not in result + assert "<--" not in result + assert "=>" not in result + + @pytest.mark.unit + def test_removes_horizontal_rules(self): + result = clean_text_for_tts("text\n---\nmore") + assert "---" not in result + + @pytest.mark.unit + def test_removes_list_markers(self): + result = clean_text_for_tts("- item1\n* item2\n1. item3") + assert "item1" in result + assert "item2" in result + assert "item3" in result + + @pytest.mark.unit + def test_normalizes_whitespace(self): + result = clean_text_for_tts(" lots of spaces ") + assert " " not in result + + @pytest.mark.unit + def test_removes_braces(self): + result = clean_text_for_tts("{content} and [more]") + assert "content" in result + assert "more" in result + assert "{" not in result + + @pytest.mark.unit + def test_removes_double_colons(self): + result = clean_text_for_tts("module::function") + assert "::" not in result