mirror of
https://github.com/arc53/DocsGPT.git
synced 2026-05-21 21:05:05 +00:00
574 lines
20 KiB
Python
574 lines
20 KiB
Python
"""Tests covering gaps in WorkflowEngine: execute loop, state/condition/end nodes,
|
|
template context, source data, structured output parsing, get_execution_summary."""
|
|
|
|
from datetime import datetime, timezone
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
import pytest
|
|
|
|
from application.agents.workflows.schemas import (
|
|
ExecutionStatus,
|
|
NodeType,
|
|
WorkflowEdge,
|
|
WorkflowGraph,
|
|
WorkflowNode,
|
|
Workflow,
|
|
)
|
|
from application.agents.workflows.workflow_engine import WorkflowEngine
|
|
|
|
|
|
def _make_graph(nodes, edges):
|
|
wf = Workflow(name="Test", description="test workflow")
|
|
return WorkflowGraph(workflow=wf, nodes=nodes, edges=edges)
|
|
|
|
|
|
def _make_node(id, type, title="Node", config=None, position=None):
|
|
return WorkflowNode(
|
|
id=id,
|
|
workflow_id="wf1",
|
|
type=type,
|
|
title=title,
|
|
position=position or {"x": 0, "y": 0},
|
|
config=config or {},
|
|
)
|
|
|
|
|
|
def _make_edge(id, source, target, source_handle=None, target_handle=None):
|
|
return WorkflowEdge(
|
|
id=id,
|
|
workflow_id="wf1",
|
|
source=source,
|
|
target=target,
|
|
sourceHandle=source_handle,
|
|
targetHandle=target_handle,
|
|
)
|
|
|
|
|
|
def _make_agent():
|
|
agent = MagicMock()
|
|
agent.chat_history = []
|
|
agent.endpoint = "https://api.example.com"
|
|
agent.llm_name = "openai"
|
|
agent.model_id = "gpt-4"
|
|
agent.api_key = "key"
|
|
agent.decoded_token = {"sub": "user1"}
|
|
agent.retrieved_docs = None
|
|
return agent
|
|
|
|
|
|
class TestExecuteLoop:
|
|
|
|
@pytest.mark.unit
|
|
def test_no_start_node_yields_error(self):
|
|
graph = _make_graph([], [])
|
|
engine = WorkflowEngine(graph, _make_agent())
|
|
events = list(engine.execute({}, "query"))
|
|
assert any(e.get("type") == "error" and "start node" in e.get("error", "") for e in events)
|
|
|
|
@pytest.mark.unit
|
|
def test_start_to_end(self):
|
|
nodes = [
|
|
_make_node("n1", NodeType.START, "Start"),
|
|
_make_node("n2", NodeType.END, "End", config={"config": {}}),
|
|
]
|
|
edges = [_make_edge("e1", "n1", "n2")]
|
|
graph = _make_graph(nodes, edges)
|
|
engine = WorkflowEngine(graph, _make_agent())
|
|
events = list(engine.execute({}, "hello"))
|
|
step_events = [e for e in events if e.get("type") == "workflow_step"]
|
|
assert len(step_events) >= 2 # At least start + end
|
|
|
|
@pytest.mark.unit
|
|
def test_node_not_found_yields_error(self):
|
|
nodes = [_make_node("n1", NodeType.START)]
|
|
edges = [_make_edge("e1", "n1", "nonexistent")]
|
|
graph = _make_graph(nodes, edges)
|
|
engine = WorkflowEngine(graph, _make_agent())
|
|
events = list(engine.execute({}, "q"))
|
|
assert any("not found" in e.get("error", "") for e in events)
|
|
|
|
@pytest.mark.unit
|
|
def test_node_execution_error_yields_error(self):
|
|
nodes = [
|
|
_make_node("n1", NodeType.START),
|
|
_make_node("n2", NodeType.STATE, "State", config={"config": {"operations": [{"expression": "bad!!!", "target_variable": "x"}]}}),
|
|
]
|
|
edges = [_make_edge("e1", "n1", "n2")]
|
|
graph = _make_graph(nodes, edges)
|
|
engine = WorkflowEngine(graph, _make_agent())
|
|
events = list(engine.execute({}, "q"))
|
|
failed_events = [e for e in events if e.get("status") == "failed"]
|
|
assert len(failed_events) >= 1
|
|
|
|
@pytest.mark.unit
|
|
def test_max_steps_limit(self):
|
|
# Create a cycle: start -> state -> state (loop)
|
|
nodes = [
|
|
_make_node("n1", NodeType.START),
|
|
_make_node("n2", NodeType.NOTE, "Note"),
|
|
]
|
|
edges = [_make_edge("e1", "n1", "n2"), _make_edge("e2", "n2", "n2")]
|
|
graph = _make_graph(nodes, edges)
|
|
engine = WorkflowEngine(graph, _make_agent())
|
|
engine.MAX_EXECUTION_STEPS = 5
|
|
events = list(engine.execute({}, "q"))
|
|
# Should not run forever
|
|
assert len(events) > 0
|
|
|
|
@pytest.mark.unit
|
|
def test_branch_ends_without_end_node(self):
|
|
nodes = [
|
|
_make_node("n1", NodeType.START),
|
|
_make_node("n2", NodeType.NOTE, "Note"),
|
|
]
|
|
edges = [_make_edge("e1", "n1", "n2")] # n2 has no outgoing edges
|
|
graph = _make_graph(nodes, edges)
|
|
engine = WorkflowEngine(graph, _make_agent())
|
|
events = list(engine.execute({}, "q"))
|
|
assert len(events) > 0
|
|
|
|
|
|
class TestInitializeState:
|
|
|
|
@pytest.mark.unit
|
|
def test_sets_query_and_history(self):
|
|
graph = _make_graph([], [])
|
|
agent = _make_agent()
|
|
agent.chat_history = [{"prompt": "hi", "response": "hey"}]
|
|
engine = WorkflowEngine(graph, agent)
|
|
engine._initialize_state({"custom": "value"}, "test query")
|
|
assert engine.state["query"] == "test query"
|
|
assert "custom" in engine.state
|
|
assert engine.state["chat_history"] is not None
|
|
|
|
|
|
class TestGetNextNodeId:
|
|
|
|
@pytest.mark.unit
|
|
def test_no_edges_returns_none(self):
|
|
nodes = [_make_node("n1", NodeType.START)]
|
|
graph = _make_graph(nodes, [])
|
|
engine = WorkflowEngine(graph, _make_agent())
|
|
assert engine._get_next_node_id("n1") is None
|
|
|
|
@pytest.mark.unit
|
|
def test_returns_first_edge_target(self):
|
|
nodes = [_make_node("n1", NodeType.START), _make_node("n2", NodeType.END)]
|
|
edges = [_make_edge("e1", "n1", "n2")]
|
|
graph = _make_graph(nodes, edges)
|
|
engine = WorkflowEngine(graph, _make_agent())
|
|
assert engine._get_next_node_id("n1") == "n2"
|
|
|
|
@pytest.mark.unit
|
|
def test_condition_uses_matched_handle(self):
|
|
nodes = [
|
|
_make_node("n1", NodeType.CONDITION),
|
|
_make_node("n2", NodeType.END, "Yes End"),
|
|
_make_node("n3", NodeType.END, "No End"),
|
|
]
|
|
edges = [
|
|
_make_edge("e1", "n1", "n2", source_handle="yes"),
|
|
_make_edge("e2", "n1", "n3", source_handle="no"),
|
|
]
|
|
graph = _make_graph(nodes, edges)
|
|
engine = WorkflowEngine(graph, _make_agent())
|
|
engine._condition_result = "no"
|
|
assert engine._get_next_node_id("n1") == "n3"
|
|
assert engine._condition_result is None # Cleared after use
|
|
|
|
@pytest.mark.unit
|
|
def test_condition_no_matching_handle_returns_none(self):
|
|
nodes = [_make_node("n1", NodeType.CONDITION)]
|
|
edges = [_make_edge("e1", "n1", "n2", source_handle="yes")]
|
|
graph = _make_graph(nodes, edges)
|
|
engine = WorkflowEngine(graph, _make_agent())
|
|
engine._condition_result = "nonexistent"
|
|
assert engine._get_next_node_id("n1") is None
|
|
|
|
|
|
class TestExecuteStateNode:
|
|
|
|
@pytest.mark.unit
|
|
def test_evaluates_operations(self):
|
|
node = _make_node("n1", NodeType.STATE, config={
|
|
"config": {
|
|
"operations": [
|
|
{"expression": "x + 1", "target_variable": "result"},
|
|
]
|
|
}
|
|
})
|
|
graph = _make_graph([node], [])
|
|
engine = WorkflowEngine(graph, _make_agent())
|
|
engine.state = {"x": 5}
|
|
list(engine._execute_state_node(node))
|
|
assert engine.state["result"] == 6
|
|
|
|
@pytest.mark.unit
|
|
def test_skips_empty_expression(self):
|
|
node = _make_node("n1", NodeType.STATE, config={
|
|
"config": {
|
|
"operations": [
|
|
{"expression": "", "target_variable": "result"},
|
|
]
|
|
}
|
|
})
|
|
graph = _make_graph([node], [])
|
|
engine = WorkflowEngine(graph, _make_agent())
|
|
engine.state = {}
|
|
list(engine._execute_state_node(node))
|
|
assert "result" not in engine.state
|
|
|
|
|
|
class TestExecuteConditionNode:
|
|
|
|
@pytest.mark.unit
|
|
def test_matches_first_true_case(self):
|
|
node = _make_node("n1", NodeType.CONDITION, config={
|
|
"config": {
|
|
"cases": [
|
|
{"expression": "x > 10", "source_handle": "high"},
|
|
{"expression": "x > 5", "source_handle": "medium"},
|
|
]
|
|
}
|
|
})
|
|
graph = _make_graph([node], [])
|
|
engine = WorkflowEngine(graph, _make_agent())
|
|
engine.state = {"x": 7}
|
|
list(engine._execute_condition_node(node))
|
|
assert engine._condition_result == "medium"
|
|
|
|
@pytest.mark.unit
|
|
def test_falls_through_to_else(self):
|
|
node = _make_node("n1", NodeType.CONDITION, config={
|
|
"config": {
|
|
"cases": [
|
|
{"expression": "x > 100", "source_handle": "high"},
|
|
]
|
|
}
|
|
})
|
|
graph = _make_graph([node], [])
|
|
engine = WorkflowEngine(graph, _make_agent())
|
|
engine.state = {"x": 1}
|
|
list(engine._execute_condition_node(node))
|
|
assert engine._condition_result == "else"
|
|
|
|
@pytest.mark.unit
|
|
def test_skips_empty_expression(self):
|
|
node = _make_node("n1", NodeType.CONDITION, config={
|
|
"config": {
|
|
"cases": [
|
|
{"expression": " ", "source_handle": "a"},
|
|
{"expression": "true", "source_handle": "b"},
|
|
]
|
|
}
|
|
})
|
|
graph = _make_graph([node], [])
|
|
engine = WorkflowEngine(graph, _make_agent())
|
|
engine.state = {}
|
|
list(engine._execute_condition_node(node))
|
|
assert engine._condition_result == "b"
|
|
|
|
@pytest.mark.unit
|
|
def test_cel_error_continues(self):
|
|
node = _make_node("n1", NodeType.CONDITION, config={
|
|
"config": {
|
|
"cases": [
|
|
{"expression": "bad!!!", "source_handle": "a"},
|
|
{"expression": "true", "source_handle": "b"},
|
|
]
|
|
}
|
|
})
|
|
graph = _make_graph([node], [])
|
|
engine = WorkflowEngine(graph, _make_agent())
|
|
engine.state = {}
|
|
list(engine._execute_condition_node(node))
|
|
assert engine._condition_result == "b"
|
|
|
|
|
|
class TestExecuteEndNode:
|
|
|
|
@pytest.mark.unit
|
|
def test_with_output_template(self):
|
|
node = _make_node("n1", NodeType.END, config={
|
|
"config": {"output_template": "Result: {{ query }}"}
|
|
})
|
|
graph = _make_graph([node], [])
|
|
engine = WorkflowEngine(graph, _make_agent())
|
|
engine.state = {"query": "hello"}
|
|
engine._format_template = MagicMock(return_value="Result: hello")
|
|
events = list(engine._execute_end_node(node))
|
|
assert len(events) == 1
|
|
assert events[0]["answer"] == "Result: hello"
|
|
|
|
@pytest.mark.unit
|
|
def test_without_output_template(self):
|
|
node = _make_node("n1", NodeType.END, config={"config": {}})
|
|
graph = _make_graph([node], [])
|
|
engine = WorkflowEngine(graph, _make_agent())
|
|
events = list(engine._execute_end_node(node))
|
|
assert len(events) == 0
|
|
|
|
|
|
class TestParseStructuredOutput:
|
|
|
|
@pytest.mark.unit
|
|
def test_valid_json(self):
|
|
graph = _make_graph([], [])
|
|
engine = WorkflowEngine(graph, _make_agent())
|
|
success, data = engine._parse_structured_output('{"key": "value"}')
|
|
assert success is True
|
|
assert data == {"key": "value"}
|
|
|
|
@pytest.mark.unit
|
|
def test_invalid_json(self):
|
|
graph = _make_graph([], [])
|
|
engine = WorkflowEngine(graph, _make_agent())
|
|
success, data = engine._parse_structured_output("not json")
|
|
assert success is False
|
|
assert data is None
|
|
|
|
@pytest.mark.unit
|
|
def test_empty_string(self):
|
|
graph = _make_graph([], [])
|
|
engine = WorkflowEngine(graph, _make_agent())
|
|
success, data = engine._parse_structured_output("")
|
|
assert success is False
|
|
assert data is None
|
|
|
|
@pytest.mark.unit
|
|
def test_whitespace_only(self):
|
|
graph = _make_graph([], [])
|
|
engine = WorkflowEngine(graph, _make_agent())
|
|
success, data = engine._parse_structured_output(" ")
|
|
assert success is False
|
|
assert data is None
|
|
|
|
|
|
class TestNormalizeNodeJsonSchema:
|
|
|
|
@pytest.mark.unit
|
|
def test_none_returns_none(self):
|
|
graph = _make_graph([], [])
|
|
engine = WorkflowEngine(graph, _make_agent())
|
|
assert engine._normalize_node_json_schema(None, "Node") is None
|
|
|
|
@pytest.mark.unit
|
|
def test_valid_schema(self):
|
|
graph = _make_graph([], [])
|
|
engine = WorkflowEngine(graph, _make_agent())
|
|
schema = {"type": "object", "properties": {"name": {"type": "string"}}}
|
|
result = engine._normalize_node_json_schema(schema, "Node")
|
|
assert result is not None
|
|
|
|
@pytest.mark.unit
|
|
def test_invalid_schema_raises(self):
|
|
graph = _make_graph([], [])
|
|
engine = WorkflowEngine(graph, _make_agent())
|
|
with patch("application.agents.workflows.workflow_engine.normalize_json_schema_payload") as mock_norm:
|
|
from application.core.json_schema_utils import JsonSchemaValidationError
|
|
mock_norm.side_effect = JsonSchemaValidationError("bad schema")
|
|
with pytest.raises(ValueError, match="Invalid JSON schema"):
|
|
engine._normalize_node_json_schema({"bad": True}, "TestNode")
|
|
|
|
|
|
class TestValidateStructuredOutput:
|
|
|
|
@pytest.mark.unit
|
|
def test_valid_output_passes(self):
|
|
graph = _make_graph([], [])
|
|
engine = WorkflowEngine(graph, _make_agent())
|
|
schema = {"type": "object", "properties": {"name": {"type": "string"}}}
|
|
engine._validate_structured_output(schema, {"name": "Alice"}) # Should not raise
|
|
|
|
@pytest.mark.unit
|
|
def test_invalid_output_raises(self):
|
|
graph = _make_graph([], [])
|
|
engine = WorkflowEngine(graph, _make_agent())
|
|
schema = {"type": "object", "properties": {"name": {"type": "string"}}, "required": ["name"]}
|
|
with pytest.raises(ValueError, match="did not match schema"):
|
|
engine._validate_structured_output(schema, {})
|
|
|
|
@pytest.mark.unit
|
|
def test_no_jsonschema_module(self):
|
|
graph = _make_graph([], [])
|
|
engine = WorkflowEngine(graph, _make_agent())
|
|
with patch("application.agents.workflows.workflow_engine.jsonschema", None):
|
|
engine._validate_structured_output({"type": "object"}, {}) # Should not raise
|
|
|
|
|
|
class TestFormatTemplate:
|
|
|
|
@pytest.mark.unit
|
|
def test_renders_template(self):
|
|
graph = _make_graph([], [])
|
|
engine = WorkflowEngine(graph, _make_agent())
|
|
engine.state = {"query": "hello"}
|
|
engine._build_template_context = MagicMock(return_value={"query": "hello"})
|
|
engine._template_engine = MagicMock()
|
|
engine._template_engine.render.return_value = "hello world"
|
|
result = engine._format_template("{{ query }} world")
|
|
assert result == "hello world"
|
|
|
|
@pytest.mark.unit
|
|
def test_render_error_returns_raw(self):
|
|
from application.templates.template_engine import TemplateRenderError
|
|
graph = _make_graph([], [])
|
|
engine = WorkflowEngine(graph, _make_agent())
|
|
engine._build_template_context = MagicMock(return_value={})
|
|
engine._template_engine = MagicMock()
|
|
engine._template_engine.render.side_effect = TemplateRenderError("fail")
|
|
result = engine._format_template("{{ bad }}")
|
|
assert result == "{{ bad }}"
|
|
|
|
|
|
class TestBuildTemplateContext:
|
|
|
|
@pytest.mark.unit
|
|
def test_includes_state_variables(self):
|
|
graph = _make_graph([], [])
|
|
agent = _make_agent()
|
|
agent.retrieved_docs = None
|
|
engine = WorkflowEngine(graph, agent)
|
|
engine.state = {"query": "hello", "custom_var": "value"}
|
|
context = engine._build_template_context()
|
|
assert context["agent"]["query"] == "hello"
|
|
assert "custom_var" in context
|
|
|
|
@pytest.mark.unit
|
|
def test_reserved_namespace_gets_prefixed(self):
|
|
graph = _make_graph([], [])
|
|
agent = _make_agent()
|
|
agent.retrieved_docs = None
|
|
engine = WorkflowEngine(graph, agent)
|
|
engine.state = {"source": "my_source_val"}
|
|
context = engine._build_template_context()
|
|
assert context.get("agent_source") == "my_source_val"
|
|
|
|
@pytest.mark.unit
|
|
def test_passthrough_data(self):
|
|
graph = _make_graph([], [])
|
|
agent = _make_agent()
|
|
agent.retrieved_docs = None
|
|
engine = WorkflowEngine(graph, agent)
|
|
engine.state = {"passthrough": {"key": "val"}}
|
|
context = engine._build_template_context()
|
|
assert "passthrough" in context or "agent_passthrough" in context
|
|
|
|
@pytest.mark.unit
|
|
def test_tools_data(self):
|
|
graph = _make_graph([], [])
|
|
agent = _make_agent()
|
|
agent.retrieved_docs = None
|
|
engine = WorkflowEngine(graph, agent)
|
|
engine.state = {"tools": {"tool1": "result"}}
|
|
context = engine._build_template_context()
|
|
assert "agent" in context
|
|
|
|
|
|
class TestGetSourceTemplateData:
|
|
|
|
@pytest.mark.unit
|
|
def test_no_docs_returns_none(self):
|
|
graph = _make_graph([], [])
|
|
agent = _make_agent()
|
|
agent.retrieved_docs = None
|
|
engine = WorkflowEngine(graph, agent)
|
|
docs, together = engine._get_source_template_data()
|
|
assert docs is None
|
|
assert together is None
|
|
|
|
@pytest.mark.unit
|
|
def test_empty_docs_returns_none(self):
|
|
graph = _make_graph([], [])
|
|
agent = _make_agent()
|
|
agent.retrieved_docs = []
|
|
engine = WorkflowEngine(graph, agent)
|
|
docs, together = engine._get_source_template_data()
|
|
assert docs is None
|
|
|
|
@pytest.mark.unit
|
|
def test_docs_with_filename(self):
|
|
graph = _make_graph([], [])
|
|
agent = _make_agent()
|
|
agent.retrieved_docs = [{"text": "content", "filename": "doc.txt"}]
|
|
engine = WorkflowEngine(graph, agent)
|
|
docs, together = engine._get_source_template_data()
|
|
assert docs is not None
|
|
assert "doc.txt" in together
|
|
assert "content" in together
|
|
|
|
@pytest.mark.unit
|
|
def test_docs_without_filename(self):
|
|
graph = _make_graph([], [])
|
|
agent = _make_agent()
|
|
agent.retrieved_docs = [{"text": "content only"}]
|
|
engine = WorkflowEngine(graph, agent)
|
|
docs, together = engine._get_source_template_data()
|
|
assert together == "content only"
|
|
|
|
@pytest.mark.unit
|
|
def test_skips_non_dict_docs(self):
|
|
graph = _make_graph([], [])
|
|
agent = _make_agent()
|
|
agent.retrieved_docs = ["not a dict", {"text": "ok"}]
|
|
engine = WorkflowEngine(graph, agent)
|
|
docs, together = engine._get_source_template_data()
|
|
assert together == "ok"
|
|
|
|
@pytest.mark.unit
|
|
def test_skips_non_string_text(self):
|
|
graph = _make_graph([], [])
|
|
agent = _make_agent()
|
|
agent.retrieved_docs = [{"text": 123}]
|
|
engine = WorkflowEngine(graph, agent)
|
|
docs, together = engine._get_source_template_data()
|
|
assert together is None
|
|
|
|
@pytest.mark.unit
|
|
def test_doc_with_title_fallback(self):
|
|
graph = _make_graph([], [])
|
|
agent = _make_agent()
|
|
agent.retrieved_docs = [{"text": "content", "title": "doc_title"}]
|
|
engine = WorkflowEngine(graph, agent)
|
|
docs, together = engine._get_source_template_data()
|
|
assert "doc_title" in together
|
|
|
|
@pytest.mark.unit
|
|
def test_doc_with_source_fallback(self):
|
|
graph = _make_graph([], [])
|
|
agent = _make_agent()
|
|
agent.retrieved_docs = [{"text": "content", "source": "src"}]
|
|
engine = WorkflowEngine(graph, agent)
|
|
docs, together = engine._get_source_template_data()
|
|
assert "src" in together
|
|
|
|
|
|
class TestGetExecutionSummary:
|
|
|
|
@pytest.mark.unit
|
|
def test_returns_log_entries(self):
|
|
graph = _make_graph([], [])
|
|
engine = WorkflowEngine(graph, _make_agent())
|
|
now = datetime.now(timezone.utc)
|
|
engine.execution_log = [
|
|
{
|
|
"node_id": "n1",
|
|
"node_type": "start",
|
|
"status": "completed",
|
|
"started_at": now,
|
|
"completed_at": now,
|
|
"error": None,
|
|
"state_snapshot": {},
|
|
}
|
|
]
|
|
summary = engine.get_execution_summary()
|
|
assert len(summary) == 1
|
|
assert summary[0].node_id == "n1"
|
|
assert summary[0].status == ExecutionStatus.COMPLETED
|
|
|
|
@pytest.mark.unit
|
|
def test_empty_log(self):
|
|
graph = _make_graph([], [])
|
|
engine = WorkflowEngine(graph, _make_agent())
|
|
assert engine.get_execution_summary() == []
|