chore: utils tests

This commit is contained in:
Alex
2026-03-29 11:49:35 +01:00
parent e06debad5f
commit 126fa01b14
8 changed files with 2826 additions and 0 deletions

View File

@@ -0,0 +1,169 @@
"""Tests for application/agents/workflows/cel_evaluator.py"""
import pytest
from application.agents.workflows.cel_evaluator import (
CelEvaluationError,
_convert_value,
build_activation,
cel_to_python,
evaluate_cel,
)
import celpy.celtypes
class TestConvertValue:
@pytest.mark.unit
def test_bool_true(self):
result = _convert_value(True)
assert isinstance(result, celpy.celtypes.BoolType)
assert bool(result) is True
@pytest.mark.unit
def test_bool_false(self):
result = _convert_value(False)
assert isinstance(result, celpy.celtypes.BoolType)
assert bool(result) is False
@pytest.mark.unit
def test_int(self):
result = _convert_value(42)
assert isinstance(result, celpy.celtypes.IntType)
assert int(result) == 42
@pytest.mark.unit
def test_float(self):
result = _convert_value(3.14)
assert isinstance(result, celpy.celtypes.DoubleType)
assert float(result) == pytest.approx(3.14)
@pytest.mark.unit
def test_string(self):
result = _convert_value("hello")
assert isinstance(result, celpy.celtypes.StringType)
assert str(result) == "hello"
@pytest.mark.unit
def test_list(self):
result = _convert_value([1, "two", 3.0])
assert isinstance(result, celpy.celtypes.ListType)
@pytest.mark.unit
def test_dict(self):
result = _convert_value({"key": "value"})
assert isinstance(result, celpy.celtypes.MapType)
@pytest.mark.unit
def test_none(self):
result = _convert_value(None)
assert isinstance(result, celpy.celtypes.BoolType)
assert bool(result) is False
@pytest.mark.unit
def test_other_type_converts_to_string(self):
result = _convert_value(object())
assert isinstance(result, celpy.celtypes.StringType)
class TestBuildActivation:
@pytest.mark.unit
def test_converts_dict_values(self):
state = {"name": "Alice", "age": 30, "active": True}
result = build_activation(state)
assert "name" in result
assert "age" in result
assert "active" in result
@pytest.mark.unit
def test_empty_state(self):
assert build_activation({}) == {}
class TestEvaluateCel:
@pytest.mark.unit
def test_simple_comparison(self):
assert evaluate_cel("x > 5", {"x": 10}) is True
assert evaluate_cel("x > 5", {"x": 3}) is False
@pytest.mark.unit
def test_string_comparison(self):
assert evaluate_cel('name == "Alice"', {"name": "Alice"}) is True
assert evaluate_cel('name == "Alice"', {"name": "Bob"}) is False
@pytest.mark.unit
def test_arithmetic(self):
assert evaluate_cel("x + y", {"x": 3, "y": 4}) == 7
@pytest.mark.unit
def test_boolean_logic(self):
assert evaluate_cel("a && b", {"a": True, "b": True}) is True
assert evaluate_cel("a && b", {"a": True, "b": False}) is False
assert evaluate_cel("a || b", {"a": False, "b": True}) is True
@pytest.mark.unit
def test_empty_expression_raises(self):
with pytest.raises(CelEvaluationError, match="Empty expression"):
evaluate_cel("", {})
@pytest.mark.unit
def test_whitespace_expression_raises(self):
with pytest.raises(CelEvaluationError, match="Empty expression"):
evaluate_cel(" ", {})
@pytest.mark.unit
def test_invalid_expression_raises(self):
with pytest.raises(CelEvaluationError):
evaluate_cel("invalid!!!", {})
@pytest.mark.unit
def test_missing_variable_raises(self):
with pytest.raises(CelEvaluationError):
evaluate_cel("undefined_var > 5", {})
class TestCelToPython:
@pytest.mark.unit
def test_bool(self):
result = cel_to_python(celpy.celtypes.BoolType(True))
assert result is True
@pytest.mark.unit
def test_int(self):
result = cel_to_python(celpy.celtypes.IntType(42))
assert result == 42
@pytest.mark.unit
def test_double(self):
result = cel_to_python(celpy.celtypes.DoubleType(3.14))
assert result == pytest.approx(3.14)
@pytest.mark.unit
def test_string(self):
result = cel_to_python(celpy.celtypes.StringType("hello"))
assert result == "hello"
@pytest.mark.unit
def test_list(self):
cel_list = celpy.celtypes.ListType([
celpy.celtypes.IntType(1),
celpy.celtypes.IntType(2),
])
result = cel_to_python(cel_list)
assert result == [1, 2]
@pytest.mark.unit
def test_map(self):
cel_map = celpy.celtypes.MapType({
celpy.celtypes.StringType("key"): celpy.celtypes.StringType("value"),
})
result = cel_to_python(cel_map)
assert result == {"key": "value"}
@pytest.mark.unit
def test_unknown_type_passthrough(self):
result = cel_to_python("raw_value")
assert result == "raw_value"

View File

@@ -0,0 +1,433 @@
"""Tests for WorkflowAgent - covering _parse_embedded_workflow, _load_from_database,
_save_workflow_run, _determine_run_status, _serialize_state, and gen flow."""
from datetime import datetime, timezone
from unittest.mock import MagicMock, patch
import pytest
from application.agents.workflows.schemas import (
ExecutionStatus,
WorkflowGraph,
)
def _make_agent(**overrides):
"""Create a WorkflowAgent with mocked base class dependencies."""
defaults = {
"endpoint": "https://api.example.com",
"llm_name": "openai",
"model_id": "gpt-4",
"api_key": "test_key",
"user_api_key": None,
"prompt": "You are helpful.",
"chat_history": [],
"decoded_token": {"sub": "user1"},
"attachments": [],
"json_schema": None,
}
defaults.update(overrides)
with patch("application.agents.workflow_agent.log_activity", lambda **kw: lambda f: f):
from application.agents.workflow_agent import WorkflowAgent
agent = WorkflowAgent(**defaults)
return agent
class TestWorkflowAgentInit:
@pytest.mark.unit
def test_sets_attributes(self):
agent = _make_agent(workflow_id="wf1", workflow_owner="owner1")
assert agent.workflow_id == "wf1"
assert agent.workflow_owner == "owner1"
assert agent._engine is None
@pytest.mark.unit
def test_embedded_workflow(self):
wf_data = {"nodes": [], "edges": [], "name": "Test"}
agent = _make_agent(workflow=wf_data)
assert agent._workflow_data == wf_data
class TestParseEmbeddedWorkflow:
@pytest.mark.unit
def test_parses_valid_workflow(self):
wf_data = {
"name": "Test Workflow",
"description": "A test",
"nodes": [
{"id": "n1", "type": "start", "title": "Start", "data": {}, "position": {"x": 0, "y": 0}},
{"id": "n2", "type": "end", "title": "End", "data": {}, "position": {"x": 100, "y": 0}},
],
"edges": [
{"id": "e1", "source": "n1", "target": "n2", "sourceHandle": "out", "targetHandle": "in"},
],
}
agent = _make_agent(workflow=wf_data, workflow_id="wf1")
graph = agent._parse_embedded_workflow()
assert graph is not None
assert len(graph.nodes) == 2
assert len(graph.edges) == 1
assert graph.workflow.name == "Test Workflow"
@pytest.mark.unit
def test_edge_source_id_alias(self):
wf_data = {
"nodes": [{"id": "n1", "type": "start", "data": {}}],
"edges": [{"id": "e1", "source_id": "n1", "target_id": "n2", "source_handle": "out", "target_handle": "in"}],
}
agent = _make_agent(workflow=wf_data)
graph = agent._parse_embedded_workflow()
assert graph is not None
assert graph.edges[0].source_id == "n1"
@pytest.mark.unit
def test_invalid_data_returns_none(self):
agent = _make_agent(workflow={"nodes": [{"bad": "data"}], "edges": []})
graph = agent._parse_embedded_workflow()
assert graph is None
class TestLoadWorkflowGraph:
@pytest.mark.unit
def test_uses_embedded_when_available(self):
agent = _make_agent(workflow={"nodes": [], "edges": [], "name": "E"})
agent._parse_embedded_workflow = MagicMock(return_value="parsed_graph")
result = agent._load_workflow_graph()
assert result == "parsed_graph"
@pytest.mark.unit
def test_uses_database_when_workflow_id(self):
agent = _make_agent(workflow_id="wf1")
agent._load_from_database = MagicMock(return_value="db_graph")
result = agent._load_workflow_graph()
assert result == "db_graph"
@pytest.mark.unit
def test_returns_none_when_nothing(self):
agent = _make_agent()
result = agent._load_workflow_graph()
assert result is None
class TestLoadFromDatabase:
@pytest.mark.unit
def test_invalid_workflow_id_returns_none(self):
agent = _make_agent(workflow_id="invalid!")
result = agent._load_from_database()
assert result is None
@pytest.mark.unit
def test_no_owner_returns_none(self):
agent = _make_agent(workflow_id="507f1f77bcf86cd799439011", decoded_token={})
agent.workflow_owner = None
result = agent._load_from_database()
assert result is None
@pytest.mark.unit
def test_uses_decoded_token_sub(self):
agent = _make_agent(
workflow_id="507f1f77bcf86cd799439011",
decoded_token={"sub": "user1"},
)
agent.workflow_owner = None
mock_collection = MagicMock()
mock_collection.find_one.return_value = None
mock_db = MagicMock()
mock_db.__getitem__ = MagicMock(return_value=mock_collection)
with patch("application.agents.workflow_agent.MongoDB") as MockMongo, \
patch("application.agents.workflow_agent.settings") as mock_settings:
mock_settings.MONGO_DB_NAME = "test_db"
MockMongo.get_client.return_value = {"test_db": mock_db}
result = agent._load_from_database()
assert result is None # workflow_doc not found
@pytest.mark.unit
def test_successful_load(self):
agent = _make_agent(
workflow_id="507f1f77bcf86cd799439011",
workflow_owner="owner1",
)
mock_wf_coll = MagicMock()
mock_wf_coll.find_one.return_value = {
"_id": "507f1f77bcf86cd799439011",
"name": "Test WF",
"user": "owner1",
"current_graph_version": 1,
}
mock_nodes_coll = MagicMock()
mock_nodes_coll.find.return_value = [
{"id": "n1", "workflow_id": "507f1f77bcf86cd799439011", "type": "start",
"title": "Start", "position": {"x": 0, "y": 0}, "config": {}},
]
mock_edges_coll = MagicMock()
mock_edges_coll.find.return_value = []
def getitem(name):
return {"workflows": mock_wf_coll, "workflow_nodes": mock_nodes_coll, "workflow_edges": mock_edges_coll}[name]
mock_db = MagicMock()
mock_db.__getitem__ = MagicMock(side_effect=getitem)
with patch("application.agents.workflow_agent.MongoDB") as MockMongo, \
patch("application.agents.workflow_agent.settings") as mock_settings:
mock_settings.MONGO_DB_NAME = "test_db"
MockMongo.get_client.return_value = {"test_db": mock_db}
result = agent._load_from_database()
assert result is not None
assert len(result.nodes) == 1
@pytest.mark.unit
def test_invalid_graph_version(self):
agent = _make_agent(
workflow_id="507f1f77bcf86cd799439011",
workflow_owner="owner1",
)
mock_wf_coll = MagicMock()
mock_wf_coll.find_one.return_value = {
"_id": "507f1f77bcf86cd799439011",
"name": "WF",
"user": "owner1",
"current_graph_version": "bad",
}
mock_nodes_coll = MagicMock()
mock_nodes_coll.find.return_value = []
mock_edges_coll = MagicMock()
mock_edges_coll.find.return_value = []
def getitem(name):
return {"workflows": mock_wf_coll, "workflow_nodes": mock_nodes_coll, "workflow_edges": mock_edges_coll}[name]
mock_db = MagicMock()
mock_db.__getitem__ = MagicMock(side_effect=getitem)
with patch("application.agents.workflow_agent.MongoDB") as MockMongo, \
patch("application.agents.workflow_agent.settings") as mock_settings:
mock_settings.MONGO_DB_NAME = "test_db"
MockMongo.get_client.return_value = {"test_db": mock_db}
result = agent._load_from_database()
assert result is not None # Defaults to version 1
@pytest.mark.unit
def test_fallback_nodes_without_version(self):
"""When graph_version=1 finds no nodes, falls back to nodes without version field."""
agent = _make_agent(
workflow_id="507f1f77bcf86cd799439011",
workflow_owner="owner1",
)
mock_wf_coll = MagicMock()
mock_wf_coll.find_one.return_value = {
"_id": "507f1f77bcf86cd799439011",
"name": "WF",
"user": "owner1",
"current_graph_version": 1,
}
call_count = [0]
def nodes_find(query):
call_count[0] += 1
if call_count[0] == 1:
return [] # No versioned nodes
return [{"id": "n1", "workflow_id": "wf", "type": "start",
"title": "S", "position": {"x": 0, "y": 0}, "config": {}}]
mock_nodes_coll = MagicMock()
mock_nodes_coll.find.side_effect = nodes_find
edge_call = [0]
def edges_find(query):
edge_call[0] += 1
if edge_call[0] == 1:
return []
return []
mock_edges_coll = MagicMock()
mock_edges_coll.find.side_effect = edges_find
def getitem(name):
return {"workflows": mock_wf_coll, "workflow_nodes": mock_nodes_coll, "workflow_edges": mock_edges_coll}[name]
mock_db = MagicMock()
mock_db.__getitem__ = MagicMock(side_effect=getitem)
with patch("application.agents.workflow_agent.MongoDB") as MockMongo, \
patch("application.agents.workflow_agent.settings") as mock_settings:
mock_settings.MONGO_DB_NAME = "test_db"
MockMongo.get_client.return_value = {"test_db": mock_db}
result = agent._load_from_database()
assert result is not None
assert len(result.nodes) == 1
@pytest.mark.unit
def test_exception_returns_none(self):
agent = _make_agent(
workflow_id="507f1f77bcf86cd799439011",
workflow_owner="owner1",
)
with patch("application.agents.workflow_agent.MongoDB") as MockMongo:
MockMongo.get_client.side_effect = Exception("db error")
result = agent._load_from_database()
assert result is None
class TestGenInner:
@pytest.mark.unit
def test_no_graph_yields_error(self):
agent = _make_agent()
agent._load_workflow_graph = MagicMock(return_value=None)
events = list(agent._gen_inner("query", None))
assert any(e.get("type") == "error" for e in events)
@pytest.mark.unit
def test_successful_execution(self):
agent = _make_agent(workflow_id="wf1")
mock_graph = MagicMock(spec=WorkflowGraph)
agent._load_workflow_graph = MagicMock(return_value=mock_graph)
agent._save_workflow_run = MagicMock()
mock_engine = MagicMock()
mock_engine.execute.return_value = iter([{"answer": "result"}])
with patch("application.agents.workflow_agent.WorkflowEngine", return_value=mock_engine):
events = list(agent._gen_inner("query", None))
assert len(events) == 1
agent._save_workflow_run.assert_called_once_with("query")
class TestSaveWorkflowRun:
@pytest.mark.unit
def test_no_engine_returns_early(self):
agent = _make_agent()
agent._engine = None
agent._save_workflow_run("query") # Should not raise
@pytest.mark.unit
def test_saves_to_mongo(self):
agent = _make_agent(workflow_id="wf1")
mock_engine = MagicMock()
mock_engine.state = {"query": "test"}
mock_engine.execution_log = []
mock_engine.get_execution_summary.return_value = []
agent._engine = mock_engine
mock_collection = MagicMock()
mock_db = MagicMock()
mock_db.__getitem__ = MagicMock(return_value=mock_collection)
with patch("application.agents.workflow_agent.MongoDB") as MockMongo, \
patch("application.agents.workflow_agent.settings") as mock_settings:
mock_settings.MONGO_DB_NAME = "test_db"
MockMongo.get_client.return_value = {"test_db": mock_db}
agent._save_workflow_run("query")
mock_collection.insert_one.assert_called_once()
@pytest.mark.unit
def test_exception_does_not_propagate(self):
agent = _make_agent(workflow_id="wf1")
mock_engine = MagicMock()
mock_engine.state = {}
mock_engine.execution_log = []
mock_engine.get_execution_summary.return_value = []
agent._engine = mock_engine
with patch("application.agents.workflow_agent.MongoDB") as MockMongo:
MockMongo.get_client.side_effect = Exception("db fail")
agent._save_workflow_run("query") # Should not raise
class TestDetermineRunStatus:
@pytest.mark.unit
def test_no_engine_returns_completed(self):
agent = _make_agent()
agent._engine = None
assert agent._determine_run_status() == ExecutionStatus.COMPLETED
@pytest.mark.unit
def test_empty_log_returns_completed(self):
agent = _make_agent()
agent._engine = MagicMock()
agent._engine.execution_log = []
assert agent._determine_run_status() == ExecutionStatus.COMPLETED
@pytest.mark.unit
def test_failed_log_returns_failed(self):
agent = _make_agent()
agent._engine = MagicMock()
agent._engine.execution_log = [
{"status": "completed"},
{"status": "failed"},
]
assert agent._determine_run_status() == ExecutionStatus.FAILED
@pytest.mark.unit
def test_all_completed_returns_completed(self):
agent = _make_agent()
agent._engine = MagicMock()
agent._engine.execution_log = [
{"status": "completed"},
{"status": "completed"},
]
assert agent._determine_run_status() == ExecutionStatus.COMPLETED
class TestSerializeState:
@pytest.mark.unit
def test_serializes_primitives(self):
agent = _make_agent()
state = {"str": "hello", "int": 42, "float": 3.14, "bool": True, "none": None}
result = agent._serialize_state(state)
assert result == state
@pytest.mark.unit
def test_serializes_nested_dict(self):
agent = _make_agent()
state = {"nested": {"key": "value"}}
result = agent._serialize_state(state)
assert result["nested"]["key"] == "value"
@pytest.mark.unit
def test_serializes_list(self):
agent = _make_agent()
state = {"items": [1, 2, "three"]}
result = agent._serialize_state(state)
assert result["items"] == [1, 2, "three"]
@pytest.mark.unit
def test_serializes_tuple(self):
agent = _make_agent()
state = {"tup": (1, 2)}
result = agent._serialize_state(state)
assert result["tup"] == [1, 2]
@pytest.mark.unit
def test_serializes_datetime(self):
agent = _make_agent()
dt = datetime(2025, 1, 1, 12, 0, 0, tzinfo=timezone.utc)
state = {"time": dt}
result = agent._serialize_state(state)
assert "2025-01-01" in result["time"]
@pytest.mark.unit
def test_serializes_unknown_to_str(self):
agent = _make_agent()
state = {"obj": object()}
result = agent._serialize_state(state)
assert isinstance(result["obj"], str)

View File

@@ -0,0 +1,573 @@
"""Tests covering gaps in WorkflowEngine: execute loop, state/condition/end nodes,
template context, source data, structured output parsing, get_execution_summary."""
from datetime import datetime, timezone
from unittest.mock import MagicMock, patch
import pytest
from application.agents.workflows.schemas import (
ExecutionStatus,
NodeType,
WorkflowEdge,
WorkflowGraph,
WorkflowNode,
Workflow,
)
from application.agents.workflows.workflow_engine import WorkflowEngine
def _make_graph(nodes, edges):
wf = Workflow(name="Test", description="test workflow")
return WorkflowGraph(workflow=wf, nodes=nodes, edges=edges)
def _make_node(id, type, title="Node", config=None, position=None):
return WorkflowNode(
id=id,
workflow_id="wf1",
type=type,
title=title,
position=position or {"x": 0, "y": 0},
config=config or {},
)
def _make_edge(id, source, target, source_handle=None, target_handle=None):
return WorkflowEdge(
id=id,
workflow_id="wf1",
source=source,
target=target,
sourceHandle=source_handle,
targetHandle=target_handle,
)
def _make_agent():
agent = MagicMock()
agent.chat_history = []
agent.endpoint = "https://api.example.com"
agent.llm_name = "openai"
agent.model_id = "gpt-4"
agent.api_key = "key"
agent.decoded_token = {"sub": "user1"}
agent.retrieved_docs = None
return agent
class TestExecuteLoop:
@pytest.mark.unit
def test_no_start_node_yields_error(self):
graph = _make_graph([], [])
engine = WorkflowEngine(graph, _make_agent())
events = list(engine.execute({}, "query"))
assert any(e.get("type") == "error" and "start node" in e.get("error", "") for e in events)
@pytest.mark.unit
def test_start_to_end(self):
nodes = [
_make_node("n1", NodeType.START, "Start"),
_make_node("n2", NodeType.END, "End", config={"config": {}}),
]
edges = [_make_edge("e1", "n1", "n2")]
graph = _make_graph(nodes, edges)
engine = WorkflowEngine(graph, _make_agent())
events = list(engine.execute({}, "hello"))
step_events = [e for e in events if e.get("type") == "workflow_step"]
assert len(step_events) >= 2 # At least start + end
@pytest.mark.unit
def test_node_not_found_yields_error(self):
nodes = [_make_node("n1", NodeType.START)]
edges = [_make_edge("e1", "n1", "nonexistent")]
graph = _make_graph(nodes, edges)
engine = WorkflowEngine(graph, _make_agent())
events = list(engine.execute({}, "q"))
assert any("not found" in e.get("error", "") for e in events)
@pytest.mark.unit
def test_node_execution_error_yields_error(self):
nodes = [
_make_node("n1", NodeType.START),
_make_node("n2", NodeType.STATE, "State", config={"config": {"operations": [{"expression": "bad!!!", "target_variable": "x"}]}}),
]
edges = [_make_edge("e1", "n1", "n2")]
graph = _make_graph(nodes, edges)
engine = WorkflowEngine(graph, _make_agent())
events = list(engine.execute({}, "q"))
failed_events = [e for e in events if e.get("status") == "failed"]
assert len(failed_events) >= 1
@pytest.mark.unit
def test_max_steps_limit(self):
# Create a cycle: start -> state -> state (loop)
nodes = [
_make_node("n1", NodeType.START),
_make_node("n2", NodeType.NOTE, "Note"),
]
edges = [_make_edge("e1", "n1", "n2"), _make_edge("e2", "n2", "n2")]
graph = _make_graph(nodes, edges)
engine = WorkflowEngine(graph, _make_agent())
engine.MAX_EXECUTION_STEPS = 5
events = list(engine.execute({}, "q"))
# Should not run forever
assert len(events) > 0
@pytest.mark.unit
def test_branch_ends_without_end_node(self):
nodes = [
_make_node("n1", NodeType.START),
_make_node("n2", NodeType.NOTE, "Note"),
]
edges = [_make_edge("e1", "n1", "n2")] # n2 has no outgoing edges
graph = _make_graph(nodes, edges)
engine = WorkflowEngine(graph, _make_agent())
events = list(engine.execute({}, "q"))
assert len(events) > 0
class TestInitializeState:
@pytest.mark.unit
def test_sets_query_and_history(self):
graph = _make_graph([], [])
agent = _make_agent()
agent.chat_history = [{"prompt": "hi", "response": "hey"}]
engine = WorkflowEngine(graph, agent)
engine._initialize_state({"custom": "value"}, "test query")
assert engine.state["query"] == "test query"
assert "custom" in engine.state
assert engine.state["chat_history"] is not None
class TestGetNextNodeId:
@pytest.mark.unit
def test_no_edges_returns_none(self):
nodes = [_make_node("n1", NodeType.START)]
graph = _make_graph(nodes, [])
engine = WorkflowEngine(graph, _make_agent())
assert engine._get_next_node_id("n1") is None
@pytest.mark.unit
def test_returns_first_edge_target(self):
nodes = [_make_node("n1", NodeType.START), _make_node("n2", NodeType.END)]
edges = [_make_edge("e1", "n1", "n2")]
graph = _make_graph(nodes, edges)
engine = WorkflowEngine(graph, _make_agent())
assert engine._get_next_node_id("n1") == "n2"
@pytest.mark.unit
def test_condition_uses_matched_handle(self):
nodes = [
_make_node("n1", NodeType.CONDITION),
_make_node("n2", NodeType.END, "Yes End"),
_make_node("n3", NodeType.END, "No End"),
]
edges = [
_make_edge("e1", "n1", "n2", source_handle="yes"),
_make_edge("e2", "n1", "n3", source_handle="no"),
]
graph = _make_graph(nodes, edges)
engine = WorkflowEngine(graph, _make_agent())
engine._condition_result = "no"
assert engine._get_next_node_id("n1") == "n3"
assert engine._condition_result is None # Cleared after use
@pytest.mark.unit
def test_condition_no_matching_handle_returns_none(self):
nodes = [_make_node("n1", NodeType.CONDITION)]
edges = [_make_edge("e1", "n1", "n2", source_handle="yes")]
graph = _make_graph(nodes, edges)
engine = WorkflowEngine(graph, _make_agent())
engine._condition_result = "nonexistent"
assert engine._get_next_node_id("n1") is None
class TestExecuteStateNode:
@pytest.mark.unit
def test_evaluates_operations(self):
node = _make_node("n1", NodeType.STATE, config={
"config": {
"operations": [
{"expression": "x + 1", "target_variable": "result"},
]
}
})
graph = _make_graph([node], [])
engine = WorkflowEngine(graph, _make_agent())
engine.state = {"x": 5}
list(engine._execute_state_node(node))
assert engine.state["result"] == 6
@pytest.mark.unit
def test_skips_empty_expression(self):
node = _make_node("n1", NodeType.STATE, config={
"config": {
"operations": [
{"expression": "", "target_variable": "result"},
]
}
})
graph = _make_graph([node], [])
engine = WorkflowEngine(graph, _make_agent())
engine.state = {}
list(engine._execute_state_node(node))
assert "result" not in engine.state
class TestExecuteConditionNode:
@pytest.mark.unit
def test_matches_first_true_case(self):
node = _make_node("n1", NodeType.CONDITION, config={
"config": {
"cases": [
{"expression": "x > 10", "source_handle": "high"},
{"expression": "x > 5", "source_handle": "medium"},
]
}
})
graph = _make_graph([node], [])
engine = WorkflowEngine(graph, _make_agent())
engine.state = {"x": 7}
list(engine._execute_condition_node(node))
assert engine._condition_result == "medium"
@pytest.mark.unit
def test_falls_through_to_else(self):
node = _make_node("n1", NodeType.CONDITION, config={
"config": {
"cases": [
{"expression": "x > 100", "source_handle": "high"},
]
}
})
graph = _make_graph([node], [])
engine = WorkflowEngine(graph, _make_agent())
engine.state = {"x": 1}
list(engine._execute_condition_node(node))
assert engine._condition_result == "else"
@pytest.mark.unit
def test_skips_empty_expression(self):
node = _make_node("n1", NodeType.CONDITION, config={
"config": {
"cases": [
{"expression": " ", "source_handle": "a"},
{"expression": "true", "source_handle": "b"},
]
}
})
graph = _make_graph([node], [])
engine = WorkflowEngine(graph, _make_agent())
engine.state = {}
list(engine._execute_condition_node(node))
assert engine._condition_result == "b"
@pytest.mark.unit
def test_cel_error_continues(self):
node = _make_node("n1", NodeType.CONDITION, config={
"config": {
"cases": [
{"expression": "bad!!!", "source_handle": "a"},
{"expression": "true", "source_handle": "b"},
]
}
})
graph = _make_graph([node], [])
engine = WorkflowEngine(graph, _make_agent())
engine.state = {}
list(engine._execute_condition_node(node))
assert engine._condition_result == "b"
class TestExecuteEndNode:
@pytest.mark.unit
def test_with_output_template(self):
node = _make_node("n1", NodeType.END, config={
"config": {"output_template": "Result: {{ query }}"}
})
graph = _make_graph([node], [])
engine = WorkflowEngine(graph, _make_agent())
engine.state = {"query": "hello"}
engine._format_template = MagicMock(return_value="Result: hello")
events = list(engine._execute_end_node(node))
assert len(events) == 1
assert events[0]["answer"] == "Result: hello"
@pytest.mark.unit
def test_without_output_template(self):
node = _make_node("n1", NodeType.END, config={"config": {}})
graph = _make_graph([node], [])
engine = WorkflowEngine(graph, _make_agent())
events = list(engine._execute_end_node(node))
assert len(events) == 0
class TestParseStructuredOutput:
@pytest.mark.unit
def test_valid_json(self):
graph = _make_graph([], [])
engine = WorkflowEngine(graph, _make_agent())
success, data = engine._parse_structured_output('{"key": "value"}')
assert success is True
assert data == {"key": "value"}
@pytest.mark.unit
def test_invalid_json(self):
graph = _make_graph([], [])
engine = WorkflowEngine(graph, _make_agent())
success, data = engine._parse_structured_output("not json")
assert success is False
assert data is None
@pytest.mark.unit
def test_empty_string(self):
graph = _make_graph([], [])
engine = WorkflowEngine(graph, _make_agent())
success, data = engine._parse_structured_output("")
assert success is False
assert data is None
@pytest.mark.unit
def test_whitespace_only(self):
graph = _make_graph([], [])
engine = WorkflowEngine(graph, _make_agent())
success, data = engine._parse_structured_output(" ")
assert success is False
assert data is None
class TestNormalizeNodeJsonSchema:
@pytest.mark.unit
def test_none_returns_none(self):
graph = _make_graph([], [])
engine = WorkflowEngine(graph, _make_agent())
assert engine._normalize_node_json_schema(None, "Node") is None
@pytest.mark.unit
def test_valid_schema(self):
graph = _make_graph([], [])
engine = WorkflowEngine(graph, _make_agent())
schema = {"type": "object", "properties": {"name": {"type": "string"}}}
result = engine._normalize_node_json_schema(schema, "Node")
assert result is not None
@pytest.mark.unit
def test_invalid_schema_raises(self):
graph = _make_graph([], [])
engine = WorkflowEngine(graph, _make_agent())
with patch("application.agents.workflows.workflow_engine.normalize_json_schema_payload") as mock_norm:
from application.core.json_schema_utils import JsonSchemaValidationError
mock_norm.side_effect = JsonSchemaValidationError("bad schema")
with pytest.raises(ValueError, match="Invalid JSON schema"):
engine._normalize_node_json_schema({"bad": True}, "TestNode")
class TestValidateStructuredOutput:
@pytest.mark.unit
def test_valid_output_passes(self):
graph = _make_graph([], [])
engine = WorkflowEngine(graph, _make_agent())
schema = {"type": "object", "properties": {"name": {"type": "string"}}}
engine._validate_structured_output(schema, {"name": "Alice"}) # Should not raise
@pytest.mark.unit
def test_invalid_output_raises(self):
graph = _make_graph([], [])
engine = WorkflowEngine(graph, _make_agent())
schema = {"type": "object", "properties": {"name": {"type": "string"}}, "required": ["name"]}
with pytest.raises(ValueError, match="did not match schema"):
engine._validate_structured_output(schema, {})
@pytest.mark.unit
def test_no_jsonschema_module(self):
graph = _make_graph([], [])
engine = WorkflowEngine(graph, _make_agent())
with patch("application.agents.workflows.workflow_engine.jsonschema", None):
engine._validate_structured_output({"type": "object"}, {}) # Should not raise
class TestFormatTemplate:
@pytest.mark.unit
def test_renders_template(self):
graph = _make_graph([], [])
engine = WorkflowEngine(graph, _make_agent())
engine.state = {"query": "hello"}
engine._build_template_context = MagicMock(return_value={"query": "hello"})
engine._template_engine = MagicMock()
engine._template_engine.render.return_value = "hello world"
result = engine._format_template("{{ query }} world")
assert result == "hello world"
@pytest.mark.unit
def test_render_error_returns_raw(self):
from application.templates.template_engine import TemplateRenderError
graph = _make_graph([], [])
engine = WorkflowEngine(graph, _make_agent())
engine._build_template_context = MagicMock(return_value={})
engine._template_engine = MagicMock()
engine._template_engine.render.side_effect = TemplateRenderError("fail")
result = engine._format_template("{{ bad }}")
assert result == "{{ bad }}"
class TestBuildTemplateContext:
@pytest.mark.unit
def test_includes_state_variables(self):
graph = _make_graph([], [])
agent = _make_agent()
agent.retrieved_docs = None
engine = WorkflowEngine(graph, agent)
engine.state = {"query": "hello", "custom_var": "value"}
context = engine._build_template_context()
assert context["agent"]["query"] == "hello"
assert "custom_var" in context
@pytest.mark.unit
def test_reserved_namespace_gets_prefixed(self):
graph = _make_graph([], [])
agent = _make_agent()
agent.retrieved_docs = None
engine = WorkflowEngine(graph, agent)
engine.state = {"source": "my_source_val"}
context = engine._build_template_context()
assert context.get("agent_source") == "my_source_val"
@pytest.mark.unit
def test_passthrough_data(self):
graph = _make_graph([], [])
agent = _make_agent()
agent.retrieved_docs = None
engine = WorkflowEngine(graph, agent)
engine.state = {"passthrough": {"key": "val"}}
context = engine._build_template_context()
assert "passthrough" in context or "agent_passthrough" in context
@pytest.mark.unit
def test_tools_data(self):
graph = _make_graph([], [])
agent = _make_agent()
agent.retrieved_docs = None
engine = WorkflowEngine(graph, agent)
engine.state = {"tools": {"tool1": "result"}}
context = engine._build_template_context()
assert "agent" in context
class TestGetSourceTemplateData:
@pytest.mark.unit
def test_no_docs_returns_none(self):
graph = _make_graph([], [])
agent = _make_agent()
agent.retrieved_docs = None
engine = WorkflowEngine(graph, agent)
docs, together = engine._get_source_template_data()
assert docs is None
assert together is None
@pytest.mark.unit
def test_empty_docs_returns_none(self):
graph = _make_graph([], [])
agent = _make_agent()
agent.retrieved_docs = []
engine = WorkflowEngine(graph, agent)
docs, together = engine._get_source_template_data()
assert docs is None
@pytest.mark.unit
def test_docs_with_filename(self):
graph = _make_graph([], [])
agent = _make_agent()
agent.retrieved_docs = [{"text": "content", "filename": "doc.txt"}]
engine = WorkflowEngine(graph, agent)
docs, together = engine._get_source_template_data()
assert docs is not None
assert "doc.txt" in together
assert "content" in together
@pytest.mark.unit
def test_docs_without_filename(self):
graph = _make_graph([], [])
agent = _make_agent()
agent.retrieved_docs = [{"text": "content only"}]
engine = WorkflowEngine(graph, agent)
docs, together = engine._get_source_template_data()
assert together == "content only"
@pytest.mark.unit
def test_skips_non_dict_docs(self):
graph = _make_graph([], [])
agent = _make_agent()
agent.retrieved_docs = ["not a dict", {"text": "ok"}]
engine = WorkflowEngine(graph, agent)
docs, together = engine._get_source_template_data()
assert together == "ok"
@pytest.mark.unit
def test_skips_non_string_text(self):
graph = _make_graph([], [])
agent = _make_agent()
agent.retrieved_docs = [{"text": 123}]
engine = WorkflowEngine(graph, agent)
docs, together = engine._get_source_template_data()
assert together is None
@pytest.mark.unit
def test_doc_with_title_fallback(self):
graph = _make_graph([], [])
agent = _make_agent()
agent.retrieved_docs = [{"text": "content", "title": "doc_title"}]
engine = WorkflowEngine(graph, agent)
docs, together = engine._get_source_template_data()
assert "doc_title" in together
@pytest.mark.unit
def test_doc_with_source_fallback(self):
graph = _make_graph([], [])
agent = _make_agent()
agent.retrieved_docs = [{"text": "content", "source": "src"}]
engine = WorkflowEngine(graph, agent)
docs, together = engine._get_source_template_data()
assert "src" in together
class TestGetExecutionSummary:
@pytest.mark.unit
def test_returns_log_entries(self):
graph = _make_graph([], [])
engine = WorkflowEngine(graph, _make_agent())
now = datetime.now(timezone.utc)
engine.execution_log = [
{
"node_id": "n1",
"node_type": "start",
"status": "completed",
"started_at": now,
"completed_at": now,
"error": None,
"state_snapshot": {},
}
]
summary = engine.get_execution_summary()
assert len(summary) == 1
assert summary[0].node_id == "n1"
assert summary[0].status == ExecutionStatus.COMPLETED
@pytest.mark.unit
def test_empty_log(self):
graph = _make_graph([], [])
engine = WorkflowEngine(graph, _make_agent())
assert engine.get_execution_summary() == []

View File

@@ -0,0 +1,331 @@
"""Tests for application/api/answer/services/stream_processor.py — get_prompt and helpers."""
from unittest.mock import MagicMock, patch
import pytest
from application.api.answer.services.stream_processor import get_prompt
class TestGetPrompt:
@pytest.mark.unit
def test_default_preset(self):
prompt = get_prompt("default")
assert isinstance(prompt, str)
assert len(prompt) > 0
@pytest.mark.unit
def test_creative_preset(self):
prompt = get_prompt("creative")
assert isinstance(prompt, str)
@pytest.mark.unit
def test_strict_preset(self):
prompt = get_prompt("strict")
assert isinstance(prompt, str)
@pytest.mark.unit
def test_reduce_preset(self):
prompt = get_prompt("reduce")
assert isinstance(prompt, str)
@pytest.mark.unit
def test_agentic_default_preset(self):
prompt = get_prompt("agentic_default")
assert isinstance(prompt, str)
@pytest.mark.unit
def test_agentic_creative_preset(self):
prompt = get_prompt("agentic_creative")
assert isinstance(prompt, str)
@pytest.mark.unit
def test_agentic_strict_preset(self):
prompt = get_prompt("agentic_strict")
assert isinstance(prompt, str)
@pytest.mark.unit
def test_mongo_prompt_by_id(self):
mock_collection = MagicMock()
mock_collection.find_one.return_value = {"_id": "abc", "content": "Custom prompt"}
prompt = get_prompt("507f1f77bcf86cd799439011", prompts_collection=mock_collection)
assert prompt == "Custom prompt"
@pytest.mark.unit
def test_mongo_prompt_not_found_raises(self):
mock_collection = MagicMock()
mock_collection.find_one.return_value = None
with pytest.raises(ValueError, match="Invalid prompt ID"):
get_prompt("507f1f77bcf86cd799439011", prompts_collection=mock_collection)
@pytest.mark.unit
def test_invalid_id_raises(self):
mock_collection = MagicMock()
mock_collection.find_one.side_effect = Exception("bad id")
with pytest.raises(ValueError, match="Invalid prompt ID"):
get_prompt("not-an-objectid", prompts_collection=mock_collection)
@pytest.mark.unit
def test_mongo_fallback_when_no_collection(self):
"""When no collection passed, it reads from MongoDB."""
mock_collection = MagicMock()
mock_collection.find_one.return_value = {"content": "From DB"}
mock_db = MagicMock()
mock_db.__getitem__ = MagicMock(return_value=mock_collection)
with patch("application.api.answer.services.stream_processor.MongoDB") as MockMongo, \
patch("application.api.answer.services.stream_processor.settings") as mock_settings:
mock_settings.MONGO_DB_NAME = "test_db"
MockMongo.get_client.return_value = {"test_db": mock_db}
prompt = get_prompt("507f1f77bcf86cd799439011")
assert prompt == "From DB"
class TestStreamProcessorInit:
@pytest.mark.unit
def test_init_sets_attributes(self):
mock_db = MagicMock()
mock_client = {"docsgpt": mock_db}
with patch("application.api.answer.services.stream_processor.MongoDB") as MockMongo, \
patch("application.api.answer.services.stream_processor.settings") as mock_settings:
mock_settings.MONGO_DB_NAME = "docsgpt"
MockMongo.get_client.return_value = mock_client
from application.api.answer.services.stream_processor import StreamProcessor
sp = StreamProcessor(
request_data={"conversation_id": "conv1", "agent_id": "a1"},
decoded_token={"sub": "user1"},
)
assert sp.conversation_id == "conv1"
assert sp.initial_user_id == "user1"
assert sp.agent_id == "a1"
assert sp.history == []
assert sp.attachments == []
@pytest.mark.unit
def test_init_no_token(self):
mock_db = MagicMock()
mock_client = {"docsgpt": mock_db}
with patch("application.api.answer.services.stream_processor.MongoDB") as MockMongo, \
patch("application.api.answer.services.stream_processor.settings") as mock_settings:
mock_settings.MONGO_DB_NAME = "docsgpt"
MockMongo.get_client.return_value = mock_client
from application.api.answer.services.stream_processor import StreamProcessor
sp = StreamProcessor(request_data={}, decoded_token=None)
assert sp.initial_user_id is None
class TestGetAttachmentsContent:
@pytest.mark.unit
def test_empty_ids_returns_empty(self):
mock_db = MagicMock()
with patch("application.api.answer.services.stream_processor.MongoDB") as MockMongo, \
patch("application.api.answer.services.stream_processor.settings") as mock_settings:
mock_settings.MONGO_DB_NAME = "docsgpt"
MockMongo.get_client.return_value = {"docsgpt": mock_db}
from application.api.answer.services.stream_processor import StreamProcessor
sp = StreamProcessor(request_data={}, decoded_token={"sub": "u"})
result = sp._get_attachments_content([], "u")
assert result == []
@pytest.mark.unit
def test_returns_matching_attachments(self):
mock_db = MagicMock()
mock_attachments = MagicMock()
mock_attachments.find_one.return_value = {"_id": "att1", "content": "data"}
mock_db.__getitem__ = MagicMock(return_value=mock_attachments)
with patch("application.api.answer.services.stream_processor.MongoDB") as MockMongo, \
patch("application.api.answer.services.stream_processor.settings") as mock_settings:
mock_settings.MONGO_DB_NAME = "docsgpt"
MockMongo.get_client.return_value = {"docsgpt": mock_db}
from application.api.answer.services.stream_processor import StreamProcessor
sp = StreamProcessor(request_data={}, decoded_token={"sub": "u"})
result = sp._get_attachments_content(["507f1f77bcf86cd799439011"], "u")
assert len(result) == 1
@pytest.mark.unit
def test_invalid_attachment_id_continues(self):
mock_db = MagicMock()
mock_attachments = MagicMock()
mock_attachments.find_one.side_effect = Exception("bad id")
mock_db.__getitem__ = MagicMock(return_value=mock_attachments)
with patch("application.api.answer.services.stream_processor.MongoDB") as MockMongo, \
patch("application.api.answer.services.stream_processor.settings") as mock_settings:
mock_settings.MONGO_DB_NAME = "docsgpt"
MockMongo.get_client.return_value = {"docsgpt": mock_db}
from application.api.answer.services.stream_processor import StreamProcessor
sp = StreamProcessor(request_data={}, decoded_token={"sub": "u"})
result = sp._get_attachments_content(["bad"], "u")
assert result == []
class TestResolveAgentId:
@pytest.mark.unit
def test_from_request_data(self):
mock_db = MagicMock()
with patch("application.api.answer.services.stream_processor.MongoDB") as MockMongo, \
patch("application.api.answer.services.stream_processor.settings") as mock_settings:
mock_settings.MONGO_DB_NAME = "docsgpt"
MockMongo.get_client.return_value = {"docsgpt": mock_db}
from application.api.answer.services.stream_processor import StreamProcessor
sp = StreamProcessor(
request_data={"agent_id": "agent_123"},
decoded_token={"sub": "u"},
)
assert sp._resolve_agent_id() == "agent_123"
@pytest.mark.unit
def test_no_agent_no_conversation(self):
mock_db = MagicMock()
with patch("application.api.answer.services.stream_processor.MongoDB") as MockMongo, \
patch("application.api.answer.services.stream_processor.settings") as mock_settings:
mock_settings.MONGO_DB_NAME = "docsgpt"
MockMongo.get_client.return_value = {"docsgpt": mock_db}
from application.api.answer.services.stream_processor import StreamProcessor
sp = StreamProcessor(request_data={}, decoded_token={"sub": "u"})
assert sp._resolve_agent_id() is None
@pytest.mark.unit
def test_from_conversation(self):
mock_db = MagicMock()
with patch("application.api.answer.services.stream_processor.MongoDB") as MockMongo, \
patch("application.api.answer.services.stream_processor.settings") as mock_settings:
mock_settings.MONGO_DB_NAME = "docsgpt"
MockMongo.get_client.return_value = {"docsgpt": mock_db}
from application.api.answer.services.stream_processor import StreamProcessor
sp = StreamProcessor(
request_data={"conversation_id": "conv1"},
decoded_token={"sub": "u"},
)
sp.conversation_service = MagicMock()
sp.conversation_service.get_conversation.return_value = {"agent_id": "from_conv"}
assert sp._resolve_agent_id() == "from_conv"
@pytest.mark.unit
def test_conversation_not_found(self):
mock_db = MagicMock()
with patch("application.api.answer.services.stream_processor.MongoDB") as MockMongo, \
patch("application.api.answer.services.stream_processor.settings") as mock_settings:
mock_settings.MONGO_DB_NAME = "docsgpt"
MockMongo.get_client.return_value = {"docsgpt": mock_db}
from application.api.answer.services.stream_processor import StreamProcessor
sp = StreamProcessor(
request_data={"conversation_id": "conv1"},
decoded_token={"sub": "u"},
)
sp.conversation_service = MagicMock()
sp.conversation_service.get_conversation.return_value = None
assert sp._resolve_agent_id() is None
@pytest.mark.unit
def test_conversation_lookup_exception(self):
mock_db = MagicMock()
with patch("application.api.answer.services.stream_processor.MongoDB") as MockMongo, \
patch("application.api.answer.services.stream_processor.settings") as mock_settings:
mock_settings.MONGO_DB_NAME = "docsgpt"
MockMongo.get_client.return_value = {"docsgpt": mock_db}
from application.api.answer.services.stream_processor import StreamProcessor
sp = StreamProcessor(
request_data={"conversation_id": "conv1"},
decoded_token={"sub": "u"},
)
sp.conversation_service = MagicMock()
sp.conversation_service.get_conversation.side_effect = Exception("db error")
assert sp._resolve_agent_id() is None
class TestGetPromptContent:
@pytest.mark.unit
def test_caches_result(self):
mock_db = MagicMock()
with patch("application.api.answer.services.stream_processor.MongoDB") as MockMongo, \
patch("application.api.answer.services.stream_processor.settings") as mock_settings:
mock_settings.MONGO_DB_NAME = "docsgpt"
MockMongo.get_client.return_value = {"docsgpt": mock_db}
from application.api.answer.services.stream_processor import StreamProcessor
sp = StreamProcessor(request_data={}, decoded_token={"sub": "u"})
sp.agent_config = {"prompt_id": "default"}
result1 = sp._get_prompt_content()
result2 = sp._get_prompt_content()
assert result1 == result2
assert result1 is not None
@pytest.mark.unit
def test_no_prompt_id(self):
mock_db = MagicMock()
with patch("application.api.answer.services.stream_processor.MongoDB") as MockMongo, \
patch("application.api.answer.services.stream_processor.settings") as mock_settings:
mock_settings.MONGO_DB_NAME = "docsgpt"
MockMongo.get_client.return_value = {"docsgpt": mock_db}
from application.api.answer.services.stream_processor import StreamProcessor
sp = StreamProcessor(request_data={}, decoded_token={"sub": "u"})
sp.agent_config = {}
assert sp._get_prompt_content() is None
@pytest.mark.unit
def test_invalid_prompt_id_returns_none(self):
mock_db = MagicMock()
mock_prompts = MagicMock()
mock_prompts.find_one.side_effect = Exception("bad")
mock_db.__getitem__ = MagicMock(return_value=mock_prompts)
with patch("application.api.answer.services.stream_processor.MongoDB") as MockMongo, \
patch("application.api.answer.services.stream_processor.settings") as mock_settings:
mock_settings.MONGO_DB_NAME = "docsgpt"
MockMongo.get_client.return_value = {"docsgpt": mock_db}
from application.api.answer.services.stream_processor import StreamProcessor
sp = StreamProcessor(request_data={}, decoded_token={"sub": "u"})
sp.agent_config = {"prompt_id": "bad_id"}
assert sp._get_prompt_content() is None
class TestGetRequiredToolActions:
@pytest.mark.unit
def test_no_prompt_returns_none(self):
mock_db = MagicMock()
with patch("application.api.answer.services.stream_processor.MongoDB") as MockMongo, \
patch("application.api.answer.services.stream_processor.settings") as mock_settings:
mock_settings.MONGO_DB_NAME = "docsgpt"
MockMongo.get_client.return_value = {"docsgpt": mock_db}
from application.api.answer.services.stream_processor import StreamProcessor
sp = StreamProcessor(request_data={}, decoded_token={"sub": "u"})
sp.agent_config = {}
assert sp._get_required_tool_actions() is None
@pytest.mark.unit
def test_no_template_syntax_returns_empty(self):
mock_db = MagicMock()
with patch("application.api.answer.services.stream_processor.MongoDB") as MockMongo, \
patch("application.api.answer.services.stream_processor.settings") as mock_settings:
mock_settings.MONGO_DB_NAME = "docsgpt"
MockMongo.get_client.return_value = {"docsgpt": mock_db}
from application.api.answer.services.stream_processor import StreamProcessor
sp = StreamProcessor(request_data={}, decoded_token={"sub": "u"})
sp.agent_config = {"prompt_id": "default"}
sp._prompt_content = "No template syntax here"
result = sp._get_required_tool_actions()
assert result == {}

View File

@@ -0,0 +1,333 @@
"""Tests for application/api/connector/routes.py"""
import json
from unittest.mock import MagicMock, patch
import mongomock
import pytest
@pytest.fixture
def app():
with patch("application.app.handle_auth", return_value={"sub": "test_user"}):
from application.app import app as flask_app
flask_app.config["TESTING"] = True
yield flask_app
@pytest.fixture
def client(app):
return app.test_client()
@pytest.fixture
def mock_sessions(monkeypatch):
mock_client = mongomock.MongoClient()
mock_db = mock_client["docsgpt"]
sessions = mock_db["connector_sessions"]
sources = mock_db["sources"]
monkeypatch.setattr("application.api.connector.routes.sessions_collection", sessions)
monkeypatch.setattr("application.api.connector.routes.sources_collection", sources)
return {"sessions": sessions, "sources": sources}
class TestConnectorAuth:
@pytest.mark.unit
def test_missing_provider(self, client):
resp = client.get("/api/connectors/auth")
assert resp.status_code == 400
@pytest.mark.unit
def test_unsupported_provider(self, client):
resp = client.get("/api/connectors/auth?provider=dropbox")
assert resp.status_code == 400
@pytest.mark.unit
def test_unauthorized(self, client, app):
with patch("application.app.handle_auth", return_value=None):
resp = client.get("/api/connectors/auth?provider=google_drive")
data = json.loads(resp.data)
# decoded_token is None -> 401
assert resp.status_code == 401 or data.get("error") == "Unauthorized"
@pytest.mark.unit
def test_success(self, client, mock_sessions):
with patch("application.api.connector.routes.ConnectorCreator") as MockCC:
MockCC.is_supported.return_value = True
mock_auth = MagicMock()
mock_auth.get_authorization_url.return_value = "https://oauth.example.com/auth"
MockCC.create_auth.return_value = mock_auth
resp = client.get("/api/connectors/auth?provider=google_drive")
assert resp.status_code == 200
data = json.loads(resp.data)
assert data["success"] is True
assert "authorization_url" in data
@pytest.mark.unit
def test_exception_returns_500(self, client, mock_sessions):
with patch("application.api.connector.routes.ConnectorCreator") as MockCC:
MockCC.is_supported.return_value = True
MockCC.create_auth.side_effect = Exception("oauth fail")
resp = client.get("/api/connectors/auth?provider=google_drive")
assert resp.status_code == 500
class TestConnectorFiles:
@pytest.mark.unit
def test_missing_params(self, client):
resp = client.post("/api/connectors/files", json={"provider": "google_drive"})
assert resp.status_code == 400
@pytest.mark.unit
def test_invalid_session(self, client, mock_sessions):
resp = client.post("/api/connectors/files", json={
"provider": "google_drive",
"session_token": "bad_token",
})
assert resp.status_code == 401
@pytest.mark.unit
def test_success(self, client, mock_sessions):
mock_sessions["sessions"].insert_one({
"session_token": "valid_tok",
"user": "test_user",
"provider": "google_drive",
})
mock_doc = MagicMock()
mock_doc.doc_id = "f1"
mock_doc.extra_info = {
"file_name": "test.pdf",
"mime_type": "application/pdf",
"size": 1024,
"modified_time": "2025-01-01T12:00:00.000Z",
"is_folder": False,
}
mock_loader = MagicMock()
mock_loader.load_data.return_value = [mock_doc]
mock_loader.next_page_token = None
with patch("application.api.connector.routes.ConnectorCreator") as MockCC:
MockCC.create_connector.return_value = mock_loader
resp = client.post("/api/connectors/files", json={
"provider": "google_drive",
"session_token": "valid_tok",
})
assert resp.status_code == 200
data = json.loads(resp.data)
assert data["success"] is True
assert len(data["files"]) == 1
@pytest.mark.unit
def test_no_modified_time(self, client, mock_sessions):
mock_sessions["sessions"].insert_one({
"session_token": "tok2",
"user": "test_user",
"provider": "google_drive",
})
mock_doc = MagicMock()
mock_doc.doc_id = "f1"
mock_doc.extra_info = {"file_name": "test.pdf", "mime_type": "application/pdf"}
mock_loader = MagicMock()
mock_loader.load_data.return_value = [mock_doc]
mock_loader.next_page_token = None
with patch("application.api.connector.routes.ConnectorCreator") as MockCC:
MockCC.create_connector.return_value = mock_loader
resp = client.post("/api/connectors/files", json={
"provider": "google_drive", "session_token": "tok2",
})
assert resp.status_code == 200
class TestConnectorValidateSession:
@pytest.mark.unit
def test_missing_params(self, client):
resp = client.post("/api/connectors/validate-session", json={"provider": "google_drive"})
assert resp.status_code == 400
@pytest.mark.unit
def test_invalid_session(self, client, mock_sessions):
resp = client.post("/api/connectors/validate-session", json={
"provider": "google_drive", "session_token": "bad",
})
assert resp.status_code == 401
@pytest.mark.unit
def test_valid_non_expired(self, client, mock_sessions):
mock_sessions["sessions"].insert_one({
"session_token": "valid",
"user": "test_user",
"provider": "google_drive",
"token_info": {"access_token": "at", "refresh_token": "rt", "expiry": None},
"user_email": "user@example.com",
})
with patch("application.api.connector.routes.ConnectorCreator") as MockCC:
mock_auth = MagicMock()
mock_auth.is_token_expired.return_value = False
MockCC.create_auth.return_value = mock_auth
resp = client.post("/api/connectors/validate-session", json={
"provider": "google_drive", "session_token": "valid",
})
assert resp.status_code == 200
data = json.loads(resp.data)
assert data["success"] is True
assert data["expired"] is False
@pytest.mark.unit
def test_expired_with_refresh(self, client, mock_sessions):
mock_sessions["sessions"].insert_one({
"session_token": "expired_tok",
"user": "test_user",
"provider": "google_drive",
"token_info": {"access_token": "old_at", "refresh_token": "rt", "expiry": 100},
})
with patch("application.api.connector.routes.ConnectorCreator") as MockCC:
mock_auth = MagicMock()
mock_auth.is_token_expired.return_value = True
mock_auth.refresh_access_token.return_value = {"access_token": "new_at", "refresh_token": "rt"}
mock_auth.sanitize_token_info.return_value = {"access_token": "new_at", "refresh_token": "rt"}
MockCC.create_auth.return_value = mock_auth
resp = client.post("/api/connectors/validate-session", json={
"provider": "google_drive", "session_token": "expired_tok",
})
assert resp.status_code == 200
@pytest.mark.unit
def test_expired_no_refresh(self, client, mock_sessions):
mock_sessions["sessions"].insert_one({
"session_token": "exp_no_ref",
"user": "test_user",
"token_info": {"access_token": "at", "expiry": 100},
})
with patch("application.api.connector.routes.ConnectorCreator") as MockCC:
mock_auth = MagicMock()
mock_auth.is_token_expired.return_value = True
MockCC.create_auth.return_value = mock_auth
resp = client.post("/api/connectors/validate-session", json={
"provider": "google_drive", "session_token": "exp_no_ref",
})
assert resp.status_code == 401
class TestConnectorDisconnect:
@pytest.mark.unit
def test_missing_provider(self, client):
resp = client.post("/api/connectors/disconnect", json={})
assert resp.status_code == 400
@pytest.mark.unit
def test_success_with_session(self, client, mock_sessions):
mock_sessions["sessions"].insert_one({"session_token": "del_me", "provider": "google_drive"})
resp = client.post("/api/connectors/disconnect", json={
"provider": "google_drive", "session_token": "del_me",
})
assert resp.status_code == 200
data = json.loads(resp.data)
assert data["success"] is True
@pytest.mark.unit
def test_success_without_session(self, client, mock_sessions):
resp = client.post("/api/connectors/disconnect", json={"provider": "google_drive"})
assert resp.status_code == 200
class TestConnectorSync:
@pytest.mark.unit
def test_missing_params(self, client, mock_sessions):
resp = client.post("/api/connectors/sync", json={"source_id": "abc"})
assert resp.status_code == 400
@pytest.mark.unit
def test_source_not_found(self, client, mock_sessions):
from bson.objectid import ObjectId
resp = client.post("/api/connectors/sync", json={
"source_id": str(ObjectId()), "session_token": "tok",
})
assert resp.status_code == 404
@pytest.mark.unit
def test_unauthorized_source(self, client, mock_sessions):
sid = mock_sessions["sources"].insert_one({"user": "other_user", "name": "src"}).inserted_id
resp = client.post("/api/connectors/sync", json={
"source_id": str(sid), "session_token": "tok",
})
assert resp.status_code == 403
@pytest.mark.unit
def test_missing_provider_in_remote_data(self, client, mock_sessions):
sid = mock_sessions["sources"].insert_one({
"user": "test_user", "name": "src", "remote_data": json.dumps({}),
}).inserted_id
resp = client.post("/api/connectors/sync", json={
"source_id": str(sid), "session_token": "tok",
})
assert resp.status_code == 400
@pytest.mark.unit
def test_success(self, client, mock_sessions):
sid = mock_sessions["sources"].insert_one({
"user": "test_user",
"name": "src",
"remote_data": json.dumps({"provider": "google_drive", "file_ids": ["f1"]}),
}).inserted_id
mock_task = MagicMock()
mock_task.id = "task_123"
with patch("application.api.connector.routes.ingest_connector_task") as mock_ingest:
mock_ingest.delay.return_value = mock_task
resp = client.post("/api/connectors/sync", json={
"source_id": str(sid), "session_token": "tok",
})
assert resp.status_code == 200
data = json.loads(resp.data)
assert data["task_id"] == "task_123"
class TestConnectorCallbackStatus:
@pytest.mark.unit
def test_success_status(self, client):
resp = client.get("/api/connectors/callback-status?status=success&message=OK&provider=google_drive&session_token=tok&user_email=u@e.com")
assert resp.status_code == 200
assert b"success" in resp.data
@pytest.mark.unit
def test_error_status(self, client):
resp = client.get("/api/connectors/callback-status?status=error&message=Failed")
assert resp.status_code == 200
assert b"error" in resp.data
@pytest.mark.unit
def test_cancelled_status(self, client):
resp = client.get("/api/connectors/callback-status?status=cancelled&message=Cancelled&provider=google_drive")
assert resp.status_code == 200
assert b"cancelled" in resp.data
@pytest.mark.unit
def test_unknown_status_defaults_to_error(self, client):
resp = client.get("/api/connectors/callback-status?status=badvalue")
assert resp.status_code == 200
assert b"error" in resp.data
@pytest.mark.unit
def test_html_escaping(self, client):
resp = client.get('/api/connectors/callback-status?status=error&message=<script>alert(1)</script>')
assert resp.status_code == 200
# The raw <script> tag should be escaped (not executable)
assert b"<script>alert(1)</script>" not in resp.data
class TestBuildCallbackRedirect:
@pytest.mark.unit
def test_builds_url(self):
from application.api.connector.routes import build_callback_redirect
url = build_callback_redirect({"status": "success", "message": "OK"})
assert url.startswith("/api/connectors/callback-status?")
assert "status=success" in url

View File

@@ -0,0 +1,339 @@
"""Tests for application/core/model_settings.py"""
from unittest.mock import MagicMock, patch
import pytest
from application.core.model_settings import (
AvailableModel,
ModelCapabilities,
ModelProvider,
ModelRegistry,
)
class TestModelProvider:
@pytest.mark.unit
def test_all_providers_exist(self):
assert ModelProvider.OPENAI == "openai"
assert ModelProvider.ANTHROPIC == "anthropic"
assert ModelProvider.GOOGLE == "google"
assert ModelProvider.GROQ == "groq"
assert ModelProvider.DOCSGPT == "docsgpt"
assert ModelProvider.HUGGINGFACE == "huggingface"
assert ModelProvider.NOVITA == "novita"
assert ModelProvider.OPENROUTER == "openrouter"
assert ModelProvider.SAGEMAKER == "sagemaker"
assert ModelProvider.PREMAI == "premai"
assert ModelProvider.LLAMA_CPP == "llama.cpp"
assert ModelProvider.AZURE_OPENAI == "azure_openai"
class TestModelCapabilities:
@pytest.mark.unit
def test_defaults(self):
caps = ModelCapabilities()
assert caps.supports_tools is False
assert caps.supports_structured_output is False
assert caps.supports_streaming is True
assert caps.supported_attachment_types == []
assert caps.context_window == 128000
assert caps.input_cost_per_token is None
assert caps.output_cost_per_token is None
@pytest.mark.unit
def test_custom_values(self):
caps = ModelCapabilities(
supports_tools=True,
supports_structured_output=True,
context_window=32000,
input_cost_per_token=0.001,
)
assert caps.supports_tools is True
assert caps.context_window == 32000
class TestAvailableModel:
@pytest.mark.unit
def test_to_dict_basic(self):
model = AvailableModel(
id="gpt-4",
provider=ModelProvider.OPENAI,
display_name="GPT-4",
description="OpenAI GPT-4",
)
d = model.to_dict()
assert d["id"] == "gpt-4"
assert d["provider"] == "openai"
assert d["display_name"] == "GPT-4"
assert d["enabled"] is True
assert "base_url" not in d
@pytest.mark.unit
def test_to_dict_with_base_url(self):
model = AvailableModel(
id="local-model",
provider=ModelProvider.OPENAI,
display_name="Local",
base_url="http://localhost:11434",
)
d = model.to_dict()
assert d["base_url"] == "http://localhost:11434"
@pytest.mark.unit
def test_to_dict_includes_capabilities(self):
caps = ModelCapabilities(supports_tools=True, context_window=64000)
model = AvailableModel(
id="m1",
provider=ModelProvider.ANTHROPIC,
display_name="M1",
capabilities=caps,
)
d = model.to_dict()
assert d["supports_tools"] is True
assert d["context_window"] == 64000
class TestModelRegistry:
@pytest.fixture(autouse=True)
def _reset_singleton(self):
"""Reset singleton between tests."""
ModelRegistry._instance = None
ModelRegistry._initialized = False
yield
ModelRegistry._instance = None
ModelRegistry._initialized = False
@pytest.mark.unit
def test_singleton(self):
with patch.object(ModelRegistry, "_load_models"):
r1 = ModelRegistry()
r2 = ModelRegistry()
assert r1 is r2
@pytest.mark.unit
def test_get_instance(self):
with patch.object(ModelRegistry, "_load_models"):
r = ModelRegistry.get_instance()
assert isinstance(r, ModelRegistry)
@pytest.mark.unit
def test_get_model(self):
with patch.object(ModelRegistry, "_load_models"):
reg = ModelRegistry()
model = AvailableModel(id="test", provider=ModelProvider.OPENAI, display_name="Test")
reg.models["test"] = model
assert reg.get_model("test") is model
assert reg.get_model("nonexistent") is None
@pytest.mark.unit
def test_get_all_models(self):
with patch.object(ModelRegistry, "_load_models"):
reg = ModelRegistry()
reg.models["m1"] = AvailableModel(id="m1", provider=ModelProvider.OPENAI, display_name="M1")
reg.models["m2"] = AvailableModel(id="m2", provider=ModelProvider.ANTHROPIC, display_name="M2")
assert len(reg.get_all_models()) == 2
@pytest.mark.unit
def test_get_enabled_models(self):
with patch.object(ModelRegistry, "_load_models"):
reg = ModelRegistry()
reg.models["m1"] = AvailableModel(id="m1", provider=ModelProvider.OPENAI, display_name="M1", enabled=True)
reg.models["m2"] = AvailableModel(id="m2", provider=ModelProvider.OPENAI, display_name="M2", enabled=False)
enabled = reg.get_enabled_models()
assert len(enabled) == 1
assert enabled[0].id == "m1"
@pytest.mark.unit
def test_model_exists(self):
with patch.object(ModelRegistry, "_load_models"):
reg = ModelRegistry()
reg.models["m1"] = AvailableModel(id="m1", provider=ModelProvider.OPENAI, display_name="M1")
assert reg.model_exists("m1") is True
assert reg.model_exists("m2") is False
@pytest.mark.unit
def test_parse_model_names(self):
with patch.object(ModelRegistry, "_load_models"):
reg = ModelRegistry()
assert reg._parse_model_names("model1,model2") == ["model1", "model2"]
assert reg._parse_model_names("model1 , model2 ") == ["model1", "model2"]
assert reg._parse_model_names("single") == ["single"]
assert reg._parse_model_names("") == []
assert reg._parse_model_names(None) == []
@pytest.mark.unit
def test_add_docsgpt_models(self):
with patch.object(ModelRegistry, "_load_models"):
reg = ModelRegistry()
reg.models = {}
mock_settings = MagicMock()
reg._add_docsgpt_models(mock_settings)
assert "docsgpt-local" in reg.models
@pytest.mark.unit
def test_add_huggingface_models(self):
with patch.object(ModelRegistry, "_load_models"):
reg = ModelRegistry()
reg.models = {}
mock_settings = MagicMock()
reg._add_huggingface_models(mock_settings)
assert "huggingface-local" in reg.models
@pytest.mark.unit
def test_load_models_with_openai_key(self):
mock_settings = MagicMock()
mock_settings.OPENAI_BASE_URL = None
mock_settings.OPENAI_API_KEY = "sk-test"
mock_settings.OPENAI_API_BASE = None
mock_settings.ANTHROPIC_API_KEY = None
mock_settings.GOOGLE_API_KEY = None
mock_settings.GROQ_API_KEY = None
mock_settings.OPEN_ROUTER_API_KEY = None
mock_settings.NOVITA_API_KEY = None
mock_settings.HUGGINGFACE_API_KEY = None
mock_settings.LLM_PROVIDER = "openai"
mock_settings.LLM_NAME = ""
mock_settings.API_KEY = None
with patch("application.core.settings.settings", mock_settings):
reg = ModelRegistry()
assert len(reg.models) > 0
@pytest.mark.unit
def test_load_models_custom_openai_base_url(self):
mock_settings = MagicMock()
mock_settings.OPENAI_BASE_URL = "http://localhost:11434/v1"
mock_settings.OPENAI_API_KEY = "sk-test"
mock_settings.OPENAI_API_BASE = None
mock_settings.ANTHROPIC_API_KEY = None
mock_settings.GOOGLE_API_KEY = None
mock_settings.GROQ_API_KEY = None
mock_settings.OPEN_ROUTER_API_KEY = None
mock_settings.NOVITA_API_KEY = None
mock_settings.HUGGINGFACE_API_KEY = None
mock_settings.LLM_PROVIDER = "openai"
mock_settings.LLM_NAME = "llama3,gemma"
mock_settings.API_KEY = None
with patch("application.core.settings.settings", mock_settings):
reg = ModelRegistry()
assert "llama3" in reg.models
assert "gemma" in reg.models
@pytest.mark.unit
def test_default_model_selection_from_llm_name(self):
with patch.object(ModelRegistry, "_load_models"):
reg = ModelRegistry()
reg.models = {"gpt-4": AvailableModel(id="gpt-4", provider=ModelProvider.OPENAI, display_name="GPT-4")}
reg.default_model_id = "gpt-4"
assert reg.default_model_id == "gpt-4"
@pytest.mark.unit
def test_add_anthropic_models_with_key(self):
with patch.object(ModelRegistry, "_load_models"):
reg = ModelRegistry()
reg.models = {}
mock_settings = MagicMock()
mock_settings.ANTHROPIC_API_KEY = "sk-ant-test"
mock_settings.LLM_PROVIDER = ""
mock_settings.LLM_NAME = ""
reg._add_anthropic_models(mock_settings)
assert len(reg.models) > 0
@pytest.mark.unit
def test_add_google_models_with_key(self):
with patch.object(ModelRegistry, "_load_models"):
reg = ModelRegistry()
reg.models = {}
mock_settings = MagicMock()
mock_settings.GOOGLE_API_KEY = "google-test"
mock_settings.LLM_PROVIDER = ""
mock_settings.LLM_NAME = ""
reg._add_google_models(mock_settings)
assert len(reg.models) > 0
@pytest.mark.unit
def test_add_groq_models_with_key(self):
with patch.object(ModelRegistry, "_load_models"):
reg = ModelRegistry()
reg.models = {}
mock_settings = MagicMock()
mock_settings.GROQ_API_KEY = "groq-test"
mock_settings.LLM_PROVIDER = ""
mock_settings.LLM_NAME = ""
reg._add_groq_models(mock_settings)
assert len(reg.models) > 0
@pytest.mark.unit
def test_add_openrouter_models_with_key(self):
with patch.object(ModelRegistry, "_load_models"):
reg = ModelRegistry()
reg.models = {}
mock_settings = MagicMock()
mock_settings.OPEN_ROUTER_API_KEY = "or-test"
mock_settings.LLM_PROVIDER = ""
mock_settings.LLM_NAME = ""
reg._add_openrouter_models(mock_settings)
assert len(reg.models) > 0
@pytest.mark.unit
def test_add_novita_models_with_key(self):
with patch.object(ModelRegistry, "_load_models"):
reg = ModelRegistry()
reg.models = {}
mock_settings = MagicMock()
mock_settings.NOVITA_API_KEY = "novita-test"
mock_settings.LLM_PROVIDER = ""
mock_settings.LLM_NAME = ""
reg._add_novita_models(mock_settings)
assert len(reg.models) > 0
@pytest.mark.unit
def test_add_azure_openai_models_specific(self):
with patch.object(ModelRegistry, "_load_models"):
reg = ModelRegistry()
reg.models = {}
mock_settings = MagicMock()
mock_settings.LLM_PROVIDER = "azure_openai"
mock_settings.LLM_NAME = "nonexistent-model"
reg._add_azure_openai_models(mock_settings)
# Falls through to adding all azure models
assert len(reg.models) > 0
@pytest.mark.unit
def test_add_anthropic_models_no_key_with_provider(self):
with patch.object(ModelRegistry, "_load_models"):
reg = ModelRegistry()
reg.models = {}
mock_settings = MagicMock()
mock_settings.ANTHROPIC_API_KEY = None
mock_settings.LLM_PROVIDER = "anthropic"
mock_settings.LLM_NAME = "nonexistent"
reg._add_anthropic_models(mock_settings)
assert len(reg.models) > 0
@pytest.mark.unit
def test_default_model_fallback_to_first(self):
mock_settings = MagicMock()
mock_settings.OPENAI_BASE_URL = None
mock_settings.OPENAI_API_KEY = None
mock_settings.OPENAI_API_BASE = None
mock_settings.ANTHROPIC_API_KEY = None
mock_settings.GOOGLE_API_KEY = None
mock_settings.GROQ_API_KEY = None
mock_settings.OPEN_ROUTER_API_KEY = None
mock_settings.NOVITA_API_KEY = None
mock_settings.HUGGINGFACE_API_KEY = None
mock_settings.LLM_PROVIDER = ""
mock_settings.LLM_NAME = ""
mock_settings.API_KEY = None
with patch("application.core.settings.settings", mock_settings):
reg = ModelRegistry()
# Should have at least docsgpt-local
assert reg.default_model_id is not None

115
tests/test_app_routes.py Normal file
View File

@@ -0,0 +1,115 @@
"""Tests for application/app.py route handlers."""
import json
from unittest.mock import patch
import pytest
@pytest.fixture
def app():
"""Import the Flask app with auth mocked to avoid JWT setup issues."""
with patch("application.app.handle_auth", return_value={"sub": "test_user"}):
from application.app import app as flask_app
flask_app.config["TESTING"] = True
yield flask_app
@pytest.fixture
def client(app):
return app.test_client()
class TestHomeRoute:
@pytest.mark.unit
def test_root_returns_200(self, client):
"""Root serves Swagger UI via Flask-RESTX."""
response = client.get("/")
assert response.status_code == 200
class TestHealthRoute:
@pytest.mark.unit
def test_returns_ok(self, client):
response = client.get("/api/health")
assert response.status_code == 200
data = json.loads(response.data)
assert data["status"] == "ok"
class TestConfigRoute:
@pytest.mark.unit
def test_returns_auth_config(self, client):
response = client.get("/api/config")
assert response.status_code == 200
data = json.loads(response.data)
assert "auth_type" in data
assert "requires_auth" in data
class TestGenerateTokenRoute:
@pytest.mark.unit
def test_session_jwt_generates_token(self, client, app):
with patch("application.app.settings") as mock_settings:
mock_settings.AUTH_TYPE = "session_jwt"
mock_settings.JWT_SECRET_KEY = "test_secret"
response = client.get("/api/generate_token")
assert response.status_code == 200
data = json.loads(response.data)
assert "token" in data
@pytest.mark.unit
def test_non_session_jwt_returns_error(self, client, app):
with patch("application.app.settings") as mock_settings:
mock_settings.AUTH_TYPE = "none"
response = client.get("/api/generate_token")
assert response.status_code == 400
class TestSttRequestSizeLimits:
@pytest.mark.unit
def test_non_stt_request_passes(self, client):
response = client.get("/api/health")
assert response.status_code == 200
@pytest.mark.unit
def test_oversized_stt_request_rejected(self, client):
with patch("application.app.should_reject_stt_request", return_value=True), \
patch("application.app.build_stt_file_size_limit_message", return_value="Too large"):
response = client.post("/api/stt/upload", data=b"x" * 100)
assert response.status_code == 413
class TestAuthenticateRequest:
@pytest.mark.unit
def test_options_returns_200(self, client):
response = client.options("/api/health")
assert response.status_code == 200
@pytest.mark.unit
def test_auth_error_returns_401(self, client, app):
with patch("application.app.handle_auth", return_value={"error": "Invalid token"}):
response = client.get("/api/health")
assert response.status_code == 401
@pytest.mark.unit
def test_no_token_sets_none(self, client, app):
with patch("application.app.handle_auth", return_value=None):
response = client.get("/api/health")
assert response.status_code == 200
class TestAfterRequest:
@pytest.mark.unit
def test_cors_headers(self, client):
response = client.get("/api/health")
assert response.headers.get("Access-Control-Allow-Origin") == "*"
assert "Content-Type" in response.headers.get("Access-Control-Allow-Headers", "")
assert "GET" in response.headers.get("Access-Control-Allow-Methods", "")

533
tests/test_utils.py Normal file
View File

@@ -0,0 +1,533 @@
"""Tests for application/utils.py"""
from unittest.mock import MagicMock, patch
import pytest
from application.utils import (
calculate_compression_threshold,
calculate_doc_token_budget,
check_required_fields,
clean_text_for_tts,
convert_pdf_to_images,
get_encoding,
get_field_validation_errors,
get_gpt_model,
get_hash,
get_missing_fields,
generate_image_url,
limit_chat_history,
num_tokens_from_object_or_list,
num_tokens_from_string,
safe_filename,
validate_function_name,
validate_required_fields,
)
class TestGetEncoding:
@pytest.mark.unit
def test_returns_encoding(self):
enc = get_encoding()
assert enc is not None
@pytest.mark.unit
def test_returns_same_instance(self):
enc1 = get_encoding()
enc2 = get_encoding()
assert enc1 is enc2
class TestGetGptModel:
@pytest.mark.unit
def test_returns_llm_name_when_set(self):
with patch("application.utils.settings") as s:
s.LLM_NAME = "my-model"
s.LLM_PROVIDER = "openai"
assert get_gpt_model() == "my-model"
@pytest.mark.unit
def test_falls_back_to_provider_map(self):
with patch("application.utils.settings") as s:
s.LLM_NAME = ""
s.LLM_PROVIDER = "openai"
assert get_gpt_model() == "gpt-4o-mini"
@pytest.mark.unit
def test_unknown_provider_returns_empty(self):
with patch("application.utils.settings") as s:
s.LLM_NAME = ""
s.LLM_PROVIDER = "unknown"
assert get_gpt_model() == ""
class TestSafeFilename:
@pytest.mark.unit
def test_normal_filename(self):
assert safe_filename("test.pdf") == "test.pdf"
@pytest.mark.unit
def test_empty_filename_returns_uuid(self):
result = safe_filename("")
assert len(result) > 10 # UUID
@pytest.mark.unit
def test_none_filename_returns_uuid(self):
result = safe_filename(None)
assert len(result) > 10
@pytest.mark.unit
def test_non_latin_filename(self):
result = safe_filename("документ.pdf")
assert result.endswith(".pdf")
class TestNumTokens:
@pytest.mark.unit
def test_string_token_count(self):
count = num_tokens_from_string("hello world")
assert count > 0
@pytest.mark.unit
def test_non_string_returns_zero(self):
assert num_tokens_from_string(123) == 0
@pytest.mark.unit
def test_empty_string(self):
assert num_tokens_from_string("") == 0
class TestNumTokensFromObjectOrList:
@pytest.mark.unit
def test_list(self):
result = num_tokens_from_object_or_list(["hello", "world"])
assert result > 0
@pytest.mark.unit
def test_dict(self):
result = num_tokens_from_object_or_list({"key": "value"})
assert result > 0
@pytest.mark.unit
def test_string(self):
result = num_tokens_from_object_or_list("hello")
assert result > 0
@pytest.mark.unit
def test_number_returns_zero(self):
assert num_tokens_from_object_or_list(42) == 0
@pytest.mark.unit
def test_nested(self):
result = num_tokens_from_object_or_list({"a": ["b", "c"]})
assert result > 0
class TestCountTokensDocs:
@pytest.mark.unit
def test_counts_doc_tokens(self):
from application.utils import count_tokens_docs
doc1 = MagicMock()
doc1.page_content = "hello world"
doc2 = MagicMock()
doc2.page_content = " foo bar"
result = count_tokens_docs([doc1, doc2])
assert result > 0
class TestCalculateDocTokenBudget:
@pytest.mark.unit
def test_returns_budget(self):
with patch("application.utils.get_token_limit", return_value=128000), \
patch("application.utils.settings") as s:
s.RESERVED_TOKENS = {"system": 500, "history": 500}
result = calculate_doc_token_budget("gpt-4o")
assert result == 127000
@pytest.mark.unit
def test_minimum_budget(self):
with patch("application.utils.get_token_limit", return_value=1000), \
patch("application.utils.settings") as s:
s.RESERVED_TOKENS = {"system": 500, "history": 500}
result = calculate_doc_token_budget("small-model")
assert result == 1000
class TestFieldValidation:
@pytest.mark.unit
def test_get_missing_fields(self):
assert get_missing_fields({"a": 1}, ["a", "b"]) == ["b"]
assert get_missing_fields({"a": 1, "b": 2}, ["a", "b"]) == []
@pytest.mark.unit
def test_check_required_fields_pass(self):
from flask import Flask
app = Flask(__name__)
with app.app_context():
result = check_required_fields({"a": 1, "b": 2}, ["a", "b"])
assert result is None
@pytest.mark.unit
def test_check_required_fields_fail(self):
from flask import Flask
app = Flask(__name__)
with app.app_context():
result = check_required_fields({"a": 1}, ["a", "b"])
assert result is not None
assert result.status_code == 400
@pytest.mark.unit
def test_get_field_validation_errors_none_when_valid(self):
assert get_field_validation_errors({"a": 1}, ["a"]) is None
@pytest.mark.unit
def test_get_field_validation_errors_missing(self):
result = get_field_validation_errors({}, ["a"])
assert result["missing_fields"] == ["a"]
@pytest.mark.unit
def test_get_field_validation_errors_empty(self):
result = get_field_validation_errors({"a": ""}, ["a"])
assert result["empty_fields"] == ["a"]
@pytest.mark.unit
def test_validate_required_fields_pass(self):
from flask import Flask
app = Flask(__name__)
with app.app_context():
result = validate_required_fields({"a": "v"}, ["a"])
assert result is None
@pytest.mark.unit
def test_validate_required_fields_missing(self):
from flask import Flask
app = Flask(__name__)
with app.app_context():
result = validate_required_fields({}, ["a"])
assert result is not None
assert result.status_code == 400
@pytest.mark.unit
def test_validate_required_fields_empty(self):
from flask import Flask
app = Flask(__name__)
with app.app_context():
result = validate_required_fields({"a": ""}, ["a"])
assert result is not None
@pytest.mark.unit
def test_validate_required_fields_both_missing_and_empty(self):
from flask import Flask
app = Flask(__name__)
with app.app_context():
result = validate_required_fields({"a": ""}, ["a", "b"])
assert result is not None
class TestGetHash:
@pytest.mark.unit
def test_returns_hex_string(self):
h = get_hash("test")
assert len(h) == 32
assert all(c in "0123456789abcdef" for c in h)
@pytest.mark.unit
def test_deterministic(self):
assert get_hash("hello") == get_hash("hello")
@pytest.mark.unit
def test_different_inputs(self):
assert get_hash("a") != get_hash("b")
class TestLimitChatHistory:
@pytest.mark.unit
def test_empty_history(self):
assert limit_chat_history([]) == []
@pytest.mark.unit
def test_none_history(self):
assert limit_chat_history(None) == []
@pytest.mark.unit
def test_keeps_recent_messages(self):
history = [
{"prompt": "q1", "response": "a1"},
{"prompt": "q2", "response": "a2"},
]
result = limit_chat_history(history, max_token_limit=10000)
assert len(result) == 2
@pytest.mark.unit
def test_trims_old_messages(self):
history = [
{"prompt": "x" * 5000, "response": "y" * 5000},
{"prompt": "q", "response": "a"},
]
result = limit_chat_history(history, max_token_limit=100)
assert len(result) <= 2
@pytest.mark.unit
def test_handles_tool_calls(self):
history = [
{
"prompt": "q",
"response": "a",
"tool_calls": [
{"tool_name": "t", "action_name": "a", "arguments": "{}", "result": "r"}
],
}
]
result = limit_chat_history(history, max_token_limit=10000)
assert len(result) == 1
class TestValidateFunctionName:
@pytest.mark.unit
def test_valid_names(self):
assert validate_function_name("hello") is True
assert validate_function_name("hello_world") is True
assert validate_function_name("hello-world") is True
assert validate_function_name("test123") is True
@pytest.mark.unit
def test_invalid_names(self):
assert validate_function_name("hello world") is False
assert validate_function_name("hello!") is False
assert validate_function_name("") is False
class TestGenerateImageUrl:
@pytest.mark.unit
def test_http_url_passthrough(self):
assert generate_image_url("https://example.com/img.png") == "https://example.com/img.png"
assert generate_image_url("http://example.com/img.png") == "http://example.com/img.png"
@pytest.mark.unit
def test_s3_strategy(self):
with patch("application.utils.settings") as s:
s.URL_STRATEGY = "s3"
s.S3_BUCKET_NAME = "my-bucket"
s.SAGEMAKER_REGION = "us-west-2"
result = generate_image_url("path/to/img.png")
assert "my-bucket.s3.us-west-2" in result
@pytest.mark.unit
def test_backend_strategy(self):
with patch("application.utils.settings") as s:
s.URL_STRATEGY = "backend"
s.API_URL = "http://localhost:7091"
result = generate_image_url("path/to/img.png")
assert result == "http://localhost:7091/api/images/path/to/img.png"
class TestCalculateCompressionThreshold:
@pytest.mark.unit
def test_default_threshold(self):
with patch("application.utils.get_token_limit", return_value=100000):
result = calculate_compression_threshold("gpt-4o")
assert result == 80000
@pytest.mark.unit
def test_custom_percentage(self):
with patch("application.utils.get_token_limit", return_value=100000):
result = calculate_compression_threshold("gpt-4o", 0.5)
assert result == 50000
class TestConvertPdfToImages:
@pytest.mark.unit
def test_missing_pdf2image_raises(self):
with patch.dict("sys.modules", {"pdf2image": None}):
# Force re-import to trigger ImportError
# The function handles the import internally
with pytest.raises(ImportError, match="pdf2image"):
convert_pdf_to_images("test.pdf")
@pytest.mark.unit
def test_converts_from_path(self):
mock_image = MagicMock()
mock_image.save = MagicMock(side_effect=lambda buf, format: buf.write(b"PNG_DATA"))
mock_module = MagicMock()
mock_module.convert_from_path.return_value = [mock_image]
mock_module.convert_from_bytes.return_value = [mock_image]
original_import = __import__
def patched_import(name, *args, **kwargs):
if name == "pdf2image":
return mock_module
return original_import(name, *args, **kwargs)
with patch("builtins.__import__", side_effect=patched_import):
result = convert_pdf_to_images("/some/file.pdf")
assert len(result) == 1
assert result[0]["mime_type"] == "image/png"
assert result[0]["page"] == 1
@pytest.mark.unit
def test_with_storage(self):
mock_image = MagicMock()
mock_image.save = MagicMock(side_effect=lambda buf, format: buf.write(b"IMG"))
mock_storage = MagicMock()
mock_file = MagicMock()
mock_file.read.return_value = b"pdf_bytes"
mock_file.__enter__ = MagicMock(return_value=mock_file)
mock_file.__exit__ = MagicMock(return_value=False)
mock_storage.get_file.return_value = mock_file
mock_module = MagicMock()
mock_module.convert_from_bytes.return_value = [mock_image]
original_import = __import__
def patched_import(name, *args, **kwargs):
if name == "pdf2image":
return mock_module
return original_import(name, *args, **kwargs)
with patch("builtins.__import__", side_effect=patched_import):
result = convert_pdf_to_images("test.pdf", storage=mock_storage)
assert len(result) == 1
mock_module.convert_from_bytes.assert_called_once()
@pytest.mark.unit
def test_file_not_found_raises(self):
mock_module = MagicMock()
mock_module.convert_from_path.side_effect = FileNotFoundError("not found")
# Patch the import inside the function
original_import = __builtins__.__import__ if hasattr(__builtins__, '__import__') else __import__
def patched_import(name, *args, **kwargs):
if name == "pdf2image":
return mock_module
return original_import(name, *args, **kwargs)
with patch("builtins.__import__", side_effect=patched_import):
with pytest.raises(FileNotFoundError):
convert_pdf_to_images("/nonexistent.pdf")
@pytest.mark.unit
def test_generic_error_raises(self):
mock_module = MagicMock()
mock_module.convert_from_path.side_effect = RuntimeError("conversion failed")
original_import = __builtins__.__import__ if hasattr(__builtins__, '__import__') else __import__
def patched_import(name, *args, **kwargs):
if name == "pdf2image":
return mock_module
return original_import(name, *args, **kwargs)
with patch("builtins.__import__", side_effect=patched_import):
with pytest.raises(RuntimeError, match="conversion failed"):
convert_pdf_to_images("/some.pdf")
class TestCleanTextForTts:
@pytest.mark.unit
def test_removes_code_blocks(self):
result = clean_text_for_tts("before ```python\ncode\n``` after")
assert "code block" in result
assert "python" not in result
@pytest.mark.unit
def test_removes_mermaid_blocks(self):
result = clean_text_for_tts("```mermaid\ngraph TD\n```")
assert "flowchart" in result
@pytest.mark.unit
def test_removes_markdown_links(self):
result = clean_text_for_tts("[click here](https://example.com)")
assert "click here" in result
assert "https" not in result
@pytest.mark.unit
def test_removes_images(self):
result = clean_text_for_tts("![alt text](image.png)")
assert "image.png" not in result
@pytest.mark.unit
def test_removes_inline_code(self):
result = clean_text_for_tts("use `foo()` here")
assert "foo()" in result
assert "`" not in result
@pytest.mark.unit
def test_removes_bold_italic(self):
result = clean_text_for_tts("**bold** and *italic*")
assert "bold" in result
assert "italic" in result
assert "*" not in result
@pytest.mark.unit
def test_removes_headers(self):
result = clean_text_for_tts("# Header\ntext")
assert "Header" in result
assert "#" not in result
@pytest.mark.unit
def test_removes_blockquotes(self):
result = clean_text_for_tts("> quoted text")
assert "quoted text" in result
assert ">" not in result
@pytest.mark.unit
def test_removes_html_tags(self):
result = clean_text_for_tts("<div>content</div>")
assert "content" in result
assert "<" not in result
@pytest.mark.unit
def test_removes_arrows(self):
result = clean_text_for_tts("a --> b <-- c => d")
assert "-->" not in result
assert "<--" not in result
assert "=>" not in result
@pytest.mark.unit
def test_removes_horizontal_rules(self):
result = clean_text_for_tts("text\n---\nmore")
assert "---" not in result
@pytest.mark.unit
def test_removes_list_markers(self):
result = clean_text_for_tts("- item1\n* item2\n1. item3")
assert "item1" in result
assert "item2" in result
assert "item3" in result
@pytest.mark.unit
def test_normalizes_whitespace(self):
result = clean_text_for_tts(" lots of spaces ")
assert " " not in result
@pytest.mark.unit
def test_removes_braces(self):
result = clean_text_for_tts("{content} and [more]")
assert "content" in result
assert "more" in result
assert "{" not in result
@pytest.mark.unit
def test_removes_double_colons(self):
result = clean_text_for_tts("module::function")
assert "::" not in result