mirror of
https://github.com/arc53/DocsGPT.git
synced 2026-02-22 04:11:41 +00:00
* feat: add condition node functionality with CEL evaluation - Introduced ConditionNode to support conditional branching in workflows. - Implemented CEL evaluation for state updates and condition expressions. - Updated WorkflowEngine to handle condition nodes and their execution logic. - Enhanced validation for workflows to ensure condition nodes have at least two outgoing edges and valid expressions. - Modified frontend components to support new condition node type and its configuration. - Added necessary types and interfaces for condition cases and state operations. - Updated requirements to include cel-python for expression evaluation. * mini-fixes * feat(workflow): improve UX --------- Co-authored-by: Alex <a@tushynski.me>
236 lines
6.9 KiB
Python
236 lines
6.9 KiB
Python
from datetime import datetime, timezone
|
|
from enum import Enum
|
|
from typing import Any, Dict, List, Literal, Optional, Union
|
|
|
|
from bson import ObjectId
|
|
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
|
|
|
|
|
class NodeType(str, Enum):
|
|
START = "start"
|
|
END = "end"
|
|
AGENT = "agent"
|
|
NOTE = "note"
|
|
STATE = "state"
|
|
CONDITION = "condition"
|
|
|
|
|
|
class AgentType(str, Enum):
|
|
CLASSIC = "classic"
|
|
REACT = "react"
|
|
|
|
|
|
class ExecutionStatus(str, Enum):
|
|
PENDING = "pending"
|
|
RUNNING = "running"
|
|
COMPLETED = "completed"
|
|
FAILED = "failed"
|
|
|
|
|
|
class Position(BaseModel):
|
|
model_config = ConfigDict(extra="forbid")
|
|
x: float = 0.0
|
|
y: float = 0.0
|
|
|
|
|
|
class AgentNodeConfig(BaseModel):
|
|
model_config = ConfigDict(extra="allow")
|
|
agent_type: AgentType = AgentType.CLASSIC
|
|
llm_name: Optional[str] = None
|
|
system_prompt: str = "You are a helpful assistant."
|
|
prompt_template: str = ""
|
|
output_variable: Optional[str] = None
|
|
stream_to_user: bool = True
|
|
tools: List[str] = Field(default_factory=list)
|
|
sources: List[str] = Field(default_factory=list)
|
|
chunks: str = "2"
|
|
retriever: str = ""
|
|
model_id: Optional[str] = None
|
|
json_schema: Optional[Dict[str, Any]] = None
|
|
|
|
|
|
class ConditionCase(BaseModel):
|
|
model_config = ConfigDict(extra="forbid", populate_by_name=True)
|
|
name: Optional[str] = None
|
|
expression: str = ""
|
|
source_handle: str = Field(..., alias="sourceHandle")
|
|
|
|
|
|
class ConditionNodeConfig(BaseModel):
|
|
model_config = ConfigDict(extra="allow")
|
|
mode: Literal["simple", "advanced"] = "simple"
|
|
cases: List[ConditionCase] = Field(default_factory=list)
|
|
|
|
|
|
class StateOperation(BaseModel):
|
|
model_config = ConfigDict(extra="forbid")
|
|
expression: str = ""
|
|
target_variable: str = ""
|
|
|
|
|
|
class WorkflowEdgeCreate(BaseModel):
|
|
model_config = ConfigDict(populate_by_name=True)
|
|
id: str
|
|
workflow_id: str
|
|
source_id: str = Field(..., alias="source")
|
|
target_id: str = Field(..., alias="target")
|
|
source_handle: Optional[str] = Field(None, alias="sourceHandle")
|
|
target_handle: Optional[str] = Field(None, alias="targetHandle")
|
|
|
|
|
|
class WorkflowEdge(WorkflowEdgeCreate):
|
|
mongo_id: Optional[str] = Field(None, alias="_id")
|
|
|
|
@field_validator("mongo_id", mode="before")
|
|
@classmethod
|
|
def convert_objectid(cls, v: Any) -> Optional[str]:
|
|
if isinstance(v, ObjectId):
|
|
return str(v)
|
|
return v
|
|
|
|
def to_mongo_doc(self) -> Dict[str, Any]:
|
|
return {
|
|
"id": self.id,
|
|
"workflow_id": self.workflow_id,
|
|
"source_id": self.source_id,
|
|
"target_id": self.target_id,
|
|
"source_handle": self.source_handle,
|
|
"target_handle": self.target_handle,
|
|
}
|
|
|
|
|
|
class WorkflowNodeCreate(BaseModel):
|
|
model_config = ConfigDict(extra="allow")
|
|
id: str
|
|
workflow_id: str
|
|
type: NodeType
|
|
title: str = "Node"
|
|
description: Optional[str] = None
|
|
position: Position = Field(default_factory=Position)
|
|
config: Dict[str, Any] = Field(default_factory=dict)
|
|
|
|
@field_validator("position", mode="before")
|
|
@classmethod
|
|
def parse_position(cls, v: Union[Dict[str, float], Position]) -> Position:
|
|
if isinstance(v, dict):
|
|
return Position(**v)
|
|
return v
|
|
|
|
|
|
class WorkflowNode(WorkflowNodeCreate):
|
|
mongo_id: Optional[str] = Field(None, alias="_id")
|
|
|
|
@field_validator("mongo_id", mode="before")
|
|
@classmethod
|
|
def convert_objectid(cls, v: Any) -> Optional[str]:
|
|
if isinstance(v, ObjectId):
|
|
return str(v)
|
|
return v
|
|
|
|
def to_mongo_doc(self) -> Dict[str, Any]:
|
|
return {
|
|
"id": self.id,
|
|
"workflow_id": self.workflow_id,
|
|
"type": self.type.value,
|
|
"title": self.title,
|
|
"description": self.description,
|
|
"position": self.position.model_dump(),
|
|
"config": self.config,
|
|
}
|
|
|
|
|
|
class WorkflowCreate(BaseModel):
|
|
model_config = ConfigDict(extra="allow")
|
|
name: str = "New Workflow"
|
|
description: Optional[str] = None
|
|
user: Optional[str] = None
|
|
|
|
|
|
class Workflow(WorkflowCreate):
|
|
id: Optional[str] = Field(None, alias="_id")
|
|
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
|
updated_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
|
|
|
@field_validator("id", mode="before")
|
|
@classmethod
|
|
def convert_objectid(cls, v: Any) -> Optional[str]:
|
|
if isinstance(v, ObjectId):
|
|
return str(v)
|
|
return v
|
|
|
|
def to_mongo_doc(self) -> Dict[str, Any]:
|
|
return {
|
|
"name": self.name,
|
|
"description": self.description,
|
|
"user": self.user,
|
|
"created_at": self.created_at,
|
|
"updated_at": self.updated_at,
|
|
}
|
|
|
|
|
|
class WorkflowGraph(BaseModel):
|
|
workflow: Workflow
|
|
nodes: List[WorkflowNode] = Field(default_factory=list)
|
|
edges: List[WorkflowEdge] = Field(default_factory=list)
|
|
|
|
def get_node_by_id(self, node_id: str) -> Optional[WorkflowNode]:
|
|
for node in self.nodes:
|
|
if node.id == node_id:
|
|
return node
|
|
return None
|
|
|
|
def get_start_node(self) -> Optional[WorkflowNode]:
|
|
for node in self.nodes:
|
|
if node.type == NodeType.START:
|
|
return node
|
|
return None
|
|
|
|
def get_outgoing_edges(self, node_id: str) -> List[WorkflowEdge]:
|
|
return [edge for edge in self.edges if edge.source_id == node_id]
|
|
|
|
|
|
class NodeExecutionLog(BaseModel):
|
|
model_config = ConfigDict(extra="forbid")
|
|
node_id: str
|
|
node_type: str
|
|
status: ExecutionStatus
|
|
started_at: datetime
|
|
completed_at: Optional[datetime] = None
|
|
error: Optional[str] = None
|
|
state_snapshot: Dict[str, Any] = Field(default_factory=dict)
|
|
|
|
|
|
class WorkflowRunCreate(BaseModel):
|
|
workflow_id: str
|
|
inputs: Dict[str, str] = Field(default_factory=dict)
|
|
|
|
|
|
class WorkflowRun(BaseModel):
|
|
model_config = ConfigDict(extra="allow")
|
|
id: Optional[str] = Field(None, alias="_id")
|
|
workflow_id: str
|
|
status: ExecutionStatus = ExecutionStatus.PENDING
|
|
inputs: Dict[str, str] = Field(default_factory=dict)
|
|
outputs: Dict[str, Any] = Field(default_factory=dict)
|
|
steps: List[NodeExecutionLog] = Field(default_factory=list)
|
|
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
|
completed_at: Optional[datetime] = None
|
|
|
|
@field_validator("id", mode="before")
|
|
@classmethod
|
|
def convert_objectid(cls, v: Any) -> Optional[str]:
|
|
if isinstance(v, ObjectId):
|
|
return str(v)
|
|
return v
|
|
|
|
def to_mongo_doc(self) -> Dict[str, Any]:
|
|
return {
|
|
"workflow_id": self.workflow_id,
|
|
"status": self.status.value,
|
|
"inputs": self.inputs,
|
|
"outputs": self.outputs,
|
|
"steps": [step.model_dump() for step in self.steps],
|
|
"created_at": self.created_at,
|
|
"completed_at": self.completed_at,
|
|
}
|