Files
DocsGPT/application/api/user/workflows/routes.py
Alex 81b6ee5daa Pg 4 (#2390)
* feat: postgres tests

* feat: mongo cutoff

* feat: mongo cutoff

* feat: adjust docs and compose files

* fix: mini code mongo removals

* fix: tests and k8s mongo stuff

* feat: test fixes

* fix: ruff

* fix: vale

* Potential fix for pull request finding 'CodeQL / Clear-text logging of sensitive information'

Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com>

* fix: mini suggestions

* vale lint fix 2

* fix: codeql columns thing

* fix: test mongo

* fix: tests coverage

* feat: better tests 4

* feat: more tests

* feat: decent coverage

* fix: ruff fixes

* fix: remove mongo mock

* feat: enhance workflow engine and API routes; add document retrieval and source handling

* feat: e2e tests

* fix: mcp, mongo and more

* fix: mini codeql warning

* fix: agent chunk view

* fix: mini issues

* fix: more pg fixes

* feat: postgres prep on start

* feat: qa tests

* fix: mini improvements

* fix: tests

---------

Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com>
Co-authored-by: Siddhant Rai <siddhant.rai.5686@gmail.com>
2026-04-18 13:13:57 +01:00

510 lines
19 KiB
Python

"""Workflow management routes."""
from typing import Any, Dict, List, Optional, Set
from flask import current_app, request
from flask_restx import Namespace, Resource
from application.storage.db.base_repository import looks_like_uuid
from application.storage.db.repositories.workflow_edges import WorkflowEdgesRepository
from application.storage.db.repositories.workflow_nodes import WorkflowNodesRepository
from application.storage.db.repositories.workflows import WorkflowsRepository
from application.storage.db.session import db_readonly, db_session
from application.core.json_schema_utils import (
JsonSchemaValidationError,
normalize_json_schema_payload,
)
from application.core.model_utils import get_model_capabilities
from application.api.user.utils import (
error_response,
get_user_id,
require_auth,
require_fields,
success_response,
)
workflows_ns = Namespace("workflows", path="/api")
def _workflow_error_response(message: str, err: Exception):
current_app.logger.error(f"{message}: {err}", exc_info=True)
return error_response(message)
def _resolve_workflow(repo: WorkflowsRepository, workflow_id: str, user_id: str):
"""Resolve a workflow by UUID or legacy Mongo id, scoped to user."""
if not workflow_id:
return None
if looks_like_uuid(workflow_id):
row = repo.get(workflow_id, user_id)
if row is not None:
return row
return repo.get_by_legacy_id(workflow_id, user_id)
def _write_graph(
conn,
pg_workflow_id: str,
graph_version: int,
nodes_data: List[Dict],
edges_data: List[Dict],
) -> List[Dict]:
"""Bulk-create nodes + edges for one graph version. Uses ON CONFLICT upsert.
Edges arrive with source/target as user-provided node-id strings. We
insert nodes first, capture their ``node_id → UUID`` map, then
translate edges before insertion. Edges referencing missing nodes are
dropped with a warning.
"""
nodes_repo = WorkflowNodesRepository(conn)
edges_repo = WorkflowEdgesRepository(conn)
if nodes_data:
created_nodes = nodes_repo.bulk_create(
pg_workflow_id, graph_version,
[
{
"node_id": n["id"],
"node_type": n["type"],
"title": n.get("title", ""),
"description": n.get("description", ""),
"position": n.get("position", {"x": 0, "y": 0}),
"config": n.get("data", {}),
}
for n in nodes_data
],
)
node_uuid_by_str = {n["node_id"]: n["id"] for n in created_nodes}
else:
created_nodes = []
node_uuid_by_str = {}
if edges_data:
translated_edges: List[Dict] = []
for e in edges_data:
src = e.get("source")
tgt = e.get("target")
from_uuid = node_uuid_by_str.get(src)
to_uuid = node_uuid_by_str.get(tgt)
if not from_uuid or not to_uuid:
current_app.logger.warning(
"Workflow graph write: dropping edge %s; node refs unresolved "
"(source=%s, target=%s)",
e.get("id"), src, tgt,
)
continue
translated_edges.append({
"edge_id": e["id"],
"from_node_id": from_uuid,
"to_node_id": to_uuid,
"source_handle": e.get("sourceHandle"),
"target_handle": e.get("targetHandle"),
})
if translated_edges:
edges_repo.bulk_create(
pg_workflow_id, graph_version, translated_edges,
)
return created_nodes
def serialize_workflow(w: Dict) -> Dict:
"""Serialize workflow row to API response format."""
created_at = w.get("created_at")
updated_at = w.get("updated_at")
return {
"id": str(w["id"]),
"name": w.get("name"),
"description": w.get("description"),
"created_at": created_at.isoformat() if hasattr(created_at, "isoformat") else created_at,
"updated_at": updated_at.isoformat() if hasattr(updated_at, "isoformat") else updated_at,
}
def serialize_node(n: Dict) -> Dict:
"""Serialize workflow node row to API response format."""
return {
"id": n["node_id"],
"type": n["node_type"],
"title": n.get("title"),
"description": n.get("description"),
"position": n.get("position"),
"data": n.get("config", {}) or {},
}
def serialize_edge(e: Dict) -> Dict:
"""Serialize workflow edge row to API response format."""
return {
"id": e["edge_id"],
"source": e.get("source_id"),
"target": e.get("target_id"),
"sourceHandle": e.get("source_handle"),
"targetHandle": e.get("target_handle"),
}
def get_workflow_graph_version(workflow: Dict) -> int:
"""Get current graph version with fallback."""
raw_version = workflow.get("current_graph_version", 1)
try:
version = int(raw_version)
return version if version > 0 else 1
except (ValueError, TypeError):
return 1
def validate_json_schema_payload(
json_schema: Any,
) -> tuple[Optional[Dict[str, Any]], Optional[str]]:
"""Validate and normalize optional JSON schema payload for structured output."""
if json_schema is None:
return None, None
try:
return normalize_json_schema_payload(json_schema), None
except JsonSchemaValidationError as exc:
return None, str(exc)
def normalize_agent_node_json_schemas(nodes: List[Dict]) -> List[Dict]:
"""Normalize agent-node JSON schema payloads before persistence."""
normalized_nodes: List[Dict] = []
for node in nodes:
if not isinstance(node, dict):
normalized_nodes.append(node)
continue
normalized_node = dict(node)
if normalized_node.get("type") != "agent":
normalized_nodes.append(normalized_node)
continue
raw_config = normalized_node.get("data")
if not isinstance(raw_config, dict) or "json_schema" not in raw_config:
normalized_nodes.append(normalized_node)
continue
normalized_config = dict(raw_config)
try:
normalized_config["json_schema"] = normalize_json_schema_payload(
raw_config.get("json_schema")
)
except JsonSchemaValidationError:
# Validation runs before normalization; keep original on unexpected shape.
normalized_config["json_schema"] = raw_config.get("json_schema")
normalized_node["data"] = normalized_config
normalized_nodes.append(normalized_node)
return normalized_nodes
def validate_workflow_structure(nodes: List[Dict], edges: List[Dict]) -> List[str]:
"""Validate workflow graph structure."""
errors = []
if not nodes:
errors.append("Workflow must have at least one node")
return errors
start_nodes = [n for n in nodes if n.get("type") == "start"]
if len(start_nodes) != 1:
errors.append("Workflow must have exactly one start node")
end_nodes = [n for n in nodes if n.get("type") == "end"]
if not end_nodes:
errors.append("Workflow must have at least one end node")
node_ids = {n.get("id") for n in nodes}
node_map = {n.get("id"): n for n in nodes}
end_ids = {n.get("id") for n in end_nodes}
for edge in edges:
source_id = edge.get("source")
target_id = edge.get("target")
if source_id not in node_ids:
errors.append(f"Edge references non-existent source: {source_id}")
if target_id not in node_ids:
errors.append(f"Edge references non-existent target: {target_id}")
if start_nodes:
start_id = start_nodes[0].get("id")
if not any(e.get("source") == start_id for e in edges):
errors.append("Start node must have at least one outgoing edge")
condition_nodes = [n for n in nodes if n.get("type") == "condition"]
for cnode in condition_nodes:
cnode_id = cnode.get("id")
cnode_title = cnode.get("title", cnode_id)
outgoing = [e for e in edges if e.get("source") == cnode_id]
if len(outgoing) < 2:
errors.append(
f"Condition node '{cnode_title}' must have at least 2 outgoing edges"
)
node_data = cnode.get("data", {}) or {}
cases = node_data.get("cases", [])
if not isinstance(cases, list):
cases = []
if not cases or not any(
isinstance(c, dict) and str(c.get("expression", "")).strip() for c in cases
):
errors.append(
f"Condition node '{cnode_title}' must have at least one case with an expression"
)
case_handles: Set[str] = set()
duplicate_case_handles: Set[str] = set()
for case in cases:
if not isinstance(case, dict):
continue
raw_handle = case.get("sourceHandle", "")
handle = raw_handle.strip() if isinstance(raw_handle, str) else ""
if not handle:
errors.append(
f"Condition node '{cnode_title}' has a case without a branch handle"
)
continue
if handle in case_handles:
duplicate_case_handles.add(handle)
case_handles.add(handle)
for handle in duplicate_case_handles:
errors.append(
f"Condition node '{cnode_title}' has duplicate case handle '{handle}'"
)
outgoing_by_handle: Dict[str, List[Dict]] = {}
for out_edge in outgoing:
raw_handle = out_edge.get("sourceHandle", "")
handle = raw_handle.strip() if isinstance(raw_handle, str) else ""
outgoing_by_handle.setdefault(handle, []).append(out_edge)
for handle, handle_edges in outgoing_by_handle.items():
if not handle:
errors.append(
f"Condition node '{cnode_title}' has an outgoing edge without sourceHandle"
)
continue
if handle != "else" and handle not in case_handles:
errors.append(
f"Condition node '{cnode_title}' has a connection from unknown branch '{handle}'"
)
if len(handle_edges) > 1:
errors.append(
f"Condition node '{cnode_title}' has multiple outgoing edges from branch '{handle}'"
)
if "else" not in outgoing_by_handle:
errors.append(f"Condition node '{cnode_title}' must have an 'else' branch")
for case in cases:
if not isinstance(case, dict):
continue
raw_handle = case.get("sourceHandle", "")
handle = raw_handle.strip() if isinstance(raw_handle, str) else ""
if not handle:
continue
raw_expression = case.get("expression", "")
has_expression = isinstance(raw_expression, str) and bool(
raw_expression.strip()
)
has_outgoing = bool(outgoing_by_handle.get(handle))
if has_expression and not has_outgoing:
errors.append(
f"Condition node '{cnode_title}' case '{handle}' has an expression but no outgoing edge"
)
if not has_expression and has_outgoing:
errors.append(
f"Condition node '{cnode_title}' case '{handle}' has an outgoing edge but no expression"
)
for handle, handle_edges in outgoing_by_handle.items():
if not handle:
continue
for out_edge in handle_edges:
target = out_edge.get("target")
if target and not _can_reach_end(target, edges, node_map, end_ids):
errors.append(
f"Branch '{handle}' of condition '{cnode_title}' "
f"must eventually reach an end node"
)
agent_nodes = [n for n in nodes if n.get("type") == "agent"]
for agent_node in agent_nodes:
agent_title = agent_node.get("title", agent_node.get("id", "unknown"))
raw_config = agent_node.get("data", {}) or {}
if not isinstance(raw_config, dict):
errors.append(f"Agent node '{agent_title}' has invalid configuration")
continue
normalized_schema, schema_error = validate_json_schema_payload(
raw_config.get("json_schema")
)
has_json_schema = normalized_schema is not None
model_id = raw_config.get("model_id")
if has_json_schema and isinstance(model_id, str) and model_id.strip():
capabilities = get_model_capabilities(model_id.strip())
if capabilities and not capabilities.get("supports_structured_output", False):
errors.append(
f"Agent node '{agent_title}' selected model does not support structured output"
)
if schema_error:
errors.append(f"Agent node '{agent_title}' JSON schema {schema_error}")
for node in nodes:
if not node.get("id"):
errors.append("All nodes must have an id")
if not node.get("type"):
errors.append(f"Node {node.get('id', 'unknown')} must have a type")
return errors
def _can_reach_end(
node_id: str, edges: List[Dict], node_map: Dict, end_ids: set, visited: set = None
) -> bool:
if visited is None:
visited = set()
if node_id in end_ids:
return True
if node_id in visited or node_id not in node_map:
return False
visited.add(node_id)
outgoing = [e.get("target") for e in edges if e.get("source") == node_id]
return any(_can_reach_end(t, edges, node_map, end_ids, visited) for t in outgoing if t)
@workflows_ns.route("/workflows")
class WorkflowList(Resource):
@require_auth
@require_fields(["name"])
def post(self):
"""Create a new workflow with nodes and edges."""
user_id = get_user_id()
data = request.get_json()
name = data.get("name", "").strip()
description = data.get("description", "")
nodes_data = data.get("nodes", [])
edges_data = data.get("edges", [])
validation_errors = validate_workflow_structure(nodes_data, edges_data)
if validation_errors:
return error_response(
"Workflow validation failed", errors=validation_errors
)
nodes_data = normalize_agent_node_json_schemas(nodes_data)
try:
with db_session() as conn:
repo = WorkflowsRepository(conn)
workflow = repo.create(user_id, name, description=description)
pg_workflow_id = str(workflow["id"])
_write_graph(conn, pg_workflow_id, 1, nodes_data, edges_data)
except Exception as err:
return _workflow_error_response("Failed to create workflow", err)
return success_response({"id": pg_workflow_id}, 201)
@workflows_ns.route("/workflows/<string:workflow_id>")
class WorkflowDetail(Resource):
@require_auth
def get(self, workflow_id: str):
"""Get workflow details with nodes and edges."""
user_id = get_user_id()
try:
with db_readonly() as conn:
repo = WorkflowsRepository(conn)
workflow = _resolve_workflow(repo, workflow_id, user_id)
if workflow is None:
return error_response("Workflow not found", 404)
pg_workflow_id = str(workflow["id"])
graph_version = get_workflow_graph_version(workflow)
nodes = WorkflowNodesRepository(conn).find_by_version(
pg_workflow_id, graph_version,
)
edges = WorkflowEdgesRepository(conn).find_by_version(
pg_workflow_id, graph_version,
)
except Exception as err:
return _workflow_error_response("Failed to fetch workflow", err)
return success_response(
{
"workflow": serialize_workflow(workflow),
"nodes": [serialize_node(n) for n in nodes],
"edges": [serialize_edge(e) for e in edges],
}
)
@require_auth
@require_fields(["name"])
def put(self, workflow_id: str):
"""Update workflow and replace nodes/edges."""
user_id = get_user_id()
data = request.get_json()
name = data.get("name", "").strip()
description = data.get("description", "")
nodes_data = data.get("nodes", [])
edges_data = data.get("edges", [])
validation_errors = validate_workflow_structure(nodes_data, edges_data)
if validation_errors:
return error_response(
"Workflow validation failed", errors=validation_errors
)
nodes_data = normalize_agent_node_json_schemas(nodes_data)
try:
with db_session() as conn:
repo = WorkflowsRepository(conn)
workflow = _resolve_workflow(repo, workflow_id, user_id)
if workflow is None:
return error_response("Workflow not found", 404)
pg_workflow_id = str(workflow["id"])
current_graph_version = get_workflow_graph_version(workflow)
next_graph_version = current_graph_version + 1
_write_graph(
conn, pg_workflow_id, next_graph_version,
nodes_data, edges_data,
)
repo.update(
pg_workflow_id, user_id,
{
"name": name,
"description": description,
"current_graph_version": next_graph_version,
},
)
WorkflowNodesRepository(conn).delete_other_versions(
pg_workflow_id, next_graph_version,
)
WorkflowEdgesRepository(conn).delete_other_versions(
pg_workflow_id, next_graph_version,
)
except Exception as err:
return _workflow_error_response("Failed to update workflow", err)
return success_response()
@require_auth
def delete(self, workflow_id: str):
"""Delete workflow and its graph."""
user_id = get_user_id()
try:
with db_session() as conn:
repo = WorkflowsRepository(conn)
workflow = _resolve_workflow(repo, workflow_id, user_id)
if workflow is None:
return error_response("Workflow not found", 404)
# ON DELETE CASCADE on workflow_nodes/edges cleans children.
repo.delete(str(workflow["id"]), user_id)
except Exception as err:
return _workflow_error_response("Failed to delete workflow", err)
return success_response()