This commit is contained in:
Alex
2026-02-17 18:46:35 +00:00
13 changed files with 1592 additions and 244 deletions

View File

@@ -0,0 +1,64 @@
from typing import Any, Dict
import celpy
import celpy.celtypes
class CelEvaluationError(Exception):
pass
def _convert_value(value: Any) -> Any:
if isinstance(value, bool):
return celpy.celtypes.BoolType(value)
if isinstance(value, int):
return celpy.celtypes.IntType(value)
if isinstance(value, float):
return celpy.celtypes.DoubleType(value)
if isinstance(value, str):
return celpy.celtypes.StringType(value)
if isinstance(value, list):
return celpy.celtypes.ListType([_convert_value(item) for item in value])
if isinstance(value, dict):
return celpy.celtypes.MapType(
{celpy.celtypes.StringType(k): _convert_value(v) for k, v in value.items()}
)
if value is None:
return celpy.celtypes.BoolType(False)
return celpy.celtypes.StringType(str(value))
def build_activation(state: Dict[str, Any]) -> Dict[str, Any]:
return {k: _convert_value(v) for k, v in state.items()}
def evaluate_cel(expression: str, state: Dict[str, Any]) -> Any:
if not expression or not expression.strip():
raise CelEvaluationError("Empty expression")
try:
env = celpy.Environment()
ast = env.compile(expression)
program = env.program(ast)
activation = build_activation(state)
result = program.evaluate(activation)
except celpy.CELEvalError as exc:
raise CelEvaluationError(f"CEL evaluation error: {exc}") from exc
except Exception as exc:
raise CelEvaluationError(f"CEL error: {exc}") from exc
return cel_to_python(result)
def cel_to_python(value: Any) -> Any:
if isinstance(value, celpy.celtypes.BoolType):
return bool(value)
if isinstance(value, celpy.celtypes.IntType):
return int(value)
if isinstance(value, celpy.celtypes.DoubleType):
return float(value)
if isinstance(value, celpy.celtypes.StringType):
return str(value)
if isinstance(value, celpy.celtypes.ListType):
return [cel_to_python(item) for item in value]
if isinstance(value, celpy.celtypes.MapType):
return {str(k): cel_to_python(v) for k, v in value.items()}
return value

View File

@@ -1,6 +1,6 @@
from datetime import datetime, timezone
from enum import Enum
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Literal, Optional, Union
from bson import ObjectId
from pydantic import BaseModel, ConfigDict, Field, field_validator
@@ -12,6 +12,7 @@ class NodeType(str, Enum):
AGENT = "agent"
NOTE = "note"
STATE = "state"
CONDITION = "condition"
class AgentType(str, Enum):
@@ -48,6 +49,25 @@ class AgentNodeConfig(BaseModel):
json_schema: Optional[Dict[str, Any]] = None
class ConditionCase(BaseModel):
model_config = ConfigDict(extra="forbid", populate_by_name=True)
name: Optional[str] = None
expression: str = ""
source_handle: str = Field(..., alias="sourceHandle")
class ConditionNodeConfig(BaseModel):
model_config = ConfigDict(extra="allow")
mode: Literal["simple", "advanced"] = "simple"
cases: List[ConditionCase] = Field(default_factory=list)
class StateOperation(BaseModel):
model_config = ConfigDict(extra="forbid")
expression: str = ""
target_variable: str = ""
class WorkflowEdgeCreate(BaseModel):
model_config = ConfigDict(populate_by_name=True)
id: str

View File

@@ -2,9 +2,11 @@ import logging
from datetime import datetime, timezone
from typing import Any, Dict, Generator, List, Optional, TYPE_CHECKING
from application.agents.workflows.cel_evaluator import CelEvaluationError, evaluate_cel
from application.agents.workflows.node_agent import WorkflowNodeAgentFactory
from application.agents.workflows.schemas import (
AgentNodeConfig,
ConditionNodeConfig,
ExecutionStatus,
NodeExecutionLog,
NodeType,
@@ -28,6 +30,7 @@ class WorkflowEngine:
self.agent = agent
self.state: WorkflowState = {}
self.execution_log: List[Dict[str, Any]] = []
self._condition_result: Optional[str] = None
def execute(
self, initial_inputs: WorkflowState, query: str
@@ -98,6 +101,10 @@ class WorkflowEngine:
if node.type == NodeType.END:
break
current_node_id = self._get_next_node_id(current_node_id)
if current_node_id is None and node.type != NodeType.END:
logger.warning(
f"Branch ended at node '{node.title}' ({node.id}) without reaching an end node"
)
steps += 1
if steps >= self.MAX_EXECUTION_STEPS:
logger.warning(
@@ -121,10 +128,20 @@ class WorkflowEngine:
}
def _get_next_node_id(self, current_node_id: str) -> Optional[str]:
node = self.graph.get_node_by_id(current_node_id)
edges = self.graph.get_outgoing_edges(current_node_id)
if edges:
return edges[0].target_id
return None
if not edges:
return None
if node and node.type == NodeType.CONDITION and self._condition_result:
target_handle = self._condition_result
self._condition_result = None
for edge in edges:
if edge.source_handle == target_handle:
return edge.target_id
return None
return edges[0].target_id
def _execute_node(
self, node: WorkflowNode
@@ -136,6 +153,7 @@ class WorkflowEngine:
NodeType.NOTE: self._execute_note_node,
NodeType.AGENT: self._execute_agent_node,
NodeType.STATE: self._execute_state_node,
NodeType.CONDITION: self._execute_condition_node,
NodeType.END: self._execute_end_node,
}
@@ -158,7 +176,7 @@ class WorkflowEngine:
) -> Generator[Dict[str, str], None, None]:
from application.core.model_utils import get_api_key_for_provider
node_config = AgentNodeConfig(**node.config)
node_config = AgentNodeConfig(**node.config.get("config", node.config))
if node_config.prompt_template:
formatted_prompt = self._format_template(node_config.prompt_template)
@@ -195,59 +213,42 @@ class WorkflowEngine:
self._has_streamed = True
output_key = node_config.output_variable or f"node_{node.id}_output"
self.state[output_key] = full_response
self.state[output_key] = full_response.strip()
def _execute_state_node(
self, node: WorkflowNode
) -> Generator[Dict[str, str], None, None]:
config = node.config
operations = config.get("operations", [])
config = node.config.get("config", node.config)
for op in config.get("operations", []):
expression = op.get("expression", "")
target_variable = op.get("target_variable", "")
if expression and target_variable:
self.state[target_variable] = evaluate_cel(expression, self.state)
yield from ()
if operations:
for op in operations:
key = op.get("key")
operation = op.get("operation", "set")
value = op.get("value")
def _execute_condition_node(
self, node: WorkflowNode
) -> Generator[Dict[str, str], None, None]:
config = ConditionNodeConfig(**node.config.get("config", node.config))
matched_handle = None
if not key:
continue
if operation == "set":
formatted_value = (
self._format_template(str(value))
if isinstance(value, str)
else value
)
self.state[key] = formatted_value
elif operation == "increment":
current = self.state.get(key, 0)
try:
self.state[key] = int(current) + int(value or 1)
except (ValueError, TypeError):
self.state[key] = 1
elif operation == "append":
if key not in self.state:
self.state[key] = []
if isinstance(self.state[key], list):
self.state[key].append(value)
else:
updates = config.get("updates", {})
if not updates:
var_name = config.get("variable")
var_value = config.get("value")
if var_name and isinstance(var_name, str):
updates = {var_name: var_value or ""}
if isinstance(updates, dict):
for key, value in updates.items():
if isinstance(value, str):
self.state[key] = self._format_template(value)
else:
self.state[key] = value
for case in config.cases:
if not case.expression.strip():
continue
try:
if evaluate_cel(case.expression, self.state):
matched_handle = case.source_handle
break
except CelEvaluationError:
continue
self._condition_result = matched_handle or "else"
yield from ()
def _execute_end_node(
self, node: WorkflowNode
) -> Generator[Dict[str, str], None, None]:
config = node.config
config = node.config.get("config", node.config)
output_template = str(config.get("output_template", ""))
if output_template:
formatted_output = self._format_template(output_template)

View File

@@ -1,7 +1,7 @@
"""Workflow management routes."""
from datetime import datetime, timezone
from typing import Dict, List
from typing import Dict, List, Set
from flask import current_app, request
from flask_restx import Namespace, Resource
@@ -102,6 +102,9 @@ def validate_workflow_structure(nodes: List[Dict], edges: List[Dict]) -> List[st
errors.append("Workflow must have at least one end node")
node_ids = {n.get("id") for n in nodes}
node_map = {n.get("id"): n for n in nodes}
end_ids = {n.get("id") for n in end_nodes}
for edge in edges:
source_id = edge.get("source")
target_id = edge.get("target")
@@ -115,6 +118,104 @@ def validate_workflow_structure(nodes: List[Dict], edges: List[Dict]) -> List[st
if not any(e.get("source") == start_id for e in edges):
errors.append("Start node must have at least one outgoing edge")
condition_nodes = [n for n in nodes if n.get("type") == "condition"]
for cnode in condition_nodes:
cnode_id = cnode.get("id")
cnode_title = cnode.get("title", cnode_id)
outgoing = [e for e in edges if e.get("source") == cnode_id]
if len(outgoing) < 2:
errors.append(
f"Condition node '{cnode_title}' must have at least 2 outgoing edges"
)
node_data = cnode.get("data", {}) or {}
cases = node_data.get("cases", [])
if not isinstance(cases, list):
cases = []
if not cases or not any(
isinstance(c, dict) and str(c.get("expression", "")).strip() for c in cases
):
errors.append(
f"Condition node '{cnode_title}' must have at least one case with an expression"
)
case_handles: Set[str] = set()
duplicate_case_handles: Set[str] = set()
for case in cases:
if not isinstance(case, dict):
continue
raw_handle = case.get("sourceHandle", "")
handle = raw_handle.strip() if isinstance(raw_handle, str) else ""
if not handle:
errors.append(
f"Condition node '{cnode_title}' has a case without a branch handle"
)
continue
if handle in case_handles:
duplicate_case_handles.add(handle)
case_handles.add(handle)
for handle in duplicate_case_handles:
errors.append(
f"Condition node '{cnode_title}' has duplicate case handle '{handle}'"
)
outgoing_by_handle: Dict[str, List[Dict]] = {}
for out_edge in outgoing:
raw_handle = out_edge.get("sourceHandle", "")
handle = raw_handle.strip() if isinstance(raw_handle, str) else ""
outgoing_by_handle.setdefault(handle, []).append(out_edge)
for handle, handle_edges in outgoing_by_handle.items():
if not handle:
errors.append(
f"Condition node '{cnode_title}' has an outgoing edge without sourceHandle"
)
continue
if handle != "else" and handle not in case_handles:
errors.append(
f"Condition node '{cnode_title}' has a connection from unknown branch '{handle}'"
)
if len(handle_edges) > 1:
errors.append(
f"Condition node '{cnode_title}' has multiple outgoing edges from branch '{handle}'"
)
if "else" not in outgoing_by_handle:
errors.append(f"Condition node '{cnode_title}' must have an 'else' branch")
for case in cases:
if not isinstance(case, dict):
continue
raw_handle = case.get("sourceHandle", "")
handle = raw_handle.strip() if isinstance(raw_handle, str) else ""
if not handle:
continue
raw_expression = case.get("expression", "")
has_expression = isinstance(raw_expression, str) and bool(
raw_expression.strip()
)
has_outgoing = bool(outgoing_by_handle.get(handle))
if has_expression and not has_outgoing:
errors.append(
f"Condition node '{cnode_title}' case '{handle}' has an expression but no outgoing edge"
)
if not has_expression and has_outgoing:
errors.append(
f"Condition node '{cnode_title}' case '{handle}' has an outgoing edge but no expression"
)
for handle, handle_edges in outgoing_by_handle.items():
if not handle:
continue
for out_edge in handle_edges:
target = out_edge.get("target")
if target and not _can_reach_end(target, edges, node_map, end_ids):
errors.append(
f"Branch '{handle}' of condition '{cnode_title}' "
f"must eventually reach an end node"
)
for node in nodes:
if not node.get("id"):
errors.append("All nodes must have an id")
@@ -124,6 +225,20 @@ def validate_workflow_structure(nodes: List[Dict], edges: List[Dict]) -> List[st
return errors
def _can_reach_end(
node_id: str, edges: List[Dict], node_map: Dict, end_ids: set, visited: set = None
) -> bool:
if visited is None:
visited = set()
if node_id in end_ids:
return True
if node_id in visited or node_id not in node_map:
return False
visited.add(node_id)
outgoing = [e.get("target") for e in edges if e.get("source") == node_id]
return any(_can_reach_end(t, edges, node_map, end_ids, visited) for t in outgoing if t)
def create_workflow_nodes(
workflow_id: str, nodes_data: List[Dict], graph_version: int
) -> None:

View File

@@ -1,6 +1,7 @@
anthropic==0.75.0
boto3==1.42.17
beautifulsoup4==4.14.3
cel-python==0.5.0
celery==5.6.0
cryptography==46.0.3
dataclasses-json==0.6.7

View File

@@ -1,4 +1,20 @@
export type NodeType = 'start' | 'end' | 'agent' | 'note' | 'state';
export type NodeType = 'start' | 'end' | 'agent' | 'note' | 'state' | 'condition';
export interface ConditionCase {
name?: string;
expression: string;
sourceHandle: string;
}
export interface ConditionNodeConfig {
mode: 'simple' | 'advanced';
cases: ConditionCase[];
}
export interface StateOperationConfig {
expression: string;
target_variable: string;
}
export interface WorkflowEdge {
id: string;

File diff suppressed because it is too large Load Diff

View File

@@ -4,6 +4,7 @@ import {
Circle,
Database,
Flag,
GitBranch,
Loader2,
MessageSquare,
Play,
@@ -53,6 +54,7 @@ const NODE_ICONS: Record<string, React.ReactNode> = {
end: <Flag className="h-3 w-3" />,
note: <StickyNote className="h-3 w-3" />,
state: <Database className="h-3 w-3" />,
condition: <GitBranch className="h-3 w-3" />,
};
const NODE_COLORS: Record<string, string> = {
@@ -61,6 +63,7 @@ const NODE_COLORS: Record<string, string> = {
end: 'text-gray-600 dark:text-gray-400',
note: 'text-yellow-600 dark:text-yellow-400',
state: 'text-blue-600 dark:text-blue-400',
condition: 'text-orange-600 dark:text-orange-400',
};
function ExecutionDetails({
@@ -267,20 +270,17 @@ function WorkflowMiniMap({
case 'failed':
return 'bg-red-100 dark:bg-red-900/30 border-red-300 dark:border-red-700';
default:
if (nodeType === 'start') {
return 'bg-green-50 dark:bg-green-900/20 border-green-200 dark:border-green-800';
}
if (nodeType === 'agent') {
return 'bg-purple-50 dark:bg-purple-900/20 border-purple-200 dark:border-purple-800';
}
if (nodeType === 'end') {
return 'bg-gray-50 dark:bg-gray-800 border-gray-200 dark:border-gray-700';
}
return 'bg-gray-50 dark:bg-gray-800 border-gray-200 dark:border-gray-700';
}
};
const executedOrder = new Map(executionSteps.map((s, i) => [s.nodeId, i]));
const sortedNodes = [...nodes].sort((a, b) => {
const aIdx = executedOrder.get(a.id);
const bIdx = executedOrder.get(b.id);
if (aIdx !== undefined && bIdx !== undefined) return aIdx - bIdx;
if (aIdx !== undefined) return -1;
if (bIdx !== undefined) return 1;
if (a.type === 'start') return -1;
if (b.type === 'start') return 1;
if (a.type === 'end') return 1;

View File

@@ -5,7 +5,7 @@ interface BaseNodeProps {
title: string;
children?: ReactNode;
selected?: boolean;
type?: 'start' | 'end' | 'default' | 'state' | 'agent';
type?: 'start' | 'end' | 'default' | 'state' | 'agent' | 'condition';
icon?: ReactNode;
handles?: {
source?: boolean;
@@ -40,6 +40,9 @@ export const BaseNode: React.FC<BaseNodeProps> = ({
} else if (type === 'state') {
iconBg = 'bg-gray-100 dark:bg-gray-800';
iconColor = 'text-gray-600 dark:text-gray-400';
} else if (type === 'condition') {
iconBg = 'bg-orange-100 dark:bg-orange-900/30';
iconColor = 'text-orange-600 dark:text-orange-400';
}
return (

View File

@@ -0,0 +1,118 @@
import { GitBranch } from 'lucide-react';
import { memo } from 'react';
import { Handle, NodeProps, Position } from 'reactflow';
import { ConditionCase } from '../../types/workflow';
type ConditionNodeData = {
label?: string;
title?: string;
config?: {
mode?: 'simple' | 'advanced';
cases?: ConditionCase[];
};
};
const ROW_HEIGHT = 18;
const HEADER_HEIGHT = 52;
const PADDING_BOTTOM = 8;
function getNodeHeight(caseCount: number): number {
return (
HEADER_HEIGHT + Math.max(caseCount + 1, 2) * ROW_HEIGHT + PADDING_BOTTOM
);
}
function getHandleTop(index: number, total: number): string {
const offset = HEADER_HEIGHT;
return `${offset + ROW_HEIGHT * index + ROW_HEIGHT / 2}px`;
}
const ConditionNode = ({ data, selected }: NodeProps<ConditionNodeData>) => {
const title = data.title || data.label || 'If / Else';
const cases = data.config?.cases || [];
const totalOutputs = cases.length + 1;
const height = getNodeHeight(cases.length);
return (
<div
className={`relative rounded-2xl border bg-white shadow-md transition-all dark:bg-[#2C2C2C] ${
selected
? 'border-violets-are-blue dark:ring-violets-are-blue scale-105 ring-2 ring-purple-300'
: 'border-gray-200 hover:shadow-lg dark:border-[#3A3A3A]'
}`}
style={{ minWidth: 180, maxWidth: 220, height }}
>
<Handle
type="target"
position={Position.Left}
isConnectable
className="hover:bg-violets-are-blue! top-1/2! -left-1! h-3! w-3! rounded-full! border-2! border-white! bg-gray-400! transition-colors dark:border-[#2C2C2C]!"
/>
<div className="flex items-center gap-3 px-3 py-2">
<div className="flex h-9 w-9 shrink-0 items-center justify-center rounded-full bg-orange-100 text-orange-600 dark:bg-orange-900/30 dark:text-orange-400">
<GitBranch size={14} />
</div>
<div className="min-w-0 flex-1 pr-2">
<div
className="truncate text-sm font-semibold text-gray-900 dark:text-white"
title={title}
>
{title}
</div>
<div className="text-[10px] text-gray-500 uppercase">
{data.config?.mode || 'simple'}
</div>
</div>
</div>
<div className="flex flex-col px-3">
{cases.map((c, i) => (
<div
key={c.sourceHandle}
className="flex items-center gap-1"
style={{ height: ROW_HEIGHT }}
>
<span className="shrink-0 text-xs font-medium text-orange-600 dark:text-orange-400">
{i === 0 ? 'If' : 'Else if'}
</span>
{c.name && (
<span
className="truncate text-xs text-gray-600 dark:text-gray-400"
title={c.name}
>
{c.name}
</span>
)}
</div>
))}
<div className="flex items-center gap-1" style={{ height: ROW_HEIGHT }}>
<span className="text-xs font-medium text-gray-500">Else</span>
</div>
</div>
{cases.map((c, i) => (
<Handle
key={c.sourceHandle}
type="source"
position={Position.Right}
id={c.sourceHandle}
isConnectable
style={{ top: getHandleTop(i, totalOutputs) }}
className="hover:bg-violets-are-blue! -right-1! h-3! w-3! rounded-full! border-2! border-white! bg-orange-400! transition-colors dark:border-[#2C2C2C]!"
/>
))}
<Handle
type="source"
position={Position.Right}
id="else"
isConnectable
style={{ top: getHandleTop(cases.length, totalOutputs) }}
className="hover:bg-violets-are-blue! -right-1! h-3! w-3! rounded-full! border-2! border-white! bg-gray-400! transition-colors dark:border-[#2C2C2C]!"
/>
</div>
);
};
export default memo(ConditionNode);

View File

@@ -2,6 +2,7 @@ import { Database } from 'lucide-react';
import { memo } from 'react';
import { NodeProps } from 'reactflow';
import { StateOperationConfig } from '../../types/workflow';
import { BaseNode } from './BaseNode';
type SetStateNodeData = {
@@ -9,10 +10,16 @@ type SetStateNodeData = {
title?: string;
variable?: string;
value?: string;
config?: {
operations?: StateOperationConfig[];
};
};
const SetStateNode = ({ data, selected }: NodeProps<SetStateNodeData>) => {
const title = data.title || data.label || 'Set State';
const operations = data.config?.operations || [];
const hasLegacy = !operations.length && data.variable;
return (
<BaseNode
title={title}
@@ -22,22 +29,31 @@ const SetStateNode = ({ data, selected }: NodeProps<SetStateNodeData>) => {
handles={{ source: true, target: true }}
>
<div className="flex flex-col gap-1">
{data.variable && (
{operations.length > 0 ? (
<div
className="truncate text-[10px] text-gray-500 uppercase"
title={`Variable: ${data.variable}`}
className="truncate text-[10px] text-gray-500"
title={`${operations.length} operation(s)`}
>
{data.variable}
{operations.length} variable{operations.length !== 1 ? 's' : ''}
</div>
)}
{data.value && (
<div
className="truncate text-xs text-blue-600 dark:text-blue-400"
title={`Value: ${data.value}`}
>
{data.value}
</div>
)}
) : hasLegacy ? (
<>
<div
className="truncate text-[10px] text-gray-500 uppercase"
title={`Variable: ${data.variable}`}
>
{data.variable}
</div>
{data.value && (
<div
className="truncate text-xs text-blue-600 dark:text-blue-400"
title={`Value: ${data.value}`}
>
{data.value}
</div>
)}
</>
) : null}
</div>
</BaseNode>
);

View File

@@ -1,7 +1,9 @@
import React, { memo } from 'react';
import { Bot, Flag, Play, StickyNote } from 'lucide-react';
import { memo } from 'react';
import { BaseNode } from './BaseNode';
import ConditionNode from './ConditionNode';
import SetStateNode from './SetStateNode';
import { Play, Bot, StickyNote, Flag } from 'lucide-react';
export const StartNode = memo(function StartNode({
selected,
@@ -142,3 +144,4 @@ export const NoteNode = memo(function NoteNode({
});
export { SetStateNode };
export { ConditionNode };

View File

@@ -43,7 +43,7 @@ export default function WrapperModal({
const modalContent = (
<div
className="fixed top-0 left-0 z-30 flex h-screen w-screen items-center justify-center"
className="fixed top-0 left-0 z-[100] flex h-screen w-screen items-center justify-center"
onClick={(e: React.MouseEvent) => e.stopPropagation()}
onMouseDown={(e: React.MouseEvent) => e.stopPropagation()}
>