mirror of
https://github.com/arc53/DocsGPT.git
synced 2026-02-22 04:11:41 +00:00
* feat: add condition node functionality with CEL evaluation - Introduced ConditionNode to support conditional branching in workflows. - Implemented CEL evaluation for state updates and condition expressions. - Updated WorkflowEngine to handle condition nodes and their execution logic. - Enhanced validation for workflows to ensure condition nodes have at least two outgoing edges and valid expressions. - Modified frontend components to support new condition node type and its configuration. - Added necessary types and interfaces for condition cases and state operations. - Updated requirements to include cel-python for expression evaluation. * mini-fixes * feat(workflow): improve UX --------- Co-authored-by: Alex <a@tushynski.me>
278 lines
10 KiB
Python
278 lines
10 KiB
Python
import logging
|
|
from datetime import datetime, timezone
|
|
from typing import Any, Dict, Generator, List, Optional, TYPE_CHECKING
|
|
|
|
from application.agents.workflows.cel_evaluator import CelEvaluationError, evaluate_cel
|
|
from application.agents.workflows.node_agent import WorkflowNodeAgentFactory
|
|
from application.agents.workflows.schemas import (
|
|
AgentNodeConfig,
|
|
ConditionNodeConfig,
|
|
ExecutionStatus,
|
|
NodeExecutionLog,
|
|
NodeType,
|
|
WorkflowGraph,
|
|
WorkflowNode,
|
|
)
|
|
|
|
if TYPE_CHECKING:
|
|
from application.agents.base import BaseAgent
|
|
logger = logging.getLogger(__name__)
|
|
|
|
StateValue = Any
|
|
WorkflowState = Dict[str, StateValue]
|
|
|
|
|
|
class WorkflowEngine:
|
|
MAX_EXECUTION_STEPS = 50
|
|
|
|
def __init__(self, graph: WorkflowGraph, agent: "BaseAgent"):
|
|
self.graph = graph
|
|
self.agent = agent
|
|
self.state: WorkflowState = {}
|
|
self.execution_log: List[Dict[str, Any]] = []
|
|
self._condition_result: Optional[str] = None
|
|
|
|
def execute(
|
|
self, initial_inputs: WorkflowState, query: str
|
|
) -> Generator[Dict[str, str], None, None]:
|
|
self._initialize_state(initial_inputs, query)
|
|
|
|
start_node = self.graph.get_start_node()
|
|
if not start_node:
|
|
yield {"type": "error", "error": "No start node found in workflow."}
|
|
return
|
|
current_node_id: Optional[str] = start_node.id
|
|
steps = 0
|
|
|
|
while current_node_id and steps < self.MAX_EXECUTION_STEPS:
|
|
node = self.graph.get_node_by_id(current_node_id)
|
|
if not node:
|
|
yield {"type": "error", "error": f"Node {current_node_id} not found."}
|
|
break
|
|
log_entry = self._create_log_entry(node)
|
|
|
|
yield {
|
|
"type": "workflow_step",
|
|
"node_id": node.id,
|
|
"node_type": node.type.value,
|
|
"node_title": node.title,
|
|
"status": "running",
|
|
}
|
|
|
|
try:
|
|
yield from self._execute_node(node)
|
|
log_entry["status"] = ExecutionStatus.COMPLETED.value
|
|
log_entry["completed_at"] = datetime.now(timezone.utc)
|
|
|
|
output_key = f"node_{node.id}_output"
|
|
node_output = self.state.get(output_key)
|
|
|
|
yield {
|
|
"type": "workflow_step",
|
|
"node_id": node.id,
|
|
"node_type": node.type.value,
|
|
"node_title": node.title,
|
|
"status": "completed",
|
|
"state_snapshot": dict(self.state),
|
|
"output": node_output,
|
|
}
|
|
except Exception as e:
|
|
logger.error(f"Error executing node {node.id}: {e}", exc_info=True)
|
|
log_entry["status"] = ExecutionStatus.FAILED.value
|
|
log_entry["error"] = str(e)
|
|
log_entry["completed_at"] = datetime.now(timezone.utc)
|
|
log_entry["state_snapshot"] = dict(self.state)
|
|
self.execution_log.append(log_entry)
|
|
|
|
yield {
|
|
"type": "workflow_step",
|
|
"node_id": node.id,
|
|
"node_type": node.type.value,
|
|
"node_title": node.title,
|
|
"status": "failed",
|
|
"state_snapshot": dict(self.state),
|
|
"error": str(e),
|
|
}
|
|
yield {"type": "error", "error": str(e)}
|
|
break
|
|
log_entry["state_snapshot"] = dict(self.state)
|
|
self.execution_log.append(log_entry)
|
|
|
|
if node.type == NodeType.END:
|
|
break
|
|
current_node_id = self._get_next_node_id(current_node_id)
|
|
if current_node_id is None and node.type != NodeType.END:
|
|
logger.warning(
|
|
f"Branch ended at node '{node.title}' ({node.id}) without reaching an end node"
|
|
)
|
|
steps += 1
|
|
if steps >= self.MAX_EXECUTION_STEPS:
|
|
logger.warning(
|
|
f"Workflow reached max steps limit ({self.MAX_EXECUTION_STEPS})"
|
|
)
|
|
|
|
def _initialize_state(self, initial_inputs: WorkflowState, query: str) -> None:
|
|
self.state.update(initial_inputs)
|
|
self.state["query"] = query
|
|
self.state["chat_history"] = str(self.agent.chat_history)
|
|
|
|
def _create_log_entry(self, node: WorkflowNode) -> Dict[str, Any]:
|
|
return {
|
|
"node_id": node.id,
|
|
"node_type": node.type.value,
|
|
"started_at": datetime.now(timezone.utc),
|
|
"completed_at": None,
|
|
"status": ExecutionStatus.RUNNING.value,
|
|
"error": None,
|
|
"state_snapshot": {},
|
|
}
|
|
|
|
def _get_next_node_id(self, current_node_id: str) -> Optional[str]:
|
|
node = self.graph.get_node_by_id(current_node_id)
|
|
edges = self.graph.get_outgoing_edges(current_node_id)
|
|
if not edges:
|
|
return None
|
|
|
|
if node and node.type == NodeType.CONDITION and self._condition_result:
|
|
target_handle = self._condition_result
|
|
self._condition_result = None
|
|
for edge in edges:
|
|
if edge.source_handle == target_handle:
|
|
return edge.target_id
|
|
return None
|
|
|
|
return edges[0].target_id
|
|
|
|
def _execute_node(
|
|
self, node: WorkflowNode
|
|
) -> Generator[Dict[str, str], None, None]:
|
|
logger.info(f"Executing node {node.id} ({node.type.value})")
|
|
|
|
node_handlers = {
|
|
NodeType.START: self._execute_start_node,
|
|
NodeType.NOTE: self._execute_note_node,
|
|
NodeType.AGENT: self._execute_agent_node,
|
|
NodeType.STATE: self._execute_state_node,
|
|
NodeType.CONDITION: self._execute_condition_node,
|
|
NodeType.END: self._execute_end_node,
|
|
}
|
|
|
|
handler = node_handlers.get(node.type)
|
|
if handler:
|
|
yield from handler(node)
|
|
|
|
def _execute_start_node(
|
|
self, node: WorkflowNode
|
|
) -> Generator[Dict[str, str], None, None]:
|
|
yield from ()
|
|
|
|
def _execute_note_node(
|
|
self, node: WorkflowNode
|
|
) -> Generator[Dict[str, str], None, None]:
|
|
yield from ()
|
|
|
|
def _execute_agent_node(
|
|
self, node: WorkflowNode
|
|
) -> Generator[Dict[str, str], None, None]:
|
|
from application.core.model_utils import get_api_key_for_provider
|
|
|
|
node_config = AgentNodeConfig(**node.config.get("config", node.config))
|
|
|
|
if node_config.prompt_template:
|
|
formatted_prompt = self._format_template(node_config.prompt_template)
|
|
else:
|
|
formatted_prompt = self.state.get("query", "")
|
|
node_llm_name = node_config.llm_name or self.agent.llm_name
|
|
node_api_key = get_api_key_for_provider(node_llm_name) or self.agent.api_key
|
|
|
|
node_agent = WorkflowNodeAgentFactory.create(
|
|
agent_type=node_config.agent_type,
|
|
endpoint=self.agent.endpoint,
|
|
llm_name=node_llm_name,
|
|
model_id=node_config.model_id or self.agent.model_id,
|
|
api_key=node_api_key,
|
|
tool_ids=node_config.tools,
|
|
prompt=node_config.system_prompt,
|
|
chat_history=self.agent.chat_history,
|
|
decoded_token=self.agent.decoded_token,
|
|
json_schema=node_config.json_schema,
|
|
)
|
|
|
|
full_response = ""
|
|
first_chunk = True
|
|
for event in node_agent.gen(formatted_prompt):
|
|
if "answer" in event:
|
|
full_response += event["answer"]
|
|
if node_config.stream_to_user:
|
|
if first_chunk and hasattr(self, "_has_streamed"):
|
|
yield {"answer": "\n\n"}
|
|
first_chunk = False
|
|
yield event
|
|
|
|
if node_config.stream_to_user:
|
|
self._has_streamed = True
|
|
|
|
output_key = node_config.output_variable or f"node_{node.id}_output"
|
|
self.state[output_key] = full_response.strip()
|
|
|
|
def _execute_state_node(
|
|
self, node: WorkflowNode
|
|
) -> Generator[Dict[str, str], None, None]:
|
|
config = node.config.get("config", node.config)
|
|
for op in config.get("operations", []):
|
|
expression = op.get("expression", "")
|
|
target_variable = op.get("target_variable", "")
|
|
if expression and target_variable:
|
|
self.state[target_variable] = evaluate_cel(expression, self.state)
|
|
yield from ()
|
|
|
|
def _execute_condition_node(
|
|
self, node: WorkflowNode
|
|
) -> Generator[Dict[str, str], None, None]:
|
|
config = ConditionNodeConfig(**node.config.get("config", node.config))
|
|
matched_handle = None
|
|
|
|
for case in config.cases:
|
|
if not case.expression.strip():
|
|
continue
|
|
try:
|
|
if evaluate_cel(case.expression, self.state):
|
|
matched_handle = case.source_handle
|
|
break
|
|
except CelEvaluationError:
|
|
continue
|
|
|
|
self._condition_result = matched_handle or "else"
|
|
yield from ()
|
|
|
|
def _execute_end_node(
|
|
self, node: WorkflowNode
|
|
) -> Generator[Dict[str, str], None, None]:
|
|
config = node.config.get("config", node.config)
|
|
output_template = str(config.get("output_template", ""))
|
|
if output_template:
|
|
formatted_output = self._format_template(output_template)
|
|
yield {"answer": formatted_output}
|
|
|
|
def _format_template(self, template: str) -> str:
|
|
formatted = template
|
|
for key, value in self.state.items():
|
|
placeholder = f"{{{{{key}}}}}"
|
|
if placeholder in formatted and value is not None:
|
|
formatted = formatted.replace(placeholder, str(value))
|
|
return formatted
|
|
|
|
def get_execution_summary(self) -> List[NodeExecutionLog]:
|
|
return [
|
|
NodeExecutionLog(
|
|
node_id=log["node_id"],
|
|
node_type=log["node_type"],
|
|
status=ExecutionStatus(log["status"]),
|
|
started_at=log["started_at"],
|
|
completed_at=log.get("completed_at"),
|
|
error=log.get("error"),
|
|
state_snapshot=log.get("state_snapshot", {}),
|
|
)
|
|
for log in self.execution_log
|
|
]
|