mirror of
https://github.com/arc53/DocsGPT.git
synced 2026-02-22 12:21:39 +00:00
feat: agent workflow builder (#2264)
* feat: implement WorkflowAgent and GraphExecutor for workflow management and execution * refactor: workflow schemas and introduce WorkflowEngine - Updated schemas in `schemas.py` to include new agent types and configurations. - Created `WorkflowEngine` class in `workflow_engine.py` to manage workflow execution. - Enhanced `StreamProcessor` to handle workflow-related data. - Added new routes and utilities for managing workflows in the user API. - Implemented validation and serialization functions for workflows. - Established MongoDB collections and indexes for workflows and related entities. * refactor: improve WorkflowAgent documentation and update type hints in WorkflowEngine * feat: workflow builder and managing in frontend - Added new endpoints for workflows in `endpoints.ts`. - Implemented `getWorkflow`, `createWorkflow`, and `updateWorkflow` methods in `userService.ts`. - Introduced new UI components for alerts, buttons, commands, dialogs, multi-select, popovers, and selects. - Enhanced styling in `index.css` with new theme variables and animations. - Refactored modal components for better layout and styling. - Configured TypeScript paths and Vite aliases for cleaner imports. * feat: add workflow preview component and related state management - Implemented WorkflowPreview component for displaying workflow execution. - Created WorkflowPreviewSlice for managing workflow preview state, including queries and execution steps. - Added WorkflowMiniMap for visual representation of workflow nodes and their statuses. - Integrated conversation handling with the ability to fetch answers and manage query states. - Introduced reusable Sheet component for UI overlays. - Updated Redux store to include workflowPreview reducer. * feat: enhance workflow execution details and state management in WorkflowEngine and WorkflowPreview * feat: enhance workflow components with improved UI and functionality - Updated WorkflowPreview to allow text truncation for better display of long names. - Enhanced BaseNode with connectable handles and improved styling for better visibility. - Added MobileBlocker component to inform users about desktop requirements for the Workflow Builder. - Introduced PromptTextArea component for improved variable insertion and search functionality, including upstream variable extraction and context addition. * feat(workflow): add owner validation and graph version support * fix: ruff lint --------- Co-authored-by: Alex <a@tushynski.me>
This commit is contained in:
@@ -1,6 +1,8 @@
|
||||
import logging
|
||||
|
||||
from application.agents.classic_agent import ClassicAgent
|
||||
from application.agents.react_agent import ReActAgent
|
||||
import logging
|
||||
from application.agents.workflow_agent import WorkflowAgent
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -9,6 +11,7 @@ class AgentCreator:
|
||||
agents = {
|
||||
"classic": ClassicAgent,
|
||||
"react": ReActAgent,
|
||||
"workflow": WorkflowAgent,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
@@ -16,5 +19,4 @@ class AgentCreator:
|
||||
agent_class = cls.agents.get(type.lower())
|
||||
if not agent_class:
|
||||
raise ValueError(f"No agent class found for type {type}")
|
||||
|
||||
return agent_class(*args, **kwargs)
|
||||
|
||||
@@ -367,7 +367,9 @@ class BaseAgent(ABC):
|
||||
f"Context at limit: {current_tokens:,}/{context_limit:,} tokens "
|
||||
f"({percentage:.1f}%). Model: {self.model_id}"
|
||||
)
|
||||
elif current_tokens >= int(context_limit * settings.COMPRESSION_THRESHOLD_PERCENTAGE):
|
||||
elif current_tokens >= int(
|
||||
context_limit * settings.COMPRESSION_THRESHOLD_PERCENTAGE
|
||||
):
|
||||
logger.info(
|
||||
f"Context approaching limit: {current_tokens:,}/{context_limit:,} tokens "
|
||||
f"({percentage:.1f}%)"
|
||||
|
||||
218
application/agents/workflow_agent.py
Normal file
218
application/agents/workflow_agent.py
Normal file
@@ -0,0 +1,218 @@
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Dict, Generator, Optional
|
||||
|
||||
from application.agents.base import BaseAgent
|
||||
from application.agents.workflows.schemas import (
|
||||
ExecutionStatus,
|
||||
Workflow,
|
||||
WorkflowEdge,
|
||||
WorkflowGraph,
|
||||
WorkflowNode,
|
||||
WorkflowRun,
|
||||
)
|
||||
from application.agents.workflows.workflow_engine import WorkflowEngine
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.core.settings import settings
|
||||
from application.logging import log_activity, LogContext
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WorkflowAgent(BaseAgent):
|
||||
"""A specialized agent that executes predefined workflows."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
workflow_id: Optional[str] = None,
|
||||
workflow: Optional[Dict[str, Any]] = None,
|
||||
workflow_owner: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.workflow_id = workflow_id
|
||||
self.workflow_owner = workflow_owner
|
||||
self._workflow_data = workflow
|
||||
self._engine: Optional[WorkflowEngine] = None
|
||||
|
||||
@log_activity()
|
||||
def gen(
|
||||
self, query: str, log_context: LogContext = None
|
||||
) -> Generator[Dict[str, str], None, None]:
|
||||
yield from self._gen_inner(query, log_context)
|
||||
|
||||
def _gen_inner(
|
||||
self, query: str, log_context: LogContext
|
||||
) -> Generator[Dict[str, str], None, None]:
|
||||
graph = self._load_workflow_graph()
|
||||
if not graph:
|
||||
yield {"type": "error", "error": "Failed to load workflow configuration."}
|
||||
return
|
||||
self._engine = WorkflowEngine(graph, self)
|
||||
yield from self._engine.execute({}, query)
|
||||
self._save_workflow_run(query)
|
||||
|
||||
def _load_workflow_graph(self) -> Optional[WorkflowGraph]:
|
||||
if self._workflow_data:
|
||||
return self._parse_embedded_workflow()
|
||||
if self.workflow_id:
|
||||
return self._load_from_database()
|
||||
return None
|
||||
|
||||
def _parse_embedded_workflow(self) -> Optional[WorkflowGraph]:
|
||||
try:
|
||||
nodes_data = self._workflow_data.get("nodes", [])
|
||||
edges_data = self._workflow_data.get("edges", [])
|
||||
|
||||
workflow = Workflow(
|
||||
name=self._workflow_data.get("name", "Embedded Workflow"),
|
||||
description=self._workflow_data.get("description"),
|
||||
)
|
||||
|
||||
nodes = []
|
||||
for n in nodes_data:
|
||||
node_config = n.get("data", {})
|
||||
nodes.append(
|
||||
WorkflowNode(
|
||||
id=n["id"],
|
||||
workflow_id=self.workflow_id or "embedded",
|
||||
type=n["type"],
|
||||
title=n.get("title", "Node"),
|
||||
description=n.get("description"),
|
||||
position=n.get("position", {"x": 0, "y": 0}),
|
||||
config=node_config,
|
||||
)
|
||||
)
|
||||
edges = []
|
||||
for e in edges_data:
|
||||
edges.append(
|
||||
WorkflowEdge(
|
||||
id=e["id"],
|
||||
workflow_id=self.workflow_id or "embedded",
|
||||
source=e.get("source") or e.get("source_id"),
|
||||
target=e.get("target") or e.get("target_id"),
|
||||
sourceHandle=e.get("sourceHandle") or e.get("source_handle"),
|
||||
targetHandle=e.get("targetHandle") or e.get("target_handle"),
|
||||
)
|
||||
)
|
||||
return WorkflowGraph(workflow=workflow, nodes=nodes, edges=edges)
|
||||
except Exception as e:
|
||||
logger.error(f"Invalid embedded workflow: {e}")
|
||||
return None
|
||||
|
||||
def _load_from_database(self) -> Optional[WorkflowGraph]:
|
||||
try:
|
||||
from bson.objectid import ObjectId
|
||||
|
||||
if not self.workflow_id or not ObjectId.is_valid(self.workflow_id):
|
||||
logger.error(f"Invalid workflow ID: {self.workflow_id}")
|
||||
return None
|
||||
owner_id = self.workflow_owner
|
||||
if not owner_id and isinstance(self.decoded_token, dict):
|
||||
owner_id = self.decoded_token.get("sub")
|
||||
if not owner_id:
|
||||
logger.error(
|
||||
f"Workflow owner not available for workflow load: {self.workflow_id}"
|
||||
)
|
||||
return None
|
||||
|
||||
mongo = MongoDB.get_client()
|
||||
db = mongo[settings.MONGO_DB_NAME]
|
||||
|
||||
workflows_coll = db["workflows"]
|
||||
workflow_nodes_coll = db["workflow_nodes"]
|
||||
workflow_edges_coll = db["workflow_edges"]
|
||||
|
||||
workflow_doc = workflows_coll.find_one(
|
||||
{"_id": ObjectId(self.workflow_id), "user": owner_id}
|
||||
)
|
||||
if not workflow_doc:
|
||||
logger.error(
|
||||
f"Workflow {self.workflow_id} not found or inaccessible for user {owner_id}"
|
||||
)
|
||||
return None
|
||||
workflow = Workflow(**workflow_doc)
|
||||
graph_version = workflow_doc.get("current_graph_version", 1)
|
||||
try:
|
||||
graph_version = int(graph_version)
|
||||
if graph_version <= 0:
|
||||
graph_version = 1
|
||||
except (ValueError, TypeError):
|
||||
graph_version = 1
|
||||
|
||||
nodes_docs = list(
|
||||
workflow_nodes_coll.find(
|
||||
{"workflow_id": self.workflow_id, "graph_version": graph_version}
|
||||
)
|
||||
)
|
||||
if not nodes_docs and graph_version == 1:
|
||||
nodes_docs = list(
|
||||
workflow_nodes_coll.find(
|
||||
{
|
||||
"workflow_id": self.workflow_id,
|
||||
"graph_version": {"$exists": False},
|
||||
}
|
||||
)
|
||||
)
|
||||
nodes = [WorkflowNode(**doc) for doc in nodes_docs]
|
||||
|
||||
edges_docs = list(
|
||||
workflow_edges_coll.find(
|
||||
{"workflow_id": self.workflow_id, "graph_version": graph_version}
|
||||
)
|
||||
)
|
||||
if not edges_docs and graph_version == 1:
|
||||
edges_docs = list(
|
||||
workflow_edges_coll.find(
|
||||
{
|
||||
"workflow_id": self.workflow_id,
|
||||
"graph_version": {"$exists": False},
|
||||
}
|
||||
)
|
||||
)
|
||||
edges = [WorkflowEdge(**doc) for doc in edges_docs]
|
||||
|
||||
return WorkflowGraph(workflow=workflow, nodes=nodes, edges=edges)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load workflow from database: {e}")
|
||||
return None
|
||||
|
||||
def _save_workflow_run(self, query: str) -> None:
|
||||
if not self._engine:
|
||||
return
|
||||
try:
|
||||
mongo = MongoDB.get_client()
|
||||
db = mongo[settings.MONGO_DB_NAME]
|
||||
workflow_runs_coll = db["workflow_runs"]
|
||||
|
||||
run = WorkflowRun(
|
||||
workflow_id=self.workflow_id or "unknown",
|
||||
status=self._determine_run_status(),
|
||||
inputs={"query": query},
|
||||
outputs=self._serialize_state(self._engine.state),
|
||||
steps=self._engine.get_execution_summary(),
|
||||
created_at=datetime.now(timezone.utc),
|
||||
completed_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
workflow_runs_coll.insert_one(run.to_mongo_doc())
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save workflow run: {e}")
|
||||
|
||||
def _determine_run_status(self) -> ExecutionStatus:
|
||||
if not self._engine or not self._engine.execution_log:
|
||||
return ExecutionStatus.COMPLETED
|
||||
for log in self._engine.execution_log:
|
||||
if log.get("status") == ExecutionStatus.FAILED.value:
|
||||
return ExecutionStatus.FAILED
|
||||
return ExecutionStatus.COMPLETED
|
||||
|
||||
def _serialize_state(self, state: Dict[str, Any]) -> Dict[str, Any]:
|
||||
serialized: Dict[str, Any] = {}
|
||||
for key, value in state.items():
|
||||
if isinstance(value, (str, int, float, bool, type(None))):
|
||||
serialized[key] = value
|
||||
else:
|
||||
serialized[key] = str(value)
|
||||
return serialized
|
||||
109
application/agents/workflows/node_agent.py
Normal file
109
application/agents/workflows/node_agent.py
Normal file
@@ -0,0 +1,109 @@
|
||||
"""Workflow Node Agents - defines specialized agents for workflow nodes."""
|
||||
|
||||
from typing import Any, Dict, List, Optional, Type
|
||||
|
||||
from application.agents.base import BaseAgent
|
||||
from application.agents.classic_agent import ClassicAgent
|
||||
from application.agents.react_agent import ReActAgent
|
||||
from application.agents.workflows.schemas import AgentType
|
||||
|
||||
|
||||
class ToolFilterMixin:
|
||||
"""Mixin that filters fetched tools to only those specified in tool_ids."""
|
||||
|
||||
_allowed_tool_ids: List[str]
|
||||
|
||||
def _get_user_tools(self, user: str = "local") -> Dict[str, Dict[str, Any]]:
|
||||
all_tools = super()._get_user_tools(user)
|
||||
if not self._allowed_tool_ids:
|
||||
return {}
|
||||
filtered_tools = {
|
||||
tool_id: tool
|
||||
for tool_id, tool in all_tools.items()
|
||||
if str(tool.get("_id", "")) in self._allowed_tool_ids
|
||||
}
|
||||
return filtered_tools
|
||||
|
||||
def _get_tools(self, api_key: str = None) -> Dict[str, Dict[str, Any]]:
|
||||
all_tools = super()._get_tools(api_key)
|
||||
if not self._allowed_tool_ids:
|
||||
return {}
|
||||
filtered_tools = {
|
||||
tool_id: tool
|
||||
for tool_id, tool in all_tools.items()
|
||||
if str(tool.get("_id", "")) in self._allowed_tool_ids
|
||||
}
|
||||
return filtered_tools
|
||||
|
||||
|
||||
class WorkflowNodeClassicAgent(ToolFilterMixin, ClassicAgent):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
endpoint: str,
|
||||
llm_name: str,
|
||||
model_id: str,
|
||||
api_key: str,
|
||||
tool_ids: Optional[List[str]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
endpoint=endpoint,
|
||||
llm_name=llm_name,
|
||||
model_id=model_id,
|
||||
api_key=api_key,
|
||||
**kwargs,
|
||||
)
|
||||
self._allowed_tool_ids = tool_ids or []
|
||||
|
||||
|
||||
class WorkflowNodeReActAgent(ToolFilterMixin, ReActAgent):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
endpoint: str,
|
||||
llm_name: str,
|
||||
model_id: str,
|
||||
api_key: str,
|
||||
tool_ids: Optional[List[str]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
endpoint=endpoint,
|
||||
llm_name=llm_name,
|
||||
model_id=model_id,
|
||||
api_key=api_key,
|
||||
**kwargs,
|
||||
)
|
||||
self._allowed_tool_ids = tool_ids or []
|
||||
|
||||
|
||||
class WorkflowNodeAgentFactory:
|
||||
|
||||
_agents: Dict[AgentType, Type[BaseAgent]] = {
|
||||
AgentType.CLASSIC: WorkflowNodeClassicAgent,
|
||||
AgentType.REACT: WorkflowNodeReActAgent,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
agent_type: AgentType,
|
||||
endpoint: str,
|
||||
llm_name: str,
|
||||
model_id: str,
|
||||
api_key: str,
|
||||
tool_ids: Optional[List[str]] = None,
|
||||
**kwargs,
|
||||
) -> BaseAgent:
|
||||
agent_class = cls._agents.get(agent_type)
|
||||
if not agent_class:
|
||||
raise ValueError(f"Unsupported agent type: {agent_type}")
|
||||
return agent_class(
|
||||
endpoint=endpoint,
|
||||
llm_name=llm_name,
|
||||
model_id=model_id,
|
||||
api_key=api_key,
|
||||
tool_ids=tool_ids,
|
||||
**kwargs,
|
||||
)
|
||||
215
application/agents/workflows/schemas.py
Normal file
215
application/agents/workflows/schemas.py
Normal file
@@ -0,0 +1,215 @@
|
||||
from datetime import datetime, timezone
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from bson import ObjectId
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
|
||||
class NodeType(str, Enum):
|
||||
START = "start"
|
||||
END = "end"
|
||||
AGENT = "agent"
|
||||
NOTE = "note"
|
||||
STATE = "state"
|
||||
|
||||
|
||||
class AgentType(str, Enum):
|
||||
CLASSIC = "classic"
|
||||
REACT = "react"
|
||||
|
||||
|
||||
class ExecutionStatus(str, Enum):
|
||||
PENDING = "pending"
|
||||
RUNNING = "running"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
|
||||
|
||||
class Position(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
x: float = 0.0
|
||||
y: float = 0.0
|
||||
|
||||
|
||||
class AgentNodeConfig(BaseModel):
|
||||
model_config = ConfigDict(extra="allow")
|
||||
agent_type: AgentType = AgentType.CLASSIC
|
||||
llm_name: Optional[str] = None
|
||||
system_prompt: str = "You are a helpful assistant."
|
||||
prompt_template: str = ""
|
||||
output_variable: Optional[str] = None
|
||||
stream_to_user: bool = True
|
||||
tools: List[str] = Field(default_factory=list)
|
||||
sources: List[str] = Field(default_factory=list)
|
||||
chunks: str = "2"
|
||||
retriever: str = ""
|
||||
model_id: Optional[str] = None
|
||||
json_schema: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class WorkflowEdgeCreate(BaseModel):
|
||||
model_config = ConfigDict(populate_by_name=True)
|
||||
id: str
|
||||
workflow_id: str
|
||||
source_id: str = Field(..., alias="source")
|
||||
target_id: str = Field(..., alias="target")
|
||||
source_handle: Optional[str] = Field(None, alias="sourceHandle")
|
||||
target_handle: Optional[str] = Field(None, alias="targetHandle")
|
||||
|
||||
|
||||
class WorkflowEdge(WorkflowEdgeCreate):
|
||||
mongo_id: Optional[str] = Field(None, alias="_id")
|
||||
|
||||
@field_validator("mongo_id", mode="before")
|
||||
@classmethod
|
||||
def convert_objectid(cls, v: Any) -> Optional[str]:
|
||||
if isinstance(v, ObjectId):
|
||||
return str(v)
|
||||
return v
|
||||
|
||||
def to_mongo_doc(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"id": self.id,
|
||||
"workflow_id": self.workflow_id,
|
||||
"source_id": self.source_id,
|
||||
"target_id": self.target_id,
|
||||
"source_handle": self.source_handle,
|
||||
"target_handle": self.target_handle,
|
||||
}
|
||||
|
||||
|
||||
class WorkflowNodeCreate(BaseModel):
|
||||
model_config = ConfigDict(extra="allow")
|
||||
id: str
|
||||
workflow_id: str
|
||||
type: NodeType
|
||||
title: str = "Node"
|
||||
description: Optional[str] = None
|
||||
position: Position = Field(default_factory=Position)
|
||||
config: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
@field_validator("position", mode="before")
|
||||
@classmethod
|
||||
def parse_position(cls, v: Union[Dict[str, float], Position]) -> Position:
|
||||
if isinstance(v, dict):
|
||||
return Position(**v)
|
||||
return v
|
||||
|
||||
|
||||
class WorkflowNode(WorkflowNodeCreate):
|
||||
mongo_id: Optional[str] = Field(None, alias="_id")
|
||||
|
||||
@field_validator("mongo_id", mode="before")
|
||||
@classmethod
|
||||
def convert_objectid(cls, v: Any) -> Optional[str]:
|
||||
if isinstance(v, ObjectId):
|
||||
return str(v)
|
||||
return v
|
||||
|
||||
def to_mongo_doc(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"id": self.id,
|
||||
"workflow_id": self.workflow_id,
|
||||
"type": self.type.value,
|
||||
"title": self.title,
|
||||
"description": self.description,
|
||||
"position": self.position.model_dump(),
|
||||
"config": self.config,
|
||||
}
|
||||
|
||||
|
||||
class WorkflowCreate(BaseModel):
|
||||
model_config = ConfigDict(extra="allow")
|
||||
name: str = "New Workflow"
|
||||
description: Optional[str] = None
|
||||
user: Optional[str] = None
|
||||
|
||||
|
||||
class Workflow(WorkflowCreate):
|
||||
id: Optional[str] = Field(None, alias="_id")
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
updated_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
|
||||
@field_validator("id", mode="before")
|
||||
@classmethod
|
||||
def convert_objectid(cls, v: Any) -> Optional[str]:
|
||||
if isinstance(v, ObjectId):
|
||||
return str(v)
|
||||
return v
|
||||
|
||||
def to_mongo_doc(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"name": self.name,
|
||||
"description": self.description,
|
||||
"user": self.user,
|
||||
"created_at": self.created_at,
|
||||
"updated_at": self.updated_at,
|
||||
}
|
||||
|
||||
|
||||
class WorkflowGraph(BaseModel):
|
||||
workflow: Workflow
|
||||
nodes: List[WorkflowNode] = Field(default_factory=list)
|
||||
edges: List[WorkflowEdge] = Field(default_factory=list)
|
||||
|
||||
def get_node_by_id(self, node_id: str) -> Optional[WorkflowNode]:
|
||||
for node in self.nodes:
|
||||
if node.id == node_id:
|
||||
return node
|
||||
return None
|
||||
|
||||
def get_start_node(self) -> Optional[WorkflowNode]:
|
||||
for node in self.nodes:
|
||||
if node.type == NodeType.START:
|
||||
return node
|
||||
return None
|
||||
|
||||
def get_outgoing_edges(self, node_id: str) -> List[WorkflowEdge]:
|
||||
return [edge for edge in self.edges if edge.source_id == node_id]
|
||||
|
||||
|
||||
class NodeExecutionLog(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
node_id: str
|
||||
node_type: str
|
||||
status: ExecutionStatus
|
||||
started_at: datetime
|
||||
completed_at: Optional[datetime] = None
|
||||
error: Optional[str] = None
|
||||
state_snapshot: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class WorkflowRunCreate(BaseModel):
|
||||
workflow_id: str
|
||||
inputs: Dict[str, str] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class WorkflowRun(BaseModel):
|
||||
model_config = ConfigDict(extra="allow")
|
||||
id: Optional[str] = Field(None, alias="_id")
|
||||
workflow_id: str
|
||||
status: ExecutionStatus = ExecutionStatus.PENDING
|
||||
inputs: Dict[str, str] = Field(default_factory=dict)
|
||||
outputs: Dict[str, Any] = Field(default_factory=dict)
|
||||
steps: List[NodeExecutionLog] = Field(default_factory=list)
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
completed_at: Optional[datetime] = None
|
||||
|
||||
@field_validator("id", mode="before")
|
||||
@classmethod
|
||||
def convert_objectid(cls, v: Any) -> Optional[str]:
|
||||
if isinstance(v, ObjectId):
|
||||
return str(v)
|
||||
return v
|
||||
|
||||
def to_mongo_doc(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"workflow_id": self.workflow_id,
|
||||
"status": self.status.value,
|
||||
"inputs": self.inputs,
|
||||
"outputs": self.outputs,
|
||||
"steps": [step.model_dump() for step in self.steps],
|
||||
"created_at": self.created_at,
|
||||
"completed_at": self.completed_at,
|
||||
}
|
||||
276
application/agents/workflows/workflow_engine.py
Normal file
276
application/agents/workflows/workflow_engine.py
Normal file
@@ -0,0 +1,276 @@
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Dict, Generator, List, Optional, TYPE_CHECKING
|
||||
|
||||
from application.agents.workflows.node_agent import WorkflowNodeAgentFactory
|
||||
from application.agents.workflows.schemas import (
|
||||
AgentNodeConfig,
|
||||
ExecutionStatus,
|
||||
NodeExecutionLog,
|
||||
NodeType,
|
||||
WorkflowGraph,
|
||||
WorkflowNode,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from application.agents.base import BaseAgent
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
StateValue = Any
|
||||
WorkflowState = Dict[str, StateValue]
|
||||
|
||||
|
||||
class WorkflowEngine:
|
||||
MAX_EXECUTION_STEPS = 50
|
||||
|
||||
def __init__(self, graph: WorkflowGraph, agent: "BaseAgent"):
|
||||
self.graph = graph
|
||||
self.agent = agent
|
||||
self.state: WorkflowState = {}
|
||||
self.execution_log: List[Dict[str, Any]] = []
|
||||
|
||||
def execute(
|
||||
self, initial_inputs: WorkflowState, query: str
|
||||
) -> Generator[Dict[str, str], None, None]:
|
||||
self._initialize_state(initial_inputs, query)
|
||||
|
||||
start_node = self.graph.get_start_node()
|
||||
if not start_node:
|
||||
yield {"type": "error", "error": "No start node found in workflow."}
|
||||
return
|
||||
current_node_id: Optional[str] = start_node.id
|
||||
steps = 0
|
||||
|
||||
while current_node_id and steps < self.MAX_EXECUTION_STEPS:
|
||||
node = self.graph.get_node_by_id(current_node_id)
|
||||
if not node:
|
||||
yield {"type": "error", "error": f"Node {current_node_id} not found."}
|
||||
break
|
||||
log_entry = self._create_log_entry(node)
|
||||
|
||||
yield {
|
||||
"type": "workflow_step",
|
||||
"node_id": node.id,
|
||||
"node_type": node.type.value,
|
||||
"node_title": node.title,
|
||||
"status": "running",
|
||||
}
|
||||
|
||||
try:
|
||||
yield from self._execute_node(node)
|
||||
log_entry["status"] = ExecutionStatus.COMPLETED.value
|
||||
log_entry["completed_at"] = datetime.now(timezone.utc)
|
||||
|
||||
output_key = f"node_{node.id}_output"
|
||||
node_output = self.state.get(output_key)
|
||||
|
||||
yield {
|
||||
"type": "workflow_step",
|
||||
"node_id": node.id,
|
||||
"node_type": node.type.value,
|
||||
"node_title": node.title,
|
||||
"status": "completed",
|
||||
"state_snapshot": dict(self.state),
|
||||
"output": node_output,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing node {node.id}: {e}", exc_info=True)
|
||||
log_entry["status"] = ExecutionStatus.FAILED.value
|
||||
log_entry["error"] = str(e)
|
||||
log_entry["completed_at"] = datetime.now(timezone.utc)
|
||||
log_entry["state_snapshot"] = dict(self.state)
|
||||
self.execution_log.append(log_entry)
|
||||
|
||||
yield {
|
||||
"type": "workflow_step",
|
||||
"node_id": node.id,
|
||||
"node_type": node.type.value,
|
||||
"node_title": node.title,
|
||||
"status": "failed",
|
||||
"state_snapshot": dict(self.state),
|
||||
"error": str(e),
|
||||
}
|
||||
yield {"type": "error", "error": str(e)}
|
||||
break
|
||||
log_entry["state_snapshot"] = dict(self.state)
|
||||
self.execution_log.append(log_entry)
|
||||
|
||||
if node.type == NodeType.END:
|
||||
break
|
||||
current_node_id = self._get_next_node_id(current_node_id)
|
||||
steps += 1
|
||||
if steps >= self.MAX_EXECUTION_STEPS:
|
||||
logger.warning(
|
||||
f"Workflow reached max steps limit ({self.MAX_EXECUTION_STEPS})"
|
||||
)
|
||||
|
||||
def _initialize_state(self, initial_inputs: WorkflowState, query: str) -> None:
|
||||
self.state.update(initial_inputs)
|
||||
self.state["query"] = query
|
||||
self.state["chat_history"] = str(self.agent.chat_history)
|
||||
|
||||
def _create_log_entry(self, node: WorkflowNode) -> Dict[str, Any]:
|
||||
return {
|
||||
"node_id": node.id,
|
||||
"node_type": node.type.value,
|
||||
"started_at": datetime.now(timezone.utc),
|
||||
"completed_at": None,
|
||||
"status": ExecutionStatus.RUNNING.value,
|
||||
"error": None,
|
||||
"state_snapshot": {},
|
||||
}
|
||||
|
||||
def _get_next_node_id(self, current_node_id: str) -> Optional[str]:
|
||||
edges = self.graph.get_outgoing_edges(current_node_id)
|
||||
if edges:
|
||||
return edges[0].target_id
|
||||
return None
|
||||
|
||||
def _execute_node(
|
||||
self, node: WorkflowNode
|
||||
) -> Generator[Dict[str, str], None, None]:
|
||||
logger.info(f"Executing node {node.id} ({node.type.value})")
|
||||
|
||||
node_handlers = {
|
||||
NodeType.START: self._execute_start_node,
|
||||
NodeType.NOTE: self._execute_note_node,
|
||||
NodeType.AGENT: self._execute_agent_node,
|
||||
NodeType.STATE: self._execute_state_node,
|
||||
NodeType.END: self._execute_end_node,
|
||||
}
|
||||
|
||||
handler = node_handlers.get(node.type)
|
||||
if handler:
|
||||
yield from handler(node)
|
||||
|
||||
def _execute_start_node(
|
||||
self, node: WorkflowNode
|
||||
) -> Generator[Dict[str, str], None, None]:
|
||||
yield from ()
|
||||
|
||||
def _execute_note_node(
|
||||
self, node: WorkflowNode
|
||||
) -> Generator[Dict[str, str], None, None]:
|
||||
yield from ()
|
||||
|
||||
def _execute_agent_node(
|
||||
self, node: WorkflowNode
|
||||
) -> Generator[Dict[str, str], None, None]:
|
||||
from application.core.model_utils import get_api_key_for_provider
|
||||
|
||||
node_config = AgentNodeConfig(**node.config)
|
||||
|
||||
if node_config.prompt_template:
|
||||
formatted_prompt = self._format_template(node_config.prompt_template)
|
||||
else:
|
||||
formatted_prompt = self.state.get("query", "")
|
||||
node_llm_name = node_config.llm_name or self.agent.llm_name
|
||||
node_api_key = get_api_key_for_provider(node_llm_name) or self.agent.api_key
|
||||
|
||||
node_agent = WorkflowNodeAgentFactory.create(
|
||||
agent_type=node_config.agent_type,
|
||||
endpoint=self.agent.endpoint,
|
||||
llm_name=node_llm_name,
|
||||
model_id=node_config.model_id or self.agent.model_id,
|
||||
api_key=node_api_key,
|
||||
tool_ids=node_config.tools,
|
||||
prompt=node_config.system_prompt,
|
||||
chat_history=self.agent.chat_history,
|
||||
decoded_token=self.agent.decoded_token,
|
||||
json_schema=node_config.json_schema,
|
||||
)
|
||||
|
||||
full_response = ""
|
||||
first_chunk = True
|
||||
for event in node_agent.gen(formatted_prompt):
|
||||
if "answer" in event:
|
||||
full_response += event["answer"]
|
||||
if node_config.stream_to_user:
|
||||
if first_chunk and hasattr(self, "_has_streamed"):
|
||||
yield {"answer": "\n\n"}
|
||||
first_chunk = False
|
||||
yield event
|
||||
|
||||
if node_config.stream_to_user:
|
||||
self._has_streamed = True
|
||||
|
||||
output_key = node_config.output_variable or f"node_{node.id}_output"
|
||||
self.state[output_key] = full_response
|
||||
|
||||
def _execute_state_node(
|
||||
self, node: WorkflowNode
|
||||
) -> Generator[Dict[str, str], None, None]:
|
||||
config = node.config
|
||||
operations = config.get("operations", [])
|
||||
|
||||
if operations:
|
||||
for op in operations:
|
||||
key = op.get("key")
|
||||
operation = op.get("operation", "set")
|
||||
value = op.get("value")
|
||||
|
||||
if not key:
|
||||
continue
|
||||
if operation == "set":
|
||||
formatted_value = (
|
||||
self._format_template(str(value))
|
||||
if isinstance(value, str)
|
||||
else value
|
||||
)
|
||||
self.state[key] = formatted_value
|
||||
elif operation == "increment":
|
||||
current = self.state.get(key, 0)
|
||||
try:
|
||||
self.state[key] = int(current) + int(value or 1)
|
||||
except (ValueError, TypeError):
|
||||
self.state[key] = 1
|
||||
elif operation == "append":
|
||||
if key not in self.state:
|
||||
self.state[key] = []
|
||||
if isinstance(self.state[key], list):
|
||||
self.state[key].append(value)
|
||||
else:
|
||||
updates = config.get("updates", {})
|
||||
if not updates:
|
||||
var_name = config.get("variable")
|
||||
var_value = config.get("value")
|
||||
if var_name and isinstance(var_name, str):
|
||||
updates = {var_name: var_value or ""}
|
||||
if isinstance(updates, dict):
|
||||
for key, value in updates.items():
|
||||
if isinstance(value, str):
|
||||
self.state[key] = self._format_template(value)
|
||||
else:
|
||||
self.state[key] = value
|
||||
yield from ()
|
||||
|
||||
def _execute_end_node(
|
||||
self, node: WorkflowNode
|
||||
) -> Generator[Dict[str, str], None, None]:
|
||||
config = node.config
|
||||
output_template = str(config.get("output_template", ""))
|
||||
if output_template:
|
||||
formatted_output = self._format_template(output_template)
|
||||
yield {"answer": formatted_output}
|
||||
|
||||
def _format_template(self, template: str) -> str:
|
||||
formatted = template
|
||||
for key, value in self.state.items():
|
||||
placeholder = f"{{{{{key}}}}}"
|
||||
if placeholder in formatted and value is not None:
|
||||
formatted = formatted.replace(placeholder, str(value))
|
||||
return formatted
|
||||
|
||||
def get_execution_summary(self) -> List[NodeExecutionLog]:
|
||||
return [
|
||||
NodeExecutionLog(
|
||||
node_id=log["node_id"],
|
||||
node_type=log["node_type"],
|
||||
status=ExecutionStatus(log["status"]),
|
||||
started_at=log["started_at"],
|
||||
completed_at=log.get("completed_at"),
|
||||
error=log.get("error"),
|
||||
state_snapshot=log.get("state_snapshot", {}),
|
||||
)
|
||||
for log in self.execution_log
|
||||
]
|
||||
Reference in New Issue
Block a user