feat: agent workflow builder (#2264)

* feat: implement WorkflowAgent and GraphExecutor for workflow management and execution

* refactor: workflow schemas and introduce WorkflowEngine

- Updated schemas in `schemas.py` to include new agent types and configurations.
- Created `WorkflowEngine` class in `workflow_engine.py` to manage workflow execution.
- Enhanced `StreamProcessor` to handle workflow-related data.
- Added new routes and utilities for managing workflows in the user API.
- Implemented validation and serialization functions for workflows.
- Established MongoDB collections and indexes for workflows and related entities.

* refactor: improve WorkflowAgent documentation and update type hints in WorkflowEngine

* feat: workflow builder and managing in frontend

- Added new endpoints for workflows in `endpoints.ts`.
- Implemented `getWorkflow`, `createWorkflow`, and `updateWorkflow` methods in `userService.ts`.
- Introduced new UI components for alerts, buttons, commands, dialogs, multi-select, popovers, and selects.
- Enhanced styling in `index.css` with new theme variables and animations.
- Refactored modal components for better layout and styling.
- Configured TypeScript paths and Vite aliases for cleaner imports.

* feat: add workflow preview component and related state management

- Implemented WorkflowPreview component for displaying workflow execution.
- Created WorkflowPreviewSlice for managing workflow preview state, including queries and execution steps.
- Added WorkflowMiniMap for visual representation of workflow nodes and their statuses.
- Integrated conversation handling with the ability to fetch answers and manage query states.
- Introduced reusable Sheet component for UI overlays.
- Updated Redux store to include workflowPreview reducer.

* feat: enhance workflow execution details and state management in WorkflowEngine and WorkflowPreview

* feat: enhance workflow components with improved UI and functionality

- Updated WorkflowPreview to allow text truncation for better display of long names.
- Enhanced BaseNode with connectable handles and improved styling for better visibility.
- Added MobileBlocker component to inform users about desktop requirements for the Workflow Builder.
- Introduced PromptTextArea component for improved variable insertion and search functionality, including upstream variable extraction and context addition.

* feat(workflow): add owner validation and graph version support

* fix: ruff lint

---------

Co-authored-by: Alex <a@tushynski.me>
This commit is contained in:
Siddhant Rai
2026-02-11 19:45:24 +05:30
committed by GitHub
parent 8353f9c649
commit 8ef321d784
52 changed files with 8634 additions and 222 deletions

View File

@@ -1,6 +1,8 @@
import logging
from application.agents.classic_agent import ClassicAgent
from application.agents.react_agent import ReActAgent
import logging
from application.agents.workflow_agent import WorkflowAgent
logger = logging.getLogger(__name__)
@@ -9,6 +11,7 @@ class AgentCreator:
agents = {
"classic": ClassicAgent,
"react": ReActAgent,
"workflow": WorkflowAgent,
}
@classmethod
@@ -16,5 +19,4 @@ class AgentCreator:
agent_class = cls.agents.get(type.lower())
if not agent_class:
raise ValueError(f"No agent class found for type {type}")
return agent_class(*args, **kwargs)

View File

@@ -367,7 +367,9 @@ class BaseAgent(ABC):
f"Context at limit: {current_tokens:,}/{context_limit:,} tokens "
f"({percentage:.1f}%). Model: {self.model_id}"
)
elif current_tokens >= int(context_limit * settings.COMPRESSION_THRESHOLD_PERCENTAGE):
elif current_tokens >= int(
context_limit * settings.COMPRESSION_THRESHOLD_PERCENTAGE
):
logger.info(
f"Context approaching limit: {current_tokens:,}/{context_limit:,} tokens "
f"({percentage:.1f}%)"

View File

@@ -0,0 +1,218 @@
import logging
from datetime import datetime, timezone
from typing import Any, Dict, Generator, Optional
from application.agents.base import BaseAgent
from application.agents.workflows.schemas import (
ExecutionStatus,
Workflow,
WorkflowEdge,
WorkflowGraph,
WorkflowNode,
WorkflowRun,
)
from application.agents.workflows.workflow_engine import WorkflowEngine
from application.core.mongo_db import MongoDB
from application.core.settings import settings
from application.logging import log_activity, LogContext
logger = logging.getLogger(__name__)
class WorkflowAgent(BaseAgent):
"""A specialized agent that executes predefined workflows."""
def __init__(
self,
*args,
workflow_id: Optional[str] = None,
workflow: Optional[Dict[str, Any]] = None,
workflow_owner: Optional[str] = None,
**kwargs,
):
super().__init__(*args, **kwargs)
self.workflow_id = workflow_id
self.workflow_owner = workflow_owner
self._workflow_data = workflow
self._engine: Optional[WorkflowEngine] = None
@log_activity()
def gen(
self, query: str, log_context: LogContext = None
) -> Generator[Dict[str, str], None, None]:
yield from self._gen_inner(query, log_context)
def _gen_inner(
self, query: str, log_context: LogContext
) -> Generator[Dict[str, str], None, None]:
graph = self._load_workflow_graph()
if not graph:
yield {"type": "error", "error": "Failed to load workflow configuration."}
return
self._engine = WorkflowEngine(graph, self)
yield from self._engine.execute({}, query)
self._save_workflow_run(query)
def _load_workflow_graph(self) -> Optional[WorkflowGraph]:
if self._workflow_data:
return self._parse_embedded_workflow()
if self.workflow_id:
return self._load_from_database()
return None
def _parse_embedded_workflow(self) -> Optional[WorkflowGraph]:
try:
nodes_data = self._workflow_data.get("nodes", [])
edges_data = self._workflow_data.get("edges", [])
workflow = Workflow(
name=self._workflow_data.get("name", "Embedded Workflow"),
description=self._workflow_data.get("description"),
)
nodes = []
for n in nodes_data:
node_config = n.get("data", {})
nodes.append(
WorkflowNode(
id=n["id"],
workflow_id=self.workflow_id or "embedded",
type=n["type"],
title=n.get("title", "Node"),
description=n.get("description"),
position=n.get("position", {"x": 0, "y": 0}),
config=node_config,
)
)
edges = []
for e in edges_data:
edges.append(
WorkflowEdge(
id=e["id"],
workflow_id=self.workflow_id or "embedded",
source=e.get("source") or e.get("source_id"),
target=e.get("target") or e.get("target_id"),
sourceHandle=e.get("sourceHandle") or e.get("source_handle"),
targetHandle=e.get("targetHandle") or e.get("target_handle"),
)
)
return WorkflowGraph(workflow=workflow, nodes=nodes, edges=edges)
except Exception as e:
logger.error(f"Invalid embedded workflow: {e}")
return None
def _load_from_database(self) -> Optional[WorkflowGraph]:
try:
from bson.objectid import ObjectId
if not self.workflow_id or not ObjectId.is_valid(self.workflow_id):
logger.error(f"Invalid workflow ID: {self.workflow_id}")
return None
owner_id = self.workflow_owner
if not owner_id and isinstance(self.decoded_token, dict):
owner_id = self.decoded_token.get("sub")
if not owner_id:
logger.error(
f"Workflow owner not available for workflow load: {self.workflow_id}"
)
return None
mongo = MongoDB.get_client()
db = mongo[settings.MONGO_DB_NAME]
workflows_coll = db["workflows"]
workflow_nodes_coll = db["workflow_nodes"]
workflow_edges_coll = db["workflow_edges"]
workflow_doc = workflows_coll.find_one(
{"_id": ObjectId(self.workflow_id), "user": owner_id}
)
if not workflow_doc:
logger.error(
f"Workflow {self.workflow_id} not found or inaccessible for user {owner_id}"
)
return None
workflow = Workflow(**workflow_doc)
graph_version = workflow_doc.get("current_graph_version", 1)
try:
graph_version = int(graph_version)
if graph_version <= 0:
graph_version = 1
except (ValueError, TypeError):
graph_version = 1
nodes_docs = list(
workflow_nodes_coll.find(
{"workflow_id": self.workflow_id, "graph_version": graph_version}
)
)
if not nodes_docs and graph_version == 1:
nodes_docs = list(
workflow_nodes_coll.find(
{
"workflow_id": self.workflow_id,
"graph_version": {"$exists": False},
}
)
)
nodes = [WorkflowNode(**doc) for doc in nodes_docs]
edges_docs = list(
workflow_edges_coll.find(
{"workflow_id": self.workflow_id, "graph_version": graph_version}
)
)
if not edges_docs and graph_version == 1:
edges_docs = list(
workflow_edges_coll.find(
{
"workflow_id": self.workflow_id,
"graph_version": {"$exists": False},
}
)
)
edges = [WorkflowEdge(**doc) for doc in edges_docs]
return WorkflowGraph(workflow=workflow, nodes=nodes, edges=edges)
except Exception as e:
logger.error(f"Failed to load workflow from database: {e}")
return None
def _save_workflow_run(self, query: str) -> None:
if not self._engine:
return
try:
mongo = MongoDB.get_client()
db = mongo[settings.MONGO_DB_NAME]
workflow_runs_coll = db["workflow_runs"]
run = WorkflowRun(
workflow_id=self.workflow_id or "unknown",
status=self._determine_run_status(),
inputs={"query": query},
outputs=self._serialize_state(self._engine.state),
steps=self._engine.get_execution_summary(),
created_at=datetime.now(timezone.utc),
completed_at=datetime.now(timezone.utc),
)
workflow_runs_coll.insert_one(run.to_mongo_doc())
except Exception as e:
logger.error(f"Failed to save workflow run: {e}")
def _determine_run_status(self) -> ExecutionStatus:
if not self._engine or not self._engine.execution_log:
return ExecutionStatus.COMPLETED
for log in self._engine.execution_log:
if log.get("status") == ExecutionStatus.FAILED.value:
return ExecutionStatus.FAILED
return ExecutionStatus.COMPLETED
def _serialize_state(self, state: Dict[str, Any]) -> Dict[str, Any]:
serialized: Dict[str, Any] = {}
for key, value in state.items():
if isinstance(value, (str, int, float, bool, type(None))):
serialized[key] = value
else:
serialized[key] = str(value)
return serialized

View File

@@ -0,0 +1,109 @@
"""Workflow Node Agents - defines specialized agents for workflow nodes."""
from typing import Any, Dict, List, Optional, Type
from application.agents.base import BaseAgent
from application.agents.classic_agent import ClassicAgent
from application.agents.react_agent import ReActAgent
from application.agents.workflows.schemas import AgentType
class ToolFilterMixin:
"""Mixin that filters fetched tools to only those specified in tool_ids."""
_allowed_tool_ids: List[str]
def _get_user_tools(self, user: str = "local") -> Dict[str, Dict[str, Any]]:
all_tools = super()._get_user_tools(user)
if not self._allowed_tool_ids:
return {}
filtered_tools = {
tool_id: tool
for tool_id, tool in all_tools.items()
if str(tool.get("_id", "")) in self._allowed_tool_ids
}
return filtered_tools
def _get_tools(self, api_key: str = None) -> Dict[str, Dict[str, Any]]:
all_tools = super()._get_tools(api_key)
if not self._allowed_tool_ids:
return {}
filtered_tools = {
tool_id: tool
for tool_id, tool in all_tools.items()
if str(tool.get("_id", "")) in self._allowed_tool_ids
}
return filtered_tools
class WorkflowNodeClassicAgent(ToolFilterMixin, ClassicAgent):
def __init__(
self,
endpoint: str,
llm_name: str,
model_id: str,
api_key: str,
tool_ids: Optional[List[str]] = None,
**kwargs,
):
super().__init__(
endpoint=endpoint,
llm_name=llm_name,
model_id=model_id,
api_key=api_key,
**kwargs,
)
self._allowed_tool_ids = tool_ids or []
class WorkflowNodeReActAgent(ToolFilterMixin, ReActAgent):
def __init__(
self,
endpoint: str,
llm_name: str,
model_id: str,
api_key: str,
tool_ids: Optional[List[str]] = None,
**kwargs,
):
super().__init__(
endpoint=endpoint,
llm_name=llm_name,
model_id=model_id,
api_key=api_key,
**kwargs,
)
self._allowed_tool_ids = tool_ids or []
class WorkflowNodeAgentFactory:
_agents: Dict[AgentType, Type[BaseAgent]] = {
AgentType.CLASSIC: WorkflowNodeClassicAgent,
AgentType.REACT: WorkflowNodeReActAgent,
}
@classmethod
def create(
cls,
agent_type: AgentType,
endpoint: str,
llm_name: str,
model_id: str,
api_key: str,
tool_ids: Optional[List[str]] = None,
**kwargs,
) -> BaseAgent:
agent_class = cls._agents.get(agent_type)
if not agent_class:
raise ValueError(f"Unsupported agent type: {agent_type}")
return agent_class(
endpoint=endpoint,
llm_name=llm_name,
model_id=model_id,
api_key=api_key,
tool_ids=tool_ids,
**kwargs,
)

View File

@@ -0,0 +1,215 @@
from datetime import datetime, timezone
from enum import Enum
from typing import Any, Dict, List, Optional, Union
from bson import ObjectId
from pydantic import BaseModel, ConfigDict, Field, field_validator
class NodeType(str, Enum):
START = "start"
END = "end"
AGENT = "agent"
NOTE = "note"
STATE = "state"
class AgentType(str, Enum):
CLASSIC = "classic"
REACT = "react"
class ExecutionStatus(str, Enum):
PENDING = "pending"
RUNNING = "running"
COMPLETED = "completed"
FAILED = "failed"
class Position(BaseModel):
model_config = ConfigDict(extra="forbid")
x: float = 0.0
y: float = 0.0
class AgentNodeConfig(BaseModel):
model_config = ConfigDict(extra="allow")
agent_type: AgentType = AgentType.CLASSIC
llm_name: Optional[str] = None
system_prompt: str = "You are a helpful assistant."
prompt_template: str = ""
output_variable: Optional[str] = None
stream_to_user: bool = True
tools: List[str] = Field(default_factory=list)
sources: List[str] = Field(default_factory=list)
chunks: str = "2"
retriever: str = ""
model_id: Optional[str] = None
json_schema: Optional[Dict[str, Any]] = None
class WorkflowEdgeCreate(BaseModel):
model_config = ConfigDict(populate_by_name=True)
id: str
workflow_id: str
source_id: str = Field(..., alias="source")
target_id: str = Field(..., alias="target")
source_handle: Optional[str] = Field(None, alias="sourceHandle")
target_handle: Optional[str] = Field(None, alias="targetHandle")
class WorkflowEdge(WorkflowEdgeCreate):
mongo_id: Optional[str] = Field(None, alias="_id")
@field_validator("mongo_id", mode="before")
@classmethod
def convert_objectid(cls, v: Any) -> Optional[str]:
if isinstance(v, ObjectId):
return str(v)
return v
def to_mongo_doc(self) -> Dict[str, Any]:
return {
"id": self.id,
"workflow_id": self.workflow_id,
"source_id": self.source_id,
"target_id": self.target_id,
"source_handle": self.source_handle,
"target_handle": self.target_handle,
}
class WorkflowNodeCreate(BaseModel):
model_config = ConfigDict(extra="allow")
id: str
workflow_id: str
type: NodeType
title: str = "Node"
description: Optional[str] = None
position: Position = Field(default_factory=Position)
config: Dict[str, Any] = Field(default_factory=dict)
@field_validator("position", mode="before")
@classmethod
def parse_position(cls, v: Union[Dict[str, float], Position]) -> Position:
if isinstance(v, dict):
return Position(**v)
return v
class WorkflowNode(WorkflowNodeCreate):
mongo_id: Optional[str] = Field(None, alias="_id")
@field_validator("mongo_id", mode="before")
@classmethod
def convert_objectid(cls, v: Any) -> Optional[str]:
if isinstance(v, ObjectId):
return str(v)
return v
def to_mongo_doc(self) -> Dict[str, Any]:
return {
"id": self.id,
"workflow_id": self.workflow_id,
"type": self.type.value,
"title": self.title,
"description": self.description,
"position": self.position.model_dump(),
"config": self.config,
}
class WorkflowCreate(BaseModel):
model_config = ConfigDict(extra="allow")
name: str = "New Workflow"
description: Optional[str] = None
user: Optional[str] = None
class Workflow(WorkflowCreate):
id: Optional[str] = Field(None, alias="_id")
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
updated_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
@field_validator("id", mode="before")
@classmethod
def convert_objectid(cls, v: Any) -> Optional[str]:
if isinstance(v, ObjectId):
return str(v)
return v
def to_mongo_doc(self) -> Dict[str, Any]:
return {
"name": self.name,
"description": self.description,
"user": self.user,
"created_at": self.created_at,
"updated_at": self.updated_at,
}
class WorkflowGraph(BaseModel):
workflow: Workflow
nodes: List[WorkflowNode] = Field(default_factory=list)
edges: List[WorkflowEdge] = Field(default_factory=list)
def get_node_by_id(self, node_id: str) -> Optional[WorkflowNode]:
for node in self.nodes:
if node.id == node_id:
return node
return None
def get_start_node(self) -> Optional[WorkflowNode]:
for node in self.nodes:
if node.type == NodeType.START:
return node
return None
def get_outgoing_edges(self, node_id: str) -> List[WorkflowEdge]:
return [edge for edge in self.edges if edge.source_id == node_id]
class NodeExecutionLog(BaseModel):
model_config = ConfigDict(extra="forbid")
node_id: str
node_type: str
status: ExecutionStatus
started_at: datetime
completed_at: Optional[datetime] = None
error: Optional[str] = None
state_snapshot: Dict[str, Any] = Field(default_factory=dict)
class WorkflowRunCreate(BaseModel):
workflow_id: str
inputs: Dict[str, str] = Field(default_factory=dict)
class WorkflowRun(BaseModel):
model_config = ConfigDict(extra="allow")
id: Optional[str] = Field(None, alias="_id")
workflow_id: str
status: ExecutionStatus = ExecutionStatus.PENDING
inputs: Dict[str, str] = Field(default_factory=dict)
outputs: Dict[str, Any] = Field(default_factory=dict)
steps: List[NodeExecutionLog] = Field(default_factory=list)
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
completed_at: Optional[datetime] = None
@field_validator("id", mode="before")
@classmethod
def convert_objectid(cls, v: Any) -> Optional[str]:
if isinstance(v, ObjectId):
return str(v)
return v
def to_mongo_doc(self) -> Dict[str, Any]:
return {
"workflow_id": self.workflow_id,
"status": self.status.value,
"inputs": self.inputs,
"outputs": self.outputs,
"steps": [step.model_dump() for step in self.steps],
"created_at": self.created_at,
"completed_at": self.completed_at,
}

View File

@@ -0,0 +1,276 @@
import logging
from datetime import datetime, timezone
from typing import Any, Dict, Generator, List, Optional, TYPE_CHECKING
from application.agents.workflows.node_agent import WorkflowNodeAgentFactory
from application.agents.workflows.schemas import (
AgentNodeConfig,
ExecutionStatus,
NodeExecutionLog,
NodeType,
WorkflowGraph,
WorkflowNode,
)
if TYPE_CHECKING:
from application.agents.base import BaseAgent
logger = logging.getLogger(__name__)
StateValue = Any
WorkflowState = Dict[str, StateValue]
class WorkflowEngine:
MAX_EXECUTION_STEPS = 50
def __init__(self, graph: WorkflowGraph, agent: "BaseAgent"):
self.graph = graph
self.agent = agent
self.state: WorkflowState = {}
self.execution_log: List[Dict[str, Any]] = []
def execute(
self, initial_inputs: WorkflowState, query: str
) -> Generator[Dict[str, str], None, None]:
self._initialize_state(initial_inputs, query)
start_node = self.graph.get_start_node()
if not start_node:
yield {"type": "error", "error": "No start node found in workflow."}
return
current_node_id: Optional[str] = start_node.id
steps = 0
while current_node_id and steps < self.MAX_EXECUTION_STEPS:
node = self.graph.get_node_by_id(current_node_id)
if not node:
yield {"type": "error", "error": f"Node {current_node_id} not found."}
break
log_entry = self._create_log_entry(node)
yield {
"type": "workflow_step",
"node_id": node.id,
"node_type": node.type.value,
"node_title": node.title,
"status": "running",
}
try:
yield from self._execute_node(node)
log_entry["status"] = ExecutionStatus.COMPLETED.value
log_entry["completed_at"] = datetime.now(timezone.utc)
output_key = f"node_{node.id}_output"
node_output = self.state.get(output_key)
yield {
"type": "workflow_step",
"node_id": node.id,
"node_type": node.type.value,
"node_title": node.title,
"status": "completed",
"state_snapshot": dict(self.state),
"output": node_output,
}
except Exception as e:
logger.error(f"Error executing node {node.id}: {e}", exc_info=True)
log_entry["status"] = ExecutionStatus.FAILED.value
log_entry["error"] = str(e)
log_entry["completed_at"] = datetime.now(timezone.utc)
log_entry["state_snapshot"] = dict(self.state)
self.execution_log.append(log_entry)
yield {
"type": "workflow_step",
"node_id": node.id,
"node_type": node.type.value,
"node_title": node.title,
"status": "failed",
"state_snapshot": dict(self.state),
"error": str(e),
}
yield {"type": "error", "error": str(e)}
break
log_entry["state_snapshot"] = dict(self.state)
self.execution_log.append(log_entry)
if node.type == NodeType.END:
break
current_node_id = self._get_next_node_id(current_node_id)
steps += 1
if steps >= self.MAX_EXECUTION_STEPS:
logger.warning(
f"Workflow reached max steps limit ({self.MAX_EXECUTION_STEPS})"
)
def _initialize_state(self, initial_inputs: WorkflowState, query: str) -> None:
self.state.update(initial_inputs)
self.state["query"] = query
self.state["chat_history"] = str(self.agent.chat_history)
def _create_log_entry(self, node: WorkflowNode) -> Dict[str, Any]:
return {
"node_id": node.id,
"node_type": node.type.value,
"started_at": datetime.now(timezone.utc),
"completed_at": None,
"status": ExecutionStatus.RUNNING.value,
"error": None,
"state_snapshot": {},
}
def _get_next_node_id(self, current_node_id: str) -> Optional[str]:
edges = self.graph.get_outgoing_edges(current_node_id)
if edges:
return edges[0].target_id
return None
def _execute_node(
self, node: WorkflowNode
) -> Generator[Dict[str, str], None, None]:
logger.info(f"Executing node {node.id} ({node.type.value})")
node_handlers = {
NodeType.START: self._execute_start_node,
NodeType.NOTE: self._execute_note_node,
NodeType.AGENT: self._execute_agent_node,
NodeType.STATE: self._execute_state_node,
NodeType.END: self._execute_end_node,
}
handler = node_handlers.get(node.type)
if handler:
yield from handler(node)
def _execute_start_node(
self, node: WorkflowNode
) -> Generator[Dict[str, str], None, None]:
yield from ()
def _execute_note_node(
self, node: WorkflowNode
) -> Generator[Dict[str, str], None, None]:
yield from ()
def _execute_agent_node(
self, node: WorkflowNode
) -> Generator[Dict[str, str], None, None]:
from application.core.model_utils import get_api_key_for_provider
node_config = AgentNodeConfig(**node.config)
if node_config.prompt_template:
formatted_prompt = self._format_template(node_config.prompt_template)
else:
formatted_prompt = self.state.get("query", "")
node_llm_name = node_config.llm_name or self.agent.llm_name
node_api_key = get_api_key_for_provider(node_llm_name) or self.agent.api_key
node_agent = WorkflowNodeAgentFactory.create(
agent_type=node_config.agent_type,
endpoint=self.agent.endpoint,
llm_name=node_llm_name,
model_id=node_config.model_id or self.agent.model_id,
api_key=node_api_key,
tool_ids=node_config.tools,
prompt=node_config.system_prompt,
chat_history=self.agent.chat_history,
decoded_token=self.agent.decoded_token,
json_schema=node_config.json_schema,
)
full_response = ""
first_chunk = True
for event in node_agent.gen(formatted_prompt):
if "answer" in event:
full_response += event["answer"]
if node_config.stream_to_user:
if first_chunk and hasattr(self, "_has_streamed"):
yield {"answer": "\n\n"}
first_chunk = False
yield event
if node_config.stream_to_user:
self._has_streamed = True
output_key = node_config.output_variable or f"node_{node.id}_output"
self.state[output_key] = full_response
def _execute_state_node(
self, node: WorkflowNode
) -> Generator[Dict[str, str], None, None]:
config = node.config
operations = config.get("operations", [])
if operations:
for op in operations:
key = op.get("key")
operation = op.get("operation", "set")
value = op.get("value")
if not key:
continue
if operation == "set":
formatted_value = (
self._format_template(str(value))
if isinstance(value, str)
else value
)
self.state[key] = formatted_value
elif operation == "increment":
current = self.state.get(key, 0)
try:
self.state[key] = int(current) + int(value or 1)
except (ValueError, TypeError):
self.state[key] = 1
elif operation == "append":
if key not in self.state:
self.state[key] = []
if isinstance(self.state[key], list):
self.state[key].append(value)
else:
updates = config.get("updates", {})
if not updates:
var_name = config.get("variable")
var_value = config.get("value")
if var_name and isinstance(var_name, str):
updates = {var_name: var_value or ""}
if isinstance(updates, dict):
for key, value in updates.items():
if isinstance(value, str):
self.state[key] = self._format_template(value)
else:
self.state[key] = value
yield from ()
def _execute_end_node(
self, node: WorkflowNode
) -> Generator[Dict[str, str], None, None]:
config = node.config
output_template = str(config.get("output_template", ""))
if output_template:
formatted_output = self._format_template(output_template)
yield {"answer": formatted_output}
def _format_template(self, template: str) -> str:
formatted = template
for key, value in self.state.items():
placeholder = f"{{{{{key}}}}}"
if placeholder in formatted and value is not None:
formatted = formatted.replace(placeholder, str(value))
return formatted
def get_execution_summary(self) -> List[NodeExecutionLog]:
return [
NodeExecutionLog(
node_id=log["node_id"],
node_type=log["node_type"],
status=ExecutionStatus(log["status"]),
started_at=log["started_at"],
completed_at=log.get("completed_at"),
error=log.get("error"),
state_snapshot=log.get("state_snapshot", {}),
)
for log in self.execution_log
]

View File

@@ -150,9 +150,7 @@ class StreamProcessor:
)
if not result.success:
logger.error(
f"Compression failed: {result.error}, using full history"
)
logger.error(f"Compression failed: {result.error}, using full history")
self.history = [
{"prompt": query["prompt"], "response": query["response"]}
for query in conversation.get("queries", [])
@@ -225,7 +223,11 @@ class StreamProcessor:
raise ValueError(
f"Invalid model_id '{requested_model}'. "
f"Available models: {', '.join(available_models[:5])}"
+ (f" and {len(available_models) - 5} more" if len(available_models) > 5 else "")
+ (
f" and {len(available_models) - 5} more"
if len(available_models) > 5
else ""
)
)
self.model_id = requested_model
else:
@@ -370,6 +372,9 @@ class StreamProcessor:
self.decoded_token = {"sub": data_key.get("user")}
if data_key.get("source"):
self.source = {"active_docs": data_key["source"]}
if data_key.get("workflow"):
self.agent_config["workflow"] = data_key["workflow"]
self.agent_config["workflow_owner"] = data_key.get("user")
if data_key.get("retriever"):
self.retriever_config["retriever_name"] = data_key["retriever"]
if data_key.get("chunks") is not None:
@@ -398,6 +403,9 @@ class StreamProcessor:
)
if data_key.get("source"):
self.source = {"active_docs": data_key["source"]}
if data_key.get("workflow"):
self.agent_config["workflow"] = data_key["workflow"]
self.agent_config["workflow_owner"] = data_key.get("user")
if data_key.get("retriever"):
self.retriever_config["retriever_name"] = data_key["retriever"]
if data_key.get("chunks") is not None:
@@ -409,10 +417,19 @@ class StreamProcessor:
)
self.retriever_config["chunks"] = 2
else:
agent_type = settings.AGENT_NAME
if self.data.get("workflow") and isinstance(
self.data.get("workflow"), dict
):
agent_type = "workflow"
self.agent_config["workflow"] = self.data["workflow"]
if isinstance(self.decoded_token, dict):
self.agent_config["workflow_owner"] = self.decoded_token.get("sub")
self.agent_config.update(
{
"prompt_id": self.data.get("prompt_id", "default"),
"agent_type": settings.AGENT_NAME,
"agent_type": agent_type,
"user_api_key": None,
"json_schema": None,
"default_model_id": "",
@@ -420,9 +437,7 @@ class StreamProcessor:
)
def _configure_retriever(self):
doc_token_limit = calculate_doc_token_budget(
model_id=self.model_id
)
doc_token_limit = calculate_doc_token_budget(model_id=self.model_id)
self.retriever_config = {
"retriever_name": self.data.get("retriever", "classic"),
@@ -731,21 +746,36 @@ class StreamProcessor:
)
system_api_key = get_api_key_for_provider(provider or settings.LLM_PROVIDER)
agent = AgentCreator.create_agent(
self.agent_config["agent_type"],
endpoint="stream",
llm_name=provider or settings.LLM_PROVIDER,
model_id=self.model_id,
api_key=system_api_key,
user_api_key=self.agent_config["user_api_key"],
prompt=rendered_prompt,
chat_history=self.history,
retrieved_docs=self.retrieved_docs,
decoded_token=self.decoded_token,
attachments=self.attachments,
json_schema=self.agent_config.get("json_schema"),
compressed_summary=self.compressed_summary,
)
agent_type = self.agent_config["agent_type"]
# Base agent kwargs
agent_kwargs = {
"endpoint": "stream",
"llm_name": provider or settings.LLM_PROVIDER,
"model_id": self.model_id,
"api_key": system_api_key,
"user_api_key": self.agent_config["user_api_key"],
"prompt": rendered_prompt,
"chat_history": self.history,
"retrieved_docs": self.retrieved_docs,
"decoded_token": self.decoded_token,
"attachments": self.attachments,
"json_schema": self.agent_config.get("json_schema"),
"compressed_summary": self.compressed_summary,
}
# Workflow-specific kwargs for workflow agents
if agent_type == "workflow":
workflow_config = self.agent_config.get("workflow")
if isinstance(workflow_config, str):
agent_kwargs["workflow_id"] = workflow_config
elif isinstance(workflow_config, dict):
agent_kwargs["workflow"] = workflow_config
workflow_owner = self.agent_config.get("workflow_owner")
if workflow_owner:
agent_kwargs["workflow_owner"] = workflow_owner
agent = AgentCreator.create_agent(agent_type, **agent_kwargs)
agent.conversation_id = self.conversation_id
agent.initial_user_id = self.initial_user_id

View File

@@ -19,6 +19,9 @@ from application.api.user.base import (
resolve_tool_details,
storage,
users_collection,
workflow_edges_collection,
workflow_nodes_collection,
workflows_collection,
)
from application.core.settings import settings
from application.utils import (
@@ -31,6 +34,189 @@ from application.utils import (
agents_ns = Namespace("agents", description="Agent management operations", path="/api")
AGENT_TYPE_SCHEMAS = {
"classic": {
"required_published": [
"name",
"description",
"chunks",
"retriever",
"prompt_id",
],
"required_draft": ["name"],
"validate_published": ["name", "description", "prompt_id"],
"validate_draft": [],
"require_source": True,
"fields": [
"user",
"name",
"description",
"agent_type",
"status",
"key",
"image",
"source",
"sources",
"chunks",
"retriever",
"prompt_id",
"tools",
"json_schema",
"models",
"default_model_id",
"folder_id",
"limited_token_mode",
"token_limit",
"limited_request_mode",
"request_limit",
"createdAt",
"updatedAt",
"lastUsedAt",
],
},
"workflow": {
"required_published": ["name", "workflow"],
"required_draft": ["name"],
"validate_published": ["name", "workflow"],
"validate_draft": [],
"fields": [
"user",
"name",
"description",
"agent_type",
"status",
"key",
"workflow",
"folder_id",
"limited_token_mode",
"token_limit",
"limited_request_mode",
"request_limit",
"createdAt",
"updatedAt",
"lastUsedAt",
],
},
}
AGENT_TYPE_SCHEMAS["react"] = AGENT_TYPE_SCHEMAS["classic"]
AGENT_TYPE_SCHEMAS["openai"] = AGENT_TYPE_SCHEMAS["classic"]
def normalize_workflow_reference(workflow_value):
"""Normalize workflow references from form/json payloads."""
if workflow_value is None:
return None
if isinstance(workflow_value, dict):
return (
workflow_value.get("id")
or workflow_value.get("_id")
or workflow_value.get("workflow_id")
)
if isinstance(workflow_value, str):
value = workflow_value.strip()
if not value:
return ""
try:
parsed = json.loads(value)
if isinstance(parsed, str):
return parsed.strip()
if isinstance(parsed, dict):
return (
parsed.get("id") or parsed.get("_id") or parsed.get("workflow_id")
)
except json.JSONDecodeError:
pass
return value
return str(workflow_value)
def validate_workflow_access(workflow_value, user, required=False):
"""Validate workflow reference and ensure ownership."""
workflow_id = normalize_workflow_reference(workflow_value)
if not workflow_id:
if required:
return None, make_response(
jsonify({"success": False, "message": "Workflow is required"}), 400
)
return None, None
if not ObjectId.is_valid(workflow_id):
return None, make_response(
jsonify({"success": False, "message": "Invalid workflow ID format"}), 400
)
workflow = workflows_collection.find_one({"_id": ObjectId(workflow_id), "user": user})
if not workflow:
return None, make_response(
jsonify({"success": False, "message": "Workflow not found"}), 404
)
return workflow_id, None
def build_agent_document(
data, user, key, agent_type, image_url=None, source_field=None, sources_list=None
):
"""Build agent document based on agent type schema."""
if not agent_type or agent_type not in AGENT_TYPE_SCHEMAS:
agent_type = "classic"
schema = AGENT_TYPE_SCHEMAS.get(agent_type, AGENT_TYPE_SCHEMAS["classic"])
allowed_fields = set(schema["fields"])
now = datetime.datetime.now(datetime.timezone.utc)
base_doc = {
"user": user,
"name": data.get("name"),
"description": data.get("description", ""),
"agent_type": agent_type,
"status": data.get("status"),
"key": key,
"createdAt": now,
"updatedAt": now,
"lastUsedAt": None,
}
if agent_type == "workflow":
base_doc["workflow"] = data.get("workflow")
base_doc["folder_id"] = data.get("folder_id")
else:
base_doc.update(
{
"image": image_url or "",
"source": source_field or "",
"sources": sources_list or [],
"chunks": data.get("chunks", ""),
"retriever": data.get("retriever", ""),
"prompt_id": data.get("prompt_id", ""),
"tools": data.get("tools", []),
"json_schema": data.get("json_schema"),
"models": data.get("models", []),
"default_model_id": data.get("default_model_id", ""),
"folder_id": data.get("folder_id"),
}
)
if "limited_token_mode" in allowed_fields:
base_doc["limited_token_mode"] = (
data.get("limited_token_mode") == "True"
if isinstance(data.get("limited_token_mode"), str)
else bool(data.get("limited_token_mode", False))
)
if "token_limit" in allowed_fields:
base_doc["token_limit"] = int(
data.get("token_limit", settings.DEFAULT_AGENT_LIMITS["token_limit"])
)
if "limited_request_mode" in allowed_fields:
base_doc["limited_request_mode"] = (
data.get("limited_request_mode") == "True"
if isinstance(data.get("limited_request_mode"), str)
else bool(data.get("limited_request_mode", False))
)
if "request_limit" in allowed_fields:
base_doc["request_limit"] = int(
data.get("request_limit", settings.DEFAULT_AGENT_LIMITS["request_limit"])
)
return {k: v for k, v in base_doc.items() if k in allowed_fields}
@agents_ns.route("/get_agent")
class GetAgent(Resource):
@api.doc(params={"id": "Agent ID"}, description="Get agent by ID")
@@ -68,7 +254,7 @@ class GetAgent(Resource):
if (isinstance(source_ref, DBRef) and db.dereference(source_ref))
or source_ref == "default"
],
"chunks": agent["chunks"],
"chunks": agent.get("chunks", "2"),
"retriever": agent.get("retriever", ""),
"prompt_id": agent.get("prompt_id", ""),
"tools": agent.get("tools", []),
@@ -99,6 +285,7 @@ class GetAgent(Resource):
"models": agent.get("models", []),
"default_model_id": agent.get("default_model_id", ""),
"folder_id": agent.get("folder_id"),
"workflow": agent.get("workflow"),
}
return make_response(jsonify(data), 200)
except Exception as e:
@@ -148,7 +335,7 @@ class GetAgents(Resource):
isinstance(source_ref, DBRef) and db.dereference(source_ref)
)
],
"chunks": agent["chunks"],
"chunks": agent.get("chunks", "2"),
"retriever": agent.get("retriever", ""),
"prompt_id": agent.get("prompt_id", ""),
"tools": agent.get("tools", []),
@@ -179,9 +366,12 @@ class GetAgents(Resource):
"models": agent.get("models", []),
"default_model_id": agent.get("default_model_id", ""),
"folder_id": agent.get("folder_id"),
"workflow": agent.get("workflow"),
}
for agent in agents
if "source" in agent or "retriever" in agent
if "source" in agent
or "retriever" in agent
or agent.get("agent_type") == "workflow"
]
except Exception as err:
current_app.logger.error(f"Error retrieving agents: {err}", exc_info=True)
@@ -209,16 +399,22 @@ class CreateAgent(Resource):
required=False,
description="List of source identifiers for multiple sources",
),
"chunks": fields.Integer(required=True, description="Chunks count"),
"retriever": fields.String(required=True, description="Retriever ID"),
"prompt_id": fields.String(required=True, description="Prompt ID"),
"chunks": fields.Integer(required=False, description="Chunks count"),
"retriever": fields.String(required=False, description="Retriever ID"),
"prompt_id": fields.String(required=False, description="Prompt ID"),
"tools": fields.List(
fields.String, required=False, description="List of tool identifiers"
),
"agent_type": fields.String(required=True, description="Type of the agent"),
"agent_type": fields.String(
required=False,
description="Type of the agent (classic, react, workflow). Defaults to 'classic' for backwards compatibility.",
),
"status": fields.String(
required=True, description="Status of the agent (draft or published)"
),
"workflow": fields.String(
required=False, description="Workflow ID for workflow-type agents"
),
"json_schema": fields.Raw(
required=False,
description="JSON schema for enforcing structured output format",
@@ -330,18 +526,34 @@ class CreateAgent(Resource):
),
400,
)
if data.get("status") == "published":
required_fields = [
"name",
"description",
"chunks",
"retriever",
"prompt_id",
"agent_type",
]
# Require either source or sources (but not both)
agent_type = data.get("agent_type", "")
# Default to classic schema for empty or unknown agent types
if not data.get("source") and not data.get("sources"):
if not agent_type or agent_type not in AGENT_TYPE_SCHEMAS:
schema = AGENT_TYPE_SCHEMAS["classic"]
# Set agent_type to classic if it was empty
if not agent_type:
agent_type = "classic"
else:
schema = AGENT_TYPE_SCHEMAS[agent_type]
is_published = data.get("status") == "published"
if agent_type == "workflow":
workflow_id, workflow_error = validate_workflow_access(
data.get("workflow"), user, required=is_published
)
if workflow_error:
return workflow_error
data["workflow"] = workflow_id
if data.get("status") == "published":
required_fields = schema["required_published"]
validate_fields = schema["validate_published"]
if (
schema.get("require_source")
and not data.get("source")
and not data.get("sources")
):
return make_response(
jsonify(
{
@@ -351,10 +563,9 @@ class CreateAgent(Resource):
),
400,
)
validate_fields = ["name", "description", "prompt_id", "agent_type"]
else:
required_fields = ["name"]
validate_fields = []
required_fields = schema["required_draft"]
validate_fields = schema["validate_draft"]
missing_fields = check_required_fields(data, required_fields)
invalid_fields = validate_required_fields(data, validate_fields)
if missing_fields:
@@ -366,7 +577,6 @@ class CreateAgent(Resource):
return make_response(
jsonify({"success": False, "message": "Image upload failed"}), 400
)
folder_id = data.get("folder_id")
if folder_id:
if not ObjectId.is_valid(folder_id):
@@ -381,76 +591,36 @@ class CreateAgent(Resource):
return make_response(
jsonify({"success": False, "message": "Folder not found"}), 404
)
try:
key = str(uuid.uuid4()) if data.get("status") == "published" else ""
sources_list = []
source_field = ""
if data.get("sources") and len(data.get("sources", [])) > 0:
for source_id in data.get("sources", []):
if source_id == "default":
sources_list.append("default")
elif ObjectId.is_valid(source_id):
sources_list.append(DBRef("sources", ObjectId(source_id)))
source_field = ""
else:
source_value = data.get("source", "")
if source_value == "default":
source_field = "default"
elif ObjectId.is_valid(source_value):
source_field = DBRef("sources", ObjectId(source_value))
else:
source_field = ""
new_agent = {
"user": user,
"name": data.get("name"),
"description": data.get("description", ""),
"image": image_url,
"source": source_field,
"sources": sources_list,
"chunks": data.get("chunks", ""),
"retriever": data.get("retriever", ""),
"prompt_id": data.get("prompt_id", ""),
"tools": data.get("tools", []),
"agent_type": data.get("agent_type", ""),
"status": data.get("status"),
"json_schema": data.get("json_schema"),
"limited_token_mode": (
data.get("limited_token_mode") == "True"
if isinstance(data.get("limited_token_mode"), str)
else bool(data.get("limited_token_mode", False))
),
"token_limit": int(
data.get(
"token_limit", settings.DEFAULT_AGENT_LIMITS["token_limit"]
)
),
"limited_request_mode": (
data.get("limited_request_mode") == "True"
if isinstance(data.get("limited_request_mode"), str)
else bool(data.get("limited_request_mode", False))
),
"request_limit": int(
data.get(
"request_limit", settings.DEFAULT_AGENT_LIMITS["request_limit"]
)
),
"createdAt": datetime.datetime.now(datetime.timezone.utc),
"updatedAt": datetime.datetime.now(datetime.timezone.utc),
"lastUsedAt": None,
"key": key,
"models": data.get("models", []),
"default_model_id": data.get("default_model_id", ""),
"folder_id": data.get("folder_id"),
}
if new_agent["chunks"] == "":
new_agent["chunks"] = "2"
if (
new_agent["source"] == ""
and new_agent["retriever"] == ""
and not new_agent["sources"]
):
new_agent["retriever"] = "classic"
new_agent = build_agent_document(
data, user, key, agent_type, image_url, source_field, sources_list
)
if agent_type != "workflow":
if new_agent.get("chunks") == "":
new_agent["chunks"] = "2"
if (
new_agent.get("source") == ""
and new_agent.get("retriever") == ""
and not new_agent.get("sources")
):
new_agent["retriever"] = "classic"
resp = agents_collection.insert_one(new_agent)
new_id = str(resp.inserted_id)
except Exception as err:
@@ -479,16 +649,22 @@ class UpdateAgent(Resource):
required=False,
description="List of source identifiers for multiple sources",
),
"chunks": fields.Integer(required=True, description="Chunks count"),
"retriever": fields.String(required=True, description="Retriever ID"),
"prompt_id": fields.String(required=True, description="Prompt ID"),
"chunks": fields.Integer(required=False, description="Chunks count"),
"retriever": fields.String(required=False, description="Retriever ID"),
"prompt_id": fields.String(required=False, description="Prompt ID"),
"tools": fields.List(
fields.String, required=False, description="List of tool identifiers"
),
"agent_type": fields.String(required=True, description="Type of the agent"),
"agent_type": fields.String(
required=False,
description="Type of the agent (classic, react, workflow). Defaults to 'classic' for backwards compatibility.",
),
"status": fields.String(
required=True, description="Status of the agent (draft or published)"
),
"workflow": fields.String(
required=False, description="Workflow ID for workflow-type agents"
),
"json_schema": fields.Raw(
required=False,
description="JSON schema for enforcing structured output format",
@@ -612,6 +788,7 @@ class UpdateAgent(Resource):
"models",
"default_model_id",
"folder_id",
"workflow",
]
for field in allowed_fields:
@@ -768,10 +945,10 @@ class UpdateAgent(Resource):
)
elif field == "token_limit":
token_limit = data.get("token_limit")
# Convert to int and store
update_fields[field] = int(token_limit) if token_limit else 0
# Validate consistency with mode
if update_fields[field] > 0 and not data.get("limited_token_mode"):
return make_response(
jsonify(
@@ -814,14 +991,24 @@ class UpdateAgent(Resource):
)
if not folder:
return make_response(
jsonify(
{"success": False, "message": "Folder not found"}
),
jsonify({"success": False, "message": "Folder not found"}),
404,
)
update_fields[field] = folder_id
else:
update_fields[field] = None
elif field == "workflow":
workflow_required = (
data.get("status", existing_agent.get("status")) == "published"
and data.get("agent_type", existing_agent.get("agent_type"))
== "workflow"
)
workflow_id, workflow_error = validate_workflow_access(
data.get("workflow"), user, required=workflow_required
)
if workflow_error:
return workflow_error
update_fields[field] = workflow_id
else:
value = data[field]
if field in ["name", "description", "prompt_id", "agent_type"]:
@@ -850,46 +1037,82 @@ class UpdateAgent(Resource):
)
newly_generated_key = None
final_status = update_fields.get("status", existing_agent.get("status"))
agent_type = update_fields.get("agent_type", existing_agent.get("agent_type"))
if final_status == "published":
required_published_fields = {
"name": "Agent name",
"description": "Agent description",
"chunks": "Chunks count",
"prompt_id": "Prompt",
"agent_type": "Agent type",
}
if agent_type == "workflow":
required_published_fields = {
"name": "Agent name",
}
missing_published_fields = []
for req_field, field_label in required_published_fields.items():
final_value = update_fields.get(
req_field, existing_agent.get(req_field)
)
if not final_value:
missing_published_fields.append(field_label)
missing_published_fields = []
for req_field, field_label in required_published_fields.items():
final_value = update_fields.get(
req_field, existing_agent.get(req_field)
workflow_id = update_fields.get("workflow", existing_agent.get("workflow"))
if not workflow_id:
missing_published_fields.append("Workflow")
elif not ObjectId.is_valid(workflow_id):
missing_published_fields.append("Valid workflow")
else:
workflow = workflows_collection.find_one(
{"_id": ObjectId(workflow_id), "user": user}
)
if not workflow:
missing_published_fields.append("Workflow access")
if missing_published_fields:
return make_response(
jsonify(
{
"success": False,
"message": f"Cannot publish workflow agent. Missing required fields: {', '.join(missing_published_fields)}",
}
),
400,
)
else:
required_published_fields = {
"name": "Agent name",
"description": "Agent description",
"chunks": "Chunks count",
"prompt_id": "Prompt",
"agent_type": "Agent type",
}
missing_published_fields = []
for req_field, field_label in required_published_fields.items():
final_value = update_fields.get(
req_field, existing_agent.get(req_field)
)
if not final_value:
missing_published_fields.append(field_label)
source_val = update_fields.get("source", existing_agent.get("source"))
sources_val = update_fields.get(
"sources", existing_agent.get("sources", [])
)
if not final_value:
missing_published_fields.append(field_label)
source_val = update_fields.get("source", existing_agent.get("source"))
sources_val = update_fields.get(
"sources", existing_agent.get("sources", [])
)
has_valid_source = (
isinstance(source_val, DBRef)
or source_val == "default"
or (isinstance(sources_val, list) and len(sources_val) > 0)
)
if not has_valid_source:
missing_published_fields.append("Source")
if missing_published_fields:
return make_response(
jsonify(
{
"success": False,
"message": f"Cannot publish agent. Missing or invalid required fields: {', '.join(missing_published_fields)}",
}
),
400,
has_valid_source = (
isinstance(source_val, DBRef)
or source_val == "default"
or (isinstance(sources_val, list) and len(sources_val) > 0)
)
if not has_valid_source:
missing_published_fields.append("Source")
if missing_published_fields:
return make_response(
jsonify(
{
"success": False,
"message": f"Cannot publish agent. Missing or invalid required fields: {', '.join(missing_published_fields)}",
}
),
400,
)
if not existing_agent.get("key"):
newly_generated_key = str(uuid.uuid4())
update_fields["key"] = newly_generated_key
@@ -961,6 +1184,29 @@ class DeleteAgent(Resource):
jsonify({"success": False, "message": "Agent not found"}), 404
)
deleted_id = str(deleted_agent["_id"])
if deleted_agent.get("agent_type") == "workflow" and deleted_agent.get(
"workflow"
):
workflow_id = normalize_workflow_reference(deleted_agent.get("workflow"))
if workflow_id and ObjectId.is_valid(workflow_id):
workflow_oid = ObjectId(workflow_id)
owned_workflow = workflows_collection.find_one(
{"_id": workflow_oid, "user": user}, {"_id": 1}
)
if owned_workflow:
workflow_nodes_collection.delete_many({"workflow_id": workflow_id})
workflow_edges_collection.delete_many({"workflow_id": workflow_id})
workflows_collection.delete_one({"_id": workflow_oid, "user": user})
else:
current_app.logger.warning(
f"Skipping workflow cleanup for non-owned workflow {workflow_id}"
)
elif workflow_id:
current_app.logger.warning(
f"Skipping workflow cleanup for invalid workflow id {workflow_id}"
)
except Exception as err:
current_app.logger.error(f"Error deleting agent: {err}", exc_info=True)
return make_response(jsonify({"success": False}), 400)
@@ -1069,19 +1315,16 @@ class AdoptAgent(Resource):
def post(self):
if not (decoded_token := request.decoded_token):
return make_response(jsonify({"success": False}), 401)
if not (agent_id := request.args.get("id")):
return make_response(
jsonify({"success": False, "message": "ID required"}), 400
)
try:
agent = agents_collection.find_one(
{"_id": ObjectId(agent_id), "user": "system"}
)
if not agent:
return make_response(jsonify({"status": "Not found"}), 404)
new_agent = agent.copy()
new_agent.pop("_id", None)
new_agent["user"] = decoded_token["sub"]

View File

@@ -38,6 +38,10 @@ users_collection = db["users"]
user_logs_collection = db["user_logs"]
user_tools_collection = db["user_tools"]
attachments_collection = db["attachments"]
workflow_runs_collection = db["workflow_runs"]
workflows_collection = db["workflows"]
workflow_nodes_collection = db["workflow_nodes"]
workflow_edges_collection = db["workflow_edges"]
try:
@@ -47,6 +51,25 @@ try:
background=True,
)
users_collection.create_index("user_id", unique=True)
workflows_collection.create_index(
[("user", 1)], name="workflow_user_index", background=True
)
workflow_nodes_collection.create_index(
[("workflow_id", 1)], name="node_workflow_index", background=True
)
workflow_nodes_collection.create_index(
[("workflow_id", 1), ("graph_version", 1)],
name="node_workflow_graph_version_index",
background=True,
)
workflow_edges_collection.create_index(
[("workflow_id", 1)], name="edge_workflow_index", background=True
)
workflow_edges_collection.create_index(
[("workflow_id", 1), ("graph_version", 1)],
name="edge_workflow_graph_version_index",
background=True,
)
except Exception as e:
print("Error creating indexes:", e)
current_dir = os.path.dirname(

View File

@@ -6,7 +6,6 @@ from flask import Blueprint
from application.api import api
from .agents import agents_ns, agents_sharing_ns, agents_webhooks_ns, agents_folders_ns
from .analytics import analytics_ns
from .attachments import attachments_ns
from .conversations import conversations_ns
@@ -15,6 +14,7 @@ from .prompts import prompts_ns
from .sharing import sharing_ns
from .sources import sources_chunks_ns, sources_ns, sources_upload_ns
from .tools import tools_mcp_ns, tools_ns
from .workflows import workflows_ns
user = Blueprint("user", __name__)
@@ -51,3 +51,6 @@ api.add_namespace(sources_upload_ns)
# Tools (main, MCP)
api.add_namespace(tools_ns)
api.add_namespace(tools_mcp_ns)
# Workflows
api.add_namespace(workflows_ns)

View File

@@ -0,0 +1,378 @@
"""Centralized utilities for API routes."""
from functools import wraps
from typing import Any, Callable, Dict, List, Optional, Tuple
from bson.errors import InvalidId
from bson.objectid import ObjectId
from flask import jsonify, make_response, request, Response
from pymongo.collection import Collection
def get_user_id() -> Optional[str]:
"""
Extract user ID from decoded JWT token.
Returns:
User ID string or None if not authenticated
"""
decoded_token = getattr(request, "decoded_token", None)
return decoded_token.get("sub") if decoded_token else None
def require_auth(func: Callable) -> Callable:
"""
Decorator to require authentication for route handlers.
Usage:
@require_auth
def get(self):
user_id = get_user_id()
...
"""
@wraps(func)
def wrapper(*args, **kwargs):
user_id = get_user_id()
if not user_id:
return error_response("Unauthorized", 401)
return func(*args, **kwargs)
return wrapper
def success_response(
data: Optional[Dict[str, Any]] = None, status: int = 200
) -> Response:
"""
Create a standardized success response.
Args:
data: Optional data dictionary to include in response
status: HTTP status code (default: 200)
Returns:
Flask Response object
Example:
return success_response({"users": [...], "total": 10})
"""
response = {"success": True}
if data:
response.update(data)
return make_response(jsonify(response), status)
def error_response(message: str, status: int = 400, **kwargs) -> Response:
"""
Create a standardized error response.
Args:
message: Error message string
status: HTTP status code (default: 400)
**kwargs: Additional fields to include in response
Returns:
Flask Response object
Example:
return error_response("Resource not found", 404)
return error_response("Invalid input", 400, errors=["field1", "field2"])
"""
response = {"success": False, "message": message}
response.update(kwargs)
return make_response(jsonify(response), status)
def validate_object_id(
id_string: str, resource_name: str = "Resource"
) -> Tuple[Optional[ObjectId], Optional[Response]]:
"""
Validate and convert string to ObjectId.
Args:
id_string: String to convert
resource_name: Name of resource for error message
Returns:
Tuple of (ObjectId or None, error_response or None)
Example:
obj_id, error = validate_object_id(workflow_id, "Workflow")
if error:
return error
"""
try:
return ObjectId(id_string), None
except (InvalidId, TypeError):
return None, error_response(f"Invalid {resource_name} ID format")
def validate_pagination(
default_limit: int = 20, max_limit: int = 100
) -> Tuple[int, int, Optional[Response]]:
"""
Extract and validate pagination parameters from request.
Args:
default_limit: Default items per page
max_limit: Maximum allowed items per page
Returns:
Tuple of (limit, skip, error_response or None)
Example:
limit, skip, error = validate_pagination()
if error:
return error
"""
try:
limit = min(int(request.args.get("limit", default_limit)), max_limit)
skip = int(request.args.get("skip", 0))
if limit < 1 or skip < 0:
return 0, 0, error_response("Invalid pagination parameters")
return limit, skip, None
except ValueError:
return 0, 0, error_response("Invalid pagination parameters")
def check_resource_ownership(
collection: Collection,
resource_id: ObjectId,
user_id: str,
resource_name: str = "Resource",
) -> Tuple[Optional[Dict], Optional[Response]]:
"""
Check if resource exists and belongs to user.
Args:
collection: MongoDB collection
resource_id: Resource ObjectId
user_id: User ID string
resource_name: Name of resource for error messages
Returns:
Tuple of (resource_dict or None, error_response or None)
Example:
workflow, error = check_resource_ownership(
workflows_collection,
workflow_id,
user_id,
"Workflow"
)
if error:
return error
"""
resource = collection.find_one({"_id": resource_id, "user": user_id})
if not resource:
return None, error_response(f"{resource_name} not found", 404)
return resource, None
def serialize_object_id(
obj: Dict[str, Any], id_field: str = "_id", new_field: str = "id"
) -> Dict[str, Any]:
"""
Convert ObjectId to string in a dictionary.
Args:
obj: Dictionary containing ObjectId
id_field: Field name containing ObjectId
new_field: New field name for string ID
Returns:
Modified dictionary
Example:
user = serialize_object_id(user_doc)
# user["id"] = "507f1f77bcf86cd799439011"
"""
if id_field in obj:
obj[new_field] = str(obj[id_field])
if id_field != new_field:
obj.pop(id_field, None)
return obj
def serialize_list(items: List[Dict], serializer: Callable[[Dict], Dict]) -> List[Dict]:
"""
Apply serializer function to list of items.
Args:
items: List of dictionaries
serializer: Function to apply to each item
Returns:
List of serialized items
Example:
workflows = serialize_list(workflow_docs, serialize_workflow)
"""
return [serializer(item) for item in items]
def paginated_response(
collection: Collection,
query: Dict[str, Any],
serializer: Callable[[Dict], Dict],
limit: int,
skip: int,
sort_field: str = "created_at",
sort_order: int = -1,
response_key: str = "items",
) -> Response:
"""
Create paginated response for collection query.
Args:
collection: MongoDB collection
query: Query dictionary
serializer: Function to serialize each item
limit: Items per page
skip: Number of items to skip
sort_field: Field to sort by
sort_order: Sort order (1=asc, -1=desc)
response_key: Key name for items in response
Returns:
Flask Response with paginated data
Example:
return paginated_response(
workflows_collection,
{"user": user_id},
serialize_workflow,
limit, skip,
response_key="workflows"
)
"""
items = list(
collection.find(query).sort(sort_field, sort_order).skip(skip).limit(limit)
)
total = collection.count_documents(query)
return success_response(
{
response_key: serialize_list(items, serializer),
"total": total,
"limit": limit,
"skip": skip,
}
)
def require_fields(required: List[str]) -> Callable:
"""
Decorator to validate required fields in request JSON.
Args:
required: List of required field names
Returns:
Decorator function
Example:
@require_fields(["name", "description"])
def post(self):
data = request.get_json()
...
"""
def decorator(func: Callable) -> Callable:
@wraps(func)
def wrapper(*args, **kwargs):
data = request.get_json()
if not data:
return error_response("Request body required")
missing = [field for field in required if not data.get(field)]
if missing:
return error_response(f"Missing required fields: {', '.join(missing)}")
return func(*args, **kwargs)
return wrapper
return decorator
def safe_db_operation(
operation: Callable, error_message: str = "Database operation failed"
) -> Tuple[Any, Optional[Response]]:
"""
Safely execute database operation with error handling.
Args:
operation: Function to execute
error_message: Error message if operation fails
Returns:
Tuple of (result or None, error_response or None)
Example:
result, error = safe_db_operation(
lambda: collection.insert_one(doc),
"Failed to create resource"
)
if error:
return error
"""
try:
result = operation()
return result, None
except Exception as e:
return None, error_response(f"{error_message}: {str(e)}")
def validate_enum(
value: Any, allowed: List[Any], field_name: str
) -> Optional[Response]:
"""
Validate that value is in allowed list.
Args:
value: Value to validate
allowed: List of allowed values
field_name: Field name for error message
Returns:
error_response if invalid, None if valid
Example:
error = validate_enum(status, ["draft", "published"], "status")
if error:
return error
"""
if value not in allowed:
allowed_str = ", ".join(f"'{v}'" for v in allowed)
return error_response(f"Invalid {field_name}. Must be one of: {allowed_str}")
return None
def extract_sort_params(
default_field: str = "created_at",
default_order: str = "desc",
allowed_fields: Optional[List[str]] = None,
) -> Tuple[str, int]:
"""
Extract and validate sort parameters from request.
Args:
default_field: Default sort field
default_order: Default sort order ("asc" or "desc")
allowed_fields: List of allowed sort fields (None = no validation)
Returns:
Tuple of (sort_field, sort_order)
Example:
sort_field, sort_order = extract_sort_params(
allowed_fields=["name", "date", "status"]
)
"""
sort_field = request.args.get("sort", default_field)
sort_order_str = request.args.get("order", default_order).lower()
if allowed_fields and sort_field not in allowed_fields:
sort_field = default_field
sort_order = -1 if sort_order_str == "desc" else 1
return sort_field, sort_order

View File

@@ -0,0 +1,3 @@
from .routes import workflows_ns
__all__ = ["workflows_ns"]

View File

@@ -0,0 +1,353 @@
"""Workflow management routes."""
from datetime import datetime, timezone
from typing import Dict, List
from flask import current_app, request
from flask_restx import Namespace, Resource
from application.api.user.base import (
workflow_edges_collection,
workflow_nodes_collection,
workflows_collection,
)
from application.api.user.utils import (
check_resource_ownership,
error_response,
get_user_id,
require_auth,
require_fields,
safe_db_operation,
success_response,
validate_object_id,
)
workflows_ns = Namespace("workflows", path="/api")
def serialize_workflow(w: Dict) -> Dict:
"""Serialize workflow document to API response format."""
return {
"id": str(w["_id"]),
"name": w.get("name"),
"description": w.get("description"),
"created_at": w["created_at"].isoformat() if w.get("created_at") else None,
"updated_at": w["updated_at"].isoformat() if w.get("updated_at") else None,
}
def serialize_node(n: Dict) -> Dict:
"""Serialize workflow node document to API response format."""
return {
"id": n["id"],
"type": n["type"],
"title": n.get("title"),
"description": n.get("description"),
"position": n.get("position"),
"data": n.get("config", {}),
}
def serialize_edge(e: Dict) -> Dict:
"""Serialize workflow edge document to API response format."""
return {
"id": e["id"],
"source": e.get("source_id"),
"target": e.get("target_id"),
"sourceHandle": e.get("source_handle"),
"targetHandle": e.get("target_handle"),
}
def get_workflow_graph_version(workflow: Dict) -> int:
"""Get current graph version with legacy fallback."""
raw_version = workflow.get("current_graph_version", 1)
try:
version = int(raw_version)
return version if version > 0 else 1
except (ValueError, TypeError):
return 1
def fetch_graph_documents(collection, workflow_id: str, graph_version: int) -> List[Dict]:
"""Fetch graph docs for active version, with fallback for legacy unversioned data."""
docs = list(
collection.find({"workflow_id": workflow_id, "graph_version": graph_version})
)
if docs:
return docs
if graph_version == 1:
return list(
collection.find(
{"workflow_id": workflow_id, "graph_version": {"$exists": False}}
)
)
return docs
def validate_workflow_structure(nodes: List[Dict], edges: List[Dict]) -> List[str]:
"""Validate workflow graph structure."""
errors = []
if not nodes:
errors.append("Workflow must have at least one node")
return errors
start_nodes = [n for n in nodes if n.get("type") == "start"]
if len(start_nodes) != 1:
errors.append("Workflow must have exactly one start node")
end_nodes = [n for n in nodes if n.get("type") == "end"]
if not end_nodes:
errors.append("Workflow must have at least one end node")
node_ids = {n.get("id") for n in nodes}
for edge in edges:
source_id = edge.get("source")
target_id = edge.get("target")
if source_id not in node_ids:
errors.append(f"Edge references non-existent source: {source_id}")
if target_id not in node_ids:
errors.append(f"Edge references non-existent target: {target_id}")
if start_nodes:
start_id = start_nodes[0].get("id")
if not any(e.get("source") == start_id for e in edges):
errors.append("Start node must have at least one outgoing edge")
for node in nodes:
if not node.get("id"):
errors.append("All nodes must have an id")
if not node.get("type"):
errors.append(f"Node {node.get('id', 'unknown')} must have a type")
return errors
def create_workflow_nodes(
workflow_id: str, nodes_data: List[Dict], graph_version: int
) -> None:
"""Insert workflow nodes into database."""
if nodes_data:
workflow_nodes_collection.insert_many(
[
{
"id": n["id"],
"workflow_id": workflow_id,
"graph_version": graph_version,
"type": n["type"],
"title": n.get("title", ""),
"description": n.get("description", ""),
"position": n.get("position", {"x": 0, "y": 0}),
"config": n.get("data", {}),
}
for n in nodes_data
]
)
def create_workflow_edges(
workflow_id: str, edges_data: List[Dict], graph_version: int
) -> None:
"""Insert workflow edges into database."""
if edges_data:
workflow_edges_collection.insert_many(
[
{
"id": e["id"],
"workflow_id": workflow_id,
"graph_version": graph_version,
"source_id": e.get("source"),
"target_id": e.get("target"),
"source_handle": e.get("sourceHandle"),
"target_handle": e.get("targetHandle"),
}
for e in edges_data
]
)
@workflows_ns.route("/workflows")
class WorkflowList(Resource):
@require_auth
@require_fields(["name"])
def post(self):
"""Create a new workflow with nodes and edges."""
user_id = get_user_id()
data = request.get_json()
name = data.get("name", "").strip()
nodes_data = data.get("nodes", [])
edges_data = data.get("edges", [])
validation_errors = validate_workflow_structure(nodes_data, edges_data)
if validation_errors:
return error_response(
"Workflow validation failed", errors=validation_errors
)
now = datetime.now(timezone.utc)
workflow_doc = {
"name": name,
"description": data.get("description", ""),
"user": user_id,
"created_at": now,
"updated_at": now,
"current_graph_version": 1,
}
result, error = safe_db_operation(
lambda: workflows_collection.insert_one(workflow_doc),
"Failed to create workflow",
)
if error:
return error
workflow_id = str(result.inserted_id)
try:
create_workflow_nodes(workflow_id, nodes_data, 1)
create_workflow_edges(workflow_id, edges_data, 1)
except Exception as e:
workflow_nodes_collection.delete_many({"workflow_id": workflow_id})
workflow_edges_collection.delete_many({"workflow_id": workflow_id})
workflows_collection.delete_one({"_id": result.inserted_id})
return error_response(f"Failed to create workflow structure: {str(e)}")
return success_response({"id": workflow_id}, 201)
@workflows_ns.route("/workflows/<string:workflow_id>")
class WorkflowDetail(Resource):
@require_auth
def get(self, workflow_id: str):
"""Get workflow details with nodes and edges."""
user_id = get_user_id()
obj_id, error = validate_object_id(workflow_id, "Workflow")
if error:
return error
workflow, error = check_resource_ownership(
workflows_collection, obj_id, user_id, "Workflow"
)
if error:
return error
graph_version = get_workflow_graph_version(workflow)
nodes = fetch_graph_documents(
workflow_nodes_collection, workflow_id, graph_version
)
edges = fetch_graph_documents(
workflow_edges_collection, workflow_id, graph_version
)
return success_response(
{
"workflow": serialize_workflow(workflow),
"nodes": [serialize_node(n) for n in nodes],
"edges": [serialize_edge(e) for e in edges],
}
)
@require_auth
@require_fields(["name"])
def put(self, workflow_id: str):
"""Update workflow and replace nodes/edges."""
user_id = get_user_id()
obj_id, error = validate_object_id(workflow_id, "Workflow")
if error:
return error
workflow, error = check_resource_ownership(
workflows_collection, obj_id, user_id, "Workflow"
)
if error:
return error
data = request.get_json()
name = data.get("name", "").strip()
nodes_data = data.get("nodes", [])
edges_data = data.get("edges", [])
validation_errors = validate_workflow_structure(nodes_data, edges_data)
if validation_errors:
return error_response(
"Workflow validation failed", errors=validation_errors
)
current_graph_version = get_workflow_graph_version(workflow)
next_graph_version = current_graph_version + 1
try:
create_workflow_nodes(workflow_id, nodes_data, next_graph_version)
create_workflow_edges(workflow_id, edges_data, next_graph_version)
except Exception as e:
workflow_nodes_collection.delete_many(
{"workflow_id": workflow_id, "graph_version": next_graph_version}
)
workflow_edges_collection.delete_many(
{"workflow_id": workflow_id, "graph_version": next_graph_version}
)
return error_response(f"Failed to update workflow structure: {str(e)}")
now = datetime.now(timezone.utc)
_, error = safe_db_operation(
lambda: workflows_collection.update_one(
{"_id": obj_id},
{
"$set": {
"name": name,
"description": data.get("description", ""),
"updated_at": now,
"current_graph_version": next_graph_version,
}
},
),
"Failed to update workflow",
)
if error:
workflow_nodes_collection.delete_many(
{"workflow_id": workflow_id, "graph_version": next_graph_version}
)
workflow_edges_collection.delete_many(
{"workflow_id": workflow_id, "graph_version": next_graph_version}
)
return error
try:
workflow_nodes_collection.delete_many(
{"workflow_id": workflow_id, "graph_version": {"$ne": next_graph_version}}
)
workflow_edges_collection.delete_many(
{"workflow_id": workflow_id, "graph_version": {"$ne": next_graph_version}}
)
except Exception as cleanup_err:
current_app.logger.warning(
f"Failed to clean old workflow graph versions for {workflow_id}: {cleanup_err}"
)
return success_response()
@require_auth
def delete(self, workflow_id: str):
"""Delete workflow and its graph."""
user_id = get_user_id()
obj_id, error = validate_object_id(workflow_id, "Workflow")
if error:
return error
workflow, error = check_resource_ownership(
workflows_collection, obj_id, user_id, "Workflow"
)
if error:
return error
try:
workflow_nodes_collection.delete_many({"workflow_id": workflow_id})
workflow_edges_collection.delete_many({"workflow_id": workflow_id})
workflows_collection.delete_one({"_id": workflow["_id"], "user": user_id})
except Exception as e:
return error_response(f"Failed to delete workflow: {str(e)}")
return success_response()