fix: mini workflow fixes

This commit is contained in:
Alex
2026-02-22 11:10:42 +00:00
parent 1a2104f474
commit a6625ec5de
14 changed files with 1261 additions and 136 deletions

View File

@@ -7,6 +7,10 @@ from bson.objectid import ObjectId
from application.agents.tools.tool_action_parser import ToolActionParser
from application.agents.tools.tool_manager import ToolManager
from application.core.json_schema_utils import (
JsonSchemaValidationError,
normalize_json_schema_payload,
)
from application.core.mongo_db import MongoDB
from application.core.settings import settings
from application.llm.handlers.handler_creator import LLMHandlerCreator
@@ -63,7 +67,12 @@ class BaseAgent(ABC):
llm_name if llm_name else "default"
)
self.attachments = attachments or []
self.json_schema = json_schema
self.json_schema = None
if json_schema is not None:
try:
self.json_schema = normalize_json_schema_payload(json_schema)
except JsonSchemaValidationError as exc:
logger.warning("Ignoring invalid JSON schema payload: %s", exc)
self.limited_token_mode = limited_token_mode
self.token_limit = token_limit
self.limited_request_mode = limited_request_mode

View File

@@ -211,8 +211,21 @@ class WorkflowAgent(BaseAgent):
def _serialize_state(self, state: Dict[str, Any]) -> Dict[str, Any]:
serialized: Dict[str, Any] = {}
for key, value in state.items():
if isinstance(value, (str, int, float, bool, type(None))):
serialized[key] = value
else:
serialized[key] = str(value)
serialized[key] = self._serialize_state_value(value)
return serialized
def _serialize_state_value(self, value: Any) -> Any:
if isinstance(value, dict):
return {
str(dict_key): self._serialize_state_value(dict_value)
for dict_key, dict_value in value.items()
}
if isinstance(value, list):
return [self._serialize_state_value(item) for item in value]
if isinstance(value, tuple):
return [self._serialize_state_value(item) for item in value]
if isinstance(value, datetime):
return value.isoformat()
if isinstance(value, (str, int, float, bool, type(None))):
return value
return str(value)

View File

@@ -1,3 +1,4 @@
import json
import logging
from datetime import datetime, timezone
from typing import Any, Dict, Generator, List, Optional, TYPE_CHECKING
@@ -13,6 +14,17 @@ from application.agents.workflows.schemas import (
WorkflowGraph,
WorkflowNode,
)
from application.core.json_schema_utils import (
JsonSchemaValidationError,
normalize_json_schema_payload,
)
from application.templates.namespaces import NamespaceManager
from application.templates.template_engine import TemplateEngine, TemplateRenderError
try:
import jsonschema
except ImportError: # pragma: no cover - optional dependency in some deployments.
jsonschema = None
if TYPE_CHECKING:
from application.agents.base import BaseAgent
@@ -20,6 +32,7 @@ logger = logging.getLogger(__name__)
StateValue = Any
WorkflowState = Dict[str, StateValue]
TEMPLATE_RESERVED_NAMESPACES = {"agent", "system", "source", "tools", "passthrough"}
class WorkflowEngine:
@@ -31,6 +44,8 @@ class WorkflowEngine:
self.state: WorkflowState = {}
self.execution_log: List[Dict[str, Any]] = []
self._condition_result: Optional[str] = None
self._template_engine = TemplateEngine()
self._namespace_manager = NamespaceManager()
def execute(
self, initial_inputs: WorkflowState, query: str
@@ -174,7 +189,11 @@ class WorkflowEngine:
def _execute_agent_node(
self, node: WorkflowNode
) -> Generator[Dict[str, str], None, None]:
from application.core.model_utils import get_api_key_for_provider
from application.core.model_utils import (
get_api_key_for_provider,
get_model_capabilities,
get_provider_from_model_id,
)
node_config = AgentNodeConfig(**node.config.get("config", node.config))
@@ -182,27 +201,50 @@ class WorkflowEngine:
formatted_prompt = self._format_template(node_config.prompt_template)
else:
formatted_prompt = self.state.get("query", "")
node_llm_name = node_config.llm_name or self.agent.llm_name
node_json_schema = self._normalize_node_json_schema(
node_config.json_schema, node.title
)
node_model_id = node_config.model_id or self.agent.model_id
node_llm_name = (
node_config.llm_name
or get_provider_from_model_id(node_model_id or "")
or self.agent.llm_name
)
node_api_key = get_api_key_for_provider(node_llm_name) or self.agent.api_key
if node_json_schema and node_model_id:
model_capabilities = get_model_capabilities(node_model_id)
if model_capabilities and not model_capabilities.get(
"supports_structured_output", False
):
raise ValueError(
f'Model "{node_model_id}" does not support structured output for node "{node.title}"'
)
node_agent = WorkflowNodeAgentFactory.create(
agent_type=node_config.agent_type,
endpoint=self.agent.endpoint,
llm_name=node_llm_name,
model_id=node_config.model_id or self.agent.model_id,
model_id=node_model_id,
api_key=node_api_key,
tool_ids=node_config.tools,
prompt=node_config.system_prompt,
chat_history=self.agent.chat_history,
decoded_token=self.agent.decoded_token,
json_schema=node_config.json_schema,
json_schema=node_json_schema,
)
full_response = ""
full_response_parts: List[str] = []
structured_response_parts: List[str] = []
has_structured_response = False
first_chunk = True
for event in node_agent.gen(formatted_prompt):
if "answer" in event:
full_response += event["answer"]
chunk = str(event["answer"])
full_response_parts.append(chunk)
if event.get("structured"):
has_structured_response = True
structured_response_parts.append(chunk)
if node_config.stream_to_user:
if first_chunk and hasattr(self, "_has_streamed"):
yield {"answer": "\n\n"}
@@ -212,8 +254,33 @@ class WorkflowEngine:
if node_config.stream_to_user:
self._has_streamed = True
output_key = node_config.output_variable or f"node_{node.id}_output"
self.state[output_key] = full_response.strip()
full_response = "".join(full_response_parts).strip()
output_value: Any = full_response
if has_structured_response:
structured_response = "".join(structured_response_parts).strip()
response_to_parse = structured_response or full_response
parsed_success, parsed_structured = self._parse_structured_output(
response_to_parse
)
output_value = parsed_structured if parsed_success else response_to_parse
if node_json_schema:
self._validate_structured_output(node_json_schema, output_value)
elif node_json_schema:
parsed_success, parsed_structured = self._parse_structured_output(
full_response
)
if not parsed_success:
raise ValueError(
"Structured output was expected but response was not valid JSON"
)
output_value = parsed_structured
self._validate_structured_output(node_json_schema, output_value)
default_output_key = f"node_{node.id}_output"
self.state[default_output_key] = output_value
if node_config.output_variable:
self.state[node_config.output_variable] = output_value
def _execute_state_node(
self, node: WorkflowNode
@@ -254,13 +321,122 @@ class WorkflowEngine:
formatted_output = self._format_template(output_template)
yield {"answer": formatted_output}
def _parse_structured_output(self, raw_response: str) -> tuple[bool, Optional[Any]]:
normalized_response = raw_response.strip()
if not normalized_response:
return False, None
try:
return True, json.loads(normalized_response)
except json.JSONDecodeError:
logger.warning(
"Workflow agent returned structured output that was not valid JSON"
)
return False, None
def _normalize_node_json_schema(
self, schema: Optional[Dict[str, Any]], node_title: str
) -> Optional[Dict[str, Any]]:
if schema is None:
return None
try:
return normalize_json_schema_payload(schema)
except JsonSchemaValidationError as exc:
raise ValueError(
f'Invalid JSON schema for node "{node_title}": {exc}'
) from exc
def _validate_structured_output(self, schema: Dict[str, Any], output_value: Any) -> None:
if jsonschema is None:
logger.warning(
"jsonschema package is not available, skipping structured output validation"
)
return
try:
normalized_schema = normalize_json_schema_payload(schema)
except JsonSchemaValidationError as exc:
raise ValueError(f"Invalid JSON schema: {exc}") from exc
try:
jsonschema.validate(instance=output_value, schema=normalized_schema)
except jsonschema.exceptions.ValidationError as exc:
raise ValueError(f"Structured output did not match schema: {exc.message}") from exc
except jsonschema.exceptions.SchemaError as exc:
raise ValueError(f"Invalid JSON schema: {exc.message}") from exc
def _format_template(self, template: str) -> str:
formatted = template
context = self._build_template_context()
try:
return self._template_engine.render(template, context)
except TemplateRenderError as e:
logger.warning(
"Workflow template rendering failed, using raw template: %s", str(e)
)
return template
def _build_template_context(self) -> Dict[str, Any]:
docs, docs_together = self._get_source_template_data()
passthrough_data = (
self.state.get("passthrough")
if isinstance(self.state.get("passthrough"), dict)
else None
)
tools_data = (
self.state.get("tools") if isinstance(self.state.get("tools"), dict) else None
)
context = self._namespace_manager.build_context(
user_id=getattr(self.agent, "user", None),
request_id=getattr(self.agent, "request_id", None),
passthrough_data=passthrough_data,
docs=docs,
docs_together=docs_together,
tools_data=tools_data,
)
agent_context: Dict[str, Any] = {}
for key, value in self.state.items():
placeholder = f"{{{{{key}}}}}"
if placeholder in formatted and value is not None:
formatted = formatted.replace(placeholder, str(value))
return formatted
if not isinstance(key, str):
continue
normalized_key = key.strip()
if not normalized_key:
continue
agent_context[normalized_key] = value
context["agent"] = agent_context
# Keep legacy top-level variables working while namespaced variables are adopted.
for key, value in agent_context.items():
if key in TEMPLATE_RESERVED_NAMESPACES:
context[f"agent_{key}"] = value
continue
if key not in context:
context[key] = value
return context
def _get_source_template_data(self) -> tuple[Optional[List[Dict[str, Any]]], Optional[str]]:
docs = getattr(self.agent, "retrieved_docs", None)
if not isinstance(docs, list) or len(docs) == 0:
return None, None
docs_together_parts: List[str] = []
for doc in docs:
if not isinstance(doc, dict):
continue
text = doc.get("text")
if not isinstance(text, str):
continue
filename = doc.get("filename") or doc.get("title") or doc.get("source")
if isinstance(filename, str) and filename.strip():
docs_together_parts.append(f"{filename}\n{text}")
else:
docs_together_parts.append(text)
docs_together = "\n\n".join(docs_together_parts) if docs_together_parts else None
return docs, docs_together
def get_execution_summary(self) -> List[NodeExecutionLog]:
return [