mirror of
https://github.com/arc53/DocsGPT.git
synced 2026-02-22 12:21:39 +00:00
feat: condition node functionality with CEL evaluation in Workflows (#2280)
* feat: add condition node functionality with CEL evaluation - Introduced ConditionNode to support conditional branching in workflows. - Implemented CEL evaluation for state updates and condition expressions. - Updated WorkflowEngine to handle condition nodes and their execution logic. - Enhanced validation for workflows to ensure condition nodes have at least two outgoing edges and valid expressions. - Modified frontend components to support new condition node type and its configuration. - Added necessary types and interfaces for condition cases and state operations. - Updated requirements to include cel-python for expression evaluation. * mini-fixes * feat(workflow): improve UX --------- Co-authored-by: Alex <a@tushynski.me>
This commit is contained in:
64
application/agents/workflows/cel_evaluator.py
Normal file
64
application/agents/workflows/cel_evaluator.py
Normal file
@@ -0,0 +1,64 @@
|
||||
from typing import Any, Dict
|
||||
|
||||
import celpy
|
||||
import celpy.celtypes
|
||||
|
||||
|
||||
class CelEvaluationError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def _convert_value(value: Any) -> Any:
|
||||
if isinstance(value, bool):
|
||||
return celpy.celtypes.BoolType(value)
|
||||
if isinstance(value, int):
|
||||
return celpy.celtypes.IntType(value)
|
||||
if isinstance(value, float):
|
||||
return celpy.celtypes.DoubleType(value)
|
||||
if isinstance(value, str):
|
||||
return celpy.celtypes.StringType(value)
|
||||
if isinstance(value, list):
|
||||
return celpy.celtypes.ListType([_convert_value(item) for item in value])
|
||||
if isinstance(value, dict):
|
||||
return celpy.celtypes.MapType(
|
||||
{celpy.celtypes.StringType(k): _convert_value(v) for k, v in value.items()}
|
||||
)
|
||||
if value is None:
|
||||
return celpy.celtypes.BoolType(False)
|
||||
return celpy.celtypes.StringType(str(value))
|
||||
|
||||
|
||||
def build_activation(state: Dict[str, Any]) -> Dict[str, Any]:
|
||||
return {k: _convert_value(v) for k, v in state.items()}
|
||||
|
||||
|
||||
def evaluate_cel(expression: str, state: Dict[str, Any]) -> Any:
|
||||
if not expression or not expression.strip():
|
||||
raise CelEvaluationError("Empty expression")
|
||||
try:
|
||||
env = celpy.Environment()
|
||||
ast = env.compile(expression)
|
||||
program = env.program(ast)
|
||||
activation = build_activation(state)
|
||||
result = program.evaluate(activation)
|
||||
except celpy.CELEvalError as exc:
|
||||
raise CelEvaluationError(f"CEL evaluation error: {exc}") from exc
|
||||
except Exception as exc:
|
||||
raise CelEvaluationError(f"CEL error: {exc}") from exc
|
||||
return cel_to_python(result)
|
||||
|
||||
|
||||
def cel_to_python(value: Any) -> Any:
|
||||
if isinstance(value, celpy.celtypes.BoolType):
|
||||
return bool(value)
|
||||
if isinstance(value, celpy.celtypes.IntType):
|
||||
return int(value)
|
||||
if isinstance(value, celpy.celtypes.DoubleType):
|
||||
return float(value)
|
||||
if isinstance(value, celpy.celtypes.StringType):
|
||||
return str(value)
|
||||
if isinstance(value, celpy.celtypes.ListType):
|
||||
return [cel_to_python(item) for item in value]
|
||||
if isinstance(value, celpy.celtypes.MapType):
|
||||
return {str(k): cel_to_python(v) for k, v in value.items()}
|
||||
return value
|
||||
@@ -1,6 +1,6 @@
|
||||
from datetime import datetime, timezone
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Any, Dict, List, Literal, Optional, Union
|
||||
|
||||
from bson import ObjectId
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
@@ -12,6 +12,7 @@ class NodeType(str, Enum):
|
||||
AGENT = "agent"
|
||||
NOTE = "note"
|
||||
STATE = "state"
|
||||
CONDITION = "condition"
|
||||
|
||||
|
||||
class AgentType(str, Enum):
|
||||
@@ -48,6 +49,25 @@ class AgentNodeConfig(BaseModel):
|
||||
json_schema: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class ConditionCase(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid", populate_by_name=True)
|
||||
name: Optional[str] = None
|
||||
expression: str = ""
|
||||
source_handle: str = Field(..., alias="sourceHandle")
|
||||
|
||||
|
||||
class ConditionNodeConfig(BaseModel):
|
||||
model_config = ConfigDict(extra="allow")
|
||||
mode: Literal["simple", "advanced"] = "simple"
|
||||
cases: List[ConditionCase] = Field(default_factory=list)
|
||||
|
||||
|
||||
class StateOperation(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
expression: str = ""
|
||||
target_variable: str = ""
|
||||
|
||||
|
||||
class WorkflowEdgeCreate(BaseModel):
|
||||
model_config = ConfigDict(populate_by_name=True)
|
||||
id: str
|
||||
|
||||
@@ -2,9 +2,11 @@ import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Dict, Generator, List, Optional, TYPE_CHECKING
|
||||
|
||||
from application.agents.workflows.cel_evaluator import CelEvaluationError, evaluate_cel
|
||||
from application.agents.workflows.node_agent import WorkflowNodeAgentFactory
|
||||
from application.agents.workflows.schemas import (
|
||||
AgentNodeConfig,
|
||||
ConditionNodeConfig,
|
||||
ExecutionStatus,
|
||||
NodeExecutionLog,
|
||||
NodeType,
|
||||
@@ -28,6 +30,7 @@ class WorkflowEngine:
|
||||
self.agent = agent
|
||||
self.state: WorkflowState = {}
|
||||
self.execution_log: List[Dict[str, Any]] = []
|
||||
self._condition_result: Optional[str] = None
|
||||
|
||||
def execute(
|
||||
self, initial_inputs: WorkflowState, query: str
|
||||
@@ -98,6 +101,10 @@ class WorkflowEngine:
|
||||
if node.type == NodeType.END:
|
||||
break
|
||||
current_node_id = self._get_next_node_id(current_node_id)
|
||||
if current_node_id is None and node.type != NodeType.END:
|
||||
logger.warning(
|
||||
f"Branch ended at node '{node.title}' ({node.id}) without reaching an end node"
|
||||
)
|
||||
steps += 1
|
||||
if steps >= self.MAX_EXECUTION_STEPS:
|
||||
logger.warning(
|
||||
@@ -121,10 +128,20 @@ class WorkflowEngine:
|
||||
}
|
||||
|
||||
def _get_next_node_id(self, current_node_id: str) -> Optional[str]:
|
||||
node = self.graph.get_node_by_id(current_node_id)
|
||||
edges = self.graph.get_outgoing_edges(current_node_id)
|
||||
if edges:
|
||||
return edges[0].target_id
|
||||
return None
|
||||
if not edges:
|
||||
return None
|
||||
|
||||
if node and node.type == NodeType.CONDITION and self._condition_result:
|
||||
target_handle = self._condition_result
|
||||
self._condition_result = None
|
||||
for edge in edges:
|
||||
if edge.source_handle == target_handle:
|
||||
return edge.target_id
|
||||
return None
|
||||
|
||||
return edges[0].target_id
|
||||
|
||||
def _execute_node(
|
||||
self, node: WorkflowNode
|
||||
@@ -136,6 +153,7 @@ class WorkflowEngine:
|
||||
NodeType.NOTE: self._execute_note_node,
|
||||
NodeType.AGENT: self._execute_agent_node,
|
||||
NodeType.STATE: self._execute_state_node,
|
||||
NodeType.CONDITION: self._execute_condition_node,
|
||||
NodeType.END: self._execute_end_node,
|
||||
}
|
||||
|
||||
@@ -158,7 +176,7 @@ class WorkflowEngine:
|
||||
) -> Generator[Dict[str, str], None, None]:
|
||||
from application.core.model_utils import get_api_key_for_provider
|
||||
|
||||
node_config = AgentNodeConfig(**node.config)
|
||||
node_config = AgentNodeConfig(**node.config.get("config", node.config))
|
||||
|
||||
if node_config.prompt_template:
|
||||
formatted_prompt = self._format_template(node_config.prompt_template)
|
||||
@@ -195,59 +213,42 @@ class WorkflowEngine:
|
||||
self._has_streamed = True
|
||||
|
||||
output_key = node_config.output_variable or f"node_{node.id}_output"
|
||||
self.state[output_key] = full_response
|
||||
self.state[output_key] = full_response.strip()
|
||||
|
||||
def _execute_state_node(
|
||||
self, node: WorkflowNode
|
||||
) -> Generator[Dict[str, str], None, None]:
|
||||
config = node.config
|
||||
operations = config.get("operations", [])
|
||||
config = node.config.get("config", node.config)
|
||||
for op in config.get("operations", []):
|
||||
expression = op.get("expression", "")
|
||||
target_variable = op.get("target_variable", "")
|
||||
if expression and target_variable:
|
||||
self.state[target_variable] = evaluate_cel(expression, self.state)
|
||||
yield from ()
|
||||
|
||||
if operations:
|
||||
for op in operations:
|
||||
key = op.get("key")
|
||||
operation = op.get("operation", "set")
|
||||
value = op.get("value")
|
||||
def _execute_condition_node(
|
||||
self, node: WorkflowNode
|
||||
) -> Generator[Dict[str, str], None, None]:
|
||||
config = ConditionNodeConfig(**node.config.get("config", node.config))
|
||||
matched_handle = None
|
||||
|
||||
if not key:
|
||||
continue
|
||||
if operation == "set":
|
||||
formatted_value = (
|
||||
self._format_template(str(value))
|
||||
if isinstance(value, str)
|
||||
else value
|
||||
)
|
||||
self.state[key] = formatted_value
|
||||
elif operation == "increment":
|
||||
current = self.state.get(key, 0)
|
||||
try:
|
||||
self.state[key] = int(current) + int(value or 1)
|
||||
except (ValueError, TypeError):
|
||||
self.state[key] = 1
|
||||
elif operation == "append":
|
||||
if key not in self.state:
|
||||
self.state[key] = []
|
||||
if isinstance(self.state[key], list):
|
||||
self.state[key].append(value)
|
||||
else:
|
||||
updates = config.get("updates", {})
|
||||
if not updates:
|
||||
var_name = config.get("variable")
|
||||
var_value = config.get("value")
|
||||
if var_name and isinstance(var_name, str):
|
||||
updates = {var_name: var_value or ""}
|
||||
if isinstance(updates, dict):
|
||||
for key, value in updates.items():
|
||||
if isinstance(value, str):
|
||||
self.state[key] = self._format_template(value)
|
||||
else:
|
||||
self.state[key] = value
|
||||
for case in config.cases:
|
||||
if not case.expression.strip():
|
||||
continue
|
||||
try:
|
||||
if evaluate_cel(case.expression, self.state):
|
||||
matched_handle = case.source_handle
|
||||
break
|
||||
except CelEvaluationError:
|
||||
continue
|
||||
|
||||
self._condition_result = matched_handle or "else"
|
||||
yield from ()
|
||||
|
||||
def _execute_end_node(
|
||||
self, node: WorkflowNode
|
||||
) -> Generator[Dict[str, str], None, None]:
|
||||
config = node.config
|
||||
config = node.config.get("config", node.config)
|
||||
output_template = str(config.get("output_template", ""))
|
||||
if output_template:
|
||||
formatted_output = self._format_template(output_template)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Workflow management routes."""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import Dict, List
|
||||
from typing import Dict, List, Set
|
||||
|
||||
from flask import current_app, request
|
||||
from flask_restx import Namespace, Resource
|
||||
@@ -102,6 +102,9 @@ def validate_workflow_structure(nodes: List[Dict], edges: List[Dict]) -> List[st
|
||||
errors.append("Workflow must have at least one end node")
|
||||
|
||||
node_ids = {n.get("id") for n in nodes}
|
||||
node_map = {n.get("id"): n for n in nodes}
|
||||
end_ids = {n.get("id") for n in end_nodes}
|
||||
|
||||
for edge in edges:
|
||||
source_id = edge.get("source")
|
||||
target_id = edge.get("target")
|
||||
@@ -115,6 +118,104 @@ def validate_workflow_structure(nodes: List[Dict], edges: List[Dict]) -> List[st
|
||||
if not any(e.get("source") == start_id for e in edges):
|
||||
errors.append("Start node must have at least one outgoing edge")
|
||||
|
||||
condition_nodes = [n for n in nodes if n.get("type") == "condition"]
|
||||
for cnode in condition_nodes:
|
||||
cnode_id = cnode.get("id")
|
||||
cnode_title = cnode.get("title", cnode_id)
|
||||
outgoing = [e for e in edges if e.get("source") == cnode_id]
|
||||
if len(outgoing) < 2:
|
||||
errors.append(
|
||||
f"Condition node '{cnode_title}' must have at least 2 outgoing edges"
|
||||
)
|
||||
node_data = cnode.get("data", {}) or {}
|
||||
cases = node_data.get("cases", [])
|
||||
if not isinstance(cases, list):
|
||||
cases = []
|
||||
if not cases or not any(
|
||||
isinstance(c, dict) and str(c.get("expression", "")).strip() for c in cases
|
||||
):
|
||||
errors.append(
|
||||
f"Condition node '{cnode_title}' must have at least one case with an expression"
|
||||
)
|
||||
|
||||
case_handles: Set[str] = set()
|
||||
duplicate_case_handles: Set[str] = set()
|
||||
for case in cases:
|
||||
if not isinstance(case, dict):
|
||||
continue
|
||||
raw_handle = case.get("sourceHandle", "")
|
||||
handle = raw_handle.strip() if isinstance(raw_handle, str) else ""
|
||||
if not handle:
|
||||
errors.append(
|
||||
f"Condition node '{cnode_title}' has a case without a branch handle"
|
||||
)
|
||||
continue
|
||||
if handle in case_handles:
|
||||
duplicate_case_handles.add(handle)
|
||||
case_handles.add(handle)
|
||||
|
||||
for handle in duplicate_case_handles:
|
||||
errors.append(
|
||||
f"Condition node '{cnode_title}' has duplicate case handle '{handle}'"
|
||||
)
|
||||
|
||||
outgoing_by_handle: Dict[str, List[Dict]] = {}
|
||||
for out_edge in outgoing:
|
||||
raw_handle = out_edge.get("sourceHandle", "")
|
||||
handle = raw_handle.strip() if isinstance(raw_handle, str) else ""
|
||||
outgoing_by_handle.setdefault(handle, []).append(out_edge)
|
||||
|
||||
for handle, handle_edges in outgoing_by_handle.items():
|
||||
if not handle:
|
||||
errors.append(
|
||||
f"Condition node '{cnode_title}' has an outgoing edge without sourceHandle"
|
||||
)
|
||||
continue
|
||||
if handle != "else" and handle not in case_handles:
|
||||
errors.append(
|
||||
f"Condition node '{cnode_title}' has a connection from unknown branch '{handle}'"
|
||||
)
|
||||
if len(handle_edges) > 1:
|
||||
errors.append(
|
||||
f"Condition node '{cnode_title}' has multiple outgoing edges from branch '{handle}'"
|
||||
)
|
||||
|
||||
if "else" not in outgoing_by_handle:
|
||||
errors.append(f"Condition node '{cnode_title}' must have an 'else' branch")
|
||||
|
||||
for case in cases:
|
||||
if not isinstance(case, dict):
|
||||
continue
|
||||
raw_handle = case.get("sourceHandle", "")
|
||||
handle = raw_handle.strip() if isinstance(raw_handle, str) else ""
|
||||
if not handle:
|
||||
continue
|
||||
|
||||
raw_expression = case.get("expression", "")
|
||||
has_expression = isinstance(raw_expression, str) and bool(
|
||||
raw_expression.strip()
|
||||
)
|
||||
has_outgoing = bool(outgoing_by_handle.get(handle))
|
||||
if has_expression and not has_outgoing:
|
||||
errors.append(
|
||||
f"Condition node '{cnode_title}' case '{handle}' has an expression but no outgoing edge"
|
||||
)
|
||||
if not has_expression and has_outgoing:
|
||||
errors.append(
|
||||
f"Condition node '{cnode_title}' case '{handle}' has an outgoing edge but no expression"
|
||||
)
|
||||
|
||||
for handle, handle_edges in outgoing_by_handle.items():
|
||||
if not handle:
|
||||
continue
|
||||
for out_edge in handle_edges:
|
||||
target = out_edge.get("target")
|
||||
if target and not _can_reach_end(target, edges, node_map, end_ids):
|
||||
errors.append(
|
||||
f"Branch '{handle}' of condition '{cnode_title}' "
|
||||
f"must eventually reach an end node"
|
||||
)
|
||||
|
||||
for node in nodes:
|
||||
if not node.get("id"):
|
||||
errors.append("All nodes must have an id")
|
||||
@@ -124,6 +225,20 @@ def validate_workflow_structure(nodes: List[Dict], edges: List[Dict]) -> List[st
|
||||
return errors
|
||||
|
||||
|
||||
def _can_reach_end(
|
||||
node_id: str, edges: List[Dict], node_map: Dict, end_ids: set, visited: set = None
|
||||
) -> bool:
|
||||
if visited is None:
|
||||
visited = set()
|
||||
if node_id in end_ids:
|
||||
return True
|
||||
if node_id in visited or node_id not in node_map:
|
||||
return False
|
||||
visited.add(node_id)
|
||||
outgoing = [e.get("target") for e in edges if e.get("source") == node_id]
|
||||
return any(_can_reach_end(t, edges, node_map, end_ids, visited) for t in outgoing if t)
|
||||
|
||||
|
||||
def create_workflow_nodes(
|
||||
workflow_id: str, nodes_data: List[Dict], graph_version: int
|
||||
) -> None:
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
anthropic==0.75.0
|
||||
boto3==1.42.17
|
||||
beautifulsoup4==4.14.3
|
||||
cel-python==0.5.0
|
||||
celery==5.6.0
|
||||
cryptography==46.0.3
|
||||
dataclasses-json==0.6.7
|
||||
|
||||
Reference in New Issue
Block a user