mirror of
https://github.com/arc53/DocsGPT.git
synced 2026-02-22 12:21:39 +00:00
feat: agent workflow builder (#2264)
* feat: implement WorkflowAgent and GraphExecutor for workflow management and execution * refactor: workflow schemas and introduce WorkflowEngine - Updated schemas in `schemas.py` to include new agent types and configurations. - Created `WorkflowEngine` class in `workflow_engine.py` to manage workflow execution. - Enhanced `StreamProcessor` to handle workflow-related data. - Added new routes and utilities for managing workflows in the user API. - Implemented validation and serialization functions for workflows. - Established MongoDB collections and indexes for workflows and related entities. * refactor: improve WorkflowAgent documentation and update type hints in WorkflowEngine * feat: workflow builder and managing in frontend - Added new endpoints for workflows in `endpoints.ts`. - Implemented `getWorkflow`, `createWorkflow`, and `updateWorkflow` methods in `userService.ts`. - Introduced new UI components for alerts, buttons, commands, dialogs, multi-select, popovers, and selects. - Enhanced styling in `index.css` with new theme variables and animations. - Refactored modal components for better layout and styling. - Configured TypeScript paths and Vite aliases for cleaner imports. * feat: add workflow preview component and related state management - Implemented WorkflowPreview component for displaying workflow execution. - Created WorkflowPreviewSlice for managing workflow preview state, including queries and execution steps. - Added WorkflowMiniMap for visual representation of workflow nodes and their statuses. - Integrated conversation handling with the ability to fetch answers and manage query states. - Introduced reusable Sheet component for UI overlays. - Updated Redux store to include workflowPreview reducer. * feat: enhance workflow execution details and state management in WorkflowEngine and WorkflowPreview * feat: enhance workflow components with improved UI and functionality - Updated WorkflowPreview to allow text truncation for better display of long names. - Enhanced BaseNode with connectable handles and improved styling for better visibility. - Added MobileBlocker component to inform users about desktop requirements for the Workflow Builder. - Introduced PromptTextArea component for improved variable insertion and search functionality, including upstream variable extraction and context addition. * feat(workflow): add owner validation and graph version support * fix: ruff lint --------- Co-authored-by: Alex <a@tushynski.me>
This commit is contained in:
@@ -19,6 +19,9 @@ from application.api.user.base import (
|
||||
resolve_tool_details,
|
||||
storage,
|
||||
users_collection,
|
||||
workflow_edges_collection,
|
||||
workflow_nodes_collection,
|
||||
workflows_collection,
|
||||
)
|
||||
from application.core.settings import settings
|
||||
from application.utils import (
|
||||
@@ -31,6 +34,189 @@ from application.utils import (
|
||||
agents_ns = Namespace("agents", description="Agent management operations", path="/api")
|
||||
|
||||
|
||||
AGENT_TYPE_SCHEMAS = {
|
||||
"classic": {
|
||||
"required_published": [
|
||||
"name",
|
||||
"description",
|
||||
"chunks",
|
||||
"retriever",
|
||||
"prompt_id",
|
||||
],
|
||||
"required_draft": ["name"],
|
||||
"validate_published": ["name", "description", "prompt_id"],
|
||||
"validate_draft": [],
|
||||
"require_source": True,
|
||||
"fields": [
|
||||
"user",
|
||||
"name",
|
||||
"description",
|
||||
"agent_type",
|
||||
"status",
|
||||
"key",
|
||||
"image",
|
||||
"source",
|
||||
"sources",
|
||||
"chunks",
|
||||
"retriever",
|
||||
"prompt_id",
|
||||
"tools",
|
||||
"json_schema",
|
||||
"models",
|
||||
"default_model_id",
|
||||
"folder_id",
|
||||
"limited_token_mode",
|
||||
"token_limit",
|
||||
"limited_request_mode",
|
||||
"request_limit",
|
||||
"createdAt",
|
||||
"updatedAt",
|
||||
"lastUsedAt",
|
||||
],
|
||||
},
|
||||
"workflow": {
|
||||
"required_published": ["name", "workflow"],
|
||||
"required_draft": ["name"],
|
||||
"validate_published": ["name", "workflow"],
|
||||
"validate_draft": [],
|
||||
"fields": [
|
||||
"user",
|
||||
"name",
|
||||
"description",
|
||||
"agent_type",
|
||||
"status",
|
||||
"key",
|
||||
"workflow",
|
||||
"folder_id",
|
||||
"limited_token_mode",
|
||||
"token_limit",
|
||||
"limited_request_mode",
|
||||
"request_limit",
|
||||
"createdAt",
|
||||
"updatedAt",
|
||||
"lastUsedAt",
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
AGENT_TYPE_SCHEMAS["react"] = AGENT_TYPE_SCHEMAS["classic"]
|
||||
AGENT_TYPE_SCHEMAS["openai"] = AGENT_TYPE_SCHEMAS["classic"]
|
||||
|
||||
|
||||
def normalize_workflow_reference(workflow_value):
|
||||
"""Normalize workflow references from form/json payloads."""
|
||||
if workflow_value is None:
|
||||
return None
|
||||
if isinstance(workflow_value, dict):
|
||||
return (
|
||||
workflow_value.get("id")
|
||||
or workflow_value.get("_id")
|
||||
or workflow_value.get("workflow_id")
|
||||
)
|
||||
if isinstance(workflow_value, str):
|
||||
value = workflow_value.strip()
|
||||
if not value:
|
||||
return ""
|
||||
try:
|
||||
parsed = json.loads(value)
|
||||
if isinstance(parsed, str):
|
||||
return parsed.strip()
|
||||
if isinstance(parsed, dict):
|
||||
return (
|
||||
parsed.get("id") or parsed.get("_id") or parsed.get("workflow_id")
|
||||
)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
return value
|
||||
return str(workflow_value)
|
||||
|
||||
|
||||
def validate_workflow_access(workflow_value, user, required=False):
|
||||
"""Validate workflow reference and ensure ownership."""
|
||||
workflow_id = normalize_workflow_reference(workflow_value)
|
||||
if not workflow_id:
|
||||
if required:
|
||||
return None, make_response(
|
||||
jsonify({"success": False, "message": "Workflow is required"}), 400
|
||||
)
|
||||
return None, None
|
||||
if not ObjectId.is_valid(workflow_id):
|
||||
return None, make_response(
|
||||
jsonify({"success": False, "message": "Invalid workflow ID format"}), 400
|
||||
)
|
||||
workflow = workflows_collection.find_one({"_id": ObjectId(workflow_id), "user": user})
|
||||
if not workflow:
|
||||
return None, make_response(
|
||||
jsonify({"success": False, "message": "Workflow not found"}), 404
|
||||
)
|
||||
return workflow_id, None
|
||||
|
||||
|
||||
def build_agent_document(
|
||||
data, user, key, agent_type, image_url=None, source_field=None, sources_list=None
|
||||
):
|
||||
"""Build agent document based on agent type schema."""
|
||||
|
||||
if not agent_type or agent_type not in AGENT_TYPE_SCHEMAS:
|
||||
agent_type = "classic"
|
||||
schema = AGENT_TYPE_SCHEMAS.get(agent_type, AGENT_TYPE_SCHEMAS["classic"])
|
||||
allowed_fields = set(schema["fields"])
|
||||
|
||||
now = datetime.datetime.now(datetime.timezone.utc)
|
||||
base_doc = {
|
||||
"user": user,
|
||||
"name": data.get("name"),
|
||||
"description": data.get("description", ""),
|
||||
"agent_type": agent_type,
|
||||
"status": data.get("status"),
|
||||
"key": key,
|
||||
"createdAt": now,
|
||||
"updatedAt": now,
|
||||
"lastUsedAt": None,
|
||||
}
|
||||
|
||||
if agent_type == "workflow":
|
||||
base_doc["workflow"] = data.get("workflow")
|
||||
base_doc["folder_id"] = data.get("folder_id")
|
||||
else:
|
||||
base_doc.update(
|
||||
{
|
||||
"image": image_url or "",
|
||||
"source": source_field or "",
|
||||
"sources": sources_list or [],
|
||||
"chunks": data.get("chunks", ""),
|
||||
"retriever": data.get("retriever", ""),
|
||||
"prompt_id": data.get("prompt_id", ""),
|
||||
"tools": data.get("tools", []),
|
||||
"json_schema": data.get("json_schema"),
|
||||
"models": data.get("models", []),
|
||||
"default_model_id": data.get("default_model_id", ""),
|
||||
"folder_id": data.get("folder_id"),
|
||||
}
|
||||
)
|
||||
if "limited_token_mode" in allowed_fields:
|
||||
base_doc["limited_token_mode"] = (
|
||||
data.get("limited_token_mode") == "True"
|
||||
if isinstance(data.get("limited_token_mode"), str)
|
||||
else bool(data.get("limited_token_mode", False))
|
||||
)
|
||||
if "token_limit" in allowed_fields:
|
||||
base_doc["token_limit"] = int(
|
||||
data.get("token_limit", settings.DEFAULT_AGENT_LIMITS["token_limit"])
|
||||
)
|
||||
if "limited_request_mode" in allowed_fields:
|
||||
base_doc["limited_request_mode"] = (
|
||||
data.get("limited_request_mode") == "True"
|
||||
if isinstance(data.get("limited_request_mode"), str)
|
||||
else bool(data.get("limited_request_mode", False))
|
||||
)
|
||||
if "request_limit" in allowed_fields:
|
||||
base_doc["request_limit"] = int(
|
||||
data.get("request_limit", settings.DEFAULT_AGENT_LIMITS["request_limit"])
|
||||
)
|
||||
return {k: v for k, v in base_doc.items() if k in allowed_fields}
|
||||
|
||||
|
||||
@agents_ns.route("/get_agent")
|
||||
class GetAgent(Resource):
|
||||
@api.doc(params={"id": "Agent ID"}, description="Get agent by ID")
|
||||
@@ -68,7 +254,7 @@ class GetAgent(Resource):
|
||||
if (isinstance(source_ref, DBRef) and db.dereference(source_ref))
|
||||
or source_ref == "default"
|
||||
],
|
||||
"chunks": agent["chunks"],
|
||||
"chunks": agent.get("chunks", "2"),
|
||||
"retriever": agent.get("retriever", ""),
|
||||
"prompt_id": agent.get("prompt_id", ""),
|
||||
"tools": agent.get("tools", []),
|
||||
@@ -99,6 +285,7 @@ class GetAgent(Resource):
|
||||
"models": agent.get("models", []),
|
||||
"default_model_id": agent.get("default_model_id", ""),
|
||||
"folder_id": agent.get("folder_id"),
|
||||
"workflow": agent.get("workflow"),
|
||||
}
|
||||
return make_response(jsonify(data), 200)
|
||||
except Exception as e:
|
||||
@@ -148,7 +335,7 @@ class GetAgents(Resource):
|
||||
isinstance(source_ref, DBRef) and db.dereference(source_ref)
|
||||
)
|
||||
],
|
||||
"chunks": agent["chunks"],
|
||||
"chunks": agent.get("chunks", "2"),
|
||||
"retriever": agent.get("retriever", ""),
|
||||
"prompt_id": agent.get("prompt_id", ""),
|
||||
"tools": agent.get("tools", []),
|
||||
@@ -179,9 +366,12 @@ class GetAgents(Resource):
|
||||
"models": agent.get("models", []),
|
||||
"default_model_id": agent.get("default_model_id", ""),
|
||||
"folder_id": agent.get("folder_id"),
|
||||
"workflow": agent.get("workflow"),
|
||||
}
|
||||
for agent in agents
|
||||
if "source" in agent or "retriever" in agent
|
||||
if "source" in agent
|
||||
or "retriever" in agent
|
||||
or agent.get("agent_type") == "workflow"
|
||||
]
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error retrieving agents: {err}", exc_info=True)
|
||||
@@ -209,16 +399,22 @@ class CreateAgent(Resource):
|
||||
required=False,
|
||||
description="List of source identifiers for multiple sources",
|
||||
),
|
||||
"chunks": fields.Integer(required=True, description="Chunks count"),
|
||||
"retriever": fields.String(required=True, description="Retriever ID"),
|
||||
"prompt_id": fields.String(required=True, description="Prompt ID"),
|
||||
"chunks": fields.Integer(required=False, description="Chunks count"),
|
||||
"retriever": fields.String(required=False, description="Retriever ID"),
|
||||
"prompt_id": fields.String(required=False, description="Prompt ID"),
|
||||
"tools": fields.List(
|
||||
fields.String, required=False, description="List of tool identifiers"
|
||||
),
|
||||
"agent_type": fields.String(required=True, description="Type of the agent"),
|
||||
"agent_type": fields.String(
|
||||
required=False,
|
||||
description="Type of the agent (classic, react, workflow). Defaults to 'classic' for backwards compatibility.",
|
||||
),
|
||||
"status": fields.String(
|
||||
required=True, description="Status of the agent (draft or published)"
|
||||
),
|
||||
"workflow": fields.String(
|
||||
required=False, description="Workflow ID for workflow-type agents"
|
||||
),
|
||||
"json_schema": fields.Raw(
|
||||
required=False,
|
||||
description="JSON schema for enforcing structured output format",
|
||||
@@ -330,18 +526,34 @@ class CreateAgent(Resource):
|
||||
),
|
||||
400,
|
||||
)
|
||||
if data.get("status") == "published":
|
||||
required_fields = [
|
||||
"name",
|
||||
"description",
|
||||
"chunks",
|
||||
"retriever",
|
||||
"prompt_id",
|
||||
"agent_type",
|
||||
]
|
||||
# Require either source or sources (but not both)
|
||||
agent_type = data.get("agent_type", "")
|
||||
# Default to classic schema for empty or unknown agent types
|
||||
|
||||
if not data.get("source") and not data.get("sources"):
|
||||
if not agent_type or agent_type not in AGENT_TYPE_SCHEMAS:
|
||||
schema = AGENT_TYPE_SCHEMAS["classic"]
|
||||
# Set agent_type to classic if it was empty
|
||||
|
||||
if not agent_type:
|
||||
agent_type = "classic"
|
||||
else:
|
||||
schema = AGENT_TYPE_SCHEMAS[agent_type]
|
||||
is_published = data.get("status") == "published"
|
||||
if agent_type == "workflow":
|
||||
workflow_id, workflow_error = validate_workflow_access(
|
||||
data.get("workflow"), user, required=is_published
|
||||
)
|
||||
if workflow_error:
|
||||
return workflow_error
|
||||
data["workflow"] = workflow_id
|
||||
if data.get("status") == "published":
|
||||
required_fields = schema["required_published"]
|
||||
validate_fields = schema["validate_published"]
|
||||
|
||||
if (
|
||||
schema.get("require_source")
|
||||
and not data.get("source")
|
||||
and not data.get("sources")
|
||||
):
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
@@ -351,10 +563,9 @@ class CreateAgent(Resource):
|
||||
),
|
||||
400,
|
||||
)
|
||||
validate_fields = ["name", "description", "prompt_id", "agent_type"]
|
||||
else:
|
||||
required_fields = ["name"]
|
||||
validate_fields = []
|
||||
required_fields = schema["required_draft"]
|
||||
validate_fields = schema["validate_draft"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
invalid_fields = validate_required_fields(data, validate_fields)
|
||||
if missing_fields:
|
||||
@@ -366,7 +577,6 @@ class CreateAgent(Resource):
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Image upload failed"}), 400
|
||||
)
|
||||
|
||||
folder_id = data.get("folder_id")
|
||||
if folder_id:
|
||||
if not ObjectId.is_valid(folder_id):
|
||||
@@ -381,76 +591,36 @@ class CreateAgent(Resource):
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Folder not found"}), 404
|
||||
)
|
||||
|
||||
try:
|
||||
key = str(uuid.uuid4()) if data.get("status") == "published" else ""
|
||||
|
||||
sources_list = []
|
||||
source_field = ""
|
||||
if data.get("sources") and len(data.get("sources", [])) > 0:
|
||||
for source_id in data.get("sources", []):
|
||||
if source_id == "default":
|
||||
sources_list.append("default")
|
||||
elif ObjectId.is_valid(source_id):
|
||||
sources_list.append(DBRef("sources", ObjectId(source_id)))
|
||||
source_field = ""
|
||||
else:
|
||||
source_value = data.get("source", "")
|
||||
if source_value == "default":
|
||||
source_field = "default"
|
||||
elif ObjectId.is_valid(source_value):
|
||||
source_field = DBRef("sources", ObjectId(source_value))
|
||||
else:
|
||||
source_field = ""
|
||||
new_agent = {
|
||||
"user": user,
|
||||
"name": data.get("name"),
|
||||
"description": data.get("description", ""),
|
||||
"image": image_url,
|
||||
"source": source_field,
|
||||
"sources": sources_list,
|
||||
"chunks": data.get("chunks", ""),
|
||||
"retriever": data.get("retriever", ""),
|
||||
"prompt_id": data.get("prompt_id", ""),
|
||||
"tools": data.get("tools", []),
|
||||
"agent_type": data.get("agent_type", ""),
|
||||
"status": data.get("status"),
|
||||
"json_schema": data.get("json_schema"),
|
||||
"limited_token_mode": (
|
||||
data.get("limited_token_mode") == "True"
|
||||
if isinstance(data.get("limited_token_mode"), str)
|
||||
else bool(data.get("limited_token_mode", False))
|
||||
),
|
||||
"token_limit": int(
|
||||
data.get(
|
||||
"token_limit", settings.DEFAULT_AGENT_LIMITS["token_limit"]
|
||||
)
|
||||
),
|
||||
"limited_request_mode": (
|
||||
data.get("limited_request_mode") == "True"
|
||||
if isinstance(data.get("limited_request_mode"), str)
|
||||
else bool(data.get("limited_request_mode", False))
|
||||
),
|
||||
"request_limit": int(
|
||||
data.get(
|
||||
"request_limit", settings.DEFAULT_AGENT_LIMITS["request_limit"]
|
||||
)
|
||||
),
|
||||
"createdAt": datetime.datetime.now(datetime.timezone.utc),
|
||||
"updatedAt": datetime.datetime.now(datetime.timezone.utc),
|
||||
"lastUsedAt": None,
|
||||
"key": key,
|
||||
"models": data.get("models", []),
|
||||
"default_model_id": data.get("default_model_id", ""),
|
||||
"folder_id": data.get("folder_id"),
|
||||
}
|
||||
if new_agent["chunks"] == "":
|
||||
new_agent["chunks"] = "2"
|
||||
if (
|
||||
new_agent["source"] == ""
|
||||
and new_agent["retriever"] == ""
|
||||
and not new_agent["sources"]
|
||||
):
|
||||
new_agent["retriever"] = "classic"
|
||||
new_agent = build_agent_document(
|
||||
data, user, key, agent_type, image_url, source_field, sources_list
|
||||
)
|
||||
|
||||
if agent_type != "workflow":
|
||||
if new_agent.get("chunks") == "":
|
||||
new_agent["chunks"] = "2"
|
||||
if (
|
||||
new_agent.get("source") == ""
|
||||
and new_agent.get("retriever") == ""
|
||||
and not new_agent.get("sources")
|
||||
):
|
||||
new_agent["retriever"] = "classic"
|
||||
resp = agents_collection.insert_one(new_agent)
|
||||
new_id = str(resp.inserted_id)
|
||||
except Exception as err:
|
||||
@@ -479,16 +649,22 @@ class UpdateAgent(Resource):
|
||||
required=False,
|
||||
description="List of source identifiers for multiple sources",
|
||||
),
|
||||
"chunks": fields.Integer(required=True, description="Chunks count"),
|
||||
"retriever": fields.String(required=True, description="Retriever ID"),
|
||||
"prompt_id": fields.String(required=True, description="Prompt ID"),
|
||||
"chunks": fields.Integer(required=False, description="Chunks count"),
|
||||
"retriever": fields.String(required=False, description="Retriever ID"),
|
||||
"prompt_id": fields.String(required=False, description="Prompt ID"),
|
||||
"tools": fields.List(
|
||||
fields.String, required=False, description="List of tool identifiers"
|
||||
),
|
||||
"agent_type": fields.String(required=True, description="Type of the agent"),
|
||||
"agent_type": fields.String(
|
||||
required=False,
|
||||
description="Type of the agent (classic, react, workflow). Defaults to 'classic' for backwards compatibility.",
|
||||
),
|
||||
"status": fields.String(
|
||||
required=True, description="Status of the agent (draft or published)"
|
||||
),
|
||||
"workflow": fields.String(
|
||||
required=False, description="Workflow ID for workflow-type agents"
|
||||
),
|
||||
"json_schema": fields.Raw(
|
||||
required=False,
|
||||
description="JSON schema for enforcing structured output format",
|
||||
@@ -612,6 +788,7 @@ class UpdateAgent(Resource):
|
||||
"models",
|
||||
"default_model_id",
|
||||
"folder_id",
|
||||
"workflow",
|
||||
]
|
||||
|
||||
for field in allowed_fields:
|
||||
@@ -768,10 +945,10 @@ class UpdateAgent(Resource):
|
||||
)
|
||||
elif field == "token_limit":
|
||||
token_limit = data.get("token_limit")
|
||||
# Convert to int and store
|
||||
update_fields[field] = int(token_limit) if token_limit else 0
|
||||
|
||||
# Validate consistency with mode
|
||||
|
||||
if update_fields[field] > 0 and not data.get("limited_token_mode"):
|
||||
return make_response(
|
||||
jsonify(
|
||||
@@ -814,14 +991,24 @@ class UpdateAgent(Resource):
|
||||
)
|
||||
if not folder:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{"success": False, "message": "Folder not found"}
|
||||
),
|
||||
jsonify({"success": False, "message": "Folder not found"}),
|
||||
404,
|
||||
)
|
||||
update_fields[field] = folder_id
|
||||
else:
|
||||
update_fields[field] = None
|
||||
elif field == "workflow":
|
||||
workflow_required = (
|
||||
data.get("status", existing_agent.get("status")) == "published"
|
||||
and data.get("agent_type", existing_agent.get("agent_type"))
|
||||
== "workflow"
|
||||
)
|
||||
workflow_id, workflow_error = validate_workflow_access(
|
||||
data.get("workflow"), user, required=workflow_required
|
||||
)
|
||||
if workflow_error:
|
||||
return workflow_error
|
||||
update_fields[field] = workflow_id
|
||||
else:
|
||||
value = data[field]
|
||||
if field in ["name", "description", "prompt_id", "agent_type"]:
|
||||
@@ -850,46 +1037,82 @@ class UpdateAgent(Resource):
|
||||
)
|
||||
newly_generated_key = None
|
||||
final_status = update_fields.get("status", existing_agent.get("status"))
|
||||
agent_type = update_fields.get("agent_type", existing_agent.get("agent_type"))
|
||||
|
||||
if final_status == "published":
|
||||
required_published_fields = {
|
||||
"name": "Agent name",
|
||||
"description": "Agent description",
|
||||
"chunks": "Chunks count",
|
||||
"prompt_id": "Prompt",
|
||||
"agent_type": "Agent type",
|
||||
}
|
||||
if agent_type == "workflow":
|
||||
required_published_fields = {
|
||||
"name": "Agent name",
|
||||
}
|
||||
missing_published_fields = []
|
||||
for req_field, field_label in required_published_fields.items():
|
||||
final_value = update_fields.get(
|
||||
req_field, existing_agent.get(req_field)
|
||||
)
|
||||
if not final_value:
|
||||
missing_published_fields.append(field_label)
|
||||
|
||||
missing_published_fields = []
|
||||
for req_field, field_label in required_published_fields.items():
|
||||
final_value = update_fields.get(
|
||||
req_field, existing_agent.get(req_field)
|
||||
workflow_id = update_fields.get("workflow", existing_agent.get("workflow"))
|
||||
if not workflow_id:
|
||||
missing_published_fields.append("Workflow")
|
||||
elif not ObjectId.is_valid(workflow_id):
|
||||
missing_published_fields.append("Valid workflow")
|
||||
else:
|
||||
workflow = workflows_collection.find_one(
|
||||
{"_id": ObjectId(workflow_id), "user": user}
|
||||
)
|
||||
if not workflow:
|
||||
missing_published_fields.append("Workflow access")
|
||||
|
||||
if missing_published_fields:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": f"Cannot publish workflow agent. Missing required fields: {', '.join(missing_published_fields)}",
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
else:
|
||||
required_published_fields = {
|
||||
"name": "Agent name",
|
||||
"description": "Agent description",
|
||||
"chunks": "Chunks count",
|
||||
"prompt_id": "Prompt",
|
||||
"agent_type": "Agent type",
|
||||
}
|
||||
|
||||
missing_published_fields = []
|
||||
for req_field, field_label in required_published_fields.items():
|
||||
final_value = update_fields.get(
|
||||
req_field, existing_agent.get(req_field)
|
||||
)
|
||||
if not final_value:
|
||||
missing_published_fields.append(field_label)
|
||||
source_val = update_fields.get("source", existing_agent.get("source"))
|
||||
sources_val = update_fields.get(
|
||||
"sources", existing_agent.get("sources", [])
|
||||
)
|
||||
if not final_value:
|
||||
missing_published_fields.append(field_label)
|
||||
source_val = update_fields.get("source", existing_agent.get("source"))
|
||||
sources_val = update_fields.get(
|
||||
"sources", existing_agent.get("sources", [])
|
||||
)
|
||||
|
||||
has_valid_source = (
|
||||
isinstance(source_val, DBRef)
|
||||
or source_val == "default"
|
||||
or (isinstance(sources_val, list) and len(sources_val) > 0)
|
||||
)
|
||||
|
||||
if not has_valid_source:
|
||||
missing_published_fields.append("Source")
|
||||
if missing_published_fields:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": f"Cannot publish agent. Missing or invalid required fields: {', '.join(missing_published_fields)}",
|
||||
}
|
||||
),
|
||||
400,
|
||||
has_valid_source = (
|
||||
isinstance(source_val, DBRef)
|
||||
or source_val == "default"
|
||||
or (isinstance(sources_val, list) and len(sources_val) > 0)
|
||||
)
|
||||
|
||||
if not has_valid_source:
|
||||
missing_published_fields.append("Source")
|
||||
if missing_published_fields:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": f"Cannot publish agent. Missing or invalid required fields: {', '.join(missing_published_fields)}",
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
if not existing_agent.get("key"):
|
||||
newly_generated_key = str(uuid.uuid4())
|
||||
update_fields["key"] = newly_generated_key
|
||||
@@ -961,6 +1184,29 @@ class DeleteAgent(Resource):
|
||||
jsonify({"success": False, "message": "Agent not found"}), 404
|
||||
)
|
||||
deleted_id = str(deleted_agent["_id"])
|
||||
|
||||
if deleted_agent.get("agent_type") == "workflow" and deleted_agent.get(
|
||||
"workflow"
|
||||
):
|
||||
workflow_id = normalize_workflow_reference(deleted_agent.get("workflow"))
|
||||
if workflow_id and ObjectId.is_valid(workflow_id):
|
||||
workflow_oid = ObjectId(workflow_id)
|
||||
owned_workflow = workflows_collection.find_one(
|
||||
{"_id": workflow_oid, "user": user}, {"_id": 1}
|
||||
)
|
||||
if owned_workflow:
|
||||
workflow_nodes_collection.delete_many({"workflow_id": workflow_id})
|
||||
workflow_edges_collection.delete_many({"workflow_id": workflow_id})
|
||||
workflows_collection.delete_one({"_id": workflow_oid, "user": user})
|
||||
else:
|
||||
current_app.logger.warning(
|
||||
f"Skipping workflow cleanup for non-owned workflow {workflow_id}"
|
||||
)
|
||||
elif workflow_id:
|
||||
current_app.logger.warning(
|
||||
f"Skipping workflow cleanup for invalid workflow id {workflow_id}"
|
||||
)
|
||||
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error deleting agent: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
@@ -1069,19 +1315,16 @@ class AdoptAgent(Resource):
|
||||
def post(self):
|
||||
if not (decoded_token := request.decoded_token):
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
|
||||
if not (agent_id := request.args.get("id")):
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "ID required"}), 400
|
||||
)
|
||||
|
||||
try:
|
||||
agent = agents_collection.find_one(
|
||||
{"_id": ObjectId(agent_id), "user": "system"}
|
||||
)
|
||||
if not agent:
|
||||
return make_response(jsonify({"status": "Not found"}), 404)
|
||||
|
||||
new_agent = agent.copy()
|
||||
new_agent.pop("_id", None)
|
||||
new_agent["user"] = decoded_token["sub"]
|
||||
|
||||
@@ -38,6 +38,10 @@ users_collection = db["users"]
|
||||
user_logs_collection = db["user_logs"]
|
||||
user_tools_collection = db["user_tools"]
|
||||
attachments_collection = db["attachments"]
|
||||
workflow_runs_collection = db["workflow_runs"]
|
||||
workflows_collection = db["workflows"]
|
||||
workflow_nodes_collection = db["workflow_nodes"]
|
||||
workflow_edges_collection = db["workflow_edges"]
|
||||
|
||||
|
||||
try:
|
||||
@@ -47,6 +51,25 @@ try:
|
||||
background=True,
|
||||
)
|
||||
users_collection.create_index("user_id", unique=True)
|
||||
workflows_collection.create_index(
|
||||
[("user", 1)], name="workflow_user_index", background=True
|
||||
)
|
||||
workflow_nodes_collection.create_index(
|
||||
[("workflow_id", 1)], name="node_workflow_index", background=True
|
||||
)
|
||||
workflow_nodes_collection.create_index(
|
||||
[("workflow_id", 1), ("graph_version", 1)],
|
||||
name="node_workflow_graph_version_index",
|
||||
background=True,
|
||||
)
|
||||
workflow_edges_collection.create_index(
|
||||
[("workflow_id", 1)], name="edge_workflow_index", background=True
|
||||
)
|
||||
workflow_edges_collection.create_index(
|
||||
[("workflow_id", 1), ("graph_version", 1)],
|
||||
name="edge_workflow_graph_version_index",
|
||||
background=True,
|
||||
)
|
||||
except Exception as e:
|
||||
print("Error creating indexes:", e)
|
||||
current_dir = os.path.dirname(
|
||||
|
||||
@@ -6,7 +6,6 @@ from flask import Blueprint
|
||||
|
||||
from application.api import api
|
||||
from .agents import agents_ns, agents_sharing_ns, agents_webhooks_ns, agents_folders_ns
|
||||
|
||||
from .analytics import analytics_ns
|
||||
from .attachments import attachments_ns
|
||||
from .conversations import conversations_ns
|
||||
@@ -15,6 +14,7 @@ from .prompts import prompts_ns
|
||||
from .sharing import sharing_ns
|
||||
from .sources import sources_chunks_ns, sources_ns, sources_upload_ns
|
||||
from .tools import tools_mcp_ns, tools_ns
|
||||
from .workflows import workflows_ns
|
||||
|
||||
|
||||
user = Blueprint("user", __name__)
|
||||
@@ -51,3 +51,6 @@ api.add_namespace(sources_upload_ns)
|
||||
# Tools (main, MCP)
|
||||
api.add_namespace(tools_ns)
|
||||
api.add_namespace(tools_mcp_ns)
|
||||
|
||||
# Workflows
|
||||
api.add_namespace(workflows_ns)
|
||||
|
||||
378
application/api/user/utils.py
Normal file
378
application/api/user/utils.py
Normal file
@@ -0,0 +1,378 @@
|
||||
"""Centralized utilities for API routes."""
|
||||
|
||||
from functools import wraps
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
from bson.errors import InvalidId
|
||||
from bson.objectid import ObjectId
|
||||
from flask import jsonify, make_response, request, Response
|
||||
from pymongo.collection import Collection
|
||||
|
||||
|
||||
def get_user_id() -> Optional[str]:
|
||||
"""
|
||||
Extract user ID from decoded JWT token.
|
||||
|
||||
Returns:
|
||||
User ID string or None if not authenticated
|
||||
"""
|
||||
decoded_token = getattr(request, "decoded_token", None)
|
||||
return decoded_token.get("sub") if decoded_token else None
|
||||
|
||||
|
||||
def require_auth(func: Callable) -> Callable:
|
||||
"""
|
||||
Decorator to require authentication for route handlers.
|
||||
|
||||
Usage:
|
||||
@require_auth
|
||||
def get(self):
|
||||
user_id = get_user_id()
|
||||
...
|
||||
"""
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
user_id = get_user_id()
|
||||
if not user_id:
|
||||
return error_response("Unauthorized", 401)
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def success_response(
|
||||
data: Optional[Dict[str, Any]] = None, status: int = 200
|
||||
) -> Response:
|
||||
"""
|
||||
Create a standardized success response.
|
||||
|
||||
Args:
|
||||
data: Optional data dictionary to include in response
|
||||
status: HTTP status code (default: 200)
|
||||
|
||||
Returns:
|
||||
Flask Response object
|
||||
|
||||
Example:
|
||||
return success_response({"users": [...], "total": 10})
|
||||
"""
|
||||
response = {"success": True}
|
||||
if data:
|
||||
response.update(data)
|
||||
return make_response(jsonify(response), status)
|
||||
|
||||
|
||||
def error_response(message: str, status: int = 400, **kwargs) -> Response:
|
||||
"""
|
||||
Create a standardized error response.
|
||||
|
||||
Args:
|
||||
message: Error message string
|
||||
status: HTTP status code (default: 400)
|
||||
**kwargs: Additional fields to include in response
|
||||
|
||||
Returns:
|
||||
Flask Response object
|
||||
|
||||
Example:
|
||||
return error_response("Resource not found", 404)
|
||||
return error_response("Invalid input", 400, errors=["field1", "field2"])
|
||||
"""
|
||||
response = {"success": False, "message": message}
|
||||
response.update(kwargs)
|
||||
return make_response(jsonify(response), status)
|
||||
|
||||
|
||||
def validate_object_id(
|
||||
id_string: str, resource_name: str = "Resource"
|
||||
) -> Tuple[Optional[ObjectId], Optional[Response]]:
|
||||
"""
|
||||
Validate and convert string to ObjectId.
|
||||
|
||||
Args:
|
||||
id_string: String to convert
|
||||
resource_name: Name of resource for error message
|
||||
|
||||
Returns:
|
||||
Tuple of (ObjectId or None, error_response or None)
|
||||
|
||||
Example:
|
||||
obj_id, error = validate_object_id(workflow_id, "Workflow")
|
||||
if error:
|
||||
return error
|
||||
"""
|
||||
try:
|
||||
return ObjectId(id_string), None
|
||||
except (InvalidId, TypeError):
|
||||
return None, error_response(f"Invalid {resource_name} ID format")
|
||||
|
||||
|
||||
def validate_pagination(
|
||||
default_limit: int = 20, max_limit: int = 100
|
||||
) -> Tuple[int, int, Optional[Response]]:
|
||||
"""
|
||||
Extract and validate pagination parameters from request.
|
||||
|
||||
Args:
|
||||
default_limit: Default items per page
|
||||
max_limit: Maximum allowed items per page
|
||||
|
||||
Returns:
|
||||
Tuple of (limit, skip, error_response or None)
|
||||
|
||||
Example:
|
||||
limit, skip, error = validate_pagination()
|
||||
if error:
|
||||
return error
|
||||
"""
|
||||
try:
|
||||
limit = min(int(request.args.get("limit", default_limit)), max_limit)
|
||||
skip = int(request.args.get("skip", 0))
|
||||
if limit < 1 or skip < 0:
|
||||
return 0, 0, error_response("Invalid pagination parameters")
|
||||
return limit, skip, None
|
||||
except ValueError:
|
||||
return 0, 0, error_response("Invalid pagination parameters")
|
||||
|
||||
|
||||
def check_resource_ownership(
|
||||
collection: Collection,
|
||||
resource_id: ObjectId,
|
||||
user_id: str,
|
||||
resource_name: str = "Resource",
|
||||
) -> Tuple[Optional[Dict], Optional[Response]]:
|
||||
"""
|
||||
Check if resource exists and belongs to user.
|
||||
|
||||
Args:
|
||||
collection: MongoDB collection
|
||||
resource_id: Resource ObjectId
|
||||
user_id: User ID string
|
||||
resource_name: Name of resource for error messages
|
||||
|
||||
Returns:
|
||||
Tuple of (resource_dict or None, error_response or None)
|
||||
|
||||
Example:
|
||||
workflow, error = check_resource_ownership(
|
||||
workflows_collection,
|
||||
workflow_id,
|
||||
user_id,
|
||||
"Workflow"
|
||||
)
|
||||
if error:
|
||||
return error
|
||||
"""
|
||||
resource = collection.find_one({"_id": resource_id, "user": user_id})
|
||||
if not resource:
|
||||
return None, error_response(f"{resource_name} not found", 404)
|
||||
return resource, None
|
||||
|
||||
|
||||
def serialize_object_id(
|
||||
obj: Dict[str, Any], id_field: str = "_id", new_field: str = "id"
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Convert ObjectId to string in a dictionary.
|
||||
|
||||
Args:
|
||||
obj: Dictionary containing ObjectId
|
||||
id_field: Field name containing ObjectId
|
||||
new_field: New field name for string ID
|
||||
|
||||
Returns:
|
||||
Modified dictionary
|
||||
|
||||
Example:
|
||||
user = serialize_object_id(user_doc)
|
||||
# user["id"] = "507f1f77bcf86cd799439011"
|
||||
"""
|
||||
if id_field in obj:
|
||||
obj[new_field] = str(obj[id_field])
|
||||
if id_field != new_field:
|
||||
obj.pop(id_field, None)
|
||||
return obj
|
||||
|
||||
|
||||
def serialize_list(items: List[Dict], serializer: Callable[[Dict], Dict]) -> List[Dict]:
|
||||
"""
|
||||
Apply serializer function to list of items.
|
||||
|
||||
Args:
|
||||
items: List of dictionaries
|
||||
serializer: Function to apply to each item
|
||||
|
||||
Returns:
|
||||
List of serialized items
|
||||
|
||||
Example:
|
||||
workflows = serialize_list(workflow_docs, serialize_workflow)
|
||||
"""
|
||||
return [serializer(item) for item in items]
|
||||
|
||||
|
||||
def paginated_response(
|
||||
collection: Collection,
|
||||
query: Dict[str, Any],
|
||||
serializer: Callable[[Dict], Dict],
|
||||
limit: int,
|
||||
skip: int,
|
||||
sort_field: str = "created_at",
|
||||
sort_order: int = -1,
|
||||
response_key: str = "items",
|
||||
) -> Response:
|
||||
"""
|
||||
Create paginated response for collection query.
|
||||
|
||||
Args:
|
||||
collection: MongoDB collection
|
||||
query: Query dictionary
|
||||
serializer: Function to serialize each item
|
||||
limit: Items per page
|
||||
skip: Number of items to skip
|
||||
sort_field: Field to sort by
|
||||
sort_order: Sort order (1=asc, -1=desc)
|
||||
response_key: Key name for items in response
|
||||
|
||||
Returns:
|
||||
Flask Response with paginated data
|
||||
|
||||
Example:
|
||||
return paginated_response(
|
||||
workflows_collection,
|
||||
{"user": user_id},
|
||||
serialize_workflow,
|
||||
limit, skip,
|
||||
response_key="workflows"
|
||||
)
|
||||
"""
|
||||
items = list(
|
||||
collection.find(query).sort(sort_field, sort_order).skip(skip).limit(limit)
|
||||
)
|
||||
total = collection.count_documents(query)
|
||||
|
||||
return success_response(
|
||||
{
|
||||
response_key: serialize_list(items, serializer),
|
||||
"total": total,
|
||||
"limit": limit,
|
||||
"skip": skip,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def require_fields(required: List[str]) -> Callable:
|
||||
"""
|
||||
Decorator to validate required fields in request JSON.
|
||||
|
||||
Args:
|
||||
required: List of required field names
|
||||
|
||||
Returns:
|
||||
Decorator function
|
||||
|
||||
Example:
|
||||
@require_fields(["name", "description"])
|
||||
def post(self):
|
||||
data = request.get_json()
|
||||
...
|
||||
"""
|
||||
|
||||
def decorator(func: Callable) -> Callable:
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
data = request.get_json()
|
||||
if not data:
|
||||
return error_response("Request body required")
|
||||
missing = [field for field in required if not data.get(field)]
|
||||
if missing:
|
||||
return error_response(f"Missing required fields: {', '.join(missing)}")
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def safe_db_operation(
|
||||
operation: Callable, error_message: str = "Database operation failed"
|
||||
) -> Tuple[Any, Optional[Response]]:
|
||||
"""
|
||||
Safely execute database operation with error handling.
|
||||
|
||||
Args:
|
||||
operation: Function to execute
|
||||
error_message: Error message if operation fails
|
||||
|
||||
Returns:
|
||||
Tuple of (result or None, error_response or None)
|
||||
|
||||
Example:
|
||||
result, error = safe_db_operation(
|
||||
lambda: collection.insert_one(doc),
|
||||
"Failed to create resource"
|
||||
)
|
||||
if error:
|
||||
return error
|
||||
"""
|
||||
try:
|
||||
result = operation()
|
||||
return result, None
|
||||
except Exception as e:
|
||||
return None, error_response(f"{error_message}: {str(e)}")
|
||||
|
||||
|
||||
def validate_enum(
|
||||
value: Any, allowed: List[Any], field_name: str
|
||||
) -> Optional[Response]:
|
||||
"""
|
||||
Validate that value is in allowed list.
|
||||
|
||||
Args:
|
||||
value: Value to validate
|
||||
allowed: List of allowed values
|
||||
field_name: Field name for error message
|
||||
|
||||
Returns:
|
||||
error_response if invalid, None if valid
|
||||
|
||||
Example:
|
||||
error = validate_enum(status, ["draft", "published"], "status")
|
||||
if error:
|
||||
return error
|
||||
"""
|
||||
if value not in allowed:
|
||||
allowed_str = ", ".join(f"'{v}'" for v in allowed)
|
||||
return error_response(f"Invalid {field_name}. Must be one of: {allowed_str}")
|
||||
return None
|
||||
|
||||
|
||||
def extract_sort_params(
|
||||
default_field: str = "created_at",
|
||||
default_order: str = "desc",
|
||||
allowed_fields: Optional[List[str]] = None,
|
||||
) -> Tuple[str, int]:
|
||||
"""
|
||||
Extract and validate sort parameters from request.
|
||||
|
||||
Args:
|
||||
default_field: Default sort field
|
||||
default_order: Default sort order ("asc" or "desc")
|
||||
allowed_fields: List of allowed sort fields (None = no validation)
|
||||
|
||||
Returns:
|
||||
Tuple of (sort_field, sort_order)
|
||||
|
||||
Example:
|
||||
sort_field, sort_order = extract_sort_params(
|
||||
allowed_fields=["name", "date", "status"]
|
||||
)
|
||||
"""
|
||||
sort_field = request.args.get("sort", default_field)
|
||||
sort_order_str = request.args.get("order", default_order).lower()
|
||||
|
||||
if allowed_fields and sort_field not in allowed_fields:
|
||||
sort_field = default_field
|
||||
sort_order = -1 if sort_order_str == "desc" else 1
|
||||
return sort_field, sort_order
|
||||
3
application/api/user/workflows/__init__.py
Normal file
3
application/api/user/workflows/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .routes import workflows_ns
|
||||
|
||||
__all__ = ["workflows_ns"]
|
||||
353
application/api/user/workflows/routes.py
Normal file
353
application/api/user/workflows/routes.py
Normal file
@@ -0,0 +1,353 @@
|
||||
"""Workflow management routes."""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import Dict, List
|
||||
|
||||
from flask import current_app, request
|
||||
from flask_restx import Namespace, Resource
|
||||
|
||||
from application.api.user.base import (
|
||||
workflow_edges_collection,
|
||||
workflow_nodes_collection,
|
||||
workflows_collection,
|
||||
)
|
||||
from application.api.user.utils import (
|
||||
check_resource_ownership,
|
||||
error_response,
|
||||
get_user_id,
|
||||
require_auth,
|
||||
require_fields,
|
||||
safe_db_operation,
|
||||
success_response,
|
||||
validate_object_id,
|
||||
)
|
||||
|
||||
workflows_ns = Namespace("workflows", path="/api")
|
||||
|
||||
|
||||
def serialize_workflow(w: Dict) -> Dict:
|
||||
"""Serialize workflow document to API response format."""
|
||||
return {
|
||||
"id": str(w["_id"]),
|
||||
"name": w.get("name"),
|
||||
"description": w.get("description"),
|
||||
"created_at": w["created_at"].isoformat() if w.get("created_at") else None,
|
||||
"updated_at": w["updated_at"].isoformat() if w.get("updated_at") else None,
|
||||
}
|
||||
|
||||
|
||||
def serialize_node(n: Dict) -> Dict:
|
||||
"""Serialize workflow node document to API response format."""
|
||||
return {
|
||||
"id": n["id"],
|
||||
"type": n["type"],
|
||||
"title": n.get("title"),
|
||||
"description": n.get("description"),
|
||||
"position": n.get("position"),
|
||||
"data": n.get("config", {}),
|
||||
}
|
||||
|
||||
|
||||
def serialize_edge(e: Dict) -> Dict:
|
||||
"""Serialize workflow edge document to API response format."""
|
||||
return {
|
||||
"id": e["id"],
|
||||
"source": e.get("source_id"),
|
||||
"target": e.get("target_id"),
|
||||
"sourceHandle": e.get("source_handle"),
|
||||
"targetHandle": e.get("target_handle"),
|
||||
}
|
||||
|
||||
|
||||
def get_workflow_graph_version(workflow: Dict) -> int:
|
||||
"""Get current graph version with legacy fallback."""
|
||||
raw_version = workflow.get("current_graph_version", 1)
|
||||
try:
|
||||
version = int(raw_version)
|
||||
return version if version > 0 else 1
|
||||
except (ValueError, TypeError):
|
||||
return 1
|
||||
|
||||
|
||||
def fetch_graph_documents(collection, workflow_id: str, graph_version: int) -> List[Dict]:
|
||||
"""Fetch graph docs for active version, with fallback for legacy unversioned data."""
|
||||
docs = list(
|
||||
collection.find({"workflow_id": workflow_id, "graph_version": graph_version})
|
||||
)
|
||||
if docs:
|
||||
return docs
|
||||
if graph_version == 1:
|
||||
return list(
|
||||
collection.find(
|
||||
{"workflow_id": workflow_id, "graph_version": {"$exists": False}}
|
||||
)
|
||||
)
|
||||
return docs
|
||||
|
||||
|
||||
def validate_workflow_structure(nodes: List[Dict], edges: List[Dict]) -> List[str]:
|
||||
"""Validate workflow graph structure."""
|
||||
errors = []
|
||||
|
||||
if not nodes:
|
||||
errors.append("Workflow must have at least one node")
|
||||
return errors
|
||||
|
||||
start_nodes = [n for n in nodes if n.get("type") == "start"]
|
||||
if len(start_nodes) != 1:
|
||||
errors.append("Workflow must have exactly one start node")
|
||||
|
||||
end_nodes = [n for n in nodes if n.get("type") == "end"]
|
||||
if not end_nodes:
|
||||
errors.append("Workflow must have at least one end node")
|
||||
|
||||
node_ids = {n.get("id") for n in nodes}
|
||||
for edge in edges:
|
||||
source_id = edge.get("source")
|
||||
target_id = edge.get("target")
|
||||
if source_id not in node_ids:
|
||||
errors.append(f"Edge references non-existent source: {source_id}")
|
||||
if target_id not in node_ids:
|
||||
errors.append(f"Edge references non-existent target: {target_id}")
|
||||
|
||||
if start_nodes:
|
||||
start_id = start_nodes[0].get("id")
|
||||
if not any(e.get("source") == start_id for e in edges):
|
||||
errors.append("Start node must have at least one outgoing edge")
|
||||
|
||||
for node in nodes:
|
||||
if not node.get("id"):
|
||||
errors.append("All nodes must have an id")
|
||||
if not node.get("type"):
|
||||
errors.append(f"Node {node.get('id', 'unknown')} must have a type")
|
||||
|
||||
return errors
|
||||
|
||||
|
||||
def create_workflow_nodes(
|
||||
workflow_id: str, nodes_data: List[Dict], graph_version: int
|
||||
) -> None:
|
||||
"""Insert workflow nodes into database."""
|
||||
if nodes_data:
|
||||
workflow_nodes_collection.insert_many(
|
||||
[
|
||||
{
|
||||
"id": n["id"],
|
||||
"workflow_id": workflow_id,
|
||||
"graph_version": graph_version,
|
||||
"type": n["type"],
|
||||
"title": n.get("title", ""),
|
||||
"description": n.get("description", ""),
|
||||
"position": n.get("position", {"x": 0, "y": 0}),
|
||||
"config": n.get("data", {}),
|
||||
}
|
||||
for n in nodes_data
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def create_workflow_edges(
|
||||
workflow_id: str, edges_data: List[Dict], graph_version: int
|
||||
) -> None:
|
||||
"""Insert workflow edges into database."""
|
||||
if edges_data:
|
||||
workflow_edges_collection.insert_many(
|
||||
[
|
||||
{
|
||||
"id": e["id"],
|
||||
"workflow_id": workflow_id,
|
||||
"graph_version": graph_version,
|
||||
"source_id": e.get("source"),
|
||||
"target_id": e.get("target"),
|
||||
"source_handle": e.get("sourceHandle"),
|
||||
"target_handle": e.get("targetHandle"),
|
||||
}
|
||||
for e in edges_data
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@workflows_ns.route("/workflows")
|
||||
class WorkflowList(Resource):
|
||||
|
||||
@require_auth
|
||||
@require_fields(["name"])
|
||||
def post(self):
|
||||
"""Create a new workflow with nodes and edges."""
|
||||
user_id = get_user_id()
|
||||
data = request.get_json()
|
||||
|
||||
name = data.get("name", "").strip()
|
||||
nodes_data = data.get("nodes", [])
|
||||
edges_data = data.get("edges", [])
|
||||
|
||||
validation_errors = validate_workflow_structure(nodes_data, edges_data)
|
||||
if validation_errors:
|
||||
return error_response(
|
||||
"Workflow validation failed", errors=validation_errors
|
||||
)
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
workflow_doc = {
|
||||
"name": name,
|
||||
"description": data.get("description", ""),
|
||||
"user": user_id,
|
||||
"created_at": now,
|
||||
"updated_at": now,
|
||||
"current_graph_version": 1,
|
||||
}
|
||||
|
||||
result, error = safe_db_operation(
|
||||
lambda: workflows_collection.insert_one(workflow_doc),
|
||||
"Failed to create workflow",
|
||||
)
|
||||
if error:
|
||||
return error
|
||||
|
||||
workflow_id = str(result.inserted_id)
|
||||
|
||||
try:
|
||||
create_workflow_nodes(workflow_id, nodes_data, 1)
|
||||
create_workflow_edges(workflow_id, edges_data, 1)
|
||||
except Exception as e:
|
||||
workflow_nodes_collection.delete_many({"workflow_id": workflow_id})
|
||||
workflow_edges_collection.delete_many({"workflow_id": workflow_id})
|
||||
workflows_collection.delete_one({"_id": result.inserted_id})
|
||||
return error_response(f"Failed to create workflow structure: {str(e)}")
|
||||
|
||||
return success_response({"id": workflow_id}, 201)
|
||||
|
||||
|
||||
@workflows_ns.route("/workflows/<string:workflow_id>")
|
||||
class WorkflowDetail(Resource):
|
||||
|
||||
@require_auth
|
||||
def get(self, workflow_id: str):
|
||||
"""Get workflow details with nodes and edges."""
|
||||
user_id = get_user_id()
|
||||
obj_id, error = validate_object_id(workflow_id, "Workflow")
|
||||
if error:
|
||||
return error
|
||||
|
||||
workflow, error = check_resource_ownership(
|
||||
workflows_collection, obj_id, user_id, "Workflow"
|
||||
)
|
||||
if error:
|
||||
return error
|
||||
|
||||
graph_version = get_workflow_graph_version(workflow)
|
||||
nodes = fetch_graph_documents(
|
||||
workflow_nodes_collection, workflow_id, graph_version
|
||||
)
|
||||
edges = fetch_graph_documents(
|
||||
workflow_edges_collection, workflow_id, graph_version
|
||||
)
|
||||
|
||||
return success_response(
|
||||
{
|
||||
"workflow": serialize_workflow(workflow),
|
||||
"nodes": [serialize_node(n) for n in nodes],
|
||||
"edges": [serialize_edge(e) for e in edges],
|
||||
}
|
||||
)
|
||||
|
||||
@require_auth
|
||||
@require_fields(["name"])
|
||||
def put(self, workflow_id: str):
|
||||
"""Update workflow and replace nodes/edges."""
|
||||
user_id = get_user_id()
|
||||
obj_id, error = validate_object_id(workflow_id, "Workflow")
|
||||
if error:
|
||||
return error
|
||||
|
||||
workflow, error = check_resource_ownership(
|
||||
workflows_collection, obj_id, user_id, "Workflow"
|
||||
)
|
||||
if error:
|
||||
return error
|
||||
|
||||
data = request.get_json()
|
||||
name = data.get("name", "").strip()
|
||||
nodes_data = data.get("nodes", [])
|
||||
edges_data = data.get("edges", [])
|
||||
|
||||
validation_errors = validate_workflow_structure(nodes_data, edges_data)
|
||||
if validation_errors:
|
||||
return error_response(
|
||||
"Workflow validation failed", errors=validation_errors
|
||||
)
|
||||
|
||||
current_graph_version = get_workflow_graph_version(workflow)
|
||||
next_graph_version = current_graph_version + 1
|
||||
try:
|
||||
create_workflow_nodes(workflow_id, nodes_data, next_graph_version)
|
||||
create_workflow_edges(workflow_id, edges_data, next_graph_version)
|
||||
except Exception as e:
|
||||
workflow_nodes_collection.delete_many(
|
||||
{"workflow_id": workflow_id, "graph_version": next_graph_version}
|
||||
)
|
||||
workflow_edges_collection.delete_many(
|
||||
{"workflow_id": workflow_id, "graph_version": next_graph_version}
|
||||
)
|
||||
return error_response(f"Failed to update workflow structure: {str(e)}")
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
_, error = safe_db_operation(
|
||||
lambda: workflows_collection.update_one(
|
||||
{"_id": obj_id},
|
||||
{
|
||||
"$set": {
|
||||
"name": name,
|
||||
"description": data.get("description", ""),
|
||||
"updated_at": now,
|
||||
"current_graph_version": next_graph_version,
|
||||
}
|
||||
},
|
||||
),
|
||||
"Failed to update workflow",
|
||||
)
|
||||
if error:
|
||||
workflow_nodes_collection.delete_many(
|
||||
{"workflow_id": workflow_id, "graph_version": next_graph_version}
|
||||
)
|
||||
workflow_edges_collection.delete_many(
|
||||
{"workflow_id": workflow_id, "graph_version": next_graph_version}
|
||||
)
|
||||
return error
|
||||
|
||||
try:
|
||||
workflow_nodes_collection.delete_many(
|
||||
{"workflow_id": workflow_id, "graph_version": {"$ne": next_graph_version}}
|
||||
)
|
||||
workflow_edges_collection.delete_many(
|
||||
{"workflow_id": workflow_id, "graph_version": {"$ne": next_graph_version}}
|
||||
)
|
||||
except Exception as cleanup_err:
|
||||
current_app.logger.warning(
|
||||
f"Failed to clean old workflow graph versions for {workflow_id}: {cleanup_err}"
|
||||
)
|
||||
|
||||
return success_response()
|
||||
|
||||
@require_auth
|
||||
def delete(self, workflow_id: str):
|
||||
"""Delete workflow and its graph."""
|
||||
user_id = get_user_id()
|
||||
obj_id, error = validate_object_id(workflow_id, "Workflow")
|
||||
if error:
|
||||
return error
|
||||
|
||||
workflow, error = check_resource_ownership(
|
||||
workflows_collection, obj_id, user_id, "Workflow"
|
||||
)
|
||||
if error:
|
||||
return error
|
||||
|
||||
try:
|
||||
workflow_nodes_collection.delete_many({"workflow_id": workflow_id})
|
||||
workflow_edges_collection.delete_many({"workflow_id": workflow_id})
|
||||
workflows_collection.delete_one({"_id": workflow["_id"], "user": user_id})
|
||||
except Exception as e:
|
||||
return error_response(f"Failed to delete workflow: {str(e)}")
|
||||
|
||||
return success_response()
|
||||
Reference in New Issue
Block a user