feat: agent workflow builder (#2264)

* feat: implement WorkflowAgent and GraphExecutor for workflow management and execution

* refactor: workflow schemas and introduce WorkflowEngine

- Updated schemas in `schemas.py` to include new agent types and configurations.
- Created `WorkflowEngine` class in `workflow_engine.py` to manage workflow execution.
- Enhanced `StreamProcessor` to handle workflow-related data.
- Added new routes and utilities for managing workflows in the user API.
- Implemented validation and serialization functions for workflows.
- Established MongoDB collections and indexes for workflows and related entities.

* refactor: improve WorkflowAgent documentation and update type hints in WorkflowEngine

* feat: workflow builder and managing in frontend

- Added new endpoints for workflows in `endpoints.ts`.
- Implemented `getWorkflow`, `createWorkflow`, and `updateWorkflow` methods in `userService.ts`.
- Introduced new UI components for alerts, buttons, commands, dialogs, multi-select, popovers, and selects.
- Enhanced styling in `index.css` with new theme variables and animations.
- Refactored modal components for better layout and styling.
- Configured TypeScript paths and Vite aliases for cleaner imports.

* feat: add workflow preview component and related state management

- Implemented WorkflowPreview component for displaying workflow execution.
- Created WorkflowPreviewSlice for managing workflow preview state, including queries and execution steps.
- Added WorkflowMiniMap for visual representation of workflow nodes and their statuses.
- Integrated conversation handling with the ability to fetch answers and manage query states.
- Introduced reusable Sheet component for UI overlays.
- Updated Redux store to include workflowPreview reducer.

* feat: enhance workflow execution details and state management in WorkflowEngine and WorkflowPreview

* feat: enhance workflow components with improved UI and functionality

- Updated WorkflowPreview to allow text truncation for better display of long names.
- Enhanced BaseNode with connectable handles and improved styling for better visibility.
- Added MobileBlocker component to inform users about desktop requirements for the Workflow Builder.
- Introduced PromptTextArea component for improved variable insertion and search functionality, including upstream variable extraction and context addition.

* feat(workflow): add owner validation and graph version support

* fix: ruff lint

---------

Co-authored-by: Alex <a@tushynski.me>
This commit is contained in:
Siddhant Rai
2026-02-11 19:45:24 +05:30
committed by GitHub
parent 8353f9c649
commit 8ef321d784
52 changed files with 8634 additions and 222 deletions

View File

@@ -19,6 +19,9 @@ from application.api.user.base import (
resolve_tool_details,
storage,
users_collection,
workflow_edges_collection,
workflow_nodes_collection,
workflows_collection,
)
from application.core.settings import settings
from application.utils import (
@@ -31,6 +34,189 @@ from application.utils import (
agents_ns = Namespace("agents", description="Agent management operations", path="/api")
AGENT_TYPE_SCHEMAS = {
"classic": {
"required_published": [
"name",
"description",
"chunks",
"retriever",
"prompt_id",
],
"required_draft": ["name"],
"validate_published": ["name", "description", "prompt_id"],
"validate_draft": [],
"require_source": True,
"fields": [
"user",
"name",
"description",
"agent_type",
"status",
"key",
"image",
"source",
"sources",
"chunks",
"retriever",
"prompt_id",
"tools",
"json_schema",
"models",
"default_model_id",
"folder_id",
"limited_token_mode",
"token_limit",
"limited_request_mode",
"request_limit",
"createdAt",
"updatedAt",
"lastUsedAt",
],
},
"workflow": {
"required_published": ["name", "workflow"],
"required_draft": ["name"],
"validate_published": ["name", "workflow"],
"validate_draft": [],
"fields": [
"user",
"name",
"description",
"agent_type",
"status",
"key",
"workflow",
"folder_id",
"limited_token_mode",
"token_limit",
"limited_request_mode",
"request_limit",
"createdAt",
"updatedAt",
"lastUsedAt",
],
},
}
AGENT_TYPE_SCHEMAS["react"] = AGENT_TYPE_SCHEMAS["classic"]
AGENT_TYPE_SCHEMAS["openai"] = AGENT_TYPE_SCHEMAS["classic"]
def normalize_workflow_reference(workflow_value):
"""Normalize workflow references from form/json payloads."""
if workflow_value is None:
return None
if isinstance(workflow_value, dict):
return (
workflow_value.get("id")
or workflow_value.get("_id")
or workflow_value.get("workflow_id")
)
if isinstance(workflow_value, str):
value = workflow_value.strip()
if not value:
return ""
try:
parsed = json.loads(value)
if isinstance(parsed, str):
return parsed.strip()
if isinstance(parsed, dict):
return (
parsed.get("id") or parsed.get("_id") or parsed.get("workflow_id")
)
except json.JSONDecodeError:
pass
return value
return str(workflow_value)
def validate_workflow_access(workflow_value, user, required=False):
"""Validate workflow reference and ensure ownership."""
workflow_id = normalize_workflow_reference(workflow_value)
if not workflow_id:
if required:
return None, make_response(
jsonify({"success": False, "message": "Workflow is required"}), 400
)
return None, None
if not ObjectId.is_valid(workflow_id):
return None, make_response(
jsonify({"success": False, "message": "Invalid workflow ID format"}), 400
)
workflow = workflows_collection.find_one({"_id": ObjectId(workflow_id), "user": user})
if not workflow:
return None, make_response(
jsonify({"success": False, "message": "Workflow not found"}), 404
)
return workflow_id, None
def build_agent_document(
data, user, key, agent_type, image_url=None, source_field=None, sources_list=None
):
"""Build agent document based on agent type schema."""
if not agent_type or agent_type not in AGENT_TYPE_SCHEMAS:
agent_type = "classic"
schema = AGENT_TYPE_SCHEMAS.get(agent_type, AGENT_TYPE_SCHEMAS["classic"])
allowed_fields = set(schema["fields"])
now = datetime.datetime.now(datetime.timezone.utc)
base_doc = {
"user": user,
"name": data.get("name"),
"description": data.get("description", ""),
"agent_type": agent_type,
"status": data.get("status"),
"key": key,
"createdAt": now,
"updatedAt": now,
"lastUsedAt": None,
}
if agent_type == "workflow":
base_doc["workflow"] = data.get("workflow")
base_doc["folder_id"] = data.get("folder_id")
else:
base_doc.update(
{
"image": image_url or "",
"source": source_field or "",
"sources": sources_list or [],
"chunks": data.get("chunks", ""),
"retriever": data.get("retriever", ""),
"prompt_id": data.get("prompt_id", ""),
"tools": data.get("tools", []),
"json_schema": data.get("json_schema"),
"models": data.get("models", []),
"default_model_id": data.get("default_model_id", ""),
"folder_id": data.get("folder_id"),
}
)
if "limited_token_mode" in allowed_fields:
base_doc["limited_token_mode"] = (
data.get("limited_token_mode") == "True"
if isinstance(data.get("limited_token_mode"), str)
else bool(data.get("limited_token_mode", False))
)
if "token_limit" in allowed_fields:
base_doc["token_limit"] = int(
data.get("token_limit", settings.DEFAULT_AGENT_LIMITS["token_limit"])
)
if "limited_request_mode" in allowed_fields:
base_doc["limited_request_mode"] = (
data.get("limited_request_mode") == "True"
if isinstance(data.get("limited_request_mode"), str)
else bool(data.get("limited_request_mode", False))
)
if "request_limit" in allowed_fields:
base_doc["request_limit"] = int(
data.get("request_limit", settings.DEFAULT_AGENT_LIMITS["request_limit"])
)
return {k: v for k, v in base_doc.items() if k in allowed_fields}
@agents_ns.route("/get_agent")
class GetAgent(Resource):
@api.doc(params={"id": "Agent ID"}, description="Get agent by ID")
@@ -68,7 +254,7 @@ class GetAgent(Resource):
if (isinstance(source_ref, DBRef) and db.dereference(source_ref))
or source_ref == "default"
],
"chunks": agent["chunks"],
"chunks": agent.get("chunks", "2"),
"retriever": agent.get("retriever", ""),
"prompt_id": agent.get("prompt_id", ""),
"tools": agent.get("tools", []),
@@ -99,6 +285,7 @@ class GetAgent(Resource):
"models": agent.get("models", []),
"default_model_id": agent.get("default_model_id", ""),
"folder_id": agent.get("folder_id"),
"workflow": agent.get("workflow"),
}
return make_response(jsonify(data), 200)
except Exception as e:
@@ -148,7 +335,7 @@ class GetAgents(Resource):
isinstance(source_ref, DBRef) and db.dereference(source_ref)
)
],
"chunks": agent["chunks"],
"chunks": agent.get("chunks", "2"),
"retriever": agent.get("retriever", ""),
"prompt_id": agent.get("prompt_id", ""),
"tools": agent.get("tools", []),
@@ -179,9 +366,12 @@ class GetAgents(Resource):
"models": agent.get("models", []),
"default_model_id": agent.get("default_model_id", ""),
"folder_id": agent.get("folder_id"),
"workflow": agent.get("workflow"),
}
for agent in agents
if "source" in agent or "retriever" in agent
if "source" in agent
or "retriever" in agent
or agent.get("agent_type") == "workflow"
]
except Exception as err:
current_app.logger.error(f"Error retrieving agents: {err}", exc_info=True)
@@ -209,16 +399,22 @@ class CreateAgent(Resource):
required=False,
description="List of source identifiers for multiple sources",
),
"chunks": fields.Integer(required=True, description="Chunks count"),
"retriever": fields.String(required=True, description="Retriever ID"),
"prompt_id": fields.String(required=True, description="Prompt ID"),
"chunks": fields.Integer(required=False, description="Chunks count"),
"retriever": fields.String(required=False, description="Retriever ID"),
"prompt_id": fields.String(required=False, description="Prompt ID"),
"tools": fields.List(
fields.String, required=False, description="List of tool identifiers"
),
"agent_type": fields.String(required=True, description="Type of the agent"),
"agent_type": fields.String(
required=False,
description="Type of the agent (classic, react, workflow). Defaults to 'classic' for backwards compatibility.",
),
"status": fields.String(
required=True, description="Status of the agent (draft or published)"
),
"workflow": fields.String(
required=False, description="Workflow ID for workflow-type agents"
),
"json_schema": fields.Raw(
required=False,
description="JSON schema for enforcing structured output format",
@@ -330,18 +526,34 @@ class CreateAgent(Resource):
),
400,
)
if data.get("status") == "published":
required_fields = [
"name",
"description",
"chunks",
"retriever",
"prompt_id",
"agent_type",
]
# Require either source or sources (but not both)
agent_type = data.get("agent_type", "")
# Default to classic schema for empty or unknown agent types
if not data.get("source") and not data.get("sources"):
if not agent_type or agent_type not in AGENT_TYPE_SCHEMAS:
schema = AGENT_TYPE_SCHEMAS["classic"]
# Set agent_type to classic if it was empty
if not agent_type:
agent_type = "classic"
else:
schema = AGENT_TYPE_SCHEMAS[agent_type]
is_published = data.get("status") == "published"
if agent_type == "workflow":
workflow_id, workflow_error = validate_workflow_access(
data.get("workflow"), user, required=is_published
)
if workflow_error:
return workflow_error
data["workflow"] = workflow_id
if data.get("status") == "published":
required_fields = schema["required_published"]
validate_fields = schema["validate_published"]
if (
schema.get("require_source")
and not data.get("source")
and not data.get("sources")
):
return make_response(
jsonify(
{
@@ -351,10 +563,9 @@ class CreateAgent(Resource):
),
400,
)
validate_fields = ["name", "description", "prompt_id", "agent_type"]
else:
required_fields = ["name"]
validate_fields = []
required_fields = schema["required_draft"]
validate_fields = schema["validate_draft"]
missing_fields = check_required_fields(data, required_fields)
invalid_fields = validate_required_fields(data, validate_fields)
if missing_fields:
@@ -366,7 +577,6 @@ class CreateAgent(Resource):
return make_response(
jsonify({"success": False, "message": "Image upload failed"}), 400
)
folder_id = data.get("folder_id")
if folder_id:
if not ObjectId.is_valid(folder_id):
@@ -381,76 +591,36 @@ class CreateAgent(Resource):
return make_response(
jsonify({"success": False, "message": "Folder not found"}), 404
)
try:
key = str(uuid.uuid4()) if data.get("status") == "published" else ""
sources_list = []
source_field = ""
if data.get("sources") and len(data.get("sources", [])) > 0:
for source_id in data.get("sources", []):
if source_id == "default":
sources_list.append("default")
elif ObjectId.is_valid(source_id):
sources_list.append(DBRef("sources", ObjectId(source_id)))
source_field = ""
else:
source_value = data.get("source", "")
if source_value == "default":
source_field = "default"
elif ObjectId.is_valid(source_value):
source_field = DBRef("sources", ObjectId(source_value))
else:
source_field = ""
new_agent = {
"user": user,
"name": data.get("name"),
"description": data.get("description", ""),
"image": image_url,
"source": source_field,
"sources": sources_list,
"chunks": data.get("chunks", ""),
"retriever": data.get("retriever", ""),
"prompt_id": data.get("prompt_id", ""),
"tools": data.get("tools", []),
"agent_type": data.get("agent_type", ""),
"status": data.get("status"),
"json_schema": data.get("json_schema"),
"limited_token_mode": (
data.get("limited_token_mode") == "True"
if isinstance(data.get("limited_token_mode"), str)
else bool(data.get("limited_token_mode", False))
),
"token_limit": int(
data.get(
"token_limit", settings.DEFAULT_AGENT_LIMITS["token_limit"]
)
),
"limited_request_mode": (
data.get("limited_request_mode") == "True"
if isinstance(data.get("limited_request_mode"), str)
else bool(data.get("limited_request_mode", False))
),
"request_limit": int(
data.get(
"request_limit", settings.DEFAULT_AGENT_LIMITS["request_limit"]
)
),
"createdAt": datetime.datetime.now(datetime.timezone.utc),
"updatedAt": datetime.datetime.now(datetime.timezone.utc),
"lastUsedAt": None,
"key": key,
"models": data.get("models", []),
"default_model_id": data.get("default_model_id", ""),
"folder_id": data.get("folder_id"),
}
if new_agent["chunks"] == "":
new_agent["chunks"] = "2"
if (
new_agent["source"] == ""
and new_agent["retriever"] == ""
and not new_agent["sources"]
):
new_agent["retriever"] = "classic"
new_agent = build_agent_document(
data, user, key, agent_type, image_url, source_field, sources_list
)
if agent_type != "workflow":
if new_agent.get("chunks") == "":
new_agent["chunks"] = "2"
if (
new_agent.get("source") == ""
and new_agent.get("retriever") == ""
and not new_agent.get("sources")
):
new_agent["retriever"] = "classic"
resp = agents_collection.insert_one(new_agent)
new_id = str(resp.inserted_id)
except Exception as err:
@@ -479,16 +649,22 @@ class UpdateAgent(Resource):
required=False,
description="List of source identifiers for multiple sources",
),
"chunks": fields.Integer(required=True, description="Chunks count"),
"retriever": fields.String(required=True, description="Retriever ID"),
"prompt_id": fields.String(required=True, description="Prompt ID"),
"chunks": fields.Integer(required=False, description="Chunks count"),
"retriever": fields.String(required=False, description="Retriever ID"),
"prompt_id": fields.String(required=False, description="Prompt ID"),
"tools": fields.List(
fields.String, required=False, description="List of tool identifiers"
),
"agent_type": fields.String(required=True, description="Type of the agent"),
"agent_type": fields.String(
required=False,
description="Type of the agent (classic, react, workflow). Defaults to 'classic' for backwards compatibility.",
),
"status": fields.String(
required=True, description="Status of the agent (draft or published)"
),
"workflow": fields.String(
required=False, description="Workflow ID for workflow-type agents"
),
"json_schema": fields.Raw(
required=False,
description="JSON schema for enforcing structured output format",
@@ -612,6 +788,7 @@ class UpdateAgent(Resource):
"models",
"default_model_id",
"folder_id",
"workflow",
]
for field in allowed_fields:
@@ -768,10 +945,10 @@ class UpdateAgent(Resource):
)
elif field == "token_limit":
token_limit = data.get("token_limit")
# Convert to int and store
update_fields[field] = int(token_limit) if token_limit else 0
# Validate consistency with mode
if update_fields[field] > 0 and not data.get("limited_token_mode"):
return make_response(
jsonify(
@@ -814,14 +991,24 @@ class UpdateAgent(Resource):
)
if not folder:
return make_response(
jsonify(
{"success": False, "message": "Folder not found"}
),
jsonify({"success": False, "message": "Folder not found"}),
404,
)
update_fields[field] = folder_id
else:
update_fields[field] = None
elif field == "workflow":
workflow_required = (
data.get("status", existing_agent.get("status")) == "published"
and data.get("agent_type", existing_agent.get("agent_type"))
== "workflow"
)
workflow_id, workflow_error = validate_workflow_access(
data.get("workflow"), user, required=workflow_required
)
if workflow_error:
return workflow_error
update_fields[field] = workflow_id
else:
value = data[field]
if field in ["name", "description", "prompt_id", "agent_type"]:
@@ -850,46 +1037,82 @@ class UpdateAgent(Resource):
)
newly_generated_key = None
final_status = update_fields.get("status", existing_agent.get("status"))
agent_type = update_fields.get("agent_type", existing_agent.get("agent_type"))
if final_status == "published":
required_published_fields = {
"name": "Agent name",
"description": "Agent description",
"chunks": "Chunks count",
"prompt_id": "Prompt",
"agent_type": "Agent type",
}
if agent_type == "workflow":
required_published_fields = {
"name": "Agent name",
}
missing_published_fields = []
for req_field, field_label in required_published_fields.items():
final_value = update_fields.get(
req_field, existing_agent.get(req_field)
)
if not final_value:
missing_published_fields.append(field_label)
missing_published_fields = []
for req_field, field_label in required_published_fields.items():
final_value = update_fields.get(
req_field, existing_agent.get(req_field)
workflow_id = update_fields.get("workflow", existing_agent.get("workflow"))
if not workflow_id:
missing_published_fields.append("Workflow")
elif not ObjectId.is_valid(workflow_id):
missing_published_fields.append("Valid workflow")
else:
workflow = workflows_collection.find_one(
{"_id": ObjectId(workflow_id), "user": user}
)
if not workflow:
missing_published_fields.append("Workflow access")
if missing_published_fields:
return make_response(
jsonify(
{
"success": False,
"message": f"Cannot publish workflow agent. Missing required fields: {', '.join(missing_published_fields)}",
}
),
400,
)
else:
required_published_fields = {
"name": "Agent name",
"description": "Agent description",
"chunks": "Chunks count",
"prompt_id": "Prompt",
"agent_type": "Agent type",
}
missing_published_fields = []
for req_field, field_label in required_published_fields.items():
final_value = update_fields.get(
req_field, existing_agent.get(req_field)
)
if not final_value:
missing_published_fields.append(field_label)
source_val = update_fields.get("source", existing_agent.get("source"))
sources_val = update_fields.get(
"sources", existing_agent.get("sources", [])
)
if not final_value:
missing_published_fields.append(field_label)
source_val = update_fields.get("source", existing_agent.get("source"))
sources_val = update_fields.get(
"sources", existing_agent.get("sources", [])
)
has_valid_source = (
isinstance(source_val, DBRef)
or source_val == "default"
or (isinstance(sources_val, list) and len(sources_val) > 0)
)
if not has_valid_source:
missing_published_fields.append("Source")
if missing_published_fields:
return make_response(
jsonify(
{
"success": False,
"message": f"Cannot publish agent. Missing or invalid required fields: {', '.join(missing_published_fields)}",
}
),
400,
has_valid_source = (
isinstance(source_val, DBRef)
or source_val == "default"
or (isinstance(sources_val, list) and len(sources_val) > 0)
)
if not has_valid_source:
missing_published_fields.append("Source")
if missing_published_fields:
return make_response(
jsonify(
{
"success": False,
"message": f"Cannot publish agent. Missing or invalid required fields: {', '.join(missing_published_fields)}",
}
),
400,
)
if not existing_agent.get("key"):
newly_generated_key = str(uuid.uuid4())
update_fields["key"] = newly_generated_key
@@ -961,6 +1184,29 @@ class DeleteAgent(Resource):
jsonify({"success": False, "message": "Agent not found"}), 404
)
deleted_id = str(deleted_agent["_id"])
if deleted_agent.get("agent_type") == "workflow" and deleted_agent.get(
"workflow"
):
workflow_id = normalize_workflow_reference(deleted_agent.get("workflow"))
if workflow_id and ObjectId.is_valid(workflow_id):
workflow_oid = ObjectId(workflow_id)
owned_workflow = workflows_collection.find_one(
{"_id": workflow_oid, "user": user}, {"_id": 1}
)
if owned_workflow:
workflow_nodes_collection.delete_many({"workflow_id": workflow_id})
workflow_edges_collection.delete_many({"workflow_id": workflow_id})
workflows_collection.delete_one({"_id": workflow_oid, "user": user})
else:
current_app.logger.warning(
f"Skipping workflow cleanup for non-owned workflow {workflow_id}"
)
elif workflow_id:
current_app.logger.warning(
f"Skipping workflow cleanup for invalid workflow id {workflow_id}"
)
except Exception as err:
current_app.logger.error(f"Error deleting agent: {err}", exc_info=True)
return make_response(jsonify({"success": False}), 400)
@@ -1069,19 +1315,16 @@ class AdoptAgent(Resource):
def post(self):
if not (decoded_token := request.decoded_token):
return make_response(jsonify({"success": False}), 401)
if not (agent_id := request.args.get("id")):
return make_response(
jsonify({"success": False, "message": "ID required"}), 400
)
try:
agent = agents_collection.find_one(
{"_id": ObjectId(agent_id), "user": "system"}
)
if not agent:
return make_response(jsonify({"status": "Not found"}), 404)
new_agent = agent.copy()
new_agent.pop("_id", None)
new_agent["user"] = decoded_token["sub"]

View File

@@ -38,6 +38,10 @@ users_collection = db["users"]
user_logs_collection = db["user_logs"]
user_tools_collection = db["user_tools"]
attachments_collection = db["attachments"]
workflow_runs_collection = db["workflow_runs"]
workflows_collection = db["workflows"]
workflow_nodes_collection = db["workflow_nodes"]
workflow_edges_collection = db["workflow_edges"]
try:
@@ -47,6 +51,25 @@ try:
background=True,
)
users_collection.create_index("user_id", unique=True)
workflows_collection.create_index(
[("user", 1)], name="workflow_user_index", background=True
)
workflow_nodes_collection.create_index(
[("workflow_id", 1)], name="node_workflow_index", background=True
)
workflow_nodes_collection.create_index(
[("workflow_id", 1), ("graph_version", 1)],
name="node_workflow_graph_version_index",
background=True,
)
workflow_edges_collection.create_index(
[("workflow_id", 1)], name="edge_workflow_index", background=True
)
workflow_edges_collection.create_index(
[("workflow_id", 1), ("graph_version", 1)],
name="edge_workflow_graph_version_index",
background=True,
)
except Exception as e:
print("Error creating indexes:", e)
current_dir = os.path.dirname(

View File

@@ -6,7 +6,6 @@ from flask import Blueprint
from application.api import api
from .agents import agents_ns, agents_sharing_ns, agents_webhooks_ns, agents_folders_ns
from .analytics import analytics_ns
from .attachments import attachments_ns
from .conversations import conversations_ns
@@ -15,6 +14,7 @@ from .prompts import prompts_ns
from .sharing import sharing_ns
from .sources import sources_chunks_ns, sources_ns, sources_upload_ns
from .tools import tools_mcp_ns, tools_ns
from .workflows import workflows_ns
user = Blueprint("user", __name__)
@@ -51,3 +51,6 @@ api.add_namespace(sources_upload_ns)
# Tools (main, MCP)
api.add_namespace(tools_ns)
api.add_namespace(tools_mcp_ns)
# Workflows
api.add_namespace(workflows_ns)

View File

@@ -0,0 +1,378 @@
"""Centralized utilities for API routes."""
from functools import wraps
from typing import Any, Callable, Dict, List, Optional, Tuple
from bson.errors import InvalidId
from bson.objectid import ObjectId
from flask import jsonify, make_response, request, Response
from pymongo.collection import Collection
def get_user_id() -> Optional[str]:
"""
Extract user ID from decoded JWT token.
Returns:
User ID string or None if not authenticated
"""
decoded_token = getattr(request, "decoded_token", None)
return decoded_token.get("sub") if decoded_token else None
def require_auth(func: Callable) -> Callable:
"""
Decorator to require authentication for route handlers.
Usage:
@require_auth
def get(self):
user_id = get_user_id()
...
"""
@wraps(func)
def wrapper(*args, **kwargs):
user_id = get_user_id()
if not user_id:
return error_response("Unauthorized", 401)
return func(*args, **kwargs)
return wrapper
def success_response(
data: Optional[Dict[str, Any]] = None, status: int = 200
) -> Response:
"""
Create a standardized success response.
Args:
data: Optional data dictionary to include in response
status: HTTP status code (default: 200)
Returns:
Flask Response object
Example:
return success_response({"users": [...], "total": 10})
"""
response = {"success": True}
if data:
response.update(data)
return make_response(jsonify(response), status)
def error_response(message: str, status: int = 400, **kwargs) -> Response:
"""
Create a standardized error response.
Args:
message: Error message string
status: HTTP status code (default: 400)
**kwargs: Additional fields to include in response
Returns:
Flask Response object
Example:
return error_response("Resource not found", 404)
return error_response("Invalid input", 400, errors=["field1", "field2"])
"""
response = {"success": False, "message": message}
response.update(kwargs)
return make_response(jsonify(response), status)
def validate_object_id(
id_string: str, resource_name: str = "Resource"
) -> Tuple[Optional[ObjectId], Optional[Response]]:
"""
Validate and convert string to ObjectId.
Args:
id_string: String to convert
resource_name: Name of resource for error message
Returns:
Tuple of (ObjectId or None, error_response or None)
Example:
obj_id, error = validate_object_id(workflow_id, "Workflow")
if error:
return error
"""
try:
return ObjectId(id_string), None
except (InvalidId, TypeError):
return None, error_response(f"Invalid {resource_name} ID format")
def validate_pagination(
default_limit: int = 20, max_limit: int = 100
) -> Tuple[int, int, Optional[Response]]:
"""
Extract and validate pagination parameters from request.
Args:
default_limit: Default items per page
max_limit: Maximum allowed items per page
Returns:
Tuple of (limit, skip, error_response or None)
Example:
limit, skip, error = validate_pagination()
if error:
return error
"""
try:
limit = min(int(request.args.get("limit", default_limit)), max_limit)
skip = int(request.args.get("skip", 0))
if limit < 1 or skip < 0:
return 0, 0, error_response("Invalid pagination parameters")
return limit, skip, None
except ValueError:
return 0, 0, error_response("Invalid pagination parameters")
def check_resource_ownership(
collection: Collection,
resource_id: ObjectId,
user_id: str,
resource_name: str = "Resource",
) -> Tuple[Optional[Dict], Optional[Response]]:
"""
Check if resource exists and belongs to user.
Args:
collection: MongoDB collection
resource_id: Resource ObjectId
user_id: User ID string
resource_name: Name of resource for error messages
Returns:
Tuple of (resource_dict or None, error_response or None)
Example:
workflow, error = check_resource_ownership(
workflows_collection,
workflow_id,
user_id,
"Workflow"
)
if error:
return error
"""
resource = collection.find_one({"_id": resource_id, "user": user_id})
if not resource:
return None, error_response(f"{resource_name} not found", 404)
return resource, None
def serialize_object_id(
obj: Dict[str, Any], id_field: str = "_id", new_field: str = "id"
) -> Dict[str, Any]:
"""
Convert ObjectId to string in a dictionary.
Args:
obj: Dictionary containing ObjectId
id_field: Field name containing ObjectId
new_field: New field name for string ID
Returns:
Modified dictionary
Example:
user = serialize_object_id(user_doc)
# user["id"] = "507f1f77bcf86cd799439011"
"""
if id_field in obj:
obj[new_field] = str(obj[id_field])
if id_field != new_field:
obj.pop(id_field, None)
return obj
def serialize_list(items: List[Dict], serializer: Callable[[Dict], Dict]) -> List[Dict]:
"""
Apply serializer function to list of items.
Args:
items: List of dictionaries
serializer: Function to apply to each item
Returns:
List of serialized items
Example:
workflows = serialize_list(workflow_docs, serialize_workflow)
"""
return [serializer(item) for item in items]
def paginated_response(
collection: Collection,
query: Dict[str, Any],
serializer: Callable[[Dict], Dict],
limit: int,
skip: int,
sort_field: str = "created_at",
sort_order: int = -1,
response_key: str = "items",
) -> Response:
"""
Create paginated response for collection query.
Args:
collection: MongoDB collection
query: Query dictionary
serializer: Function to serialize each item
limit: Items per page
skip: Number of items to skip
sort_field: Field to sort by
sort_order: Sort order (1=asc, -1=desc)
response_key: Key name for items in response
Returns:
Flask Response with paginated data
Example:
return paginated_response(
workflows_collection,
{"user": user_id},
serialize_workflow,
limit, skip,
response_key="workflows"
)
"""
items = list(
collection.find(query).sort(sort_field, sort_order).skip(skip).limit(limit)
)
total = collection.count_documents(query)
return success_response(
{
response_key: serialize_list(items, serializer),
"total": total,
"limit": limit,
"skip": skip,
}
)
def require_fields(required: List[str]) -> Callable:
"""
Decorator to validate required fields in request JSON.
Args:
required: List of required field names
Returns:
Decorator function
Example:
@require_fields(["name", "description"])
def post(self):
data = request.get_json()
...
"""
def decorator(func: Callable) -> Callable:
@wraps(func)
def wrapper(*args, **kwargs):
data = request.get_json()
if not data:
return error_response("Request body required")
missing = [field for field in required if not data.get(field)]
if missing:
return error_response(f"Missing required fields: {', '.join(missing)}")
return func(*args, **kwargs)
return wrapper
return decorator
def safe_db_operation(
operation: Callable, error_message: str = "Database operation failed"
) -> Tuple[Any, Optional[Response]]:
"""
Safely execute database operation with error handling.
Args:
operation: Function to execute
error_message: Error message if operation fails
Returns:
Tuple of (result or None, error_response or None)
Example:
result, error = safe_db_operation(
lambda: collection.insert_one(doc),
"Failed to create resource"
)
if error:
return error
"""
try:
result = operation()
return result, None
except Exception as e:
return None, error_response(f"{error_message}: {str(e)}")
def validate_enum(
value: Any, allowed: List[Any], field_name: str
) -> Optional[Response]:
"""
Validate that value is in allowed list.
Args:
value: Value to validate
allowed: List of allowed values
field_name: Field name for error message
Returns:
error_response if invalid, None if valid
Example:
error = validate_enum(status, ["draft", "published"], "status")
if error:
return error
"""
if value not in allowed:
allowed_str = ", ".join(f"'{v}'" for v in allowed)
return error_response(f"Invalid {field_name}. Must be one of: {allowed_str}")
return None
def extract_sort_params(
default_field: str = "created_at",
default_order: str = "desc",
allowed_fields: Optional[List[str]] = None,
) -> Tuple[str, int]:
"""
Extract and validate sort parameters from request.
Args:
default_field: Default sort field
default_order: Default sort order ("asc" or "desc")
allowed_fields: List of allowed sort fields (None = no validation)
Returns:
Tuple of (sort_field, sort_order)
Example:
sort_field, sort_order = extract_sort_params(
allowed_fields=["name", "date", "status"]
)
"""
sort_field = request.args.get("sort", default_field)
sort_order_str = request.args.get("order", default_order).lower()
if allowed_fields and sort_field not in allowed_fields:
sort_field = default_field
sort_order = -1 if sort_order_str == "desc" else 1
return sort_field, sort_order

View File

@@ -0,0 +1,3 @@
from .routes import workflows_ns
__all__ = ["workflows_ns"]

View File

@@ -0,0 +1,353 @@
"""Workflow management routes."""
from datetime import datetime, timezone
from typing import Dict, List
from flask import current_app, request
from flask_restx import Namespace, Resource
from application.api.user.base import (
workflow_edges_collection,
workflow_nodes_collection,
workflows_collection,
)
from application.api.user.utils import (
check_resource_ownership,
error_response,
get_user_id,
require_auth,
require_fields,
safe_db_operation,
success_response,
validate_object_id,
)
workflows_ns = Namespace("workflows", path="/api")
def serialize_workflow(w: Dict) -> Dict:
"""Serialize workflow document to API response format."""
return {
"id": str(w["_id"]),
"name": w.get("name"),
"description": w.get("description"),
"created_at": w["created_at"].isoformat() if w.get("created_at") else None,
"updated_at": w["updated_at"].isoformat() if w.get("updated_at") else None,
}
def serialize_node(n: Dict) -> Dict:
"""Serialize workflow node document to API response format."""
return {
"id": n["id"],
"type": n["type"],
"title": n.get("title"),
"description": n.get("description"),
"position": n.get("position"),
"data": n.get("config", {}),
}
def serialize_edge(e: Dict) -> Dict:
"""Serialize workflow edge document to API response format."""
return {
"id": e["id"],
"source": e.get("source_id"),
"target": e.get("target_id"),
"sourceHandle": e.get("source_handle"),
"targetHandle": e.get("target_handle"),
}
def get_workflow_graph_version(workflow: Dict) -> int:
"""Get current graph version with legacy fallback."""
raw_version = workflow.get("current_graph_version", 1)
try:
version = int(raw_version)
return version if version > 0 else 1
except (ValueError, TypeError):
return 1
def fetch_graph_documents(collection, workflow_id: str, graph_version: int) -> List[Dict]:
"""Fetch graph docs for active version, with fallback for legacy unversioned data."""
docs = list(
collection.find({"workflow_id": workflow_id, "graph_version": graph_version})
)
if docs:
return docs
if graph_version == 1:
return list(
collection.find(
{"workflow_id": workflow_id, "graph_version": {"$exists": False}}
)
)
return docs
def validate_workflow_structure(nodes: List[Dict], edges: List[Dict]) -> List[str]:
"""Validate workflow graph structure."""
errors = []
if not nodes:
errors.append("Workflow must have at least one node")
return errors
start_nodes = [n for n in nodes if n.get("type") == "start"]
if len(start_nodes) != 1:
errors.append("Workflow must have exactly one start node")
end_nodes = [n for n in nodes if n.get("type") == "end"]
if not end_nodes:
errors.append("Workflow must have at least one end node")
node_ids = {n.get("id") for n in nodes}
for edge in edges:
source_id = edge.get("source")
target_id = edge.get("target")
if source_id not in node_ids:
errors.append(f"Edge references non-existent source: {source_id}")
if target_id not in node_ids:
errors.append(f"Edge references non-existent target: {target_id}")
if start_nodes:
start_id = start_nodes[0].get("id")
if not any(e.get("source") == start_id for e in edges):
errors.append("Start node must have at least one outgoing edge")
for node in nodes:
if not node.get("id"):
errors.append("All nodes must have an id")
if not node.get("type"):
errors.append(f"Node {node.get('id', 'unknown')} must have a type")
return errors
def create_workflow_nodes(
workflow_id: str, nodes_data: List[Dict], graph_version: int
) -> None:
"""Insert workflow nodes into database."""
if nodes_data:
workflow_nodes_collection.insert_many(
[
{
"id": n["id"],
"workflow_id": workflow_id,
"graph_version": graph_version,
"type": n["type"],
"title": n.get("title", ""),
"description": n.get("description", ""),
"position": n.get("position", {"x": 0, "y": 0}),
"config": n.get("data", {}),
}
for n in nodes_data
]
)
def create_workflow_edges(
workflow_id: str, edges_data: List[Dict], graph_version: int
) -> None:
"""Insert workflow edges into database."""
if edges_data:
workflow_edges_collection.insert_many(
[
{
"id": e["id"],
"workflow_id": workflow_id,
"graph_version": graph_version,
"source_id": e.get("source"),
"target_id": e.get("target"),
"source_handle": e.get("sourceHandle"),
"target_handle": e.get("targetHandle"),
}
for e in edges_data
]
)
@workflows_ns.route("/workflows")
class WorkflowList(Resource):
@require_auth
@require_fields(["name"])
def post(self):
"""Create a new workflow with nodes and edges."""
user_id = get_user_id()
data = request.get_json()
name = data.get("name", "").strip()
nodes_data = data.get("nodes", [])
edges_data = data.get("edges", [])
validation_errors = validate_workflow_structure(nodes_data, edges_data)
if validation_errors:
return error_response(
"Workflow validation failed", errors=validation_errors
)
now = datetime.now(timezone.utc)
workflow_doc = {
"name": name,
"description": data.get("description", ""),
"user": user_id,
"created_at": now,
"updated_at": now,
"current_graph_version": 1,
}
result, error = safe_db_operation(
lambda: workflows_collection.insert_one(workflow_doc),
"Failed to create workflow",
)
if error:
return error
workflow_id = str(result.inserted_id)
try:
create_workflow_nodes(workflow_id, nodes_data, 1)
create_workflow_edges(workflow_id, edges_data, 1)
except Exception as e:
workflow_nodes_collection.delete_many({"workflow_id": workflow_id})
workflow_edges_collection.delete_many({"workflow_id": workflow_id})
workflows_collection.delete_one({"_id": result.inserted_id})
return error_response(f"Failed to create workflow structure: {str(e)}")
return success_response({"id": workflow_id}, 201)
@workflows_ns.route("/workflows/<string:workflow_id>")
class WorkflowDetail(Resource):
@require_auth
def get(self, workflow_id: str):
"""Get workflow details with nodes and edges."""
user_id = get_user_id()
obj_id, error = validate_object_id(workflow_id, "Workflow")
if error:
return error
workflow, error = check_resource_ownership(
workflows_collection, obj_id, user_id, "Workflow"
)
if error:
return error
graph_version = get_workflow_graph_version(workflow)
nodes = fetch_graph_documents(
workflow_nodes_collection, workflow_id, graph_version
)
edges = fetch_graph_documents(
workflow_edges_collection, workflow_id, graph_version
)
return success_response(
{
"workflow": serialize_workflow(workflow),
"nodes": [serialize_node(n) for n in nodes],
"edges": [serialize_edge(e) for e in edges],
}
)
@require_auth
@require_fields(["name"])
def put(self, workflow_id: str):
"""Update workflow and replace nodes/edges."""
user_id = get_user_id()
obj_id, error = validate_object_id(workflow_id, "Workflow")
if error:
return error
workflow, error = check_resource_ownership(
workflows_collection, obj_id, user_id, "Workflow"
)
if error:
return error
data = request.get_json()
name = data.get("name", "").strip()
nodes_data = data.get("nodes", [])
edges_data = data.get("edges", [])
validation_errors = validate_workflow_structure(nodes_data, edges_data)
if validation_errors:
return error_response(
"Workflow validation failed", errors=validation_errors
)
current_graph_version = get_workflow_graph_version(workflow)
next_graph_version = current_graph_version + 1
try:
create_workflow_nodes(workflow_id, nodes_data, next_graph_version)
create_workflow_edges(workflow_id, edges_data, next_graph_version)
except Exception as e:
workflow_nodes_collection.delete_many(
{"workflow_id": workflow_id, "graph_version": next_graph_version}
)
workflow_edges_collection.delete_many(
{"workflow_id": workflow_id, "graph_version": next_graph_version}
)
return error_response(f"Failed to update workflow structure: {str(e)}")
now = datetime.now(timezone.utc)
_, error = safe_db_operation(
lambda: workflows_collection.update_one(
{"_id": obj_id},
{
"$set": {
"name": name,
"description": data.get("description", ""),
"updated_at": now,
"current_graph_version": next_graph_version,
}
},
),
"Failed to update workflow",
)
if error:
workflow_nodes_collection.delete_many(
{"workflow_id": workflow_id, "graph_version": next_graph_version}
)
workflow_edges_collection.delete_many(
{"workflow_id": workflow_id, "graph_version": next_graph_version}
)
return error
try:
workflow_nodes_collection.delete_many(
{"workflow_id": workflow_id, "graph_version": {"$ne": next_graph_version}}
)
workflow_edges_collection.delete_many(
{"workflow_id": workflow_id, "graph_version": {"$ne": next_graph_version}}
)
except Exception as cleanup_err:
current_app.logger.warning(
f"Failed to clean old workflow graph versions for {workflow_id}: {cleanup_err}"
)
return success_response()
@require_auth
def delete(self, workflow_id: str):
"""Delete workflow and its graph."""
user_id = get_user_id()
obj_id, error = validate_object_id(workflow_id, "Workflow")
if error:
return error
workflow, error = check_resource_ownership(
workflows_collection, obj_id, user_id, "Workflow"
)
if error:
return error
try:
workflow_nodes_collection.delete_many({"workflow_id": workflow_id})
workflow_edges_collection.delete_many({"workflow_id": workflow_id})
workflows_collection.delete_one({"_id": workflow["_id"], "user": user_id})
except Exception as e:
return error_response(f"Failed to delete workflow: {str(e)}")
return success_response()