mirror of
https://github.com/arc53/DocsGPT.git
synced 2026-02-21 20:01:26 +00:00
feat: condition node functionality with CEL evaluation in Workflows (#2280)
* feat: add condition node functionality with CEL evaluation - Introduced ConditionNode to support conditional branching in workflows. - Implemented CEL evaluation for state updates and condition expressions. - Updated WorkflowEngine to handle condition nodes and their execution logic. - Enhanced validation for workflows to ensure condition nodes have at least two outgoing edges and valid expressions. - Modified frontend components to support new condition node type and its configuration. - Added necessary types and interfaces for condition cases and state operations. - Updated requirements to include cel-python for expression evaluation. * mini-fixes * feat(workflow): improve UX --------- Co-authored-by: Alex <a@tushynski.me>
This commit is contained in:
64
application/agents/workflows/cel_evaluator.py
Normal file
64
application/agents/workflows/cel_evaluator.py
Normal file
@@ -0,0 +1,64 @@
|
||||
from typing import Any, Dict
|
||||
|
||||
import celpy
|
||||
import celpy.celtypes
|
||||
|
||||
|
||||
class CelEvaluationError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def _convert_value(value: Any) -> Any:
|
||||
if isinstance(value, bool):
|
||||
return celpy.celtypes.BoolType(value)
|
||||
if isinstance(value, int):
|
||||
return celpy.celtypes.IntType(value)
|
||||
if isinstance(value, float):
|
||||
return celpy.celtypes.DoubleType(value)
|
||||
if isinstance(value, str):
|
||||
return celpy.celtypes.StringType(value)
|
||||
if isinstance(value, list):
|
||||
return celpy.celtypes.ListType([_convert_value(item) for item in value])
|
||||
if isinstance(value, dict):
|
||||
return celpy.celtypes.MapType(
|
||||
{celpy.celtypes.StringType(k): _convert_value(v) for k, v in value.items()}
|
||||
)
|
||||
if value is None:
|
||||
return celpy.celtypes.BoolType(False)
|
||||
return celpy.celtypes.StringType(str(value))
|
||||
|
||||
|
||||
def build_activation(state: Dict[str, Any]) -> Dict[str, Any]:
|
||||
return {k: _convert_value(v) for k, v in state.items()}
|
||||
|
||||
|
||||
def evaluate_cel(expression: str, state: Dict[str, Any]) -> Any:
|
||||
if not expression or not expression.strip():
|
||||
raise CelEvaluationError("Empty expression")
|
||||
try:
|
||||
env = celpy.Environment()
|
||||
ast = env.compile(expression)
|
||||
program = env.program(ast)
|
||||
activation = build_activation(state)
|
||||
result = program.evaluate(activation)
|
||||
except celpy.CELEvalError as exc:
|
||||
raise CelEvaluationError(f"CEL evaluation error: {exc}") from exc
|
||||
except Exception as exc:
|
||||
raise CelEvaluationError(f"CEL error: {exc}") from exc
|
||||
return cel_to_python(result)
|
||||
|
||||
|
||||
def cel_to_python(value: Any) -> Any:
|
||||
if isinstance(value, celpy.celtypes.BoolType):
|
||||
return bool(value)
|
||||
if isinstance(value, celpy.celtypes.IntType):
|
||||
return int(value)
|
||||
if isinstance(value, celpy.celtypes.DoubleType):
|
||||
return float(value)
|
||||
if isinstance(value, celpy.celtypes.StringType):
|
||||
return str(value)
|
||||
if isinstance(value, celpy.celtypes.ListType):
|
||||
return [cel_to_python(item) for item in value]
|
||||
if isinstance(value, celpy.celtypes.MapType):
|
||||
return {str(k): cel_to_python(v) for k, v in value.items()}
|
||||
return value
|
||||
@@ -1,6 +1,6 @@
|
||||
from datetime import datetime, timezone
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Any, Dict, List, Literal, Optional, Union
|
||||
|
||||
from bson import ObjectId
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
@@ -12,6 +12,7 @@ class NodeType(str, Enum):
|
||||
AGENT = "agent"
|
||||
NOTE = "note"
|
||||
STATE = "state"
|
||||
CONDITION = "condition"
|
||||
|
||||
|
||||
class AgentType(str, Enum):
|
||||
@@ -48,6 +49,25 @@ class AgentNodeConfig(BaseModel):
|
||||
json_schema: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class ConditionCase(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid", populate_by_name=True)
|
||||
name: Optional[str] = None
|
||||
expression: str = ""
|
||||
source_handle: str = Field(..., alias="sourceHandle")
|
||||
|
||||
|
||||
class ConditionNodeConfig(BaseModel):
|
||||
model_config = ConfigDict(extra="allow")
|
||||
mode: Literal["simple", "advanced"] = "simple"
|
||||
cases: List[ConditionCase] = Field(default_factory=list)
|
||||
|
||||
|
||||
class StateOperation(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
expression: str = ""
|
||||
target_variable: str = ""
|
||||
|
||||
|
||||
class WorkflowEdgeCreate(BaseModel):
|
||||
model_config = ConfigDict(populate_by_name=True)
|
||||
id: str
|
||||
|
||||
@@ -2,9 +2,11 @@ import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Dict, Generator, List, Optional, TYPE_CHECKING
|
||||
|
||||
from application.agents.workflows.cel_evaluator import CelEvaluationError, evaluate_cel
|
||||
from application.agents.workflows.node_agent import WorkflowNodeAgentFactory
|
||||
from application.agents.workflows.schemas import (
|
||||
AgentNodeConfig,
|
||||
ConditionNodeConfig,
|
||||
ExecutionStatus,
|
||||
NodeExecutionLog,
|
||||
NodeType,
|
||||
@@ -28,6 +30,7 @@ class WorkflowEngine:
|
||||
self.agent = agent
|
||||
self.state: WorkflowState = {}
|
||||
self.execution_log: List[Dict[str, Any]] = []
|
||||
self._condition_result: Optional[str] = None
|
||||
|
||||
def execute(
|
||||
self, initial_inputs: WorkflowState, query: str
|
||||
@@ -98,6 +101,10 @@ class WorkflowEngine:
|
||||
if node.type == NodeType.END:
|
||||
break
|
||||
current_node_id = self._get_next_node_id(current_node_id)
|
||||
if current_node_id is None and node.type != NodeType.END:
|
||||
logger.warning(
|
||||
f"Branch ended at node '{node.title}' ({node.id}) without reaching an end node"
|
||||
)
|
||||
steps += 1
|
||||
if steps >= self.MAX_EXECUTION_STEPS:
|
||||
logger.warning(
|
||||
@@ -121,10 +128,20 @@ class WorkflowEngine:
|
||||
}
|
||||
|
||||
def _get_next_node_id(self, current_node_id: str) -> Optional[str]:
|
||||
node = self.graph.get_node_by_id(current_node_id)
|
||||
edges = self.graph.get_outgoing_edges(current_node_id)
|
||||
if edges:
|
||||
return edges[0].target_id
|
||||
return None
|
||||
if not edges:
|
||||
return None
|
||||
|
||||
if node and node.type == NodeType.CONDITION and self._condition_result:
|
||||
target_handle = self._condition_result
|
||||
self._condition_result = None
|
||||
for edge in edges:
|
||||
if edge.source_handle == target_handle:
|
||||
return edge.target_id
|
||||
return None
|
||||
|
||||
return edges[0].target_id
|
||||
|
||||
def _execute_node(
|
||||
self, node: WorkflowNode
|
||||
@@ -136,6 +153,7 @@ class WorkflowEngine:
|
||||
NodeType.NOTE: self._execute_note_node,
|
||||
NodeType.AGENT: self._execute_agent_node,
|
||||
NodeType.STATE: self._execute_state_node,
|
||||
NodeType.CONDITION: self._execute_condition_node,
|
||||
NodeType.END: self._execute_end_node,
|
||||
}
|
||||
|
||||
@@ -158,7 +176,7 @@ class WorkflowEngine:
|
||||
) -> Generator[Dict[str, str], None, None]:
|
||||
from application.core.model_utils import get_api_key_for_provider
|
||||
|
||||
node_config = AgentNodeConfig(**node.config)
|
||||
node_config = AgentNodeConfig(**node.config.get("config", node.config))
|
||||
|
||||
if node_config.prompt_template:
|
||||
formatted_prompt = self._format_template(node_config.prompt_template)
|
||||
@@ -195,59 +213,42 @@ class WorkflowEngine:
|
||||
self._has_streamed = True
|
||||
|
||||
output_key = node_config.output_variable or f"node_{node.id}_output"
|
||||
self.state[output_key] = full_response
|
||||
self.state[output_key] = full_response.strip()
|
||||
|
||||
def _execute_state_node(
|
||||
self, node: WorkflowNode
|
||||
) -> Generator[Dict[str, str], None, None]:
|
||||
config = node.config
|
||||
operations = config.get("operations", [])
|
||||
config = node.config.get("config", node.config)
|
||||
for op in config.get("operations", []):
|
||||
expression = op.get("expression", "")
|
||||
target_variable = op.get("target_variable", "")
|
||||
if expression and target_variable:
|
||||
self.state[target_variable] = evaluate_cel(expression, self.state)
|
||||
yield from ()
|
||||
|
||||
if operations:
|
||||
for op in operations:
|
||||
key = op.get("key")
|
||||
operation = op.get("operation", "set")
|
||||
value = op.get("value")
|
||||
def _execute_condition_node(
|
||||
self, node: WorkflowNode
|
||||
) -> Generator[Dict[str, str], None, None]:
|
||||
config = ConditionNodeConfig(**node.config.get("config", node.config))
|
||||
matched_handle = None
|
||||
|
||||
if not key:
|
||||
continue
|
||||
if operation == "set":
|
||||
formatted_value = (
|
||||
self._format_template(str(value))
|
||||
if isinstance(value, str)
|
||||
else value
|
||||
)
|
||||
self.state[key] = formatted_value
|
||||
elif operation == "increment":
|
||||
current = self.state.get(key, 0)
|
||||
try:
|
||||
self.state[key] = int(current) + int(value or 1)
|
||||
except (ValueError, TypeError):
|
||||
self.state[key] = 1
|
||||
elif operation == "append":
|
||||
if key not in self.state:
|
||||
self.state[key] = []
|
||||
if isinstance(self.state[key], list):
|
||||
self.state[key].append(value)
|
||||
else:
|
||||
updates = config.get("updates", {})
|
||||
if not updates:
|
||||
var_name = config.get("variable")
|
||||
var_value = config.get("value")
|
||||
if var_name and isinstance(var_name, str):
|
||||
updates = {var_name: var_value or ""}
|
||||
if isinstance(updates, dict):
|
||||
for key, value in updates.items():
|
||||
if isinstance(value, str):
|
||||
self.state[key] = self._format_template(value)
|
||||
else:
|
||||
self.state[key] = value
|
||||
for case in config.cases:
|
||||
if not case.expression.strip():
|
||||
continue
|
||||
try:
|
||||
if evaluate_cel(case.expression, self.state):
|
||||
matched_handle = case.source_handle
|
||||
break
|
||||
except CelEvaluationError:
|
||||
continue
|
||||
|
||||
self._condition_result = matched_handle or "else"
|
||||
yield from ()
|
||||
|
||||
def _execute_end_node(
|
||||
self, node: WorkflowNode
|
||||
) -> Generator[Dict[str, str], None, None]:
|
||||
config = node.config
|
||||
config = node.config.get("config", node.config)
|
||||
output_template = str(config.get("output_template", ""))
|
||||
if output_template:
|
||||
formatted_output = self._format_template(output_template)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Workflow management routes."""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import Dict, List
|
||||
from typing import Dict, List, Set
|
||||
|
||||
from flask import current_app, request
|
||||
from flask_restx import Namespace, Resource
|
||||
@@ -102,6 +102,9 @@ def validate_workflow_structure(nodes: List[Dict], edges: List[Dict]) -> List[st
|
||||
errors.append("Workflow must have at least one end node")
|
||||
|
||||
node_ids = {n.get("id") for n in nodes}
|
||||
node_map = {n.get("id"): n for n in nodes}
|
||||
end_ids = {n.get("id") for n in end_nodes}
|
||||
|
||||
for edge in edges:
|
||||
source_id = edge.get("source")
|
||||
target_id = edge.get("target")
|
||||
@@ -115,6 +118,104 @@ def validate_workflow_structure(nodes: List[Dict], edges: List[Dict]) -> List[st
|
||||
if not any(e.get("source") == start_id for e in edges):
|
||||
errors.append("Start node must have at least one outgoing edge")
|
||||
|
||||
condition_nodes = [n for n in nodes if n.get("type") == "condition"]
|
||||
for cnode in condition_nodes:
|
||||
cnode_id = cnode.get("id")
|
||||
cnode_title = cnode.get("title", cnode_id)
|
||||
outgoing = [e for e in edges if e.get("source") == cnode_id]
|
||||
if len(outgoing) < 2:
|
||||
errors.append(
|
||||
f"Condition node '{cnode_title}' must have at least 2 outgoing edges"
|
||||
)
|
||||
node_data = cnode.get("data", {}) or {}
|
||||
cases = node_data.get("cases", [])
|
||||
if not isinstance(cases, list):
|
||||
cases = []
|
||||
if not cases or not any(
|
||||
isinstance(c, dict) and str(c.get("expression", "")).strip() for c in cases
|
||||
):
|
||||
errors.append(
|
||||
f"Condition node '{cnode_title}' must have at least one case with an expression"
|
||||
)
|
||||
|
||||
case_handles: Set[str] = set()
|
||||
duplicate_case_handles: Set[str] = set()
|
||||
for case in cases:
|
||||
if not isinstance(case, dict):
|
||||
continue
|
||||
raw_handle = case.get("sourceHandle", "")
|
||||
handle = raw_handle.strip() if isinstance(raw_handle, str) else ""
|
||||
if not handle:
|
||||
errors.append(
|
||||
f"Condition node '{cnode_title}' has a case without a branch handle"
|
||||
)
|
||||
continue
|
||||
if handle in case_handles:
|
||||
duplicate_case_handles.add(handle)
|
||||
case_handles.add(handle)
|
||||
|
||||
for handle in duplicate_case_handles:
|
||||
errors.append(
|
||||
f"Condition node '{cnode_title}' has duplicate case handle '{handle}'"
|
||||
)
|
||||
|
||||
outgoing_by_handle: Dict[str, List[Dict]] = {}
|
||||
for out_edge in outgoing:
|
||||
raw_handle = out_edge.get("sourceHandle", "")
|
||||
handle = raw_handle.strip() if isinstance(raw_handle, str) else ""
|
||||
outgoing_by_handle.setdefault(handle, []).append(out_edge)
|
||||
|
||||
for handle, handle_edges in outgoing_by_handle.items():
|
||||
if not handle:
|
||||
errors.append(
|
||||
f"Condition node '{cnode_title}' has an outgoing edge without sourceHandle"
|
||||
)
|
||||
continue
|
||||
if handle != "else" and handle not in case_handles:
|
||||
errors.append(
|
||||
f"Condition node '{cnode_title}' has a connection from unknown branch '{handle}'"
|
||||
)
|
||||
if len(handle_edges) > 1:
|
||||
errors.append(
|
||||
f"Condition node '{cnode_title}' has multiple outgoing edges from branch '{handle}'"
|
||||
)
|
||||
|
||||
if "else" not in outgoing_by_handle:
|
||||
errors.append(f"Condition node '{cnode_title}' must have an 'else' branch")
|
||||
|
||||
for case in cases:
|
||||
if not isinstance(case, dict):
|
||||
continue
|
||||
raw_handle = case.get("sourceHandle", "")
|
||||
handle = raw_handle.strip() if isinstance(raw_handle, str) else ""
|
||||
if not handle:
|
||||
continue
|
||||
|
||||
raw_expression = case.get("expression", "")
|
||||
has_expression = isinstance(raw_expression, str) and bool(
|
||||
raw_expression.strip()
|
||||
)
|
||||
has_outgoing = bool(outgoing_by_handle.get(handle))
|
||||
if has_expression and not has_outgoing:
|
||||
errors.append(
|
||||
f"Condition node '{cnode_title}' case '{handle}' has an expression but no outgoing edge"
|
||||
)
|
||||
if not has_expression and has_outgoing:
|
||||
errors.append(
|
||||
f"Condition node '{cnode_title}' case '{handle}' has an outgoing edge but no expression"
|
||||
)
|
||||
|
||||
for handle, handle_edges in outgoing_by_handle.items():
|
||||
if not handle:
|
||||
continue
|
||||
for out_edge in handle_edges:
|
||||
target = out_edge.get("target")
|
||||
if target and not _can_reach_end(target, edges, node_map, end_ids):
|
||||
errors.append(
|
||||
f"Branch '{handle}' of condition '{cnode_title}' "
|
||||
f"must eventually reach an end node"
|
||||
)
|
||||
|
||||
for node in nodes:
|
||||
if not node.get("id"):
|
||||
errors.append("All nodes must have an id")
|
||||
@@ -124,6 +225,20 @@ def validate_workflow_structure(nodes: List[Dict], edges: List[Dict]) -> List[st
|
||||
return errors
|
||||
|
||||
|
||||
def _can_reach_end(
|
||||
node_id: str, edges: List[Dict], node_map: Dict, end_ids: set, visited: set = None
|
||||
) -> bool:
|
||||
if visited is None:
|
||||
visited = set()
|
||||
if node_id in end_ids:
|
||||
return True
|
||||
if node_id in visited or node_id not in node_map:
|
||||
return False
|
||||
visited.add(node_id)
|
||||
outgoing = [e.get("target") for e in edges if e.get("source") == node_id]
|
||||
return any(_can_reach_end(t, edges, node_map, end_ids, visited) for t in outgoing if t)
|
||||
|
||||
|
||||
def create_workflow_nodes(
|
||||
workflow_id: str, nodes_data: List[Dict], graph_version: int
|
||||
) -> None:
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
anthropic==0.75.0
|
||||
boto3==1.42.17
|
||||
beautifulsoup4==4.14.3
|
||||
cel-python==0.5.0
|
||||
celery==5.6.0
|
||||
cryptography==46.0.3
|
||||
dataclasses-json==0.6.7
|
||||
|
||||
@@ -1,4 +1,20 @@
|
||||
export type NodeType = 'start' | 'end' | 'agent' | 'note' | 'state';
|
||||
export type NodeType = 'start' | 'end' | 'agent' | 'note' | 'state' | 'condition';
|
||||
|
||||
export interface ConditionCase {
|
||||
name?: string;
|
||||
expression: string;
|
||||
sourceHandle: string;
|
||||
}
|
||||
|
||||
export interface ConditionNodeConfig {
|
||||
mode: 'simple' | 'advanced';
|
||||
cases: ConditionCase[];
|
||||
}
|
||||
|
||||
export interface StateOperationConfig {
|
||||
expression: string;
|
||||
target_variable: string;
|
||||
}
|
||||
|
||||
export interface WorkflowEdge {
|
||||
id: string;
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -4,6 +4,7 @@ import {
|
||||
Circle,
|
||||
Database,
|
||||
Flag,
|
||||
GitBranch,
|
||||
Loader2,
|
||||
MessageSquare,
|
||||
Play,
|
||||
@@ -53,6 +54,7 @@ const NODE_ICONS: Record<string, React.ReactNode> = {
|
||||
end: <Flag className="h-3 w-3" />,
|
||||
note: <StickyNote className="h-3 w-3" />,
|
||||
state: <Database className="h-3 w-3" />,
|
||||
condition: <GitBranch className="h-3 w-3" />,
|
||||
};
|
||||
|
||||
const NODE_COLORS: Record<string, string> = {
|
||||
@@ -61,6 +63,7 @@ const NODE_COLORS: Record<string, string> = {
|
||||
end: 'text-gray-600 dark:text-gray-400',
|
||||
note: 'text-yellow-600 dark:text-yellow-400',
|
||||
state: 'text-blue-600 dark:text-blue-400',
|
||||
condition: 'text-orange-600 dark:text-orange-400',
|
||||
};
|
||||
|
||||
function ExecutionDetails({
|
||||
@@ -267,20 +270,17 @@ function WorkflowMiniMap({
|
||||
case 'failed':
|
||||
return 'bg-red-100 dark:bg-red-900/30 border-red-300 dark:border-red-700';
|
||||
default:
|
||||
if (nodeType === 'start') {
|
||||
return 'bg-green-50 dark:bg-green-900/20 border-green-200 dark:border-green-800';
|
||||
}
|
||||
if (nodeType === 'agent') {
|
||||
return 'bg-purple-50 dark:bg-purple-900/20 border-purple-200 dark:border-purple-800';
|
||||
}
|
||||
if (nodeType === 'end') {
|
||||
return 'bg-gray-50 dark:bg-gray-800 border-gray-200 dark:border-gray-700';
|
||||
}
|
||||
return 'bg-gray-50 dark:bg-gray-800 border-gray-200 dark:border-gray-700';
|
||||
}
|
||||
};
|
||||
|
||||
const executedOrder = new Map(executionSteps.map((s, i) => [s.nodeId, i]));
|
||||
const sortedNodes = [...nodes].sort((a, b) => {
|
||||
const aIdx = executedOrder.get(a.id);
|
||||
const bIdx = executedOrder.get(b.id);
|
||||
if (aIdx !== undefined && bIdx !== undefined) return aIdx - bIdx;
|
||||
if (aIdx !== undefined) return -1;
|
||||
if (bIdx !== undefined) return 1;
|
||||
if (a.type === 'start') return -1;
|
||||
if (b.type === 'start') return 1;
|
||||
if (a.type === 'end') return 1;
|
||||
|
||||
@@ -5,7 +5,7 @@ interface BaseNodeProps {
|
||||
title: string;
|
||||
children?: ReactNode;
|
||||
selected?: boolean;
|
||||
type?: 'start' | 'end' | 'default' | 'state' | 'agent';
|
||||
type?: 'start' | 'end' | 'default' | 'state' | 'agent' | 'condition';
|
||||
icon?: ReactNode;
|
||||
handles?: {
|
||||
source?: boolean;
|
||||
@@ -40,6 +40,9 @@ export const BaseNode: React.FC<BaseNodeProps> = ({
|
||||
} else if (type === 'state') {
|
||||
iconBg = 'bg-gray-100 dark:bg-gray-800';
|
||||
iconColor = 'text-gray-600 dark:text-gray-400';
|
||||
} else if (type === 'condition') {
|
||||
iconBg = 'bg-orange-100 dark:bg-orange-900/30';
|
||||
iconColor = 'text-orange-600 dark:text-orange-400';
|
||||
}
|
||||
|
||||
return (
|
||||
|
||||
118
frontend/src/agents/workflow/nodes/ConditionNode.tsx
Normal file
118
frontend/src/agents/workflow/nodes/ConditionNode.tsx
Normal file
@@ -0,0 +1,118 @@
|
||||
import { GitBranch } from 'lucide-react';
|
||||
import { memo } from 'react';
|
||||
import { Handle, NodeProps, Position } from 'reactflow';
|
||||
|
||||
import { ConditionCase } from '../../types/workflow';
|
||||
|
||||
type ConditionNodeData = {
|
||||
label?: string;
|
||||
title?: string;
|
||||
config?: {
|
||||
mode?: 'simple' | 'advanced';
|
||||
cases?: ConditionCase[];
|
||||
};
|
||||
};
|
||||
|
||||
const ROW_HEIGHT = 18;
|
||||
const HEADER_HEIGHT = 52;
|
||||
const PADDING_BOTTOM = 8;
|
||||
|
||||
function getNodeHeight(caseCount: number): number {
|
||||
return (
|
||||
HEADER_HEIGHT + Math.max(caseCount + 1, 2) * ROW_HEIGHT + PADDING_BOTTOM
|
||||
);
|
||||
}
|
||||
|
||||
function getHandleTop(index: number, total: number): string {
|
||||
const offset = HEADER_HEIGHT;
|
||||
return `${offset + ROW_HEIGHT * index + ROW_HEIGHT / 2}px`;
|
||||
}
|
||||
|
||||
const ConditionNode = ({ data, selected }: NodeProps<ConditionNodeData>) => {
|
||||
const title = data.title || data.label || 'If / Else';
|
||||
const cases = data.config?.cases || [];
|
||||
const totalOutputs = cases.length + 1;
|
||||
const height = getNodeHeight(cases.length);
|
||||
|
||||
return (
|
||||
<div
|
||||
className={`relative rounded-2xl border bg-white shadow-md transition-all dark:bg-[#2C2C2C] ${
|
||||
selected
|
||||
? 'border-violets-are-blue dark:ring-violets-are-blue scale-105 ring-2 ring-purple-300'
|
||||
: 'border-gray-200 hover:shadow-lg dark:border-[#3A3A3A]'
|
||||
}`}
|
||||
style={{ minWidth: 180, maxWidth: 220, height }}
|
||||
>
|
||||
<Handle
|
||||
type="target"
|
||||
position={Position.Left}
|
||||
isConnectable
|
||||
className="hover:bg-violets-are-blue! top-1/2! -left-1! h-3! w-3! rounded-full! border-2! border-white! bg-gray-400! transition-colors dark:border-[#2C2C2C]!"
|
||||
/>
|
||||
|
||||
<div className="flex items-center gap-3 px-3 py-2">
|
||||
<div className="flex h-9 w-9 shrink-0 items-center justify-center rounded-full bg-orange-100 text-orange-600 dark:bg-orange-900/30 dark:text-orange-400">
|
||||
<GitBranch size={14} />
|
||||
</div>
|
||||
<div className="min-w-0 flex-1 pr-2">
|
||||
<div
|
||||
className="truncate text-sm font-semibold text-gray-900 dark:text-white"
|
||||
title={title}
|
||||
>
|
||||
{title}
|
||||
</div>
|
||||
<div className="text-[10px] text-gray-500 uppercase">
|
||||
{data.config?.mode || 'simple'}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="flex flex-col px-3">
|
||||
{cases.map((c, i) => (
|
||||
<div
|
||||
key={c.sourceHandle}
|
||||
className="flex items-center gap-1"
|
||||
style={{ height: ROW_HEIGHT }}
|
||||
>
|
||||
<span className="shrink-0 text-xs font-medium text-orange-600 dark:text-orange-400">
|
||||
{i === 0 ? 'If' : 'Else if'}
|
||||
</span>
|
||||
{c.name && (
|
||||
<span
|
||||
className="truncate text-xs text-gray-600 dark:text-gray-400"
|
||||
title={c.name}
|
||||
>
|
||||
{c.name}
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
))}
|
||||
<div className="flex items-center gap-1" style={{ height: ROW_HEIGHT }}>
|
||||
<span className="text-xs font-medium text-gray-500">Else</span>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{cases.map((c, i) => (
|
||||
<Handle
|
||||
key={c.sourceHandle}
|
||||
type="source"
|
||||
position={Position.Right}
|
||||
id={c.sourceHandle}
|
||||
isConnectable
|
||||
style={{ top: getHandleTop(i, totalOutputs) }}
|
||||
className="hover:bg-violets-are-blue! -right-1! h-3! w-3! rounded-full! border-2! border-white! bg-orange-400! transition-colors dark:border-[#2C2C2C]!"
|
||||
/>
|
||||
))}
|
||||
<Handle
|
||||
type="source"
|
||||
position={Position.Right}
|
||||
id="else"
|
||||
isConnectable
|
||||
style={{ top: getHandleTop(cases.length, totalOutputs) }}
|
||||
className="hover:bg-violets-are-blue! -right-1! h-3! w-3! rounded-full! border-2! border-white! bg-gray-400! transition-colors dark:border-[#2C2C2C]!"
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(ConditionNode);
|
||||
@@ -2,6 +2,7 @@ import { Database } from 'lucide-react';
|
||||
import { memo } from 'react';
|
||||
import { NodeProps } from 'reactflow';
|
||||
|
||||
import { StateOperationConfig } from '../../types/workflow';
|
||||
import { BaseNode } from './BaseNode';
|
||||
|
||||
type SetStateNodeData = {
|
||||
@@ -9,10 +10,16 @@ type SetStateNodeData = {
|
||||
title?: string;
|
||||
variable?: string;
|
||||
value?: string;
|
||||
config?: {
|
||||
operations?: StateOperationConfig[];
|
||||
};
|
||||
};
|
||||
|
||||
const SetStateNode = ({ data, selected }: NodeProps<SetStateNodeData>) => {
|
||||
const title = data.title || data.label || 'Set State';
|
||||
const operations = data.config?.operations || [];
|
||||
const hasLegacy = !operations.length && data.variable;
|
||||
|
||||
return (
|
||||
<BaseNode
|
||||
title={title}
|
||||
@@ -22,22 +29,31 @@ const SetStateNode = ({ data, selected }: NodeProps<SetStateNodeData>) => {
|
||||
handles={{ source: true, target: true }}
|
||||
>
|
||||
<div className="flex flex-col gap-1">
|
||||
{data.variable && (
|
||||
{operations.length > 0 ? (
|
||||
<div
|
||||
className="truncate text-[10px] text-gray-500 uppercase"
|
||||
title={`Variable: ${data.variable}`}
|
||||
className="truncate text-[10px] text-gray-500"
|
||||
title={`${operations.length} operation(s)`}
|
||||
>
|
||||
{data.variable}
|
||||
{operations.length} variable{operations.length !== 1 ? 's' : ''}
|
||||
</div>
|
||||
)}
|
||||
{data.value && (
|
||||
<div
|
||||
className="truncate text-xs text-blue-600 dark:text-blue-400"
|
||||
title={`Value: ${data.value}`}
|
||||
>
|
||||
{data.value}
|
||||
</div>
|
||||
)}
|
||||
) : hasLegacy ? (
|
||||
<>
|
||||
<div
|
||||
className="truncate text-[10px] text-gray-500 uppercase"
|
||||
title={`Variable: ${data.variable}`}
|
||||
>
|
||||
{data.variable}
|
||||
</div>
|
||||
{data.value && (
|
||||
<div
|
||||
className="truncate text-xs text-blue-600 dark:text-blue-400"
|
||||
title={`Value: ${data.value}`}
|
||||
>
|
||||
{data.value}
|
||||
</div>
|
||||
)}
|
||||
</>
|
||||
) : null}
|
||||
</div>
|
||||
</BaseNode>
|
||||
);
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
import React, { memo } from 'react';
|
||||
import { Bot, Flag, Play, StickyNote } from 'lucide-react';
|
||||
import { memo } from 'react';
|
||||
|
||||
import { BaseNode } from './BaseNode';
|
||||
import ConditionNode from './ConditionNode';
|
||||
import SetStateNode from './SetStateNode';
|
||||
import { Play, Bot, StickyNote, Flag } from 'lucide-react';
|
||||
|
||||
export const StartNode = memo(function StartNode({
|
||||
selected,
|
||||
@@ -142,3 +144,4 @@ export const NoteNode = memo(function NoteNode({
|
||||
});
|
||||
|
||||
export { SetStateNode };
|
||||
export { ConditionNode };
|
||||
|
||||
@@ -43,7 +43,7 @@ export default function WrapperModal({
|
||||
|
||||
const modalContent = (
|
||||
<div
|
||||
className="fixed top-0 left-0 z-30 flex h-screen w-screen items-center justify-center"
|
||||
className="fixed top-0 left-0 z-[100] flex h-screen w-screen items-center justify-center"
|
||||
onClick={(e: React.MouseEvent) => e.stopPropagation()}
|
||||
onMouseDown={(e: React.MouseEvent) => e.stopPropagation()}
|
||||
>
|
||||
|
||||
Reference in New Issue
Block a user