Files
DocsGPT/application/agents/workflows/workflow_engine.py
2026-02-22 11:10:42 +00:00

454 lines
17 KiB
Python

import json
import logging
from datetime import datetime, timezone
from typing import Any, Dict, Generator, List, Optional, TYPE_CHECKING
from application.agents.workflows.cel_evaluator import CelEvaluationError, evaluate_cel
from application.agents.workflows.node_agent import WorkflowNodeAgentFactory
from application.agents.workflows.schemas import (
AgentNodeConfig,
ConditionNodeConfig,
ExecutionStatus,
NodeExecutionLog,
NodeType,
WorkflowGraph,
WorkflowNode,
)
from application.core.json_schema_utils import (
JsonSchemaValidationError,
normalize_json_schema_payload,
)
from application.templates.namespaces import NamespaceManager
from application.templates.template_engine import TemplateEngine, TemplateRenderError
try:
import jsonschema
except ImportError: # pragma: no cover - optional dependency in some deployments.
jsonschema = None
if TYPE_CHECKING:
from application.agents.base import BaseAgent
logger = logging.getLogger(__name__)
StateValue = Any
WorkflowState = Dict[str, StateValue]
TEMPLATE_RESERVED_NAMESPACES = {"agent", "system", "source", "tools", "passthrough"}
class WorkflowEngine:
MAX_EXECUTION_STEPS = 50
def __init__(self, graph: WorkflowGraph, agent: "BaseAgent"):
self.graph = graph
self.agent = agent
self.state: WorkflowState = {}
self.execution_log: List[Dict[str, Any]] = []
self._condition_result: Optional[str] = None
self._template_engine = TemplateEngine()
self._namespace_manager = NamespaceManager()
def execute(
self, initial_inputs: WorkflowState, query: str
) -> Generator[Dict[str, str], None, None]:
self._initialize_state(initial_inputs, query)
start_node = self.graph.get_start_node()
if not start_node:
yield {"type": "error", "error": "No start node found in workflow."}
return
current_node_id: Optional[str] = start_node.id
steps = 0
while current_node_id and steps < self.MAX_EXECUTION_STEPS:
node = self.graph.get_node_by_id(current_node_id)
if not node:
yield {"type": "error", "error": f"Node {current_node_id} not found."}
break
log_entry = self._create_log_entry(node)
yield {
"type": "workflow_step",
"node_id": node.id,
"node_type": node.type.value,
"node_title": node.title,
"status": "running",
}
try:
yield from self._execute_node(node)
log_entry["status"] = ExecutionStatus.COMPLETED.value
log_entry["completed_at"] = datetime.now(timezone.utc)
output_key = f"node_{node.id}_output"
node_output = self.state.get(output_key)
yield {
"type": "workflow_step",
"node_id": node.id,
"node_type": node.type.value,
"node_title": node.title,
"status": "completed",
"state_snapshot": dict(self.state),
"output": node_output,
}
except Exception as e:
logger.error(f"Error executing node {node.id}: {e}", exc_info=True)
log_entry["status"] = ExecutionStatus.FAILED.value
log_entry["error"] = str(e)
log_entry["completed_at"] = datetime.now(timezone.utc)
log_entry["state_snapshot"] = dict(self.state)
self.execution_log.append(log_entry)
yield {
"type": "workflow_step",
"node_id": node.id,
"node_type": node.type.value,
"node_title": node.title,
"status": "failed",
"state_snapshot": dict(self.state),
"error": str(e),
}
yield {"type": "error", "error": str(e)}
break
log_entry["state_snapshot"] = dict(self.state)
self.execution_log.append(log_entry)
if node.type == NodeType.END:
break
current_node_id = self._get_next_node_id(current_node_id)
if current_node_id is None and node.type != NodeType.END:
logger.warning(
f"Branch ended at node '{node.title}' ({node.id}) without reaching an end node"
)
steps += 1
if steps >= self.MAX_EXECUTION_STEPS:
logger.warning(
f"Workflow reached max steps limit ({self.MAX_EXECUTION_STEPS})"
)
def _initialize_state(self, initial_inputs: WorkflowState, query: str) -> None:
self.state.update(initial_inputs)
self.state["query"] = query
self.state["chat_history"] = str(self.agent.chat_history)
def _create_log_entry(self, node: WorkflowNode) -> Dict[str, Any]:
return {
"node_id": node.id,
"node_type": node.type.value,
"started_at": datetime.now(timezone.utc),
"completed_at": None,
"status": ExecutionStatus.RUNNING.value,
"error": None,
"state_snapshot": {},
}
def _get_next_node_id(self, current_node_id: str) -> Optional[str]:
node = self.graph.get_node_by_id(current_node_id)
edges = self.graph.get_outgoing_edges(current_node_id)
if not edges:
return None
if node and node.type == NodeType.CONDITION and self._condition_result:
target_handle = self._condition_result
self._condition_result = None
for edge in edges:
if edge.source_handle == target_handle:
return edge.target_id
return None
return edges[0].target_id
def _execute_node(
self, node: WorkflowNode
) -> Generator[Dict[str, str], None, None]:
logger.info(f"Executing node {node.id} ({node.type.value})")
node_handlers = {
NodeType.START: self._execute_start_node,
NodeType.NOTE: self._execute_note_node,
NodeType.AGENT: self._execute_agent_node,
NodeType.STATE: self._execute_state_node,
NodeType.CONDITION: self._execute_condition_node,
NodeType.END: self._execute_end_node,
}
handler = node_handlers.get(node.type)
if handler:
yield from handler(node)
def _execute_start_node(
self, node: WorkflowNode
) -> Generator[Dict[str, str], None, None]:
yield from ()
def _execute_note_node(
self, node: WorkflowNode
) -> Generator[Dict[str, str], None, None]:
yield from ()
def _execute_agent_node(
self, node: WorkflowNode
) -> Generator[Dict[str, str], None, None]:
from application.core.model_utils import (
get_api_key_for_provider,
get_model_capabilities,
get_provider_from_model_id,
)
node_config = AgentNodeConfig(**node.config.get("config", node.config))
if node_config.prompt_template:
formatted_prompt = self._format_template(node_config.prompt_template)
else:
formatted_prompt = self.state.get("query", "")
node_json_schema = self._normalize_node_json_schema(
node_config.json_schema, node.title
)
node_model_id = node_config.model_id or self.agent.model_id
node_llm_name = (
node_config.llm_name
or get_provider_from_model_id(node_model_id or "")
or self.agent.llm_name
)
node_api_key = get_api_key_for_provider(node_llm_name) or self.agent.api_key
if node_json_schema and node_model_id:
model_capabilities = get_model_capabilities(node_model_id)
if model_capabilities and not model_capabilities.get(
"supports_structured_output", False
):
raise ValueError(
f'Model "{node_model_id}" does not support structured output for node "{node.title}"'
)
node_agent = WorkflowNodeAgentFactory.create(
agent_type=node_config.agent_type,
endpoint=self.agent.endpoint,
llm_name=node_llm_name,
model_id=node_model_id,
api_key=node_api_key,
tool_ids=node_config.tools,
prompt=node_config.system_prompt,
chat_history=self.agent.chat_history,
decoded_token=self.agent.decoded_token,
json_schema=node_json_schema,
)
full_response_parts: List[str] = []
structured_response_parts: List[str] = []
has_structured_response = False
first_chunk = True
for event in node_agent.gen(formatted_prompt):
if "answer" in event:
chunk = str(event["answer"])
full_response_parts.append(chunk)
if event.get("structured"):
has_structured_response = True
structured_response_parts.append(chunk)
if node_config.stream_to_user:
if first_chunk and hasattr(self, "_has_streamed"):
yield {"answer": "\n\n"}
first_chunk = False
yield event
if node_config.stream_to_user:
self._has_streamed = True
full_response = "".join(full_response_parts).strip()
output_value: Any = full_response
if has_structured_response:
structured_response = "".join(structured_response_parts).strip()
response_to_parse = structured_response or full_response
parsed_success, parsed_structured = self._parse_structured_output(
response_to_parse
)
output_value = parsed_structured if parsed_success else response_to_parse
if node_json_schema:
self._validate_structured_output(node_json_schema, output_value)
elif node_json_schema:
parsed_success, parsed_structured = self._parse_structured_output(
full_response
)
if not parsed_success:
raise ValueError(
"Structured output was expected but response was not valid JSON"
)
output_value = parsed_structured
self._validate_structured_output(node_json_schema, output_value)
default_output_key = f"node_{node.id}_output"
self.state[default_output_key] = output_value
if node_config.output_variable:
self.state[node_config.output_variable] = output_value
def _execute_state_node(
self, node: WorkflowNode
) -> Generator[Dict[str, str], None, None]:
config = node.config.get("config", node.config)
for op in config.get("operations", []):
expression = op.get("expression", "")
target_variable = op.get("target_variable", "")
if expression and target_variable:
self.state[target_variable] = evaluate_cel(expression, self.state)
yield from ()
def _execute_condition_node(
self, node: WorkflowNode
) -> Generator[Dict[str, str], None, None]:
config = ConditionNodeConfig(**node.config.get("config", node.config))
matched_handle = None
for case in config.cases:
if not case.expression.strip():
continue
try:
if evaluate_cel(case.expression, self.state):
matched_handle = case.source_handle
break
except CelEvaluationError:
continue
self._condition_result = matched_handle or "else"
yield from ()
def _execute_end_node(
self, node: WorkflowNode
) -> Generator[Dict[str, str], None, None]:
config = node.config.get("config", node.config)
output_template = str(config.get("output_template", ""))
if output_template:
formatted_output = self._format_template(output_template)
yield {"answer": formatted_output}
def _parse_structured_output(self, raw_response: str) -> tuple[bool, Optional[Any]]:
normalized_response = raw_response.strip()
if not normalized_response:
return False, None
try:
return True, json.loads(normalized_response)
except json.JSONDecodeError:
logger.warning(
"Workflow agent returned structured output that was not valid JSON"
)
return False, None
def _normalize_node_json_schema(
self, schema: Optional[Dict[str, Any]], node_title: str
) -> Optional[Dict[str, Any]]:
if schema is None:
return None
try:
return normalize_json_schema_payload(schema)
except JsonSchemaValidationError as exc:
raise ValueError(
f'Invalid JSON schema for node "{node_title}": {exc}'
) from exc
def _validate_structured_output(self, schema: Dict[str, Any], output_value: Any) -> None:
if jsonschema is None:
logger.warning(
"jsonschema package is not available, skipping structured output validation"
)
return
try:
normalized_schema = normalize_json_schema_payload(schema)
except JsonSchemaValidationError as exc:
raise ValueError(f"Invalid JSON schema: {exc}") from exc
try:
jsonschema.validate(instance=output_value, schema=normalized_schema)
except jsonschema.exceptions.ValidationError as exc:
raise ValueError(f"Structured output did not match schema: {exc.message}") from exc
except jsonschema.exceptions.SchemaError as exc:
raise ValueError(f"Invalid JSON schema: {exc.message}") from exc
def _format_template(self, template: str) -> str:
context = self._build_template_context()
try:
return self._template_engine.render(template, context)
except TemplateRenderError as e:
logger.warning(
"Workflow template rendering failed, using raw template: %s", str(e)
)
return template
def _build_template_context(self) -> Dict[str, Any]:
docs, docs_together = self._get_source_template_data()
passthrough_data = (
self.state.get("passthrough")
if isinstance(self.state.get("passthrough"), dict)
else None
)
tools_data = (
self.state.get("tools") if isinstance(self.state.get("tools"), dict) else None
)
context = self._namespace_manager.build_context(
user_id=getattr(self.agent, "user", None),
request_id=getattr(self.agent, "request_id", None),
passthrough_data=passthrough_data,
docs=docs,
docs_together=docs_together,
tools_data=tools_data,
)
agent_context: Dict[str, Any] = {}
for key, value in self.state.items():
if not isinstance(key, str):
continue
normalized_key = key.strip()
if not normalized_key:
continue
agent_context[normalized_key] = value
context["agent"] = agent_context
# Keep legacy top-level variables working while namespaced variables are adopted.
for key, value in agent_context.items():
if key in TEMPLATE_RESERVED_NAMESPACES:
context[f"agent_{key}"] = value
continue
if key not in context:
context[key] = value
return context
def _get_source_template_data(self) -> tuple[Optional[List[Dict[str, Any]]], Optional[str]]:
docs = getattr(self.agent, "retrieved_docs", None)
if not isinstance(docs, list) or len(docs) == 0:
return None, None
docs_together_parts: List[str] = []
for doc in docs:
if not isinstance(doc, dict):
continue
text = doc.get("text")
if not isinstance(text, str):
continue
filename = doc.get("filename") or doc.get("title") or doc.get("source")
if isinstance(filename, str) and filename.strip():
docs_together_parts.append(f"{filename}\n{text}")
else:
docs_together_parts.append(text)
docs_together = "\n\n".join(docs_together_parts) if docs_together_parts else None
return docs, docs_together
def get_execution_summary(self) -> List[NodeExecutionLog]:
return [
NodeExecutionLog(
node_id=log["node_id"],
node_type=log["node_type"],
status=ExecutionStatus(log["status"]),
started_at=log["started_at"],
completed_at=log.get("completed_at"),
error=log.get("error"),
state_snapshot=log.get("state_snapshot", {}),
)
for log in self.execution_log
]