mirror of
https://github.com/arc53/DocsGPT.git
synced 2026-05-07 06:30:03 +00:00
chore: utils tests
This commit is contained in:
169
tests/agents/test_cel_evaluator.py
Normal file
169
tests/agents/test_cel_evaluator.py
Normal file
@@ -0,0 +1,169 @@
|
||||
"""Tests for application/agents/workflows/cel_evaluator.py"""
|
||||
|
||||
import pytest
|
||||
|
||||
from application.agents.workflows.cel_evaluator import (
|
||||
CelEvaluationError,
|
||||
_convert_value,
|
||||
build_activation,
|
||||
cel_to_python,
|
||||
evaluate_cel,
|
||||
)
|
||||
import celpy.celtypes
|
||||
|
||||
|
||||
class TestConvertValue:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_bool_true(self):
|
||||
result = _convert_value(True)
|
||||
assert isinstance(result, celpy.celtypes.BoolType)
|
||||
assert bool(result) is True
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_bool_false(self):
|
||||
result = _convert_value(False)
|
||||
assert isinstance(result, celpy.celtypes.BoolType)
|
||||
assert bool(result) is False
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_int(self):
|
||||
result = _convert_value(42)
|
||||
assert isinstance(result, celpy.celtypes.IntType)
|
||||
assert int(result) == 42
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_float(self):
|
||||
result = _convert_value(3.14)
|
||||
assert isinstance(result, celpy.celtypes.DoubleType)
|
||||
assert float(result) == pytest.approx(3.14)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_string(self):
|
||||
result = _convert_value("hello")
|
||||
assert isinstance(result, celpy.celtypes.StringType)
|
||||
assert str(result) == "hello"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_list(self):
|
||||
result = _convert_value([1, "two", 3.0])
|
||||
assert isinstance(result, celpy.celtypes.ListType)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_dict(self):
|
||||
result = _convert_value({"key": "value"})
|
||||
assert isinstance(result, celpy.celtypes.MapType)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_none(self):
|
||||
result = _convert_value(None)
|
||||
assert isinstance(result, celpy.celtypes.BoolType)
|
||||
assert bool(result) is False
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_other_type_converts_to_string(self):
|
||||
result = _convert_value(object())
|
||||
assert isinstance(result, celpy.celtypes.StringType)
|
||||
|
||||
|
||||
class TestBuildActivation:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_converts_dict_values(self):
|
||||
state = {"name": "Alice", "age": 30, "active": True}
|
||||
result = build_activation(state)
|
||||
assert "name" in result
|
||||
assert "age" in result
|
||||
assert "active" in result
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_empty_state(self):
|
||||
assert build_activation({}) == {}
|
||||
|
||||
|
||||
class TestEvaluateCel:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_simple_comparison(self):
|
||||
assert evaluate_cel("x > 5", {"x": 10}) is True
|
||||
assert evaluate_cel("x > 5", {"x": 3}) is False
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_string_comparison(self):
|
||||
assert evaluate_cel('name == "Alice"', {"name": "Alice"}) is True
|
||||
assert evaluate_cel('name == "Alice"', {"name": "Bob"}) is False
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_arithmetic(self):
|
||||
assert evaluate_cel("x + y", {"x": 3, "y": 4}) == 7
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_boolean_logic(self):
|
||||
assert evaluate_cel("a && b", {"a": True, "b": True}) is True
|
||||
assert evaluate_cel("a && b", {"a": True, "b": False}) is False
|
||||
assert evaluate_cel("a || b", {"a": False, "b": True}) is True
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_empty_expression_raises(self):
|
||||
with pytest.raises(CelEvaluationError, match="Empty expression"):
|
||||
evaluate_cel("", {})
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_whitespace_expression_raises(self):
|
||||
with pytest.raises(CelEvaluationError, match="Empty expression"):
|
||||
evaluate_cel(" ", {})
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_invalid_expression_raises(self):
|
||||
with pytest.raises(CelEvaluationError):
|
||||
evaluate_cel("invalid!!!", {})
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_missing_variable_raises(self):
|
||||
with pytest.raises(CelEvaluationError):
|
||||
evaluate_cel("undefined_var > 5", {})
|
||||
|
||||
|
||||
class TestCelToPython:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_bool(self):
|
||||
result = cel_to_python(celpy.celtypes.BoolType(True))
|
||||
assert result is True
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_int(self):
|
||||
result = cel_to_python(celpy.celtypes.IntType(42))
|
||||
assert result == 42
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_double(self):
|
||||
result = cel_to_python(celpy.celtypes.DoubleType(3.14))
|
||||
assert result == pytest.approx(3.14)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_string(self):
|
||||
result = cel_to_python(celpy.celtypes.StringType("hello"))
|
||||
assert result == "hello"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_list(self):
|
||||
cel_list = celpy.celtypes.ListType([
|
||||
celpy.celtypes.IntType(1),
|
||||
celpy.celtypes.IntType(2),
|
||||
])
|
||||
result = cel_to_python(cel_list)
|
||||
assert result == [1, 2]
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_map(self):
|
||||
cel_map = celpy.celtypes.MapType({
|
||||
celpy.celtypes.StringType("key"): celpy.celtypes.StringType("value"),
|
||||
})
|
||||
result = cel_to_python(cel_map)
|
||||
assert result == {"key": "value"}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_unknown_type_passthrough(self):
|
||||
result = cel_to_python("raw_value")
|
||||
assert result == "raw_value"
|
||||
433
tests/agents/test_workflow_agent_coverage.py
Normal file
433
tests/agents/test_workflow_agent_coverage.py
Normal file
@@ -0,0 +1,433 @@
|
||||
"""Tests for WorkflowAgent - covering _parse_embedded_workflow, _load_from_database,
|
||||
_save_workflow_run, _determine_run_status, _serialize_state, and gen flow."""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from application.agents.workflows.schemas import (
|
||||
ExecutionStatus,
|
||||
WorkflowGraph,
|
||||
)
|
||||
|
||||
|
||||
def _make_agent(**overrides):
|
||||
"""Create a WorkflowAgent with mocked base class dependencies."""
|
||||
defaults = {
|
||||
"endpoint": "https://api.example.com",
|
||||
"llm_name": "openai",
|
||||
"model_id": "gpt-4",
|
||||
"api_key": "test_key",
|
||||
"user_api_key": None,
|
||||
"prompt": "You are helpful.",
|
||||
"chat_history": [],
|
||||
"decoded_token": {"sub": "user1"},
|
||||
"attachments": [],
|
||||
"json_schema": None,
|
||||
}
|
||||
defaults.update(overrides)
|
||||
|
||||
with patch("application.agents.workflow_agent.log_activity", lambda **kw: lambda f: f):
|
||||
from application.agents.workflow_agent import WorkflowAgent
|
||||
agent = WorkflowAgent(**defaults)
|
||||
return agent
|
||||
|
||||
|
||||
class TestWorkflowAgentInit:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_sets_attributes(self):
|
||||
agent = _make_agent(workflow_id="wf1", workflow_owner="owner1")
|
||||
assert agent.workflow_id == "wf1"
|
||||
assert agent.workflow_owner == "owner1"
|
||||
assert agent._engine is None
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_embedded_workflow(self):
|
||||
wf_data = {"nodes": [], "edges": [], "name": "Test"}
|
||||
agent = _make_agent(workflow=wf_data)
|
||||
assert agent._workflow_data == wf_data
|
||||
|
||||
|
||||
class TestParseEmbeddedWorkflow:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_parses_valid_workflow(self):
|
||||
wf_data = {
|
||||
"name": "Test Workflow",
|
||||
"description": "A test",
|
||||
"nodes": [
|
||||
{"id": "n1", "type": "start", "title": "Start", "data": {}, "position": {"x": 0, "y": 0}},
|
||||
{"id": "n2", "type": "end", "title": "End", "data": {}, "position": {"x": 100, "y": 0}},
|
||||
],
|
||||
"edges": [
|
||||
{"id": "e1", "source": "n1", "target": "n2", "sourceHandle": "out", "targetHandle": "in"},
|
||||
],
|
||||
}
|
||||
agent = _make_agent(workflow=wf_data, workflow_id="wf1")
|
||||
graph = agent._parse_embedded_workflow()
|
||||
assert graph is not None
|
||||
assert len(graph.nodes) == 2
|
||||
assert len(graph.edges) == 1
|
||||
assert graph.workflow.name == "Test Workflow"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_edge_source_id_alias(self):
|
||||
wf_data = {
|
||||
"nodes": [{"id": "n1", "type": "start", "data": {}}],
|
||||
"edges": [{"id": "e1", "source_id": "n1", "target_id": "n2", "source_handle": "out", "target_handle": "in"}],
|
||||
}
|
||||
agent = _make_agent(workflow=wf_data)
|
||||
graph = agent._parse_embedded_workflow()
|
||||
assert graph is not None
|
||||
assert graph.edges[0].source_id == "n1"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_invalid_data_returns_none(self):
|
||||
agent = _make_agent(workflow={"nodes": [{"bad": "data"}], "edges": []})
|
||||
graph = agent._parse_embedded_workflow()
|
||||
assert graph is None
|
||||
|
||||
|
||||
class TestLoadWorkflowGraph:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_uses_embedded_when_available(self):
|
||||
agent = _make_agent(workflow={"nodes": [], "edges": [], "name": "E"})
|
||||
agent._parse_embedded_workflow = MagicMock(return_value="parsed_graph")
|
||||
result = agent._load_workflow_graph()
|
||||
assert result == "parsed_graph"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_uses_database_when_workflow_id(self):
|
||||
agent = _make_agent(workflow_id="wf1")
|
||||
agent._load_from_database = MagicMock(return_value="db_graph")
|
||||
result = agent._load_workflow_graph()
|
||||
assert result == "db_graph"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_returns_none_when_nothing(self):
|
||||
agent = _make_agent()
|
||||
result = agent._load_workflow_graph()
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestLoadFromDatabase:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_invalid_workflow_id_returns_none(self):
|
||||
agent = _make_agent(workflow_id="invalid!")
|
||||
result = agent._load_from_database()
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_no_owner_returns_none(self):
|
||||
agent = _make_agent(workflow_id="507f1f77bcf86cd799439011", decoded_token={})
|
||||
agent.workflow_owner = None
|
||||
result = agent._load_from_database()
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_uses_decoded_token_sub(self):
|
||||
agent = _make_agent(
|
||||
workflow_id="507f1f77bcf86cd799439011",
|
||||
decoded_token={"sub": "user1"},
|
||||
)
|
||||
agent.workflow_owner = None
|
||||
|
||||
mock_collection = MagicMock()
|
||||
mock_collection.find_one.return_value = None
|
||||
mock_db = MagicMock()
|
||||
mock_db.__getitem__ = MagicMock(return_value=mock_collection)
|
||||
|
||||
with patch("application.agents.workflow_agent.MongoDB") as MockMongo, \
|
||||
patch("application.agents.workflow_agent.settings") as mock_settings:
|
||||
mock_settings.MONGO_DB_NAME = "test_db"
|
||||
MockMongo.get_client.return_value = {"test_db": mock_db}
|
||||
result = agent._load_from_database()
|
||||
assert result is None # workflow_doc not found
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_successful_load(self):
|
||||
agent = _make_agent(
|
||||
workflow_id="507f1f77bcf86cd799439011",
|
||||
workflow_owner="owner1",
|
||||
)
|
||||
|
||||
mock_wf_coll = MagicMock()
|
||||
mock_wf_coll.find_one.return_value = {
|
||||
"_id": "507f1f77bcf86cd799439011",
|
||||
"name": "Test WF",
|
||||
"user": "owner1",
|
||||
"current_graph_version": 1,
|
||||
}
|
||||
|
||||
mock_nodes_coll = MagicMock()
|
||||
mock_nodes_coll.find.return_value = [
|
||||
{"id": "n1", "workflow_id": "507f1f77bcf86cd799439011", "type": "start",
|
||||
"title": "Start", "position": {"x": 0, "y": 0}, "config": {}},
|
||||
]
|
||||
|
||||
mock_edges_coll = MagicMock()
|
||||
mock_edges_coll.find.return_value = []
|
||||
|
||||
def getitem(name):
|
||||
return {"workflows": mock_wf_coll, "workflow_nodes": mock_nodes_coll, "workflow_edges": mock_edges_coll}[name]
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.__getitem__ = MagicMock(side_effect=getitem)
|
||||
|
||||
with patch("application.agents.workflow_agent.MongoDB") as MockMongo, \
|
||||
patch("application.agents.workflow_agent.settings") as mock_settings:
|
||||
mock_settings.MONGO_DB_NAME = "test_db"
|
||||
MockMongo.get_client.return_value = {"test_db": mock_db}
|
||||
result = agent._load_from_database()
|
||||
|
||||
assert result is not None
|
||||
assert len(result.nodes) == 1
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_invalid_graph_version(self):
|
||||
agent = _make_agent(
|
||||
workflow_id="507f1f77bcf86cd799439011",
|
||||
workflow_owner="owner1",
|
||||
)
|
||||
|
||||
mock_wf_coll = MagicMock()
|
||||
mock_wf_coll.find_one.return_value = {
|
||||
"_id": "507f1f77bcf86cd799439011",
|
||||
"name": "WF",
|
||||
"user": "owner1",
|
||||
"current_graph_version": "bad",
|
||||
}
|
||||
|
||||
mock_nodes_coll = MagicMock()
|
||||
mock_nodes_coll.find.return_value = []
|
||||
mock_edges_coll = MagicMock()
|
||||
mock_edges_coll.find.return_value = []
|
||||
|
||||
def getitem(name):
|
||||
return {"workflows": mock_wf_coll, "workflow_nodes": mock_nodes_coll, "workflow_edges": mock_edges_coll}[name]
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.__getitem__ = MagicMock(side_effect=getitem)
|
||||
|
||||
with patch("application.agents.workflow_agent.MongoDB") as MockMongo, \
|
||||
patch("application.agents.workflow_agent.settings") as mock_settings:
|
||||
mock_settings.MONGO_DB_NAME = "test_db"
|
||||
MockMongo.get_client.return_value = {"test_db": mock_db}
|
||||
result = agent._load_from_database()
|
||||
assert result is not None # Defaults to version 1
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_fallback_nodes_without_version(self):
|
||||
"""When graph_version=1 finds no nodes, falls back to nodes without version field."""
|
||||
agent = _make_agent(
|
||||
workflow_id="507f1f77bcf86cd799439011",
|
||||
workflow_owner="owner1",
|
||||
)
|
||||
|
||||
mock_wf_coll = MagicMock()
|
||||
mock_wf_coll.find_one.return_value = {
|
||||
"_id": "507f1f77bcf86cd799439011",
|
||||
"name": "WF",
|
||||
"user": "owner1",
|
||||
"current_graph_version": 1,
|
||||
}
|
||||
|
||||
call_count = [0]
|
||||
def nodes_find(query):
|
||||
call_count[0] += 1
|
||||
if call_count[0] == 1:
|
||||
return [] # No versioned nodes
|
||||
return [{"id": "n1", "workflow_id": "wf", "type": "start",
|
||||
"title": "S", "position": {"x": 0, "y": 0}, "config": {}}]
|
||||
|
||||
mock_nodes_coll = MagicMock()
|
||||
mock_nodes_coll.find.side_effect = nodes_find
|
||||
|
||||
edge_call = [0]
|
||||
def edges_find(query):
|
||||
edge_call[0] += 1
|
||||
if edge_call[0] == 1:
|
||||
return []
|
||||
return []
|
||||
|
||||
mock_edges_coll = MagicMock()
|
||||
mock_edges_coll.find.side_effect = edges_find
|
||||
|
||||
def getitem(name):
|
||||
return {"workflows": mock_wf_coll, "workflow_nodes": mock_nodes_coll, "workflow_edges": mock_edges_coll}[name]
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.__getitem__ = MagicMock(side_effect=getitem)
|
||||
|
||||
with patch("application.agents.workflow_agent.MongoDB") as MockMongo, \
|
||||
patch("application.agents.workflow_agent.settings") as mock_settings:
|
||||
mock_settings.MONGO_DB_NAME = "test_db"
|
||||
MockMongo.get_client.return_value = {"test_db": mock_db}
|
||||
result = agent._load_from_database()
|
||||
assert result is not None
|
||||
assert len(result.nodes) == 1
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_exception_returns_none(self):
|
||||
agent = _make_agent(
|
||||
workflow_id="507f1f77bcf86cd799439011",
|
||||
workflow_owner="owner1",
|
||||
)
|
||||
with patch("application.agents.workflow_agent.MongoDB") as MockMongo:
|
||||
MockMongo.get_client.side_effect = Exception("db error")
|
||||
result = agent._load_from_database()
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestGenInner:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_no_graph_yields_error(self):
|
||||
agent = _make_agent()
|
||||
agent._load_workflow_graph = MagicMock(return_value=None)
|
||||
events = list(agent._gen_inner("query", None))
|
||||
assert any(e.get("type") == "error" for e in events)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_successful_execution(self):
|
||||
agent = _make_agent(workflow_id="wf1")
|
||||
mock_graph = MagicMock(spec=WorkflowGraph)
|
||||
agent._load_workflow_graph = MagicMock(return_value=mock_graph)
|
||||
agent._save_workflow_run = MagicMock()
|
||||
|
||||
mock_engine = MagicMock()
|
||||
mock_engine.execute.return_value = iter([{"answer": "result"}])
|
||||
|
||||
with patch("application.agents.workflow_agent.WorkflowEngine", return_value=mock_engine):
|
||||
events = list(agent._gen_inner("query", None))
|
||||
assert len(events) == 1
|
||||
agent._save_workflow_run.assert_called_once_with("query")
|
||||
|
||||
|
||||
class TestSaveWorkflowRun:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_no_engine_returns_early(self):
|
||||
agent = _make_agent()
|
||||
agent._engine = None
|
||||
agent._save_workflow_run("query") # Should not raise
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_saves_to_mongo(self):
|
||||
agent = _make_agent(workflow_id="wf1")
|
||||
mock_engine = MagicMock()
|
||||
mock_engine.state = {"query": "test"}
|
||||
mock_engine.execution_log = []
|
||||
mock_engine.get_execution_summary.return_value = []
|
||||
agent._engine = mock_engine
|
||||
|
||||
mock_collection = MagicMock()
|
||||
mock_db = MagicMock()
|
||||
mock_db.__getitem__ = MagicMock(return_value=mock_collection)
|
||||
|
||||
with patch("application.agents.workflow_agent.MongoDB") as MockMongo, \
|
||||
patch("application.agents.workflow_agent.settings") as mock_settings:
|
||||
mock_settings.MONGO_DB_NAME = "test_db"
|
||||
MockMongo.get_client.return_value = {"test_db": mock_db}
|
||||
agent._save_workflow_run("query")
|
||||
|
||||
mock_collection.insert_one.assert_called_once()
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_exception_does_not_propagate(self):
|
||||
agent = _make_agent(workflow_id="wf1")
|
||||
mock_engine = MagicMock()
|
||||
mock_engine.state = {}
|
||||
mock_engine.execution_log = []
|
||||
mock_engine.get_execution_summary.return_value = []
|
||||
agent._engine = mock_engine
|
||||
|
||||
with patch("application.agents.workflow_agent.MongoDB") as MockMongo:
|
||||
MockMongo.get_client.side_effect = Exception("db fail")
|
||||
agent._save_workflow_run("query") # Should not raise
|
||||
|
||||
|
||||
class TestDetermineRunStatus:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_no_engine_returns_completed(self):
|
||||
agent = _make_agent()
|
||||
agent._engine = None
|
||||
assert agent._determine_run_status() == ExecutionStatus.COMPLETED
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_empty_log_returns_completed(self):
|
||||
agent = _make_agent()
|
||||
agent._engine = MagicMock()
|
||||
agent._engine.execution_log = []
|
||||
assert agent._determine_run_status() == ExecutionStatus.COMPLETED
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_failed_log_returns_failed(self):
|
||||
agent = _make_agent()
|
||||
agent._engine = MagicMock()
|
||||
agent._engine.execution_log = [
|
||||
{"status": "completed"},
|
||||
{"status": "failed"},
|
||||
]
|
||||
assert agent._determine_run_status() == ExecutionStatus.FAILED
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_all_completed_returns_completed(self):
|
||||
agent = _make_agent()
|
||||
agent._engine = MagicMock()
|
||||
agent._engine.execution_log = [
|
||||
{"status": "completed"},
|
||||
{"status": "completed"},
|
||||
]
|
||||
assert agent._determine_run_status() == ExecutionStatus.COMPLETED
|
||||
|
||||
|
||||
class TestSerializeState:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_serializes_primitives(self):
|
||||
agent = _make_agent()
|
||||
state = {"str": "hello", "int": 42, "float": 3.14, "bool": True, "none": None}
|
||||
result = agent._serialize_state(state)
|
||||
assert result == state
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_serializes_nested_dict(self):
|
||||
agent = _make_agent()
|
||||
state = {"nested": {"key": "value"}}
|
||||
result = agent._serialize_state(state)
|
||||
assert result["nested"]["key"] == "value"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_serializes_list(self):
|
||||
agent = _make_agent()
|
||||
state = {"items": [1, 2, "three"]}
|
||||
result = agent._serialize_state(state)
|
||||
assert result["items"] == [1, 2, "three"]
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_serializes_tuple(self):
|
||||
agent = _make_agent()
|
||||
state = {"tup": (1, 2)}
|
||||
result = agent._serialize_state(state)
|
||||
assert result["tup"] == [1, 2]
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_serializes_datetime(self):
|
||||
agent = _make_agent()
|
||||
dt = datetime(2025, 1, 1, 12, 0, 0, tzinfo=timezone.utc)
|
||||
state = {"time": dt}
|
||||
result = agent._serialize_state(state)
|
||||
assert "2025-01-01" in result["time"]
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_serializes_unknown_to_str(self):
|
||||
agent = _make_agent()
|
||||
state = {"obj": object()}
|
||||
result = agent._serialize_state(state)
|
||||
assert isinstance(result["obj"], str)
|
||||
573
tests/agents/test_workflow_engine_coverage.py
Normal file
573
tests/agents/test_workflow_engine_coverage.py
Normal file
@@ -0,0 +1,573 @@
|
||||
"""Tests covering gaps in WorkflowEngine: execute loop, state/condition/end nodes,
|
||||
template context, source data, structured output parsing, get_execution_summary."""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from application.agents.workflows.schemas import (
|
||||
ExecutionStatus,
|
||||
NodeType,
|
||||
WorkflowEdge,
|
||||
WorkflowGraph,
|
||||
WorkflowNode,
|
||||
Workflow,
|
||||
)
|
||||
from application.agents.workflows.workflow_engine import WorkflowEngine
|
||||
|
||||
|
||||
def _make_graph(nodes, edges):
|
||||
wf = Workflow(name="Test", description="test workflow")
|
||||
return WorkflowGraph(workflow=wf, nodes=nodes, edges=edges)
|
||||
|
||||
|
||||
def _make_node(id, type, title="Node", config=None, position=None):
|
||||
return WorkflowNode(
|
||||
id=id,
|
||||
workflow_id="wf1",
|
||||
type=type,
|
||||
title=title,
|
||||
position=position or {"x": 0, "y": 0},
|
||||
config=config or {},
|
||||
)
|
||||
|
||||
|
||||
def _make_edge(id, source, target, source_handle=None, target_handle=None):
|
||||
return WorkflowEdge(
|
||||
id=id,
|
||||
workflow_id="wf1",
|
||||
source=source,
|
||||
target=target,
|
||||
sourceHandle=source_handle,
|
||||
targetHandle=target_handle,
|
||||
)
|
||||
|
||||
|
||||
def _make_agent():
|
||||
agent = MagicMock()
|
||||
agent.chat_history = []
|
||||
agent.endpoint = "https://api.example.com"
|
||||
agent.llm_name = "openai"
|
||||
agent.model_id = "gpt-4"
|
||||
agent.api_key = "key"
|
||||
agent.decoded_token = {"sub": "user1"}
|
||||
agent.retrieved_docs = None
|
||||
return agent
|
||||
|
||||
|
||||
class TestExecuteLoop:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_no_start_node_yields_error(self):
|
||||
graph = _make_graph([], [])
|
||||
engine = WorkflowEngine(graph, _make_agent())
|
||||
events = list(engine.execute({}, "query"))
|
||||
assert any(e.get("type") == "error" and "start node" in e.get("error", "") for e in events)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_start_to_end(self):
|
||||
nodes = [
|
||||
_make_node("n1", NodeType.START, "Start"),
|
||||
_make_node("n2", NodeType.END, "End", config={"config": {}}),
|
||||
]
|
||||
edges = [_make_edge("e1", "n1", "n2")]
|
||||
graph = _make_graph(nodes, edges)
|
||||
engine = WorkflowEngine(graph, _make_agent())
|
||||
events = list(engine.execute({}, "hello"))
|
||||
step_events = [e for e in events if e.get("type") == "workflow_step"]
|
||||
assert len(step_events) >= 2 # At least start + end
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_node_not_found_yields_error(self):
|
||||
nodes = [_make_node("n1", NodeType.START)]
|
||||
edges = [_make_edge("e1", "n1", "nonexistent")]
|
||||
graph = _make_graph(nodes, edges)
|
||||
engine = WorkflowEngine(graph, _make_agent())
|
||||
events = list(engine.execute({}, "q"))
|
||||
assert any("not found" in e.get("error", "") for e in events)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_node_execution_error_yields_error(self):
|
||||
nodes = [
|
||||
_make_node("n1", NodeType.START),
|
||||
_make_node("n2", NodeType.STATE, "State", config={"config": {"operations": [{"expression": "bad!!!", "target_variable": "x"}]}}),
|
||||
]
|
||||
edges = [_make_edge("e1", "n1", "n2")]
|
||||
graph = _make_graph(nodes, edges)
|
||||
engine = WorkflowEngine(graph, _make_agent())
|
||||
events = list(engine.execute({}, "q"))
|
||||
failed_events = [e for e in events if e.get("status") == "failed"]
|
||||
assert len(failed_events) >= 1
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_max_steps_limit(self):
|
||||
# Create a cycle: start -> state -> state (loop)
|
||||
nodes = [
|
||||
_make_node("n1", NodeType.START),
|
||||
_make_node("n2", NodeType.NOTE, "Note"),
|
||||
]
|
||||
edges = [_make_edge("e1", "n1", "n2"), _make_edge("e2", "n2", "n2")]
|
||||
graph = _make_graph(nodes, edges)
|
||||
engine = WorkflowEngine(graph, _make_agent())
|
||||
engine.MAX_EXECUTION_STEPS = 5
|
||||
events = list(engine.execute({}, "q"))
|
||||
# Should not run forever
|
||||
assert len(events) > 0
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_branch_ends_without_end_node(self):
|
||||
nodes = [
|
||||
_make_node("n1", NodeType.START),
|
||||
_make_node("n2", NodeType.NOTE, "Note"),
|
||||
]
|
||||
edges = [_make_edge("e1", "n1", "n2")] # n2 has no outgoing edges
|
||||
graph = _make_graph(nodes, edges)
|
||||
engine = WorkflowEngine(graph, _make_agent())
|
||||
events = list(engine.execute({}, "q"))
|
||||
assert len(events) > 0
|
||||
|
||||
|
||||
class TestInitializeState:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_sets_query_and_history(self):
|
||||
graph = _make_graph([], [])
|
||||
agent = _make_agent()
|
||||
agent.chat_history = [{"prompt": "hi", "response": "hey"}]
|
||||
engine = WorkflowEngine(graph, agent)
|
||||
engine._initialize_state({"custom": "value"}, "test query")
|
||||
assert engine.state["query"] == "test query"
|
||||
assert "custom" in engine.state
|
||||
assert engine.state["chat_history"] is not None
|
||||
|
||||
|
||||
class TestGetNextNodeId:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_no_edges_returns_none(self):
|
||||
nodes = [_make_node("n1", NodeType.START)]
|
||||
graph = _make_graph(nodes, [])
|
||||
engine = WorkflowEngine(graph, _make_agent())
|
||||
assert engine._get_next_node_id("n1") is None
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_returns_first_edge_target(self):
|
||||
nodes = [_make_node("n1", NodeType.START), _make_node("n2", NodeType.END)]
|
||||
edges = [_make_edge("e1", "n1", "n2")]
|
||||
graph = _make_graph(nodes, edges)
|
||||
engine = WorkflowEngine(graph, _make_agent())
|
||||
assert engine._get_next_node_id("n1") == "n2"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_condition_uses_matched_handle(self):
|
||||
nodes = [
|
||||
_make_node("n1", NodeType.CONDITION),
|
||||
_make_node("n2", NodeType.END, "Yes End"),
|
||||
_make_node("n3", NodeType.END, "No End"),
|
||||
]
|
||||
edges = [
|
||||
_make_edge("e1", "n1", "n2", source_handle="yes"),
|
||||
_make_edge("e2", "n1", "n3", source_handle="no"),
|
||||
]
|
||||
graph = _make_graph(nodes, edges)
|
||||
engine = WorkflowEngine(graph, _make_agent())
|
||||
engine._condition_result = "no"
|
||||
assert engine._get_next_node_id("n1") == "n3"
|
||||
assert engine._condition_result is None # Cleared after use
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_condition_no_matching_handle_returns_none(self):
|
||||
nodes = [_make_node("n1", NodeType.CONDITION)]
|
||||
edges = [_make_edge("e1", "n1", "n2", source_handle="yes")]
|
||||
graph = _make_graph(nodes, edges)
|
||||
engine = WorkflowEngine(graph, _make_agent())
|
||||
engine._condition_result = "nonexistent"
|
||||
assert engine._get_next_node_id("n1") is None
|
||||
|
||||
|
||||
class TestExecuteStateNode:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_evaluates_operations(self):
|
||||
node = _make_node("n1", NodeType.STATE, config={
|
||||
"config": {
|
||||
"operations": [
|
||||
{"expression": "x + 1", "target_variable": "result"},
|
||||
]
|
||||
}
|
||||
})
|
||||
graph = _make_graph([node], [])
|
||||
engine = WorkflowEngine(graph, _make_agent())
|
||||
engine.state = {"x": 5}
|
||||
list(engine._execute_state_node(node))
|
||||
assert engine.state["result"] == 6
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_skips_empty_expression(self):
|
||||
node = _make_node("n1", NodeType.STATE, config={
|
||||
"config": {
|
||||
"operations": [
|
||||
{"expression": "", "target_variable": "result"},
|
||||
]
|
||||
}
|
||||
})
|
||||
graph = _make_graph([node], [])
|
||||
engine = WorkflowEngine(graph, _make_agent())
|
||||
engine.state = {}
|
||||
list(engine._execute_state_node(node))
|
||||
assert "result" not in engine.state
|
||||
|
||||
|
||||
class TestExecuteConditionNode:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_matches_first_true_case(self):
|
||||
node = _make_node("n1", NodeType.CONDITION, config={
|
||||
"config": {
|
||||
"cases": [
|
||||
{"expression": "x > 10", "source_handle": "high"},
|
||||
{"expression": "x > 5", "source_handle": "medium"},
|
||||
]
|
||||
}
|
||||
})
|
||||
graph = _make_graph([node], [])
|
||||
engine = WorkflowEngine(graph, _make_agent())
|
||||
engine.state = {"x": 7}
|
||||
list(engine._execute_condition_node(node))
|
||||
assert engine._condition_result == "medium"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_falls_through_to_else(self):
|
||||
node = _make_node("n1", NodeType.CONDITION, config={
|
||||
"config": {
|
||||
"cases": [
|
||||
{"expression": "x > 100", "source_handle": "high"},
|
||||
]
|
||||
}
|
||||
})
|
||||
graph = _make_graph([node], [])
|
||||
engine = WorkflowEngine(graph, _make_agent())
|
||||
engine.state = {"x": 1}
|
||||
list(engine._execute_condition_node(node))
|
||||
assert engine._condition_result == "else"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_skips_empty_expression(self):
|
||||
node = _make_node("n1", NodeType.CONDITION, config={
|
||||
"config": {
|
||||
"cases": [
|
||||
{"expression": " ", "source_handle": "a"},
|
||||
{"expression": "true", "source_handle": "b"},
|
||||
]
|
||||
}
|
||||
})
|
||||
graph = _make_graph([node], [])
|
||||
engine = WorkflowEngine(graph, _make_agent())
|
||||
engine.state = {}
|
||||
list(engine._execute_condition_node(node))
|
||||
assert engine._condition_result == "b"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_cel_error_continues(self):
|
||||
node = _make_node("n1", NodeType.CONDITION, config={
|
||||
"config": {
|
||||
"cases": [
|
||||
{"expression": "bad!!!", "source_handle": "a"},
|
||||
{"expression": "true", "source_handle": "b"},
|
||||
]
|
||||
}
|
||||
})
|
||||
graph = _make_graph([node], [])
|
||||
engine = WorkflowEngine(graph, _make_agent())
|
||||
engine.state = {}
|
||||
list(engine._execute_condition_node(node))
|
||||
assert engine._condition_result == "b"
|
||||
|
||||
|
||||
class TestExecuteEndNode:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_with_output_template(self):
|
||||
node = _make_node("n1", NodeType.END, config={
|
||||
"config": {"output_template": "Result: {{ query }}"}
|
||||
})
|
||||
graph = _make_graph([node], [])
|
||||
engine = WorkflowEngine(graph, _make_agent())
|
||||
engine.state = {"query": "hello"}
|
||||
engine._format_template = MagicMock(return_value="Result: hello")
|
||||
events = list(engine._execute_end_node(node))
|
||||
assert len(events) == 1
|
||||
assert events[0]["answer"] == "Result: hello"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_without_output_template(self):
|
||||
node = _make_node("n1", NodeType.END, config={"config": {}})
|
||||
graph = _make_graph([node], [])
|
||||
engine = WorkflowEngine(graph, _make_agent())
|
||||
events = list(engine._execute_end_node(node))
|
||||
assert len(events) == 0
|
||||
|
||||
|
||||
class TestParseStructuredOutput:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_valid_json(self):
|
||||
graph = _make_graph([], [])
|
||||
engine = WorkflowEngine(graph, _make_agent())
|
||||
success, data = engine._parse_structured_output('{"key": "value"}')
|
||||
assert success is True
|
||||
assert data == {"key": "value"}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_invalid_json(self):
|
||||
graph = _make_graph([], [])
|
||||
engine = WorkflowEngine(graph, _make_agent())
|
||||
success, data = engine._parse_structured_output("not json")
|
||||
assert success is False
|
||||
assert data is None
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_empty_string(self):
|
||||
graph = _make_graph([], [])
|
||||
engine = WorkflowEngine(graph, _make_agent())
|
||||
success, data = engine._parse_structured_output("")
|
||||
assert success is False
|
||||
assert data is None
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_whitespace_only(self):
|
||||
graph = _make_graph([], [])
|
||||
engine = WorkflowEngine(graph, _make_agent())
|
||||
success, data = engine._parse_structured_output(" ")
|
||||
assert success is False
|
||||
assert data is None
|
||||
|
||||
|
||||
class TestNormalizeNodeJsonSchema:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_none_returns_none(self):
|
||||
graph = _make_graph([], [])
|
||||
engine = WorkflowEngine(graph, _make_agent())
|
||||
assert engine._normalize_node_json_schema(None, "Node") is None
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_valid_schema(self):
|
||||
graph = _make_graph([], [])
|
||||
engine = WorkflowEngine(graph, _make_agent())
|
||||
schema = {"type": "object", "properties": {"name": {"type": "string"}}}
|
||||
result = engine._normalize_node_json_schema(schema, "Node")
|
||||
assert result is not None
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_invalid_schema_raises(self):
|
||||
graph = _make_graph([], [])
|
||||
engine = WorkflowEngine(graph, _make_agent())
|
||||
with patch("application.agents.workflows.workflow_engine.normalize_json_schema_payload") as mock_norm:
|
||||
from application.core.json_schema_utils import JsonSchemaValidationError
|
||||
mock_norm.side_effect = JsonSchemaValidationError("bad schema")
|
||||
with pytest.raises(ValueError, match="Invalid JSON schema"):
|
||||
engine._normalize_node_json_schema({"bad": True}, "TestNode")
|
||||
|
||||
|
||||
class TestValidateStructuredOutput:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_valid_output_passes(self):
|
||||
graph = _make_graph([], [])
|
||||
engine = WorkflowEngine(graph, _make_agent())
|
||||
schema = {"type": "object", "properties": {"name": {"type": "string"}}}
|
||||
engine._validate_structured_output(schema, {"name": "Alice"}) # Should not raise
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_invalid_output_raises(self):
|
||||
graph = _make_graph([], [])
|
||||
engine = WorkflowEngine(graph, _make_agent())
|
||||
schema = {"type": "object", "properties": {"name": {"type": "string"}}, "required": ["name"]}
|
||||
with pytest.raises(ValueError, match="did not match schema"):
|
||||
engine._validate_structured_output(schema, {})
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_no_jsonschema_module(self):
|
||||
graph = _make_graph([], [])
|
||||
engine = WorkflowEngine(graph, _make_agent())
|
||||
with patch("application.agents.workflows.workflow_engine.jsonschema", None):
|
||||
engine._validate_structured_output({"type": "object"}, {}) # Should not raise
|
||||
|
||||
|
||||
class TestFormatTemplate:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_renders_template(self):
|
||||
graph = _make_graph([], [])
|
||||
engine = WorkflowEngine(graph, _make_agent())
|
||||
engine.state = {"query": "hello"}
|
||||
engine._build_template_context = MagicMock(return_value={"query": "hello"})
|
||||
engine._template_engine = MagicMock()
|
||||
engine._template_engine.render.return_value = "hello world"
|
||||
result = engine._format_template("{{ query }} world")
|
||||
assert result == "hello world"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_render_error_returns_raw(self):
|
||||
from application.templates.template_engine import TemplateRenderError
|
||||
graph = _make_graph([], [])
|
||||
engine = WorkflowEngine(graph, _make_agent())
|
||||
engine._build_template_context = MagicMock(return_value={})
|
||||
engine._template_engine = MagicMock()
|
||||
engine._template_engine.render.side_effect = TemplateRenderError("fail")
|
||||
result = engine._format_template("{{ bad }}")
|
||||
assert result == "{{ bad }}"
|
||||
|
||||
|
||||
class TestBuildTemplateContext:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_includes_state_variables(self):
|
||||
graph = _make_graph([], [])
|
||||
agent = _make_agent()
|
||||
agent.retrieved_docs = None
|
||||
engine = WorkflowEngine(graph, agent)
|
||||
engine.state = {"query": "hello", "custom_var": "value"}
|
||||
context = engine._build_template_context()
|
||||
assert context["agent"]["query"] == "hello"
|
||||
assert "custom_var" in context
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_reserved_namespace_gets_prefixed(self):
|
||||
graph = _make_graph([], [])
|
||||
agent = _make_agent()
|
||||
agent.retrieved_docs = None
|
||||
engine = WorkflowEngine(graph, agent)
|
||||
engine.state = {"source": "my_source_val"}
|
||||
context = engine._build_template_context()
|
||||
assert context.get("agent_source") == "my_source_val"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_passthrough_data(self):
|
||||
graph = _make_graph([], [])
|
||||
agent = _make_agent()
|
||||
agent.retrieved_docs = None
|
||||
engine = WorkflowEngine(graph, agent)
|
||||
engine.state = {"passthrough": {"key": "val"}}
|
||||
context = engine._build_template_context()
|
||||
assert "passthrough" in context or "agent_passthrough" in context
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_tools_data(self):
|
||||
graph = _make_graph([], [])
|
||||
agent = _make_agent()
|
||||
agent.retrieved_docs = None
|
||||
engine = WorkflowEngine(graph, agent)
|
||||
engine.state = {"tools": {"tool1": "result"}}
|
||||
context = engine._build_template_context()
|
||||
assert "agent" in context
|
||||
|
||||
|
||||
class TestGetSourceTemplateData:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_no_docs_returns_none(self):
|
||||
graph = _make_graph([], [])
|
||||
agent = _make_agent()
|
||||
agent.retrieved_docs = None
|
||||
engine = WorkflowEngine(graph, agent)
|
||||
docs, together = engine._get_source_template_data()
|
||||
assert docs is None
|
||||
assert together is None
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_empty_docs_returns_none(self):
|
||||
graph = _make_graph([], [])
|
||||
agent = _make_agent()
|
||||
agent.retrieved_docs = []
|
||||
engine = WorkflowEngine(graph, agent)
|
||||
docs, together = engine._get_source_template_data()
|
||||
assert docs is None
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_docs_with_filename(self):
|
||||
graph = _make_graph([], [])
|
||||
agent = _make_agent()
|
||||
agent.retrieved_docs = [{"text": "content", "filename": "doc.txt"}]
|
||||
engine = WorkflowEngine(graph, agent)
|
||||
docs, together = engine._get_source_template_data()
|
||||
assert docs is not None
|
||||
assert "doc.txt" in together
|
||||
assert "content" in together
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_docs_without_filename(self):
|
||||
graph = _make_graph([], [])
|
||||
agent = _make_agent()
|
||||
agent.retrieved_docs = [{"text": "content only"}]
|
||||
engine = WorkflowEngine(graph, agent)
|
||||
docs, together = engine._get_source_template_data()
|
||||
assert together == "content only"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_skips_non_dict_docs(self):
|
||||
graph = _make_graph([], [])
|
||||
agent = _make_agent()
|
||||
agent.retrieved_docs = ["not a dict", {"text": "ok"}]
|
||||
engine = WorkflowEngine(graph, agent)
|
||||
docs, together = engine._get_source_template_data()
|
||||
assert together == "ok"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_skips_non_string_text(self):
|
||||
graph = _make_graph([], [])
|
||||
agent = _make_agent()
|
||||
agent.retrieved_docs = [{"text": 123}]
|
||||
engine = WorkflowEngine(graph, agent)
|
||||
docs, together = engine._get_source_template_data()
|
||||
assert together is None
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_doc_with_title_fallback(self):
|
||||
graph = _make_graph([], [])
|
||||
agent = _make_agent()
|
||||
agent.retrieved_docs = [{"text": "content", "title": "doc_title"}]
|
||||
engine = WorkflowEngine(graph, agent)
|
||||
docs, together = engine._get_source_template_data()
|
||||
assert "doc_title" in together
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_doc_with_source_fallback(self):
|
||||
graph = _make_graph([], [])
|
||||
agent = _make_agent()
|
||||
agent.retrieved_docs = [{"text": "content", "source": "src"}]
|
||||
engine = WorkflowEngine(graph, agent)
|
||||
docs, together = engine._get_source_template_data()
|
||||
assert "src" in together
|
||||
|
||||
|
||||
class TestGetExecutionSummary:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_returns_log_entries(self):
|
||||
graph = _make_graph([], [])
|
||||
engine = WorkflowEngine(graph, _make_agent())
|
||||
now = datetime.now(timezone.utc)
|
||||
engine.execution_log = [
|
||||
{
|
||||
"node_id": "n1",
|
||||
"node_type": "start",
|
||||
"status": "completed",
|
||||
"started_at": now,
|
||||
"completed_at": now,
|
||||
"error": None,
|
||||
"state_snapshot": {},
|
||||
}
|
||||
]
|
||||
summary = engine.get_execution_summary()
|
||||
assert len(summary) == 1
|
||||
assert summary[0].node_id == "n1"
|
||||
assert summary[0].status == ExecutionStatus.COMPLETED
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_empty_log(self):
|
||||
graph = _make_graph([], [])
|
||||
engine = WorkflowEngine(graph, _make_agent())
|
||||
assert engine.get_execution_summary() == []
|
||||
331
tests/api/answer/test_stream_processor.py
Normal file
331
tests/api/answer/test_stream_processor.py
Normal file
@@ -0,0 +1,331 @@
|
||||
"""Tests for application/api/answer/services/stream_processor.py — get_prompt and helpers."""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from application.api.answer.services.stream_processor import get_prompt
|
||||
|
||||
|
||||
class TestGetPrompt:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_default_preset(self):
|
||||
prompt = get_prompt("default")
|
||||
assert isinstance(prompt, str)
|
||||
assert len(prompt) > 0
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_creative_preset(self):
|
||||
prompt = get_prompt("creative")
|
||||
assert isinstance(prompt, str)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_strict_preset(self):
|
||||
prompt = get_prompt("strict")
|
||||
assert isinstance(prompt, str)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_reduce_preset(self):
|
||||
prompt = get_prompt("reduce")
|
||||
assert isinstance(prompt, str)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_agentic_default_preset(self):
|
||||
prompt = get_prompt("agentic_default")
|
||||
assert isinstance(prompt, str)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_agentic_creative_preset(self):
|
||||
prompt = get_prompt("agentic_creative")
|
||||
assert isinstance(prompt, str)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_agentic_strict_preset(self):
|
||||
prompt = get_prompt("agentic_strict")
|
||||
assert isinstance(prompt, str)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_mongo_prompt_by_id(self):
|
||||
mock_collection = MagicMock()
|
||||
mock_collection.find_one.return_value = {"_id": "abc", "content": "Custom prompt"}
|
||||
prompt = get_prompt("507f1f77bcf86cd799439011", prompts_collection=mock_collection)
|
||||
assert prompt == "Custom prompt"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_mongo_prompt_not_found_raises(self):
|
||||
mock_collection = MagicMock()
|
||||
mock_collection.find_one.return_value = None
|
||||
with pytest.raises(ValueError, match="Invalid prompt ID"):
|
||||
get_prompt("507f1f77bcf86cd799439011", prompts_collection=mock_collection)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_invalid_id_raises(self):
|
||||
mock_collection = MagicMock()
|
||||
mock_collection.find_one.side_effect = Exception("bad id")
|
||||
with pytest.raises(ValueError, match="Invalid prompt ID"):
|
||||
get_prompt("not-an-objectid", prompts_collection=mock_collection)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_mongo_fallback_when_no_collection(self):
|
||||
"""When no collection passed, it reads from MongoDB."""
|
||||
mock_collection = MagicMock()
|
||||
mock_collection.find_one.return_value = {"content": "From DB"}
|
||||
mock_db = MagicMock()
|
||||
mock_db.__getitem__ = MagicMock(return_value=mock_collection)
|
||||
|
||||
with patch("application.api.answer.services.stream_processor.MongoDB") as MockMongo, \
|
||||
patch("application.api.answer.services.stream_processor.settings") as mock_settings:
|
||||
mock_settings.MONGO_DB_NAME = "test_db"
|
||||
MockMongo.get_client.return_value = {"test_db": mock_db}
|
||||
prompt = get_prompt("507f1f77bcf86cd799439011")
|
||||
assert prompt == "From DB"
|
||||
|
||||
|
||||
class TestStreamProcessorInit:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_init_sets_attributes(self):
|
||||
mock_db = MagicMock()
|
||||
mock_client = {"docsgpt": mock_db}
|
||||
|
||||
with patch("application.api.answer.services.stream_processor.MongoDB") as MockMongo, \
|
||||
patch("application.api.answer.services.stream_processor.settings") as mock_settings:
|
||||
mock_settings.MONGO_DB_NAME = "docsgpt"
|
||||
MockMongo.get_client.return_value = mock_client
|
||||
|
||||
from application.api.answer.services.stream_processor import StreamProcessor
|
||||
sp = StreamProcessor(
|
||||
request_data={"conversation_id": "conv1", "agent_id": "a1"},
|
||||
decoded_token={"sub": "user1"},
|
||||
)
|
||||
assert sp.conversation_id == "conv1"
|
||||
assert sp.initial_user_id == "user1"
|
||||
assert sp.agent_id == "a1"
|
||||
assert sp.history == []
|
||||
assert sp.attachments == []
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_init_no_token(self):
|
||||
mock_db = MagicMock()
|
||||
mock_client = {"docsgpt": mock_db}
|
||||
|
||||
with patch("application.api.answer.services.stream_processor.MongoDB") as MockMongo, \
|
||||
patch("application.api.answer.services.stream_processor.settings") as mock_settings:
|
||||
mock_settings.MONGO_DB_NAME = "docsgpt"
|
||||
MockMongo.get_client.return_value = mock_client
|
||||
|
||||
from application.api.answer.services.stream_processor import StreamProcessor
|
||||
sp = StreamProcessor(request_data={}, decoded_token=None)
|
||||
assert sp.initial_user_id is None
|
||||
|
||||
|
||||
class TestGetAttachmentsContent:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_empty_ids_returns_empty(self):
|
||||
mock_db = MagicMock()
|
||||
with patch("application.api.answer.services.stream_processor.MongoDB") as MockMongo, \
|
||||
patch("application.api.answer.services.stream_processor.settings") as mock_settings:
|
||||
mock_settings.MONGO_DB_NAME = "docsgpt"
|
||||
MockMongo.get_client.return_value = {"docsgpt": mock_db}
|
||||
|
||||
from application.api.answer.services.stream_processor import StreamProcessor
|
||||
sp = StreamProcessor(request_data={}, decoded_token={"sub": "u"})
|
||||
result = sp._get_attachments_content([], "u")
|
||||
assert result == []
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_returns_matching_attachments(self):
|
||||
mock_db = MagicMock()
|
||||
mock_attachments = MagicMock()
|
||||
mock_attachments.find_one.return_value = {"_id": "att1", "content": "data"}
|
||||
mock_db.__getitem__ = MagicMock(return_value=mock_attachments)
|
||||
|
||||
with patch("application.api.answer.services.stream_processor.MongoDB") as MockMongo, \
|
||||
patch("application.api.answer.services.stream_processor.settings") as mock_settings:
|
||||
mock_settings.MONGO_DB_NAME = "docsgpt"
|
||||
MockMongo.get_client.return_value = {"docsgpt": mock_db}
|
||||
|
||||
from application.api.answer.services.stream_processor import StreamProcessor
|
||||
sp = StreamProcessor(request_data={}, decoded_token={"sub": "u"})
|
||||
result = sp._get_attachments_content(["507f1f77bcf86cd799439011"], "u")
|
||||
assert len(result) == 1
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_invalid_attachment_id_continues(self):
|
||||
mock_db = MagicMock()
|
||||
mock_attachments = MagicMock()
|
||||
mock_attachments.find_one.side_effect = Exception("bad id")
|
||||
mock_db.__getitem__ = MagicMock(return_value=mock_attachments)
|
||||
|
||||
with patch("application.api.answer.services.stream_processor.MongoDB") as MockMongo, \
|
||||
patch("application.api.answer.services.stream_processor.settings") as mock_settings:
|
||||
mock_settings.MONGO_DB_NAME = "docsgpt"
|
||||
MockMongo.get_client.return_value = {"docsgpt": mock_db}
|
||||
|
||||
from application.api.answer.services.stream_processor import StreamProcessor
|
||||
sp = StreamProcessor(request_data={}, decoded_token={"sub": "u"})
|
||||
result = sp._get_attachments_content(["bad"], "u")
|
||||
assert result == []
|
||||
|
||||
|
||||
class TestResolveAgentId:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_from_request_data(self):
|
||||
mock_db = MagicMock()
|
||||
with patch("application.api.answer.services.stream_processor.MongoDB") as MockMongo, \
|
||||
patch("application.api.answer.services.stream_processor.settings") as mock_settings:
|
||||
mock_settings.MONGO_DB_NAME = "docsgpt"
|
||||
MockMongo.get_client.return_value = {"docsgpt": mock_db}
|
||||
|
||||
from application.api.answer.services.stream_processor import StreamProcessor
|
||||
sp = StreamProcessor(
|
||||
request_data={"agent_id": "agent_123"},
|
||||
decoded_token={"sub": "u"},
|
||||
)
|
||||
assert sp._resolve_agent_id() == "agent_123"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_no_agent_no_conversation(self):
|
||||
mock_db = MagicMock()
|
||||
with patch("application.api.answer.services.stream_processor.MongoDB") as MockMongo, \
|
||||
patch("application.api.answer.services.stream_processor.settings") as mock_settings:
|
||||
mock_settings.MONGO_DB_NAME = "docsgpt"
|
||||
MockMongo.get_client.return_value = {"docsgpt": mock_db}
|
||||
|
||||
from application.api.answer.services.stream_processor import StreamProcessor
|
||||
sp = StreamProcessor(request_data={}, decoded_token={"sub": "u"})
|
||||
assert sp._resolve_agent_id() is None
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_from_conversation(self):
|
||||
mock_db = MagicMock()
|
||||
with patch("application.api.answer.services.stream_processor.MongoDB") as MockMongo, \
|
||||
patch("application.api.answer.services.stream_processor.settings") as mock_settings:
|
||||
mock_settings.MONGO_DB_NAME = "docsgpt"
|
||||
MockMongo.get_client.return_value = {"docsgpt": mock_db}
|
||||
|
||||
from application.api.answer.services.stream_processor import StreamProcessor
|
||||
sp = StreamProcessor(
|
||||
request_data={"conversation_id": "conv1"},
|
||||
decoded_token={"sub": "u"},
|
||||
)
|
||||
sp.conversation_service = MagicMock()
|
||||
sp.conversation_service.get_conversation.return_value = {"agent_id": "from_conv"}
|
||||
assert sp._resolve_agent_id() == "from_conv"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_conversation_not_found(self):
|
||||
mock_db = MagicMock()
|
||||
with patch("application.api.answer.services.stream_processor.MongoDB") as MockMongo, \
|
||||
patch("application.api.answer.services.stream_processor.settings") as mock_settings:
|
||||
mock_settings.MONGO_DB_NAME = "docsgpt"
|
||||
MockMongo.get_client.return_value = {"docsgpt": mock_db}
|
||||
|
||||
from application.api.answer.services.stream_processor import StreamProcessor
|
||||
sp = StreamProcessor(
|
||||
request_data={"conversation_id": "conv1"},
|
||||
decoded_token={"sub": "u"},
|
||||
)
|
||||
sp.conversation_service = MagicMock()
|
||||
sp.conversation_service.get_conversation.return_value = None
|
||||
assert sp._resolve_agent_id() is None
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_conversation_lookup_exception(self):
|
||||
mock_db = MagicMock()
|
||||
with patch("application.api.answer.services.stream_processor.MongoDB") as MockMongo, \
|
||||
patch("application.api.answer.services.stream_processor.settings") as mock_settings:
|
||||
mock_settings.MONGO_DB_NAME = "docsgpt"
|
||||
MockMongo.get_client.return_value = {"docsgpt": mock_db}
|
||||
|
||||
from application.api.answer.services.stream_processor import StreamProcessor
|
||||
sp = StreamProcessor(
|
||||
request_data={"conversation_id": "conv1"},
|
||||
decoded_token={"sub": "u"},
|
||||
)
|
||||
sp.conversation_service = MagicMock()
|
||||
sp.conversation_service.get_conversation.side_effect = Exception("db error")
|
||||
assert sp._resolve_agent_id() is None
|
||||
|
||||
|
||||
class TestGetPromptContent:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_caches_result(self):
|
||||
mock_db = MagicMock()
|
||||
with patch("application.api.answer.services.stream_processor.MongoDB") as MockMongo, \
|
||||
patch("application.api.answer.services.stream_processor.settings") as mock_settings:
|
||||
mock_settings.MONGO_DB_NAME = "docsgpt"
|
||||
MockMongo.get_client.return_value = {"docsgpt": mock_db}
|
||||
|
||||
from application.api.answer.services.stream_processor import StreamProcessor
|
||||
sp = StreamProcessor(request_data={}, decoded_token={"sub": "u"})
|
||||
sp.agent_config = {"prompt_id": "default"}
|
||||
result1 = sp._get_prompt_content()
|
||||
result2 = sp._get_prompt_content()
|
||||
assert result1 == result2
|
||||
assert result1 is not None
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_no_prompt_id(self):
|
||||
mock_db = MagicMock()
|
||||
with patch("application.api.answer.services.stream_processor.MongoDB") as MockMongo, \
|
||||
patch("application.api.answer.services.stream_processor.settings") as mock_settings:
|
||||
mock_settings.MONGO_DB_NAME = "docsgpt"
|
||||
MockMongo.get_client.return_value = {"docsgpt": mock_db}
|
||||
|
||||
from application.api.answer.services.stream_processor import StreamProcessor
|
||||
sp = StreamProcessor(request_data={}, decoded_token={"sub": "u"})
|
||||
sp.agent_config = {}
|
||||
assert sp._get_prompt_content() is None
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_invalid_prompt_id_returns_none(self):
|
||||
mock_db = MagicMock()
|
||||
mock_prompts = MagicMock()
|
||||
mock_prompts.find_one.side_effect = Exception("bad")
|
||||
mock_db.__getitem__ = MagicMock(return_value=mock_prompts)
|
||||
|
||||
with patch("application.api.answer.services.stream_processor.MongoDB") as MockMongo, \
|
||||
patch("application.api.answer.services.stream_processor.settings") as mock_settings:
|
||||
mock_settings.MONGO_DB_NAME = "docsgpt"
|
||||
MockMongo.get_client.return_value = {"docsgpt": mock_db}
|
||||
|
||||
from application.api.answer.services.stream_processor import StreamProcessor
|
||||
sp = StreamProcessor(request_data={}, decoded_token={"sub": "u"})
|
||||
sp.agent_config = {"prompt_id": "bad_id"}
|
||||
assert sp._get_prompt_content() is None
|
||||
|
||||
|
||||
class TestGetRequiredToolActions:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_no_prompt_returns_none(self):
|
||||
mock_db = MagicMock()
|
||||
with patch("application.api.answer.services.stream_processor.MongoDB") as MockMongo, \
|
||||
patch("application.api.answer.services.stream_processor.settings") as mock_settings:
|
||||
mock_settings.MONGO_DB_NAME = "docsgpt"
|
||||
MockMongo.get_client.return_value = {"docsgpt": mock_db}
|
||||
|
||||
from application.api.answer.services.stream_processor import StreamProcessor
|
||||
sp = StreamProcessor(request_data={}, decoded_token={"sub": "u"})
|
||||
sp.agent_config = {}
|
||||
assert sp._get_required_tool_actions() is None
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_no_template_syntax_returns_empty(self):
|
||||
mock_db = MagicMock()
|
||||
with patch("application.api.answer.services.stream_processor.MongoDB") as MockMongo, \
|
||||
patch("application.api.answer.services.stream_processor.settings") as mock_settings:
|
||||
mock_settings.MONGO_DB_NAME = "docsgpt"
|
||||
MockMongo.get_client.return_value = {"docsgpt": mock_db}
|
||||
|
||||
from application.api.answer.services.stream_processor import StreamProcessor
|
||||
sp = StreamProcessor(request_data={}, decoded_token={"sub": "u"})
|
||||
sp.agent_config = {"prompt_id": "default"}
|
||||
sp._prompt_content = "No template syntax here"
|
||||
result = sp._get_required_tool_actions()
|
||||
assert result == {}
|
||||
333
tests/api/test_connector_routes.py
Normal file
333
tests/api/test_connector_routes.py
Normal file
@@ -0,0 +1,333 @@
|
||||
"""Tests for application/api/connector/routes.py"""
|
||||
|
||||
import json
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import mongomock
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app():
|
||||
with patch("application.app.handle_auth", return_value={"sub": "test_user"}):
|
||||
from application.app import app as flask_app
|
||||
flask_app.config["TESTING"] = True
|
||||
yield flask_app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client(app):
|
||||
return app.test_client()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_sessions(monkeypatch):
|
||||
mock_client = mongomock.MongoClient()
|
||||
mock_db = mock_client["docsgpt"]
|
||||
sessions = mock_db["connector_sessions"]
|
||||
sources = mock_db["sources"]
|
||||
monkeypatch.setattr("application.api.connector.routes.sessions_collection", sessions)
|
||||
monkeypatch.setattr("application.api.connector.routes.sources_collection", sources)
|
||||
return {"sessions": sessions, "sources": sources}
|
||||
|
||||
|
||||
class TestConnectorAuth:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_missing_provider(self, client):
|
||||
resp = client.get("/api/connectors/auth")
|
||||
assert resp.status_code == 400
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_unsupported_provider(self, client):
|
||||
resp = client.get("/api/connectors/auth?provider=dropbox")
|
||||
assert resp.status_code == 400
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_unauthorized(self, client, app):
|
||||
with patch("application.app.handle_auth", return_value=None):
|
||||
resp = client.get("/api/connectors/auth?provider=google_drive")
|
||||
data = json.loads(resp.data)
|
||||
# decoded_token is None -> 401
|
||||
assert resp.status_code == 401 or data.get("error") == "Unauthorized"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_success(self, client, mock_sessions):
|
||||
with patch("application.api.connector.routes.ConnectorCreator") as MockCC:
|
||||
MockCC.is_supported.return_value = True
|
||||
mock_auth = MagicMock()
|
||||
mock_auth.get_authorization_url.return_value = "https://oauth.example.com/auth"
|
||||
MockCC.create_auth.return_value = mock_auth
|
||||
|
||||
resp = client.get("/api/connectors/auth?provider=google_drive")
|
||||
assert resp.status_code == 200
|
||||
data = json.loads(resp.data)
|
||||
assert data["success"] is True
|
||||
assert "authorization_url" in data
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_exception_returns_500(self, client, mock_sessions):
|
||||
with patch("application.api.connector.routes.ConnectorCreator") as MockCC:
|
||||
MockCC.is_supported.return_value = True
|
||||
MockCC.create_auth.side_effect = Exception("oauth fail")
|
||||
resp = client.get("/api/connectors/auth?provider=google_drive")
|
||||
assert resp.status_code == 500
|
||||
|
||||
|
||||
class TestConnectorFiles:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_missing_params(self, client):
|
||||
resp = client.post("/api/connectors/files", json={"provider": "google_drive"})
|
||||
assert resp.status_code == 400
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_invalid_session(self, client, mock_sessions):
|
||||
resp = client.post("/api/connectors/files", json={
|
||||
"provider": "google_drive",
|
||||
"session_token": "bad_token",
|
||||
})
|
||||
assert resp.status_code == 401
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_success(self, client, mock_sessions):
|
||||
mock_sessions["sessions"].insert_one({
|
||||
"session_token": "valid_tok",
|
||||
"user": "test_user",
|
||||
"provider": "google_drive",
|
||||
})
|
||||
|
||||
mock_doc = MagicMock()
|
||||
mock_doc.doc_id = "f1"
|
||||
mock_doc.extra_info = {
|
||||
"file_name": "test.pdf",
|
||||
"mime_type": "application/pdf",
|
||||
"size": 1024,
|
||||
"modified_time": "2025-01-01T12:00:00.000Z",
|
||||
"is_folder": False,
|
||||
}
|
||||
mock_loader = MagicMock()
|
||||
mock_loader.load_data.return_value = [mock_doc]
|
||||
mock_loader.next_page_token = None
|
||||
|
||||
with patch("application.api.connector.routes.ConnectorCreator") as MockCC:
|
||||
MockCC.create_connector.return_value = mock_loader
|
||||
resp = client.post("/api/connectors/files", json={
|
||||
"provider": "google_drive",
|
||||
"session_token": "valid_tok",
|
||||
})
|
||||
assert resp.status_code == 200
|
||||
data = json.loads(resp.data)
|
||||
assert data["success"] is True
|
||||
assert len(data["files"]) == 1
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_no_modified_time(self, client, mock_sessions):
|
||||
mock_sessions["sessions"].insert_one({
|
||||
"session_token": "tok2",
|
||||
"user": "test_user",
|
||||
"provider": "google_drive",
|
||||
})
|
||||
mock_doc = MagicMock()
|
||||
mock_doc.doc_id = "f1"
|
||||
mock_doc.extra_info = {"file_name": "test.pdf", "mime_type": "application/pdf"}
|
||||
mock_loader = MagicMock()
|
||||
mock_loader.load_data.return_value = [mock_doc]
|
||||
mock_loader.next_page_token = None
|
||||
|
||||
with patch("application.api.connector.routes.ConnectorCreator") as MockCC:
|
||||
MockCC.create_connector.return_value = mock_loader
|
||||
resp = client.post("/api/connectors/files", json={
|
||||
"provider": "google_drive", "session_token": "tok2",
|
||||
})
|
||||
assert resp.status_code == 200
|
||||
|
||||
|
||||
class TestConnectorValidateSession:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_missing_params(self, client):
|
||||
resp = client.post("/api/connectors/validate-session", json={"provider": "google_drive"})
|
||||
assert resp.status_code == 400
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_invalid_session(self, client, mock_sessions):
|
||||
resp = client.post("/api/connectors/validate-session", json={
|
||||
"provider": "google_drive", "session_token": "bad",
|
||||
})
|
||||
assert resp.status_code == 401
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_valid_non_expired(self, client, mock_sessions):
|
||||
mock_sessions["sessions"].insert_one({
|
||||
"session_token": "valid",
|
||||
"user": "test_user",
|
||||
"provider": "google_drive",
|
||||
"token_info": {"access_token": "at", "refresh_token": "rt", "expiry": None},
|
||||
"user_email": "user@example.com",
|
||||
})
|
||||
with patch("application.api.connector.routes.ConnectorCreator") as MockCC:
|
||||
mock_auth = MagicMock()
|
||||
mock_auth.is_token_expired.return_value = False
|
||||
MockCC.create_auth.return_value = mock_auth
|
||||
resp = client.post("/api/connectors/validate-session", json={
|
||||
"provider": "google_drive", "session_token": "valid",
|
||||
})
|
||||
assert resp.status_code == 200
|
||||
data = json.loads(resp.data)
|
||||
assert data["success"] is True
|
||||
assert data["expired"] is False
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_expired_with_refresh(self, client, mock_sessions):
|
||||
mock_sessions["sessions"].insert_one({
|
||||
"session_token": "expired_tok",
|
||||
"user": "test_user",
|
||||
"provider": "google_drive",
|
||||
"token_info": {"access_token": "old_at", "refresh_token": "rt", "expiry": 100},
|
||||
})
|
||||
with patch("application.api.connector.routes.ConnectorCreator") as MockCC:
|
||||
mock_auth = MagicMock()
|
||||
mock_auth.is_token_expired.return_value = True
|
||||
mock_auth.refresh_access_token.return_value = {"access_token": "new_at", "refresh_token": "rt"}
|
||||
mock_auth.sanitize_token_info.return_value = {"access_token": "new_at", "refresh_token": "rt"}
|
||||
MockCC.create_auth.return_value = mock_auth
|
||||
resp = client.post("/api/connectors/validate-session", json={
|
||||
"provider": "google_drive", "session_token": "expired_tok",
|
||||
})
|
||||
assert resp.status_code == 200
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_expired_no_refresh(self, client, mock_sessions):
|
||||
mock_sessions["sessions"].insert_one({
|
||||
"session_token": "exp_no_ref",
|
||||
"user": "test_user",
|
||||
"token_info": {"access_token": "at", "expiry": 100},
|
||||
})
|
||||
with patch("application.api.connector.routes.ConnectorCreator") as MockCC:
|
||||
mock_auth = MagicMock()
|
||||
mock_auth.is_token_expired.return_value = True
|
||||
MockCC.create_auth.return_value = mock_auth
|
||||
resp = client.post("/api/connectors/validate-session", json={
|
||||
"provider": "google_drive", "session_token": "exp_no_ref",
|
||||
})
|
||||
assert resp.status_code == 401
|
||||
|
||||
|
||||
class TestConnectorDisconnect:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_missing_provider(self, client):
|
||||
resp = client.post("/api/connectors/disconnect", json={})
|
||||
assert resp.status_code == 400
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_success_with_session(self, client, mock_sessions):
|
||||
mock_sessions["sessions"].insert_one({"session_token": "del_me", "provider": "google_drive"})
|
||||
resp = client.post("/api/connectors/disconnect", json={
|
||||
"provider": "google_drive", "session_token": "del_me",
|
||||
})
|
||||
assert resp.status_code == 200
|
||||
data = json.loads(resp.data)
|
||||
assert data["success"] is True
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_success_without_session(self, client, mock_sessions):
|
||||
resp = client.post("/api/connectors/disconnect", json={"provider": "google_drive"})
|
||||
assert resp.status_code == 200
|
||||
|
||||
|
||||
class TestConnectorSync:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_missing_params(self, client, mock_sessions):
|
||||
resp = client.post("/api/connectors/sync", json={"source_id": "abc"})
|
||||
assert resp.status_code == 400
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_source_not_found(self, client, mock_sessions):
|
||||
from bson.objectid import ObjectId
|
||||
resp = client.post("/api/connectors/sync", json={
|
||||
"source_id": str(ObjectId()), "session_token": "tok",
|
||||
})
|
||||
assert resp.status_code == 404
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_unauthorized_source(self, client, mock_sessions):
|
||||
sid = mock_sessions["sources"].insert_one({"user": "other_user", "name": "src"}).inserted_id
|
||||
resp = client.post("/api/connectors/sync", json={
|
||||
"source_id": str(sid), "session_token": "tok",
|
||||
})
|
||||
assert resp.status_code == 403
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_missing_provider_in_remote_data(self, client, mock_sessions):
|
||||
sid = mock_sessions["sources"].insert_one({
|
||||
"user": "test_user", "name": "src", "remote_data": json.dumps({}),
|
||||
}).inserted_id
|
||||
resp = client.post("/api/connectors/sync", json={
|
||||
"source_id": str(sid), "session_token": "tok",
|
||||
})
|
||||
assert resp.status_code == 400
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_success(self, client, mock_sessions):
|
||||
sid = mock_sessions["sources"].insert_one({
|
||||
"user": "test_user",
|
||||
"name": "src",
|
||||
"remote_data": json.dumps({"provider": "google_drive", "file_ids": ["f1"]}),
|
||||
}).inserted_id
|
||||
mock_task = MagicMock()
|
||||
mock_task.id = "task_123"
|
||||
with patch("application.api.connector.routes.ingest_connector_task") as mock_ingest:
|
||||
mock_ingest.delay.return_value = mock_task
|
||||
resp = client.post("/api/connectors/sync", json={
|
||||
"source_id": str(sid), "session_token": "tok",
|
||||
})
|
||||
assert resp.status_code == 200
|
||||
data = json.loads(resp.data)
|
||||
assert data["task_id"] == "task_123"
|
||||
|
||||
|
||||
class TestConnectorCallbackStatus:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_success_status(self, client):
|
||||
resp = client.get("/api/connectors/callback-status?status=success&message=OK&provider=google_drive&session_token=tok&user_email=u@e.com")
|
||||
assert resp.status_code == 200
|
||||
assert b"success" in resp.data
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_error_status(self, client):
|
||||
resp = client.get("/api/connectors/callback-status?status=error&message=Failed")
|
||||
assert resp.status_code == 200
|
||||
assert b"error" in resp.data
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_cancelled_status(self, client):
|
||||
resp = client.get("/api/connectors/callback-status?status=cancelled&message=Cancelled&provider=google_drive")
|
||||
assert resp.status_code == 200
|
||||
assert b"cancelled" in resp.data
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_unknown_status_defaults_to_error(self, client):
|
||||
resp = client.get("/api/connectors/callback-status?status=badvalue")
|
||||
assert resp.status_code == 200
|
||||
assert b"error" in resp.data
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_html_escaping(self, client):
|
||||
resp = client.get('/api/connectors/callback-status?status=error&message=<script>alert(1)</script>')
|
||||
assert resp.status_code == 200
|
||||
# The raw <script> tag should be escaped (not executable)
|
||||
assert b"<script>alert(1)</script>" not in resp.data
|
||||
|
||||
|
||||
class TestBuildCallbackRedirect:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_builds_url(self):
|
||||
from application.api.connector.routes import build_callback_redirect
|
||||
url = build_callback_redirect({"status": "success", "message": "OK"})
|
||||
assert url.startswith("/api/connectors/callback-status?")
|
||||
assert "status=success" in url
|
||||
339
tests/core/test_model_settings.py
Normal file
339
tests/core/test_model_settings.py
Normal file
@@ -0,0 +1,339 @@
|
||||
"""Tests for application/core/model_settings.py"""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from application.core.model_settings import (
|
||||
AvailableModel,
|
||||
ModelCapabilities,
|
||||
ModelProvider,
|
||||
ModelRegistry,
|
||||
)
|
||||
|
||||
|
||||
class TestModelProvider:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_all_providers_exist(self):
|
||||
assert ModelProvider.OPENAI == "openai"
|
||||
assert ModelProvider.ANTHROPIC == "anthropic"
|
||||
assert ModelProvider.GOOGLE == "google"
|
||||
assert ModelProvider.GROQ == "groq"
|
||||
assert ModelProvider.DOCSGPT == "docsgpt"
|
||||
assert ModelProvider.HUGGINGFACE == "huggingface"
|
||||
assert ModelProvider.NOVITA == "novita"
|
||||
assert ModelProvider.OPENROUTER == "openrouter"
|
||||
assert ModelProvider.SAGEMAKER == "sagemaker"
|
||||
assert ModelProvider.PREMAI == "premai"
|
||||
assert ModelProvider.LLAMA_CPP == "llama.cpp"
|
||||
assert ModelProvider.AZURE_OPENAI == "azure_openai"
|
||||
|
||||
|
||||
class TestModelCapabilities:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_defaults(self):
|
||||
caps = ModelCapabilities()
|
||||
assert caps.supports_tools is False
|
||||
assert caps.supports_structured_output is False
|
||||
assert caps.supports_streaming is True
|
||||
assert caps.supported_attachment_types == []
|
||||
assert caps.context_window == 128000
|
||||
assert caps.input_cost_per_token is None
|
||||
assert caps.output_cost_per_token is None
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_custom_values(self):
|
||||
caps = ModelCapabilities(
|
||||
supports_tools=True,
|
||||
supports_structured_output=True,
|
||||
context_window=32000,
|
||||
input_cost_per_token=0.001,
|
||||
)
|
||||
assert caps.supports_tools is True
|
||||
assert caps.context_window == 32000
|
||||
|
||||
|
||||
class TestAvailableModel:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_to_dict_basic(self):
|
||||
model = AvailableModel(
|
||||
id="gpt-4",
|
||||
provider=ModelProvider.OPENAI,
|
||||
display_name="GPT-4",
|
||||
description="OpenAI GPT-4",
|
||||
)
|
||||
d = model.to_dict()
|
||||
assert d["id"] == "gpt-4"
|
||||
assert d["provider"] == "openai"
|
||||
assert d["display_name"] == "GPT-4"
|
||||
assert d["enabled"] is True
|
||||
assert "base_url" not in d
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_to_dict_with_base_url(self):
|
||||
model = AvailableModel(
|
||||
id="local-model",
|
||||
provider=ModelProvider.OPENAI,
|
||||
display_name="Local",
|
||||
base_url="http://localhost:11434",
|
||||
)
|
||||
d = model.to_dict()
|
||||
assert d["base_url"] == "http://localhost:11434"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_to_dict_includes_capabilities(self):
|
||||
caps = ModelCapabilities(supports_tools=True, context_window=64000)
|
||||
model = AvailableModel(
|
||||
id="m1",
|
||||
provider=ModelProvider.ANTHROPIC,
|
||||
display_name="M1",
|
||||
capabilities=caps,
|
||||
)
|
||||
d = model.to_dict()
|
||||
assert d["supports_tools"] is True
|
||||
assert d["context_window"] == 64000
|
||||
|
||||
|
||||
class TestModelRegistry:
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _reset_singleton(self):
|
||||
"""Reset singleton between tests."""
|
||||
ModelRegistry._instance = None
|
||||
ModelRegistry._initialized = False
|
||||
yield
|
||||
ModelRegistry._instance = None
|
||||
ModelRegistry._initialized = False
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_singleton(self):
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
r1 = ModelRegistry()
|
||||
r2 = ModelRegistry()
|
||||
assert r1 is r2
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_get_instance(self):
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
r = ModelRegistry.get_instance()
|
||||
assert isinstance(r, ModelRegistry)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_get_model(self):
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
reg = ModelRegistry()
|
||||
model = AvailableModel(id="test", provider=ModelProvider.OPENAI, display_name="Test")
|
||||
reg.models["test"] = model
|
||||
assert reg.get_model("test") is model
|
||||
assert reg.get_model("nonexistent") is None
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_get_all_models(self):
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
reg = ModelRegistry()
|
||||
reg.models["m1"] = AvailableModel(id="m1", provider=ModelProvider.OPENAI, display_name="M1")
|
||||
reg.models["m2"] = AvailableModel(id="m2", provider=ModelProvider.ANTHROPIC, display_name="M2")
|
||||
assert len(reg.get_all_models()) == 2
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_get_enabled_models(self):
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
reg = ModelRegistry()
|
||||
reg.models["m1"] = AvailableModel(id="m1", provider=ModelProvider.OPENAI, display_name="M1", enabled=True)
|
||||
reg.models["m2"] = AvailableModel(id="m2", provider=ModelProvider.OPENAI, display_name="M2", enabled=False)
|
||||
enabled = reg.get_enabled_models()
|
||||
assert len(enabled) == 1
|
||||
assert enabled[0].id == "m1"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_model_exists(self):
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
reg = ModelRegistry()
|
||||
reg.models["m1"] = AvailableModel(id="m1", provider=ModelProvider.OPENAI, display_name="M1")
|
||||
assert reg.model_exists("m1") is True
|
||||
assert reg.model_exists("m2") is False
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_parse_model_names(self):
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
reg = ModelRegistry()
|
||||
assert reg._parse_model_names("model1,model2") == ["model1", "model2"]
|
||||
assert reg._parse_model_names("model1 , model2 ") == ["model1", "model2"]
|
||||
assert reg._parse_model_names("single") == ["single"]
|
||||
assert reg._parse_model_names("") == []
|
||||
assert reg._parse_model_names(None) == []
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_add_docsgpt_models(self):
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
reg = ModelRegistry()
|
||||
reg.models = {}
|
||||
mock_settings = MagicMock()
|
||||
reg._add_docsgpt_models(mock_settings)
|
||||
assert "docsgpt-local" in reg.models
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_add_huggingface_models(self):
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
reg = ModelRegistry()
|
||||
reg.models = {}
|
||||
mock_settings = MagicMock()
|
||||
reg._add_huggingface_models(mock_settings)
|
||||
assert "huggingface-local" in reg.models
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_load_models_with_openai_key(self):
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.OPENAI_BASE_URL = None
|
||||
mock_settings.OPENAI_API_KEY = "sk-test"
|
||||
mock_settings.OPENAI_API_BASE = None
|
||||
mock_settings.ANTHROPIC_API_KEY = None
|
||||
mock_settings.GOOGLE_API_KEY = None
|
||||
mock_settings.GROQ_API_KEY = None
|
||||
mock_settings.OPEN_ROUTER_API_KEY = None
|
||||
mock_settings.NOVITA_API_KEY = None
|
||||
mock_settings.HUGGINGFACE_API_KEY = None
|
||||
mock_settings.LLM_PROVIDER = "openai"
|
||||
mock_settings.LLM_NAME = ""
|
||||
mock_settings.API_KEY = None
|
||||
|
||||
with patch("application.core.settings.settings", mock_settings):
|
||||
reg = ModelRegistry()
|
||||
assert len(reg.models) > 0
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_load_models_custom_openai_base_url(self):
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.OPENAI_BASE_URL = "http://localhost:11434/v1"
|
||||
mock_settings.OPENAI_API_KEY = "sk-test"
|
||||
mock_settings.OPENAI_API_BASE = None
|
||||
mock_settings.ANTHROPIC_API_KEY = None
|
||||
mock_settings.GOOGLE_API_KEY = None
|
||||
mock_settings.GROQ_API_KEY = None
|
||||
mock_settings.OPEN_ROUTER_API_KEY = None
|
||||
mock_settings.NOVITA_API_KEY = None
|
||||
mock_settings.HUGGINGFACE_API_KEY = None
|
||||
mock_settings.LLM_PROVIDER = "openai"
|
||||
mock_settings.LLM_NAME = "llama3,gemma"
|
||||
mock_settings.API_KEY = None
|
||||
|
||||
with patch("application.core.settings.settings", mock_settings):
|
||||
reg = ModelRegistry()
|
||||
assert "llama3" in reg.models
|
||||
assert "gemma" in reg.models
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_default_model_selection_from_llm_name(self):
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
reg = ModelRegistry()
|
||||
reg.models = {"gpt-4": AvailableModel(id="gpt-4", provider=ModelProvider.OPENAI, display_name="GPT-4")}
|
||||
reg.default_model_id = "gpt-4"
|
||||
assert reg.default_model_id == "gpt-4"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_add_anthropic_models_with_key(self):
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
reg = ModelRegistry()
|
||||
reg.models = {}
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.ANTHROPIC_API_KEY = "sk-ant-test"
|
||||
mock_settings.LLM_PROVIDER = ""
|
||||
mock_settings.LLM_NAME = ""
|
||||
reg._add_anthropic_models(mock_settings)
|
||||
assert len(reg.models) > 0
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_add_google_models_with_key(self):
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
reg = ModelRegistry()
|
||||
reg.models = {}
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.GOOGLE_API_KEY = "google-test"
|
||||
mock_settings.LLM_PROVIDER = ""
|
||||
mock_settings.LLM_NAME = ""
|
||||
reg._add_google_models(mock_settings)
|
||||
assert len(reg.models) > 0
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_add_groq_models_with_key(self):
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
reg = ModelRegistry()
|
||||
reg.models = {}
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.GROQ_API_KEY = "groq-test"
|
||||
mock_settings.LLM_PROVIDER = ""
|
||||
mock_settings.LLM_NAME = ""
|
||||
reg._add_groq_models(mock_settings)
|
||||
assert len(reg.models) > 0
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_add_openrouter_models_with_key(self):
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
reg = ModelRegistry()
|
||||
reg.models = {}
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.OPEN_ROUTER_API_KEY = "or-test"
|
||||
mock_settings.LLM_PROVIDER = ""
|
||||
mock_settings.LLM_NAME = ""
|
||||
reg._add_openrouter_models(mock_settings)
|
||||
assert len(reg.models) > 0
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_add_novita_models_with_key(self):
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
reg = ModelRegistry()
|
||||
reg.models = {}
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.NOVITA_API_KEY = "novita-test"
|
||||
mock_settings.LLM_PROVIDER = ""
|
||||
mock_settings.LLM_NAME = ""
|
||||
reg._add_novita_models(mock_settings)
|
||||
assert len(reg.models) > 0
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_add_azure_openai_models_specific(self):
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
reg = ModelRegistry()
|
||||
reg.models = {}
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.LLM_PROVIDER = "azure_openai"
|
||||
mock_settings.LLM_NAME = "nonexistent-model"
|
||||
reg._add_azure_openai_models(mock_settings)
|
||||
# Falls through to adding all azure models
|
||||
assert len(reg.models) > 0
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_add_anthropic_models_no_key_with_provider(self):
|
||||
with patch.object(ModelRegistry, "_load_models"):
|
||||
reg = ModelRegistry()
|
||||
reg.models = {}
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.ANTHROPIC_API_KEY = None
|
||||
mock_settings.LLM_PROVIDER = "anthropic"
|
||||
mock_settings.LLM_NAME = "nonexistent"
|
||||
reg._add_anthropic_models(mock_settings)
|
||||
assert len(reg.models) > 0
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_default_model_fallback_to_first(self):
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.OPENAI_BASE_URL = None
|
||||
mock_settings.OPENAI_API_KEY = None
|
||||
mock_settings.OPENAI_API_BASE = None
|
||||
mock_settings.ANTHROPIC_API_KEY = None
|
||||
mock_settings.GOOGLE_API_KEY = None
|
||||
mock_settings.GROQ_API_KEY = None
|
||||
mock_settings.OPEN_ROUTER_API_KEY = None
|
||||
mock_settings.NOVITA_API_KEY = None
|
||||
mock_settings.HUGGINGFACE_API_KEY = None
|
||||
mock_settings.LLM_PROVIDER = ""
|
||||
mock_settings.LLM_NAME = ""
|
||||
mock_settings.API_KEY = None
|
||||
|
||||
with patch("application.core.settings.settings", mock_settings):
|
||||
reg = ModelRegistry()
|
||||
# Should have at least docsgpt-local
|
||||
assert reg.default_model_id is not None
|
||||
115
tests/test_app_routes.py
Normal file
115
tests/test_app_routes.py
Normal file
@@ -0,0 +1,115 @@
|
||||
"""Tests for application/app.py route handlers."""
|
||||
|
||||
import json
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app():
|
||||
"""Import the Flask app with auth mocked to avoid JWT setup issues."""
|
||||
with patch("application.app.handle_auth", return_value={"sub": "test_user"}):
|
||||
from application.app import app as flask_app
|
||||
flask_app.config["TESTING"] = True
|
||||
yield flask_app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client(app):
|
||||
return app.test_client()
|
||||
|
||||
|
||||
class TestHomeRoute:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_root_returns_200(self, client):
|
||||
"""Root serves Swagger UI via Flask-RESTX."""
|
||||
response = client.get("/")
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
class TestHealthRoute:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_returns_ok(self, client):
|
||||
response = client.get("/api/health")
|
||||
assert response.status_code == 200
|
||||
data = json.loads(response.data)
|
||||
assert data["status"] == "ok"
|
||||
|
||||
|
||||
class TestConfigRoute:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_returns_auth_config(self, client):
|
||||
response = client.get("/api/config")
|
||||
assert response.status_code == 200
|
||||
data = json.loads(response.data)
|
||||
assert "auth_type" in data
|
||||
assert "requires_auth" in data
|
||||
|
||||
|
||||
class TestGenerateTokenRoute:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_session_jwt_generates_token(self, client, app):
|
||||
with patch("application.app.settings") as mock_settings:
|
||||
mock_settings.AUTH_TYPE = "session_jwt"
|
||||
mock_settings.JWT_SECRET_KEY = "test_secret"
|
||||
response = client.get("/api/generate_token")
|
||||
assert response.status_code == 200
|
||||
data = json.loads(response.data)
|
||||
assert "token" in data
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_non_session_jwt_returns_error(self, client, app):
|
||||
with patch("application.app.settings") as mock_settings:
|
||||
mock_settings.AUTH_TYPE = "none"
|
||||
response = client.get("/api/generate_token")
|
||||
assert response.status_code == 400
|
||||
|
||||
|
||||
class TestSttRequestSizeLimits:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_non_stt_request_passes(self, client):
|
||||
response = client.get("/api/health")
|
||||
assert response.status_code == 200
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_oversized_stt_request_rejected(self, client):
|
||||
with patch("application.app.should_reject_stt_request", return_value=True), \
|
||||
patch("application.app.build_stt_file_size_limit_message", return_value="Too large"):
|
||||
response = client.post("/api/stt/upload", data=b"x" * 100)
|
||||
assert response.status_code == 413
|
||||
|
||||
|
||||
class TestAuthenticateRequest:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_options_returns_200(self, client):
|
||||
response = client.options("/api/health")
|
||||
assert response.status_code == 200
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_auth_error_returns_401(self, client, app):
|
||||
with patch("application.app.handle_auth", return_value={"error": "Invalid token"}):
|
||||
response = client.get("/api/health")
|
||||
assert response.status_code == 401
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_no_token_sets_none(self, client, app):
|
||||
with patch("application.app.handle_auth", return_value=None):
|
||||
response = client.get("/api/health")
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
class TestAfterRequest:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_cors_headers(self, client):
|
||||
response = client.get("/api/health")
|
||||
assert response.headers.get("Access-Control-Allow-Origin") == "*"
|
||||
assert "Content-Type" in response.headers.get("Access-Control-Allow-Headers", "")
|
||||
assert "GET" in response.headers.get("Access-Control-Allow-Methods", "")
|
||||
533
tests/test_utils.py
Normal file
533
tests/test_utils.py
Normal file
@@ -0,0 +1,533 @@
|
||||
"""Tests for application/utils.py"""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from application.utils import (
|
||||
calculate_compression_threshold,
|
||||
calculate_doc_token_budget,
|
||||
check_required_fields,
|
||||
clean_text_for_tts,
|
||||
convert_pdf_to_images,
|
||||
get_encoding,
|
||||
get_field_validation_errors,
|
||||
get_gpt_model,
|
||||
get_hash,
|
||||
get_missing_fields,
|
||||
generate_image_url,
|
||||
limit_chat_history,
|
||||
num_tokens_from_object_or_list,
|
||||
num_tokens_from_string,
|
||||
safe_filename,
|
||||
validate_function_name,
|
||||
validate_required_fields,
|
||||
)
|
||||
|
||||
|
||||
class TestGetEncoding:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_returns_encoding(self):
|
||||
enc = get_encoding()
|
||||
assert enc is not None
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_returns_same_instance(self):
|
||||
enc1 = get_encoding()
|
||||
enc2 = get_encoding()
|
||||
assert enc1 is enc2
|
||||
|
||||
|
||||
class TestGetGptModel:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_returns_llm_name_when_set(self):
|
||||
with patch("application.utils.settings") as s:
|
||||
s.LLM_NAME = "my-model"
|
||||
s.LLM_PROVIDER = "openai"
|
||||
assert get_gpt_model() == "my-model"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_falls_back_to_provider_map(self):
|
||||
with patch("application.utils.settings") as s:
|
||||
s.LLM_NAME = ""
|
||||
s.LLM_PROVIDER = "openai"
|
||||
assert get_gpt_model() == "gpt-4o-mini"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_unknown_provider_returns_empty(self):
|
||||
with patch("application.utils.settings") as s:
|
||||
s.LLM_NAME = ""
|
||||
s.LLM_PROVIDER = "unknown"
|
||||
assert get_gpt_model() == ""
|
||||
|
||||
|
||||
class TestSafeFilename:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_normal_filename(self):
|
||||
assert safe_filename("test.pdf") == "test.pdf"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_empty_filename_returns_uuid(self):
|
||||
result = safe_filename("")
|
||||
assert len(result) > 10 # UUID
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_none_filename_returns_uuid(self):
|
||||
result = safe_filename(None)
|
||||
assert len(result) > 10
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_non_latin_filename(self):
|
||||
result = safe_filename("документ.pdf")
|
||||
assert result.endswith(".pdf")
|
||||
|
||||
|
||||
class TestNumTokens:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_string_token_count(self):
|
||||
count = num_tokens_from_string("hello world")
|
||||
assert count > 0
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_non_string_returns_zero(self):
|
||||
assert num_tokens_from_string(123) == 0
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_empty_string(self):
|
||||
assert num_tokens_from_string("") == 0
|
||||
|
||||
|
||||
class TestNumTokensFromObjectOrList:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_list(self):
|
||||
result = num_tokens_from_object_or_list(["hello", "world"])
|
||||
assert result > 0
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_dict(self):
|
||||
result = num_tokens_from_object_or_list({"key": "value"})
|
||||
assert result > 0
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_string(self):
|
||||
result = num_tokens_from_object_or_list("hello")
|
||||
assert result > 0
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_number_returns_zero(self):
|
||||
assert num_tokens_from_object_or_list(42) == 0
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_nested(self):
|
||||
result = num_tokens_from_object_or_list({"a": ["b", "c"]})
|
||||
assert result > 0
|
||||
|
||||
|
||||
class TestCountTokensDocs:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_counts_doc_tokens(self):
|
||||
from application.utils import count_tokens_docs
|
||||
doc1 = MagicMock()
|
||||
doc1.page_content = "hello world"
|
||||
doc2 = MagicMock()
|
||||
doc2.page_content = " foo bar"
|
||||
result = count_tokens_docs([doc1, doc2])
|
||||
assert result > 0
|
||||
|
||||
|
||||
class TestCalculateDocTokenBudget:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_returns_budget(self):
|
||||
with patch("application.utils.get_token_limit", return_value=128000), \
|
||||
patch("application.utils.settings") as s:
|
||||
s.RESERVED_TOKENS = {"system": 500, "history": 500}
|
||||
result = calculate_doc_token_budget("gpt-4o")
|
||||
assert result == 127000
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_minimum_budget(self):
|
||||
with patch("application.utils.get_token_limit", return_value=1000), \
|
||||
patch("application.utils.settings") as s:
|
||||
s.RESERVED_TOKENS = {"system": 500, "history": 500}
|
||||
result = calculate_doc_token_budget("small-model")
|
||||
assert result == 1000
|
||||
|
||||
|
||||
class TestFieldValidation:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_get_missing_fields(self):
|
||||
assert get_missing_fields({"a": 1}, ["a", "b"]) == ["b"]
|
||||
assert get_missing_fields({"a": 1, "b": 2}, ["a", "b"]) == []
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_check_required_fields_pass(self):
|
||||
from flask import Flask
|
||||
app = Flask(__name__)
|
||||
with app.app_context():
|
||||
result = check_required_fields({"a": 1, "b": 2}, ["a", "b"])
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_check_required_fields_fail(self):
|
||||
from flask import Flask
|
||||
app = Flask(__name__)
|
||||
with app.app_context():
|
||||
result = check_required_fields({"a": 1}, ["a", "b"])
|
||||
assert result is not None
|
||||
assert result.status_code == 400
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_get_field_validation_errors_none_when_valid(self):
|
||||
assert get_field_validation_errors({"a": 1}, ["a"]) is None
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_get_field_validation_errors_missing(self):
|
||||
result = get_field_validation_errors({}, ["a"])
|
||||
assert result["missing_fields"] == ["a"]
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_get_field_validation_errors_empty(self):
|
||||
result = get_field_validation_errors({"a": ""}, ["a"])
|
||||
assert result["empty_fields"] == ["a"]
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_validate_required_fields_pass(self):
|
||||
from flask import Flask
|
||||
app = Flask(__name__)
|
||||
with app.app_context():
|
||||
result = validate_required_fields({"a": "v"}, ["a"])
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_validate_required_fields_missing(self):
|
||||
from flask import Flask
|
||||
app = Flask(__name__)
|
||||
with app.app_context():
|
||||
result = validate_required_fields({}, ["a"])
|
||||
assert result is not None
|
||||
assert result.status_code == 400
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_validate_required_fields_empty(self):
|
||||
from flask import Flask
|
||||
app = Flask(__name__)
|
||||
with app.app_context():
|
||||
result = validate_required_fields({"a": ""}, ["a"])
|
||||
assert result is not None
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_validate_required_fields_both_missing_and_empty(self):
|
||||
from flask import Flask
|
||||
app = Flask(__name__)
|
||||
with app.app_context():
|
||||
result = validate_required_fields({"a": ""}, ["a", "b"])
|
||||
assert result is not None
|
||||
|
||||
|
||||
class TestGetHash:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_returns_hex_string(self):
|
||||
h = get_hash("test")
|
||||
assert len(h) == 32
|
||||
assert all(c in "0123456789abcdef" for c in h)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_deterministic(self):
|
||||
assert get_hash("hello") == get_hash("hello")
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_different_inputs(self):
|
||||
assert get_hash("a") != get_hash("b")
|
||||
|
||||
|
||||
class TestLimitChatHistory:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_empty_history(self):
|
||||
assert limit_chat_history([]) == []
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_none_history(self):
|
||||
assert limit_chat_history(None) == []
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_keeps_recent_messages(self):
|
||||
history = [
|
||||
{"prompt": "q1", "response": "a1"},
|
||||
{"prompt": "q2", "response": "a2"},
|
||||
]
|
||||
result = limit_chat_history(history, max_token_limit=10000)
|
||||
assert len(result) == 2
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_trims_old_messages(self):
|
||||
history = [
|
||||
{"prompt": "x" * 5000, "response": "y" * 5000},
|
||||
{"prompt": "q", "response": "a"},
|
||||
]
|
||||
result = limit_chat_history(history, max_token_limit=100)
|
||||
assert len(result) <= 2
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_handles_tool_calls(self):
|
||||
history = [
|
||||
{
|
||||
"prompt": "q",
|
||||
"response": "a",
|
||||
"tool_calls": [
|
||||
{"tool_name": "t", "action_name": "a", "arguments": "{}", "result": "r"}
|
||||
],
|
||||
}
|
||||
]
|
||||
result = limit_chat_history(history, max_token_limit=10000)
|
||||
assert len(result) == 1
|
||||
|
||||
|
||||
class TestValidateFunctionName:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_valid_names(self):
|
||||
assert validate_function_name("hello") is True
|
||||
assert validate_function_name("hello_world") is True
|
||||
assert validate_function_name("hello-world") is True
|
||||
assert validate_function_name("test123") is True
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_invalid_names(self):
|
||||
assert validate_function_name("hello world") is False
|
||||
assert validate_function_name("hello!") is False
|
||||
assert validate_function_name("") is False
|
||||
|
||||
|
||||
class TestGenerateImageUrl:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_http_url_passthrough(self):
|
||||
assert generate_image_url("https://example.com/img.png") == "https://example.com/img.png"
|
||||
assert generate_image_url("http://example.com/img.png") == "http://example.com/img.png"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_s3_strategy(self):
|
||||
with patch("application.utils.settings") as s:
|
||||
s.URL_STRATEGY = "s3"
|
||||
s.S3_BUCKET_NAME = "my-bucket"
|
||||
s.SAGEMAKER_REGION = "us-west-2"
|
||||
result = generate_image_url("path/to/img.png")
|
||||
assert "my-bucket.s3.us-west-2" in result
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_backend_strategy(self):
|
||||
with patch("application.utils.settings") as s:
|
||||
s.URL_STRATEGY = "backend"
|
||||
s.API_URL = "http://localhost:7091"
|
||||
result = generate_image_url("path/to/img.png")
|
||||
assert result == "http://localhost:7091/api/images/path/to/img.png"
|
||||
|
||||
|
||||
class TestCalculateCompressionThreshold:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_default_threshold(self):
|
||||
with patch("application.utils.get_token_limit", return_value=100000):
|
||||
result = calculate_compression_threshold("gpt-4o")
|
||||
assert result == 80000
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_custom_percentage(self):
|
||||
with patch("application.utils.get_token_limit", return_value=100000):
|
||||
result = calculate_compression_threshold("gpt-4o", 0.5)
|
||||
assert result == 50000
|
||||
|
||||
|
||||
class TestConvertPdfToImages:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_missing_pdf2image_raises(self):
|
||||
with patch.dict("sys.modules", {"pdf2image": None}):
|
||||
# Force re-import to trigger ImportError
|
||||
# The function handles the import internally
|
||||
with pytest.raises(ImportError, match="pdf2image"):
|
||||
convert_pdf_to_images("test.pdf")
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_converts_from_path(self):
|
||||
mock_image = MagicMock()
|
||||
mock_image.save = MagicMock(side_effect=lambda buf, format: buf.write(b"PNG_DATA"))
|
||||
|
||||
mock_module = MagicMock()
|
||||
mock_module.convert_from_path.return_value = [mock_image]
|
||||
mock_module.convert_from_bytes.return_value = [mock_image]
|
||||
|
||||
original_import = __import__
|
||||
|
||||
def patched_import(name, *args, **kwargs):
|
||||
if name == "pdf2image":
|
||||
return mock_module
|
||||
return original_import(name, *args, **kwargs)
|
||||
|
||||
with patch("builtins.__import__", side_effect=patched_import):
|
||||
result = convert_pdf_to_images("/some/file.pdf")
|
||||
assert len(result) == 1
|
||||
assert result[0]["mime_type"] == "image/png"
|
||||
assert result[0]["page"] == 1
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_with_storage(self):
|
||||
mock_image = MagicMock()
|
||||
mock_image.save = MagicMock(side_effect=lambda buf, format: buf.write(b"IMG"))
|
||||
|
||||
mock_storage = MagicMock()
|
||||
mock_file = MagicMock()
|
||||
mock_file.read.return_value = b"pdf_bytes"
|
||||
mock_file.__enter__ = MagicMock(return_value=mock_file)
|
||||
mock_file.__exit__ = MagicMock(return_value=False)
|
||||
mock_storage.get_file.return_value = mock_file
|
||||
|
||||
mock_module = MagicMock()
|
||||
mock_module.convert_from_bytes.return_value = [mock_image]
|
||||
|
||||
original_import = __import__
|
||||
|
||||
def patched_import(name, *args, **kwargs):
|
||||
if name == "pdf2image":
|
||||
return mock_module
|
||||
return original_import(name, *args, **kwargs)
|
||||
|
||||
with patch("builtins.__import__", side_effect=patched_import):
|
||||
result = convert_pdf_to_images("test.pdf", storage=mock_storage)
|
||||
assert len(result) == 1
|
||||
mock_module.convert_from_bytes.assert_called_once()
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_file_not_found_raises(self):
|
||||
mock_module = MagicMock()
|
||||
mock_module.convert_from_path.side_effect = FileNotFoundError("not found")
|
||||
|
||||
# Patch the import inside the function
|
||||
original_import = __builtins__.__import__ if hasattr(__builtins__, '__import__') else __import__
|
||||
|
||||
def patched_import(name, *args, **kwargs):
|
||||
if name == "pdf2image":
|
||||
return mock_module
|
||||
return original_import(name, *args, **kwargs)
|
||||
|
||||
with patch("builtins.__import__", side_effect=patched_import):
|
||||
with pytest.raises(FileNotFoundError):
|
||||
convert_pdf_to_images("/nonexistent.pdf")
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_generic_error_raises(self):
|
||||
mock_module = MagicMock()
|
||||
mock_module.convert_from_path.side_effect = RuntimeError("conversion failed")
|
||||
|
||||
original_import = __builtins__.__import__ if hasattr(__builtins__, '__import__') else __import__
|
||||
|
||||
def patched_import(name, *args, **kwargs):
|
||||
if name == "pdf2image":
|
||||
return mock_module
|
||||
return original_import(name, *args, **kwargs)
|
||||
|
||||
with patch("builtins.__import__", side_effect=patched_import):
|
||||
with pytest.raises(RuntimeError, match="conversion failed"):
|
||||
convert_pdf_to_images("/some.pdf")
|
||||
|
||||
|
||||
class TestCleanTextForTts:
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_removes_code_blocks(self):
|
||||
result = clean_text_for_tts("before ```python\ncode\n``` after")
|
||||
assert "code block" in result
|
||||
assert "python" not in result
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_removes_mermaid_blocks(self):
|
||||
result = clean_text_for_tts("```mermaid\ngraph TD\n```")
|
||||
assert "flowchart" in result
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_removes_markdown_links(self):
|
||||
result = clean_text_for_tts("[click here](https://example.com)")
|
||||
assert "click here" in result
|
||||
assert "https" not in result
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_removes_images(self):
|
||||
result = clean_text_for_tts("")
|
||||
assert "image.png" not in result
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_removes_inline_code(self):
|
||||
result = clean_text_for_tts("use `foo()` here")
|
||||
assert "foo()" in result
|
||||
assert "`" not in result
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_removes_bold_italic(self):
|
||||
result = clean_text_for_tts("**bold** and *italic*")
|
||||
assert "bold" in result
|
||||
assert "italic" in result
|
||||
assert "*" not in result
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_removes_headers(self):
|
||||
result = clean_text_for_tts("# Header\ntext")
|
||||
assert "Header" in result
|
||||
assert "#" not in result
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_removes_blockquotes(self):
|
||||
result = clean_text_for_tts("> quoted text")
|
||||
assert "quoted text" in result
|
||||
assert ">" not in result
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_removes_html_tags(self):
|
||||
result = clean_text_for_tts("<div>content</div>")
|
||||
assert "content" in result
|
||||
assert "<" not in result
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_removes_arrows(self):
|
||||
result = clean_text_for_tts("a --> b <-- c => d")
|
||||
assert "-->" not in result
|
||||
assert "<--" not in result
|
||||
assert "=>" not in result
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_removes_horizontal_rules(self):
|
||||
result = clean_text_for_tts("text\n---\nmore")
|
||||
assert "---" not in result
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_removes_list_markers(self):
|
||||
result = clean_text_for_tts("- item1\n* item2\n1. item3")
|
||||
assert "item1" in result
|
||||
assert "item2" in result
|
||||
assert "item3" in result
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_normalizes_whitespace(self):
|
||||
result = clean_text_for_tts(" lots of spaces ")
|
||||
assert " " not in result
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_removes_braces(self):
|
||||
result = clean_text_for_tts("{content} and [more]")
|
||||
assert "content" in result
|
||||
assert "more" in result
|
||||
assert "{" not in result
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_removes_double_colons(self):
|
||||
result = clean_text_for_tts("module::function")
|
||||
assert "::" not in result
|
||||
Reference in New Issue
Block a user