fix: mini workflow fixes

This commit is contained in:
Alex
2026-02-22 11:10:42 +00:00
parent 1a2104f474
commit a6625ec5de
14 changed files with 1261 additions and 136 deletions

View File

@@ -7,6 +7,10 @@ from bson.objectid import ObjectId
from application.agents.tools.tool_action_parser import ToolActionParser
from application.agents.tools.tool_manager import ToolManager
from application.core.json_schema_utils import (
JsonSchemaValidationError,
normalize_json_schema_payload,
)
from application.core.mongo_db import MongoDB
from application.core.settings import settings
from application.llm.handlers.handler_creator import LLMHandlerCreator
@@ -63,7 +67,12 @@ class BaseAgent(ABC):
llm_name if llm_name else "default"
)
self.attachments = attachments or []
self.json_schema = json_schema
self.json_schema = None
if json_schema is not None:
try:
self.json_schema = normalize_json_schema_payload(json_schema)
except JsonSchemaValidationError as exc:
logger.warning("Ignoring invalid JSON schema payload: %s", exc)
self.limited_token_mode = limited_token_mode
self.token_limit = token_limit
self.limited_request_mode = limited_request_mode

View File

@@ -211,8 +211,21 @@ class WorkflowAgent(BaseAgent):
def _serialize_state(self, state: Dict[str, Any]) -> Dict[str, Any]:
serialized: Dict[str, Any] = {}
for key, value in state.items():
if isinstance(value, (str, int, float, bool, type(None))):
serialized[key] = value
else:
serialized[key] = str(value)
serialized[key] = self._serialize_state_value(value)
return serialized
def _serialize_state_value(self, value: Any) -> Any:
if isinstance(value, dict):
return {
str(dict_key): self._serialize_state_value(dict_value)
for dict_key, dict_value in value.items()
}
if isinstance(value, list):
return [self._serialize_state_value(item) for item in value]
if isinstance(value, tuple):
return [self._serialize_state_value(item) for item in value]
if isinstance(value, datetime):
return value.isoformat()
if isinstance(value, (str, int, float, bool, type(None))):
return value
return str(value)

View File

@@ -1,3 +1,4 @@
import json
import logging
from datetime import datetime, timezone
from typing import Any, Dict, Generator, List, Optional, TYPE_CHECKING
@@ -13,6 +14,17 @@ from application.agents.workflows.schemas import (
WorkflowGraph,
WorkflowNode,
)
from application.core.json_schema_utils import (
JsonSchemaValidationError,
normalize_json_schema_payload,
)
from application.templates.namespaces import NamespaceManager
from application.templates.template_engine import TemplateEngine, TemplateRenderError
try:
import jsonschema
except ImportError: # pragma: no cover - optional dependency in some deployments.
jsonschema = None
if TYPE_CHECKING:
from application.agents.base import BaseAgent
@@ -20,6 +32,7 @@ logger = logging.getLogger(__name__)
StateValue = Any
WorkflowState = Dict[str, StateValue]
TEMPLATE_RESERVED_NAMESPACES = {"agent", "system", "source", "tools", "passthrough"}
class WorkflowEngine:
@@ -31,6 +44,8 @@ class WorkflowEngine:
self.state: WorkflowState = {}
self.execution_log: List[Dict[str, Any]] = []
self._condition_result: Optional[str] = None
self._template_engine = TemplateEngine()
self._namespace_manager = NamespaceManager()
def execute(
self, initial_inputs: WorkflowState, query: str
@@ -174,7 +189,11 @@ class WorkflowEngine:
def _execute_agent_node(
self, node: WorkflowNode
) -> Generator[Dict[str, str], None, None]:
from application.core.model_utils import get_api_key_for_provider
from application.core.model_utils import (
get_api_key_for_provider,
get_model_capabilities,
get_provider_from_model_id,
)
node_config = AgentNodeConfig(**node.config.get("config", node.config))
@@ -182,27 +201,50 @@ class WorkflowEngine:
formatted_prompt = self._format_template(node_config.prompt_template)
else:
formatted_prompt = self.state.get("query", "")
node_llm_name = node_config.llm_name or self.agent.llm_name
node_json_schema = self._normalize_node_json_schema(
node_config.json_schema, node.title
)
node_model_id = node_config.model_id or self.agent.model_id
node_llm_name = (
node_config.llm_name
or get_provider_from_model_id(node_model_id or "")
or self.agent.llm_name
)
node_api_key = get_api_key_for_provider(node_llm_name) or self.agent.api_key
if node_json_schema and node_model_id:
model_capabilities = get_model_capabilities(node_model_id)
if model_capabilities and not model_capabilities.get(
"supports_structured_output", False
):
raise ValueError(
f'Model "{node_model_id}" does not support structured output for node "{node.title}"'
)
node_agent = WorkflowNodeAgentFactory.create(
agent_type=node_config.agent_type,
endpoint=self.agent.endpoint,
llm_name=node_llm_name,
model_id=node_config.model_id or self.agent.model_id,
model_id=node_model_id,
api_key=node_api_key,
tool_ids=node_config.tools,
prompt=node_config.system_prompt,
chat_history=self.agent.chat_history,
decoded_token=self.agent.decoded_token,
json_schema=node_config.json_schema,
json_schema=node_json_schema,
)
full_response = ""
full_response_parts: List[str] = []
structured_response_parts: List[str] = []
has_structured_response = False
first_chunk = True
for event in node_agent.gen(formatted_prompt):
if "answer" in event:
full_response += event["answer"]
chunk = str(event["answer"])
full_response_parts.append(chunk)
if event.get("structured"):
has_structured_response = True
structured_response_parts.append(chunk)
if node_config.stream_to_user:
if first_chunk and hasattr(self, "_has_streamed"):
yield {"answer": "\n\n"}
@@ -212,8 +254,33 @@ class WorkflowEngine:
if node_config.stream_to_user:
self._has_streamed = True
output_key = node_config.output_variable or f"node_{node.id}_output"
self.state[output_key] = full_response.strip()
full_response = "".join(full_response_parts).strip()
output_value: Any = full_response
if has_structured_response:
structured_response = "".join(structured_response_parts).strip()
response_to_parse = structured_response or full_response
parsed_success, parsed_structured = self._parse_structured_output(
response_to_parse
)
output_value = parsed_structured if parsed_success else response_to_parse
if node_json_schema:
self._validate_structured_output(node_json_schema, output_value)
elif node_json_schema:
parsed_success, parsed_structured = self._parse_structured_output(
full_response
)
if not parsed_success:
raise ValueError(
"Structured output was expected but response was not valid JSON"
)
output_value = parsed_structured
self._validate_structured_output(node_json_schema, output_value)
default_output_key = f"node_{node.id}_output"
self.state[default_output_key] = output_value
if node_config.output_variable:
self.state[node_config.output_variable] = output_value
def _execute_state_node(
self, node: WorkflowNode
@@ -254,13 +321,122 @@ class WorkflowEngine:
formatted_output = self._format_template(output_template)
yield {"answer": formatted_output}
def _parse_structured_output(self, raw_response: str) -> tuple[bool, Optional[Any]]:
normalized_response = raw_response.strip()
if not normalized_response:
return False, None
try:
return True, json.loads(normalized_response)
except json.JSONDecodeError:
logger.warning(
"Workflow agent returned structured output that was not valid JSON"
)
return False, None
def _normalize_node_json_schema(
self, schema: Optional[Dict[str, Any]], node_title: str
) -> Optional[Dict[str, Any]]:
if schema is None:
return None
try:
return normalize_json_schema_payload(schema)
except JsonSchemaValidationError as exc:
raise ValueError(
f'Invalid JSON schema for node "{node_title}": {exc}'
) from exc
def _validate_structured_output(self, schema: Dict[str, Any], output_value: Any) -> None:
if jsonschema is None:
logger.warning(
"jsonschema package is not available, skipping structured output validation"
)
return
try:
normalized_schema = normalize_json_schema_payload(schema)
except JsonSchemaValidationError as exc:
raise ValueError(f"Invalid JSON schema: {exc}") from exc
try:
jsonschema.validate(instance=output_value, schema=normalized_schema)
except jsonschema.exceptions.ValidationError as exc:
raise ValueError(f"Structured output did not match schema: {exc.message}") from exc
except jsonschema.exceptions.SchemaError as exc:
raise ValueError(f"Invalid JSON schema: {exc.message}") from exc
def _format_template(self, template: str) -> str:
formatted = template
context = self._build_template_context()
try:
return self._template_engine.render(template, context)
except TemplateRenderError as e:
logger.warning(
"Workflow template rendering failed, using raw template: %s", str(e)
)
return template
def _build_template_context(self) -> Dict[str, Any]:
docs, docs_together = self._get_source_template_data()
passthrough_data = (
self.state.get("passthrough")
if isinstance(self.state.get("passthrough"), dict)
else None
)
tools_data = (
self.state.get("tools") if isinstance(self.state.get("tools"), dict) else None
)
context = self._namespace_manager.build_context(
user_id=getattr(self.agent, "user", None),
request_id=getattr(self.agent, "request_id", None),
passthrough_data=passthrough_data,
docs=docs,
docs_together=docs_together,
tools_data=tools_data,
)
agent_context: Dict[str, Any] = {}
for key, value in self.state.items():
placeholder = f"{{{{{key}}}}}"
if placeholder in formatted and value is not None:
formatted = formatted.replace(placeholder, str(value))
return formatted
if not isinstance(key, str):
continue
normalized_key = key.strip()
if not normalized_key:
continue
agent_context[normalized_key] = value
context["agent"] = agent_context
# Keep legacy top-level variables working while namespaced variables are adopted.
for key, value in agent_context.items():
if key in TEMPLATE_RESERVED_NAMESPACES:
context[f"agent_{key}"] = value
continue
if key not in context:
context[key] = value
return context
def _get_source_template_data(self) -> tuple[Optional[List[Dict[str, Any]]], Optional[str]]:
docs = getattr(self.agent, "retrieved_docs", None)
if not isinstance(docs, list) or len(docs) == 0:
return None, None
docs_together_parts: List[str] = []
for doc in docs:
if not isinstance(doc, dict):
continue
text = doc.get("text")
if not isinstance(text, str):
continue
filename = doc.get("filename") or doc.get("title") or doc.get("source")
if isinstance(filename, str) and filename.strip():
docs_together_parts.append(f"{filename}\n{text}")
else:
docs_together_parts.append(text)
docs_together = "\n\n".join(docs_together_parts) if docs_together_parts else None
return docs, docs_together
def get_execution_summary(self) -> List[NodeExecutionLog]:
return [

View File

@@ -23,6 +23,10 @@ from application.api.user.base import (
workflow_nodes_collection,
workflows_collection,
)
from application.core.json_schema_utils import (
JsonSchemaValidationError,
normalize_json_schema_payload,
)
from application.core.settings import settings
from application.utils import (
check_required_fields,
@@ -479,41 +483,15 @@ class CreateAgent(Resource):
data["models"] = []
print(f"Received data: {data}")
# Validate JSON schema if provided
if data.get("json_schema"):
# Validate and normalize JSON schema if provided
if "json_schema" in data:
try:
# Basic validation - ensure it's a valid JSON structure
json_schema = data.get("json_schema")
if not isinstance(json_schema, dict):
return make_response(
jsonify(
{
"success": False,
"message": "JSON schema must be a valid JSON object",
}
),
400,
)
# Validate that it has either a 'schema' property or is itself a schema
if "schema" not in json_schema and "type" not in json_schema:
return make_response(
jsonify(
{
"success": False,
"message": "JSON schema must contain either a 'schema' property or be a valid JSON schema with 'type' property",
}
),
400,
)
except Exception as e:
current_app.logger.error(f"Invalid JSON schema: {e}")
data["json_schema"] = normalize_json_schema_payload(
data.get("json_schema")
)
except JsonSchemaValidationError as exc:
return make_response(
jsonify(
{"success": False, "message": "Invalid JSON schema format"}
),
jsonify({"success": False, "message": f"JSON schema {exc}"}),
400,
)
if data.get("status") not in ["draft", "published"]:
@@ -732,6 +710,8 @@ class UpdateAgent(Resource):
),
400,
)
if data.get("json_schema") == "":
data["json_schema"] = None
except Exception as err:
current_app.logger.error(
f"Error parsing request data: {err}", exc_info=True
@@ -892,17 +872,15 @@ class UpdateAgent(Resource):
elif field == "json_schema":
json_schema = data.get("json_schema")
if json_schema is not None:
if not isinstance(json_schema, dict):
try:
update_fields[field] = normalize_json_schema_payload(
json_schema
)
except JsonSchemaValidationError as exc:
return make_response(
jsonify(
{
"success": False,
"message": "JSON schema must be a valid object",
}
),
jsonify({"success": False, "message": f"JSON schema {exc}"}),
400,
)
update_fields[field] = json_schema
else:
update_fields[field] = None
elif field == "limited_token_mode":

View File

@@ -1,7 +1,7 @@
"""Workflow management routes."""
from datetime import datetime, timezone
from typing import Dict, List, Set
from typing import Any, Dict, List, Optional, Set
from flask import current_app, request
from flask_restx import Namespace, Resource
@@ -11,6 +11,11 @@ from application.api.user.base import (
workflow_nodes_collection,
workflows_collection,
)
from application.core.json_schema_utils import (
JsonSchemaValidationError,
normalize_json_schema_payload,
)
from application.core.model_utils import get_model_capabilities
from application.api.user.utils import (
check_resource_ownership,
error_response,
@@ -85,6 +90,50 @@ def fetch_graph_documents(collection, workflow_id: str, graph_version: int) -> L
return docs
def validate_json_schema_payload(
json_schema: Any,
) -> tuple[Optional[Dict[str, Any]], Optional[str]]:
"""Validate and normalize optional JSON schema payload for structured output."""
if json_schema is None:
return None, None
try:
return normalize_json_schema_payload(json_schema), None
except JsonSchemaValidationError as exc:
return None, str(exc)
def normalize_agent_node_json_schemas(nodes: List[Dict]) -> List[Dict]:
"""Normalize agent-node JSON schema payloads before persistence."""
normalized_nodes: List[Dict] = []
for node in nodes:
if not isinstance(node, dict):
normalized_nodes.append(node)
continue
normalized_node = dict(node)
if normalized_node.get("type") != "agent":
normalized_nodes.append(normalized_node)
continue
raw_config = normalized_node.get("data")
if not isinstance(raw_config, dict) or "json_schema" not in raw_config:
normalized_nodes.append(normalized_node)
continue
normalized_config = dict(raw_config)
try:
normalized_config["json_schema"] = normalize_json_schema_payload(
raw_config.get("json_schema")
)
except JsonSchemaValidationError:
# Validation runs before normalization; keep original on unexpected shape.
normalized_config["json_schema"] = raw_config.get("json_schema")
normalized_node["data"] = normalized_config
normalized_nodes.append(normalized_node)
return normalized_nodes
def validate_workflow_structure(nodes: List[Dict], edges: List[Dict]) -> List[str]:
"""Validate workflow graph structure."""
errors = []
@@ -216,6 +265,28 @@ def validate_workflow_structure(nodes: List[Dict], edges: List[Dict]) -> List[st
f"must eventually reach an end node"
)
agent_nodes = [n for n in nodes if n.get("type") == "agent"]
for agent_node in agent_nodes:
agent_title = agent_node.get("title", agent_node.get("id", "unknown"))
raw_config = agent_node.get("data", {}) or {}
if not isinstance(raw_config, dict):
errors.append(f"Agent node '{agent_title}' has invalid configuration")
continue
normalized_schema, schema_error = validate_json_schema_payload(
raw_config.get("json_schema")
)
has_json_schema = normalized_schema is not None
model_id = raw_config.get("model_id")
if has_json_schema and isinstance(model_id, str) and model_id.strip():
capabilities = get_model_capabilities(model_id.strip())
if capabilities and not capabilities.get("supports_structured_output", False):
errors.append(
f"Agent node '{agent_title}' selected model does not support structured output"
)
if schema_error:
errors.append(f"Agent node '{agent_title}' JSON schema {schema_error}")
for node in nodes:
if not node.get("id"):
errors.append("All nodes must have an id")
@@ -301,6 +372,7 @@ class WorkflowList(Resource):
return error_response(
"Workflow validation failed", errors=validation_errors
)
nodes_data = normalize_agent_node_json_schemas(nodes_data)
now = datetime.now(timezone.utc)
workflow_doc = {
@@ -391,6 +463,7 @@ class WorkflowDetail(Resource):
return error_response(
"Workflow validation failed", errors=validation_errors
)
nodes_data = normalize_agent_node_json_schemas(nodes_data)
current_graph_version = get_workflow_graph_version(workflow)
next_graph_version = current_graph_version + 1

View File

@@ -0,0 +1,34 @@
from typing import Any, Dict, Optional
class JsonSchemaValidationError(ValueError):
"""Raised when a JSON schema payload is invalid."""
def normalize_json_schema_payload(json_schema: Any) -> Optional[Dict[str, Any]]:
"""
Normalize accepted JSON schema payload shapes to a plain schema object.
Accepted inputs:
- None
- A raw schema object with a top-level "type"
- A wrapped payload with a top-level "schema" object
"""
if json_schema is None:
return None
if not isinstance(json_schema, dict):
raise JsonSchemaValidationError("must be a valid JSON object")
wrapped_schema = json_schema.get("schema")
if wrapped_schema is not None:
if not isinstance(wrapped_schema, dict):
raise JsonSchemaValidationError('field "schema" must be a valid JSON object')
return wrapped_schema
if "type" not in json_schema:
raise JsonSchemaValidationError(
'must include either a "type" or "schema" field'
)
return json_schema

View File

@@ -439,10 +439,24 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
const data = await response.json();
const transformed = modelService.transformModels(data.models || []);
setAvailableModels(transformed);
if (mode === 'new' && transformed.length > 0) {
const preferredDefaultModelId =
transformed.find((model) => model.id === data.default_model_id)?.id ||
transformed[0].id;
if (preferredDefaultModelId) {
setSelectedModelIds((prevSelectedModelIds) =>
prevSelectedModelIds.size > 0
? prevSelectedModelIds
: new Set([preferredDefaultModelId]),
);
}
}
};
getTools();
getModels();
}, [token]);
}, [token, mode]);
// Validate folder_id from URL against user's folders
useEffect(() => {

View File

@@ -94,6 +94,20 @@ interface UserTool {
displayName: string;
}
function validateJsonSchemaConfig(schema: unknown): string | null {
if (schema === undefined || schema === null) return null;
if (typeof schema !== 'object' || Array.isArray(schema)) {
return 'must be a valid JSON object';
}
const schemaObject = schema as Record<string, unknown>;
if (!('schema' in schemaObject) && !('type' in schemaObject)) {
return 'must include either a "type" or "schema" field';
}
return null;
}
function createEmptyWorkflowAgent(): Agent {
return {
id: '',
@@ -130,17 +144,46 @@ function parseSimpleCel(expression: string): {
operator: string;
value: string;
} {
let match = expression.match(
const trimmedExpression = expression.trim();
let match = trimmedExpression.match(
/^(\w+)\.(contains|startsWith)\(["'](.*)["']\)$/,
);
if (match) return { variable: match[1], operator: match[2], value: match[3] };
match = expression.match(/^(\w+)\s*(==|!=|>=|<=|>|<)\s*["'](.*)["']$/);
match = trimmedExpression.match(/^(\w+)\.(contains|startsWith)\((.*)\)$/);
if (match) {
const rawValue = match[3].trim();
const unquotedValue = rawValue.replace(/^["'](.*)["']$/, '$1');
return {
variable: match[1],
operator: match[2],
value: unquotedValue,
};
}
match = trimmedExpression.match(/^(contains|startsWith)\(["'](.*)["']\)$/);
if (match) return { variable: '', operator: match[1], value: match[2] };
match = trimmedExpression.match(/^(contains|startsWith)\((.*)\)$/);
if (match) {
const rawValue = match[2].trim();
const unquotedValue = rawValue.replace(/^["'](.*)["']$/, '$1');
return { variable: '', operator: match[1], value: unquotedValue };
}
match = trimmedExpression.match(/^(\w+)\s*(==|!=|>=|<=|>|<)\s*["'](.*)["']$/);
if (match) return { variable: match[1], operator: match[2], value: match[3] };
match = expression.match(/^(\w+)\s*(==|!=|>=|<=|>|<)\s*(.*)$/);
match = trimmedExpression.match(/^(==|!=|>=|<=|>|<)\s*["'](.*)["']$/);
if (match) return { variable: '', operator: match[1], value: match[2] };
match = trimmedExpression.match(/^(\w+)\s*(==|!=|>=|<=|>|<)\s*(.*)$/);
if (match) return { variable: match[1], operator: match[2], value: match[3] };
match = trimmedExpression.match(/^(==|!=|>=|<=|>|<)\s*(.*)$/);
if (match) return { variable: '', operator: match[1], value: match[2] };
return { variable: '', operator: '==', value: '' };
}
@@ -149,13 +192,24 @@ function buildSimpleCel(
operator: string,
value: string,
): string {
if (!variable) return '';
const isNumeric = value !== '' && !isNaN(Number(value));
const isBool = value === 'true' || value === 'false';
const quoted = isNumeric || isBool ? value : `"${value}"`;
if (operator === 'contains') return `${variable}.contains(${quoted})`;
if (operator === 'startsWith') return `${variable}.startsWith(${quoted})`;
return `${variable} ${operator} ${quoted}`;
const trimmedValue = value.trim();
const isNumeric = trimmedValue !== '' && !isNaN(Number(trimmedValue));
const isBool = trimmedValue === 'true' || trimmedValue === 'false';
const literalValue =
isNumeric || isBool ? trimmedValue : JSON.stringify(value);
const stringValue = JSON.stringify(value);
if (operator === 'contains') {
return variable
? `${variable}.contains(${stringValue})`
: `contains(${stringValue})`;
}
if (operator === 'startsWith') {
return variable
? `${variable}.startsWith(${stringValue})`
: `startsWith(${stringValue})`;
}
if (!variable) return `${operator} ${literalValue}`;
return `${variable} ${operator} ${literalValue}`;
}
function normalizeConditionCases(cases: ConditionCase[]): ConditionCase[] {
@@ -283,7 +337,14 @@ function WorkflowBuilderInner() {
>(null);
const workflowSettingsRef = useRef<HTMLDivElement>(null);
const [availableModels, setAvailableModels] = useState<Model[]>([]);
const [defaultAgentModelId, setDefaultAgentModelId] = useState('');
const [availableTools, setAvailableTools] = useState<UserTool[]>([]);
const [agentJsonSchemaDrafts, setAgentJsonSchemaDrafts] = useState<
Record<string, string>
>({});
const [agentJsonSchemaErrors, setAgentJsonSchemaErrors] = useState<
Record<string, string | null>
>({});
const nodeTypes = useMemo<NodeTypes>(
() => ({
@@ -404,8 +465,14 @@ function WorkflowBuilderInner() {
};
if (type === 'agent') {
const defaultModelId = defaultAgentModelId || availableModels[0]?.id;
const defaultModelProvider = availableModels.find(
(model) => model.id === defaultModelId,
)?.provider;
baseNode.data.config = {
agent_type: 'classic',
model_id: defaultModelId,
llm_name: defaultModelProvider || '',
system_prompt: 'You are a helpful assistant.',
prompt_template: '',
stream_to_user: true,
@@ -430,7 +497,7 @@ function WorkflowBuilderInner() {
setNodes((nds) => nds.concat(baseNode));
},
[reactFlowInstance],
[reactFlowInstance, availableModels, defaultAgentModelId],
);
const handleNodeClick = useCallback(
@@ -449,6 +516,18 @@ function WorkflowBuilderInner() {
(e) => e.source !== selectedNode.id && e.target !== selectedNode.id,
),
);
setAgentJsonSchemaDrafts((prev) => {
if (!(selectedNode.id in prev)) return prev;
const next = { ...prev };
delete next[selectedNode.id];
return next;
});
setAgentJsonSchemaErrors((prev) => {
if (!(selectedNode.id in prev)) return prev;
const next = { ...prev };
delete next[selectedNode.id];
return next;
});
setSelectedNode(null);
setShowNodeConfig(false);
}, [selectedNode]);
@@ -468,6 +547,49 @@ function WorkflowBuilderInner() {
[selectedNode],
);
const handleAgentJsonSchemaChange = useCallback(
(text: string) => {
if (!selectedNode || selectedNode.type !== 'agent') return;
const nodeId = selectedNode.id;
setAgentJsonSchemaDrafts((prev) => ({ ...prev, [nodeId]: text }));
if (text.trim() === '') {
setAgentJsonSchemaErrors((prev) => ({ ...prev, [nodeId]: null }));
handleUpdateNodeData({
config: {
...(selectedNode.data.config || {}),
json_schema: undefined,
},
});
return;
}
try {
const parsed = JSON.parse(text);
const validationError = validateJsonSchemaConfig(parsed);
setAgentJsonSchemaErrors((prev) => ({
...prev,
[nodeId]: validationError,
}));
if (!validationError) {
handleUpdateNodeData({
config: {
...(selectedNode.data.config || {}),
json_schema: parsed,
},
});
}
} catch {
setAgentJsonSchemaErrors((prev) => ({
...prev,
[nodeId]: 'must be valid JSON',
}));
}
},
[handleUpdateNodeData, selectedNode],
);
const handleUpload = useCallback((files: File[]) => {
if (files && files.length > 0) {
setImageFile(files[0]);
@@ -564,7 +686,17 @@ function WorkflowBuilderInner() {
const modelsResponse = await modelService.getModels(null);
if (modelsResponse.ok) {
const modelsData = await modelsResponse.json();
setAvailableModels(modelService.transformModels(modelsData.models));
const transformedModels = modelService.transformModels(
modelsData.models || [],
);
setAvailableModels(transformedModels);
const preferredDefaultModel =
transformedModels.find(
(model) => model.id === modelsData.default_model_id,
)?.id ||
transformedModels[0]?.id ||
'';
setDefaultAgentModelId(preferredDefaultModel);
}
const toolsResponse = await userService.getUserTools(null);
@@ -579,6 +711,51 @@ function WorkflowBuilderInner() {
loadModelsAndTools();
}, []);
useEffect(() => {
if (!selectedNode || selectedNode.type !== 'agent') return;
if (!defaultAgentModelId) return;
if (selectedNode.data.config?.model_id) return;
handleUpdateNodeData({
config: {
...(selectedNode.data.config || {}),
model_id: defaultAgentModelId,
llm_name:
availableModels.find((model) => model.id === defaultAgentModelId)
?.provider || '',
},
});
}, [
selectedNode,
defaultAgentModelId,
availableModels,
handleUpdateNodeData,
]);
useEffect(() => {
if (!selectedNode || selectedNode.type !== 'agent') return;
const nodeId = selectedNode.id;
const rawSchema = selectedNode.data.config?.json_schema;
setAgentJsonSchemaDrafts((prev) => {
if (prev[nodeId] !== undefined) return prev;
if (rawSchema === undefined || rawSchema === null) {
return { ...prev, [nodeId]: '' };
}
try {
return { ...prev, [nodeId]: JSON.stringify(rawSchema, null, 2) };
} catch {
return { ...prev, [nodeId]: String(rawSchema) };
}
});
setAgentJsonSchemaErrors((prev) => {
if (prev[nodeId] !== undefined) return prev;
return { ...prev, [nodeId]: validateJsonSchemaConfig(rawSchema) };
});
}, [selectedNode]);
useEffect(() => {
const loadAgentDetails = async () => {
if (!agentId) return;
@@ -655,6 +832,8 @@ function WorkflowBuilderInner() {
);
setWorkflowName(nextWorkflowName);
setWorkflowDescription(nextWorkflowDescription);
setAgentJsonSchemaDrafts({});
setAgentJsonSchemaErrors({});
setNodes(mappedNodes);
setEdges(mappedEdges);
setSavedWorkflowSignature(
@@ -711,6 +890,33 @@ function WorkflowBuilderInner() {
`Agent "${node.data?.title || node.id}" must have a model selected`,
);
}
const hasSchema =
config?.json_schema !== undefined && config?.json_schema !== null;
if (hasSchema && config?.model_id) {
const selectedModel = availableModels.find(
(model) => model.id === config.model_id,
);
if (selectedModel && !selectedModel.supports_structured_output) {
errors.push(
`Agent "${node.data?.title || node.id}" selected model does not support structured output`,
);
}
}
const schemaValidationError = validateJsonSchemaConfig(
config?.json_schema,
);
const draftSchemaError = agentJsonSchemaErrors[node.id];
const effectiveSchemaError =
draftSchemaError !== undefined
? draftSchemaError
: schemaValidationError;
if (effectiveSchemaError) {
errors.push(
`Agent "${node.data?.title || node.id}" JSON schema ${effectiveSchemaError}`,
);
}
});
if (startNodes.length === 1) {
@@ -743,6 +949,7 @@ function WorkflowBuilderInner() {
const conditionNodes = nodes.filter((n) => n.type === 'condition');
conditionNodes.forEach((node) => {
const conditionTitle = node.data?.title || node.id;
const conditionMode = node.data?.config?.mode || 'simple';
const cases = (node.data?.config?.cases || []) as ConditionCase[];
if (
!cases.length ||
@@ -831,6 +1038,16 @@ function WorkflowBuilderInner() {
`Condition "${conditionTitle}" case "${handle}" has a branch connection but no expression`,
);
}
if (conditionMode === 'simple' && hasExpression) {
const parsedCondition = parseSimpleCel(
conditionCase.expression || '',
);
if (!parsedCondition.variable.trim()) {
errors.push(
`Condition "${conditionTitle}" case "${handle}" must specify a variable in Simple mode`,
);
}
}
});
outgoing.forEach((edge) => {
@@ -844,7 +1061,7 @@ function WorkflowBuilderInner() {
});
return errors;
}, [workflowName, nodes, edges]);
}, [workflowName, nodes, edges, agentJsonSchemaErrors, availableModels]);
const canManageAgent = Boolean(currentAgentId || currentAgent.id);
const effectiveAgentId = currentAgentId || currentAgent.id || '';
@@ -1077,6 +1294,42 @@ function WorkflowBuilderInner() {
],
);
const selectedAgentJsonSchemaText = useMemo(() => {
if (!selectedNode || selectedNode.type !== 'agent') return '';
const draft = agentJsonSchemaDrafts[selectedNode.id];
if (draft !== undefined) return draft;
const schema = selectedNode.data.config?.json_schema;
if (schema === undefined || schema === null) return '';
try {
return JSON.stringify(schema, null, 2);
} catch {
return String(schema);
}
}, [selectedNode, agentJsonSchemaDrafts]);
const selectedAgentJsonSchemaError = useMemo(() => {
if (!selectedNode || selectedNode.type !== 'agent') return null;
const cachedError = agentJsonSchemaErrors[selectedNode.id];
if (cachedError !== undefined) return cachedError;
return validateJsonSchemaConfig(selectedNode.data.config?.json_schema);
}, [selectedNode, agentJsonSchemaErrors]);
const selectedAgentModelSupportsStructuredOutput = useMemo(() => {
if (!selectedNode || selectedNode.type !== 'agent') return true;
const modelId = selectedNode.data.config?.model_id;
if (!modelId) return true;
const selectedModel = availableModels.find((model) => model.id === modelId);
if (!selectedModel) return true;
return selectedModel.supports_structured_output;
}, [selectedNode, availableModels]);
return (
<>
<MobileBlocker />
@@ -1577,7 +1830,7 @@ function WorkflowBuilderInner() {
nodes={nodes}
edges={edges}
selectedNodeId={selectedNode.id}
placeholder="Use {{variable}} for dynamic content"
placeholder="Use {{ agent.variable }} for dynamic content"
/>
<div>
<label className="mb-2 block text-sm font-medium text-gray-700 dark:text-gray-300">
@@ -1589,14 +1842,15 @@ function WorkflowBuilderInner() {
selectedNode.data.config
?.output_variable || ''
}
onChange={(e) =>
onChange={(e) => {
const nextOutputVariable = e.target.value;
handleUpdateNodeData({
config: {
...(selectedNode.data.config || {}),
output_variable: e.target.value,
output_variable: nextOutputVariable,
},
})
}
});
}}
className="border-light-silver focus:ring-purple-30 w-full rounded-xl border bg-white px-3 py-2 text-sm transition-all outline-none focus:ring-2 dark:border-[#3A3A3A] dark:bg-[#2C2C2C] dark:text-white"
placeholder="Variable name for output"
/>
@@ -1651,6 +1905,48 @@ function WorkflowBuilderInner() {
emptyText="No tools available"
/>
</div>
<div>
<label className="mb-2 block text-sm font-medium text-gray-700 dark:text-gray-300">
Structured Output (JSON Schema)
</label>
{!selectedAgentModelSupportsStructuredOutput && (
<p className="mb-2 text-xs text-red-600 dark:text-red-400">
Selected model does not support structured
output.
</p>
)}
<textarea
value={selectedAgentJsonSchemaText}
onChange={(e) =>
handleAgentJsonSchemaChange(
e.target.value,
)
}
className="border-light-silver focus:ring-purple-30 w-full rounded-xl border bg-white px-3 py-2 font-mono text-xs transition-all outline-none focus:ring-2 dark:border-[#3A3A3A] dark:bg-[#2C2C2C] dark:text-white"
rows={8}
placeholder={`{
"type": "object",
"properties": {
"summary": { "type": "string" }
},
"required": ["summary"]
}`}
/>
{selectedAgentJsonSchemaText.trim() !==
'' && (
<p
className={`mt-2 text-xs ${
selectedAgentJsonSchemaError
? 'text-red-600 dark:text-red-400'
: 'text-green-600 dark:text-green-400'
}`}
>
{selectedAgentJsonSchemaError
? `Invalid JSON schema: ${selectedAgentJsonSchemaError}`
: 'Valid JSON schema'}
</p>
)}
</div>
</>
)}

View File

@@ -87,7 +87,9 @@ function ExecutionDetails({
const formatValue = (value: unknown): string => {
if (typeof value === 'string') return value;
return JSON.stringify(value, null, 2);
if (value === undefined) return '';
const formatted = JSON.stringify(value, null, 2);
return formatted ?? String(value);
};
return (
@@ -136,6 +138,11 @@ function ExecutionDetails({
if (text.length <= maxLength) return text;
return text.slice(0, maxLength) + '...';
};
const hasOutput =
step.output !== undefined &&
step.output !== null &&
formatValue(step.output) !== '';
const formattedOutput = hasOutput ? formatValue(step.output) : '';
return (
<div
@@ -171,15 +178,15 @@ function ExecutionDetails({
)}
</div>
</div>
{(step.output || step.error || stateVars.length > 0) && (
{(hasOutput || step.error || stateVars.length > 0) && (
<div className="mt-3 space-y-2 text-sm">
{step.output && (
{hasOutput && (
<div className="rounded-lg bg-white p-2 dark:bg-[#2A2A2A]">
<span className="font-medium text-gray-600 dark:text-gray-400">
Output:{' '}
</span>
<span className="wrap-break-word whitespace-pre-wrap text-gray-900 dark:text-gray-100">
{truncateText(step.output, 300)}
{truncateText(formattedOutput, 300)}
</span>
</div>
)}
@@ -254,7 +261,7 @@ function WorkflowMiniMap({
return step?.status || 'pending';
};
const getStatusColor = (nodeId: string, nodeType: string) => {
const getStatusColor = (nodeId: string) => {
const status = getNodeStatus(nodeId);
const isActive = nodeId === activeNodeId;
@@ -275,18 +282,28 @@ function WorkflowMiniMap({
};
const executedOrder = new Map(executionSteps.map((s, i) => [s.nodeId, i]));
const sortedNodes = [...nodes].sort((a, b) => {
const aIdx = executedOrder.get(a.id);
const bIdx = executedOrder.get(b.id);
if (aIdx !== undefined && bIdx !== undefined) return aIdx - bIdx;
if (aIdx !== undefined) return -1;
if (bIdx !== undefined) return 1;
if (a.type === 'start') return -1;
if (b.type === 'start') return 1;
if (a.type === 'end') return 1;
if (b.type === 'end') return -1;
return (a.position?.y || 0) - (b.position?.y || 0);
});
const startNode = nodes.find((node) => node.type === 'start');
const visibleNodeIds = new Set(executionSteps.map((step) => step.nodeId));
if (activeNodeId) {
visibleNodeIds.add(activeNodeId);
}
if (startNode) {
visibleNodeIds.add(startNode.id);
}
const sortedNodes = nodes
.filter((node) => visibleNodeIds.has(node.id))
.sort((a, b) => {
if (a.type === 'start') return -1;
if (b.type === 'start') return 1;
const aIdx = executedOrder.get(a.id);
const bIdx = executedOrder.get(b.id);
if (aIdx !== undefined && bIdx !== undefined) return aIdx - bIdx;
if (aIdx !== undefined) return -1;
if (bIdx !== undefined) return 1;
return (a.position?.y || 0) - (b.position?.y || 0);
});
const hasStepData = (nodeId: string) => {
const step = executionSteps.find((s) => s.nodeId === nodeId);
@@ -306,7 +323,7 @@ function WorkflowMiniMap({
disabled={!hasStepData(node.id)}
className={cn(
'flex h-12 w-full items-center gap-2 rounded-lg border px-3 text-xs transition-all',
getStatusColor(node.id, node.type),
getStatusColor(node.id),
hasStepData(node.id) && 'cursor-pointer hover:opacity-80',
)}
>
@@ -533,6 +550,10 @@ export default function WorkflowPreview({
const querySteps = query.executionSteps || [];
const hasResponse = !!(query.response || query.error);
const isLastQuery = index === queries.length - 1;
const isStreamingLastQuery =
status === 'loading' && isLastQuery;
const shouldShowThought =
!isStreamingLastQuery && Boolean(query.thought);
const isOpen =
openDetailsIndex === index ||
(!hasResponse && isLastQuery && querySteps.length > 0);
@@ -567,17 +588,19 @@ export default function WorkflowPreview({
{/* Response bubble */}
{(query.response ||
query.thought ||
shouldShowThought ||
query.tool_calls) && (
<ConversationBubble
className={isLastQuery ? 'mb-32' : 'mb-7'}
message={query.response}
type="ANSWER"
thought={query.thought}
thought={
shouldShowThought ? query.thought : undefined
}
sources={query.sources}
toolCalls={query.tool_calls}
feedback={query.feedback}
isStreaming={status === 'loading' && isLastQuery}
isStreaming={isStreamingLastQuery}
/>
)}

View File

@@ -9,10 +9,71 @@ import {
} from '@/components/ui/popover';
interface WorkflowVariable {
name: string;
label: string;
templatePath: string;
section: string;
}
const GLOBAL_CONTEXT_VARIABLES: WorkflowVariable[] = [
{
label: 'source.content',
templatePath: 'source.content',
section: 'Global context',
},
{
label: 'source.summaries',
templatePath: 'source.summaries',
section: 'Global context',
},
{
label: 'source.documents',
templatePath: 'source.documents',
section: 'Global context',
},
{
label: 'source.count',
templatePath: 'source.count',
section: 'Global context',
},
{
label: 'system.date',
templatePath: 'system.date',
section: 'Global context',
},
{
label: 'system.time',
templatePath: 'system.time',
section: 'Global context',
},
{
label: 'system.timestamp',
templatePath: 'system.timestamp',
section: 'Global context',
},
{
label: 'system.request_id',
templatePath: 'system.request_id',
section: 'Global context',
},
{
label: 'system.user_id',
templatePath: 'system.user_id',
section: 'Global context',
},
];
function toAgentTemplatePath(variableName: string): string {
const trimmed = variableName.trim();
if (!trimmed) return 'agent';
if (/^[A-Za-z_][A-Za-z0-9_]*$/.test(trimmed)) {
return `agent.${trimmed}`;
}
const escaped = trimmed.replace(/\\/g, '\\\\').replace(/'/g, "\\'");
return `agent['${escaped}']`;
}
function getUpstreamNodeIds(nodeId: string, edges: Edge[]): Set<string> {
const upstream = new Set<string>();
const queue = [nodeId];
@@ -36,32 +97,69 @@ function extractUpstreamVariables(
selectedNodeId: string,
): WorkflowVariable[] {
const variables: WorkflowVariable[] = [
{ name: 'query', section: 'Workflow input' },
{ name: 'chat_history', section: 'Workflow input' },
{
label: 'agent.query',
templatePath: 'agent.query',
section: 'Workflow input',
},
{
label: 'agent.chat_history',
templatePath: 'agent.chat_history',
section: 'Workflow input',
},
...GLOBAL_CONTEXT_VARIABLES,
];
const seen = new Set(['query', 'chat_history']);
const seen = new Set(variables.map((variable) => variable.templatePath));
const upstreamIds = getUpstreamNodeIds(selectedNodeId, edges);
for (const node of nodes) {
if (!upstreamIds.has(node.id)) continue;
if (node.type === 'agent' && node.data?.config?.output_variable) {
const name = node.data.config.output_variable;
if (!seen.has(name)) {
seen.add(name);
if (node.type === 'agent') {
const defaultOutputTemplatePath = toAgentTemplatePath(
`node_${node.id}_output`,
);
if (!seen.has(defaultOutputTemplatePath)) {
seen.add(defaultOutputTemplatePath);
variables.push({
name,
label: defaultOutputTemplatePath,
templatePath: defaultOutputTemplatePath,
section: node.data.title || node.data.label || 'Agent',
});
}
const outputVariable = String(
node.data?.config?.output_variable || '',
).trim();
if (outputVariable) {
const templatePath = toAgentTemplatePath(outputVariable);
if (!seen.has(templatePath)) {
seen.add(templatePath);
variables.push({
label: templatePath,
templatePath,
section: node.data.title || node.data.label || 'Agent',
});
}
}
}
if (node.type === 'state' && node.data?.variable) {
const name = node.data.variable;
if (!seen.has(name)) {
seen.add(name);
if (node.type === 'state') {
const operations = node.data?.config?.operations;
if (!Array.isArray(operations)) continue;
for (const operation of operations) {
const targetVariable = String(operation?.target_variable || '').trim();
if (!targetVariable) continue;
const templatePath = toAgentTemplatePath(targetVariable);
if (seen.has(templatePath)) continue;
seen.add(templatePath);
variables.push({
name,
section: 'Set State',
label: templatePath,
templatePath,
section: node.data.title || node.data.label || 'Set State',
});
}
}
@@ -106,14 +204,16 @@ function VariableListWithSearch({
onSelect,
}: {
variables: WorkflowVariable[];
onSelect: (name: string) => void;
onSelect: (templatePath: string) => void;
}) {
const [search, setSearch] = useState('');
const filtered = useMemo(
() =>
variables.filter((v) =>
v.name.toLowerCase().includes(search.toLowerCase()),
`${v.label} ${v.templatePath}`
.toLowerCase()
.includes(search.toLowerCase()),
),
[variables, search],
);
@@ -146,17 +246,17 @@ function VariableListWithSearch({
</div>
{vars.map((v) => (
<button
key={v.name}
key={`${section}-${v.templatePath}`}
onMouseDown={(e) => {
e.preventDefault();
e.stopPropagation();
onSelect(v.name);
onSelect(v.templatePath);
}}
className="flex w-full cursor-pointer items-center gap-2 px-3 py-1.5 text-left text-sm transition-colors hover:bg-gray-50 dark:hover:bg-[#383838]"
>
<Braces className="text-violets-are-blue h-3.5 w-3.5 shrink-0" />
<span className="truncate font-medium text-gray-800 dark:text-gray-200">
{v.name}
{v.label}
</span>
</button>
))}
@@ -206,7 +306,9 @@ export default function PromptTextArea({
const filtered = useMemo(
() =>
variables.filter((v) =>
v.name.toLowerCase().includes(filterText.toLowerCase()),
`${v.label} ${v.templatePath}`
.toLowerCase()
.includes(filterText.toLowerCase()),
),
[variables, filterText],
);
@@ -217,10 +319,12 @@ export default function PromptTextArea({
const cursorPos = textarea.selectionStart;
const textBeforeCursor = value.slice(0, cursorPos);
const triggerMatch = textBeforeCursor.match(/\{\{(\w*)$/);
const triggerMatch = textBeforeCursor.match(
/\{\{\s*([A-Za-z0-9_.[\]'"]*)$/,
);
if (triggerMatch) {
setFilterText(triggerMatch[1]);
setFilterText(triggerMatch[1].trim());
setCursorInsertPos(cursorPos);
const wrapper = wrapperRef.current;
@@ -237,15 +341,17 @@ export default function PromptTextArea({
}, [value]);
const insertVariable = useCallback(
(varName: string) => {
(templatePath: string) => {
if (cursorInsertPos === null) return;
const textBeforeCursor = value.slice(0, cursorInsertPos);
const triggerMatch = textBeforeCursor.match(/\{\{(\w*)$/);
const triggerMatch = textBeforeCursor.match(
/\{\{\s*([A-Za-z0-9_.[\]'"]*)$/,
);
if (!triggerMatch) return;
const startPos = cursorInsertPos - triggerMatch[0].length;
const insertion = `{{${varName}}}`;
const insertion = `{{ ${templatePath} }}`;
const newValue =
value.slice(0, startPos) + insertion + value.slice(cursorInsertPos);
@@ -262,10 +368,10 @@ export default function PromptTextArea({
);
const insertVariableFromButton = useCallback(
(varName: string) => {
(templatePath: string) => {
const textarea = textareaRef.current;
const cursorPos = textarea?.selectionStart ?? value.length;
const insertion = `{{${varName}}}`;
const insertion = `{{ ${templatePath} }}`;
const newValue =
value.slice(0, cursorPos) + insertion + value.slice(cursorPos);

View File

@@ -85,10 +85,10 @@ export const AgentNode = memo(function AgentNode({
)}
{config.output_variable && (
<div
className="truncate text-xs text-green-600 dark:text-green-400"
title={`Output ${config.output_variable}`}
className="truncate text-xs text-gray-500 dark:text-gray-400"
title={`Output: ${config.output_variable}`}
>
Output {config.output_variable}
Output: {config.output_variable}
</div>
)}
</div>

View File

@@ -13,7 +13,7 @@ export interface WorkflowExecutionStep {
startedAt?: number;
completedAt?: number;
stateSnapshot?: Record<string, unknown>;
output?: string;
output?: unknown;
error?: string;
}
@@ -321,7 +321,9 @@ export const workflowPreviewSlice = createSlice({
}
const querySteps = state.queries[index].executionSteps!;
const existingIndex = querySteps.findIndex((s) => s.nodeId === step.nodeId);
const existingIndex = querySteps.findIndex(
(s) => s.nodeId === step.nodeId,
);
const updatedStep: WorkflowExecutionStep = {
nodeId: step.nodeId,
@@ -332,7 +334,10 @@ export const workflowPreviewSlice = createSlice({
stateSnapshot: step.stateSnapshot,
output: step.output,
error: step.error,
startedAt: existingIndex !== -1 ? querySteps[existingIndex].startedAt : Date.now(),
startedAt:
existingIndex !== -1
? querySteps[existingIndex].startedAt
: Date.now(),
completedAt:
step.status === 'completed' || step.status === 'failed'
? Date.now()
@@ -342,7 +347,8 @@ export const workflowPreviewSlice = createSlice({
};
if (existingIndex !== -1) {
updatedStep.stateSnapshot = step.stateSnapshot ?? querySteps[existingIndex].stateSnapshot;
updatedStep.stateSnapshot =
step.stateSnapshot ?? querySteps[existingIndex].stateSnapshot;
updatedStep.output = step.output ?? querySteps[existingIndex].output;
updatedStep.error = step.error ?? querySteps[existingIndex].error;
querySteps[existingIndex] = updatedStep;
@@ -350,7 +356,9 @@ export const workflowPreviewSlice = createSlice({
querySteps.push(updatedStep);
}
const globalIndex = state.executionSteps.findIndex((s) => s.nodeId === step.nodeId);
const globalIndex = state.executionSteps.findIndex(
(s) => s.nodeId === step.nodeId,
);
if (globalIndex !== -1) {
state.executionSteps[globalIndex] = updatedStep;
} else {

View File

@@ -0,0 +1,332 @@
from types import SimpleNamespace
from typing import Any, Dict, Optional
import pytest
from application.api.user.workflows import routes as workflow_routes
from application.agents.workflows.node_agent import WorkflowNodeAgentFactory
from application.agents.workflows.schemas import (
NodeType,
Workflow,
WorkflowGraph,
WorkflowNode,
)
from application.agents.workflows.workflow_engine import WorkflowEngine
from application.api.user.workflows.routes import validate_workflow_structure
class StubNodeAgent:
def __init__(self, events):
self.events = events
def gen(self, _prompt):
yield from self.events
def create_engine() -> WorkflowEngine:
graph = WorkflowGraph(workflow=Workflow(name="Engine Test"), nodes=[], edges=[])
agent = SimpleNamespace(
endpoint="stream",
llm_name="openai",
model_id="gpt-4o-mini",
api_key="test-key",
chat_history=[],
decoded_token={"sub": "user-1"},
)
return WorkflowEngine(graph, agent)
def create_agent_node(
node_id: str,
output_variable: str = "",
json_schema: Optional[Dict[str, Any]] = None,
) -> WorkflowNode:
config = {
"agent_type": "classic",
"system_prompt": "You are a helpful assistant.",
"prompt_template": "",
"stream_to_user": False,
"tools": [],
}
if output_variable:
config["output_variable"] = output_variable
if json_schema is not None:
config["json_schema"] = json_schema
return WorkflowNode(
id=node_id,
workflow_id="workflow-1",
type=NodeType.AGENT,
title="Agent",
position={"x": 0, "y": 0},
config=config,
)
def test_execute_agent_node_saves_structured_output_as_json(monkeypatch):
engine = create_engine()
node = create_agent_node(
node_id="agent_1",
output_variable="result",
json_schema={"type": "object"},
)
node_events = [
{"answer": '{"summary":"ok",', "structured": True},
{"answer": '"score":2}', "structured": True},
]
monkeypatch.setattr(
WorkflowNodeAgentFactory,
"create",
staticmethod(lambda **kwargs: StubNodeAgent(node_events)),
)
monkeypatch.setattr(
"application.core.model_utils.get_api_key_for_provider",
lambda _provider: None,
)
list(engine._execute_agent_node(node))
expected_output = {"summary": "ok", "score": 2}
assert engine.state["node_agent_1_output"] == expected_output
assert engine.state["result"] == expected_output
def test_execute_agent_node_normalizes_wrapped_schema_before_agent_create(monkeypatch):
engine = create_engine()
node = create_agent_node(
node_id="agent_wrapped",
json_schema={"schema": {"type": "object"}},
)
node_events = [{"answer": '{"summary":"ok"}', "structured": True}]
captured: Dict[str, Any] = {}
def create_node_agent(**kwargs):
captured["json_schema"] = kwargs.get("json_schema")
return StubNodeAgent(node_events)
monkeypatch.setattr(
WorkflowNodeAgentFactory,
"create",
staticmethod(create_node_agent),
)
monkeypatch.setattr(
"application.core.model_utils.get_api_key_for_provider",
lambda _provider: None,
)
monkeypatch.setattr(
"application.core.model_utils.get_model_capabilities",
lambda _model_id: {"supports_structured_output": True},
)
list(engine._execute_agent_node(node))
assert captured["json_schema"] == {"type": "object"}
assert engine.state["node_agent_wrapped_output"] == {"summary": "ok"}
def test_execute_agent_node_falls_back_to_text_when_schema_not_configured(monkeypatch):
engine = create_engine()
node = create_agent_node(node_id="agent_2", output_variable="result")
node_events = [{"answer": "plain text answer"}]
monkeypatch.setattr(
WorkflowNodeAgentFactory,
"create",
staticmethod(lambda **kwargs: StubNodeAgent(node_events)),
)
monkeypatch.setattr(
"application.core.model_utils.get_api_key_for_provider",
lambda _provider: None,
)
list(engine._execute_agent_node(node))
assert engine.state["node_agent_2_output"] == "plain text answer"
assert engine.state["result"] == "plain text answer"
def test_validate_workflow_structure_rejects_invalid_agent_json_schema():
nodes = [
{"id": "start", "type": "start", "title": "Start", "data": {}},
{
"id": "agent",
"type": "agent",
"title": "Agent",
"data": {"json_schema": "invalid"},
},
{"id": "end", "type": "end", "title": "End", "data": {}},
]
edges = [
{"id": "edge_1", "source": "start", "target": "agent"},
{"id": "edge_2", "source": "agent", "target": "end"},
]
errors = validate_workflow_structure(nodes, edges)
assert any(
"Agent node 'Agent' JSON schema must be a valid JSON object" in err
for err in errors
)
def test_validate_workflow_structure_accepts_valid_agent_json_schema():
nodes = [
{"id": "start", "type": "start", "title": "Start", "data": {}},
{
"id": "agent",
"type": "agent",
"title": "Agent",
"data": {"json_schema": {"type": "object"}},
},
{"id": "end", "type": "end", "title": "End", "data": {}},
]
edges = [
{"id": "edge_1", "source": "start", "target": "agent"},
{"id": "edge_2", "source": "agent", "target": "end"},
]
errors = validate_workflow_structure(nodes, edges)
assert errors == []
def test_validate_workflow_structure_accepts_wrapped_agent_json_schema():
nodes = [
{"id": "start", "type": "start", "title": "Start", "data": {}},
{
"id": "agent",
"type": "agent",
"title": "Agent",
"data": {"json_schema": {"schema": {"type": "object"}}},
},
{"id": "end", "type": "end", "title": "End", "data": {}},
]
edges = [
{"id": "edge_1", "source": "start", "target": "agent"},
{"id": "edge_2", "source": "agent", "target": "end"},
]
errors = validate_workflow_structure(nodes, edges)
assert errors == []
def test_validate_workflow_structure_accepts_output_variable_and_schema_together():
nodes = [
{"id": "start", "type": "start", "title": "Start", "data": {}},
{
"id": "agent",
"type": "agent",
"title": "Agent",
"data": {
"output_variable": "answer",
"json_schema": {"type": "object"},
},
},
{"id": "end", "type": "end", "title": "End", "data": {}},
]
edges = [
{"id": "edge_1", "source": "start", "target": "agent"},
{"id": "edge_2", "source": "agent", "target": "end"},
]
errors = validate_workflow_structure(nodes, edges)
assert errors == []
def test_validate_workflow_structure_rejects_unsupported_structured_output_model(
monkeypatch,
):
monkeypatch.setattr(
workflow_routes,
"get_model_capabilities",
lambda _model_id: {"supports_structured_output": False},
)
nodes = [
{"id": "start", "type": "start", "title": "Start", "data": {}},
{
"id": "agent",
"type": "agent",
"title": "Agent",
"data": {
"model_id": "some-model",
"json_schema": {"type": "object"},
},
},
{"id": "end", "type": "end", "title": "End", "data": {}},
]
edges = [
{"id": "edge_1", "source": "start", "target": "agent"},
{"id": "edge_2", "source": "agent", "target": "end"},
]
errors = validate_workflow_structure(nodes, edges)
assert any(
"Agent node 'Agent' selected model does not support structured output"
in err
for err in errors
)
def test_execute_agent_node_raises_when_structured_output_violates_schema(monkeypatch):
engine = create_engine()
node = create_agent_node(
node_id="agent_3",
json_schema={
"type": "object",
"properties": {"summary": {"type": "string"}},
"required": ["summary"],
"additionalProperties": False,
},
)
node_events = [{"answer": '{"score":2}', "structured": True}]
monkeypatch.setattr(
WorkflowNodeAgentFactory,
"create",
staticmethod(lambda **kwargs: StubNodeAgent(node_events)),
)
monkeypatch.setattr(
"application.core.model_utils.get_api_key_for_provider",
lambda _provider: None,
)
monkeypatch.setattr(
"application.core.model_utils.get_model_capabilities",
lambda _model_id: {"supports_structured_output": True},
)
with pytest.raises(ValueError, match="Structured output did not match schema"):
list(engine._execute_agent_node(node))
def test_execute_agent_node_raises_when_schema_set_and_response_not_json(monkeypatch):
engine = create_engine()
node = create_agent_node(
node_id="agent_4",
json_schema={"type": "object"},
)
node_events = [{"answer": "not-json"}]
monkeypatch.setattr(
WorkflowNodeAgentFactory,
"create",
staticmethod(lambda **kwargs: StubNodeAgent(node_events)),
)
monkeypatch.setattr(
"application.core.model_utils.get_api_key_for_provider",
lambda _provider: None,
)
monkeypatch.setattr(
"application.core.model_utils.get_model_capabilities",
lambda _model_id: {"supports_structured_output": True},
)
with pytest.raises(
ValueError,
match="Structured output was expected but response was not valid JSON",
):
list(engine._execute_agent_node(node))

View File

@@ -0,0 +1,63 @@
from types import SimpleNamespace
from application.agents.workflows.schemas import Workflow, WorkflowGraph
from application.agents.workflows.workflow_engine import WorkflowEngine
def create_engine() -> WorkflowEngine:
graph = WorkflowGraph(workflow=Workflow(name="Template Test"), nodes=[], edges=[])
agent = SimpleNamespace(
user="user-1",
request_id="req-1",
retrieved_docs=[
{"title": "Doc A", "text": "Summary A"},
{"title": "Doc B", "text": "Summary B"},
],
)
return WorkflowEngine(graph, agent)
def test_workflow_template_supports_agent_namespace_and_legacy_variables():
engine = create_engine()
engine.state = {"query": "Hello", "chat_history": "[]", "ticket_id": 42}
rendered = engine._format_template(
"{{ agent.query }}|{{ agent.ticket_id }}|{{ query }}|{{ ticket_id }}"
)
assert rendered == "Hello|42|Hello|42"
def test_workflow_template_supports_global_namespaces():
engine = create_engine()
engine.state = {"query": "Hello"}
rendered = engine._format_template(
"{{ source.count }}|{{ source.summaries }}|{{ system.request_id }}"
)
assert rendered.startswith("2|")
assert "Doc A" in rendered
assert "Summary A" in rendered
assert rendered.endswith("|req-1")
def test_workflow_template_handles_namespace_conflicts_with_agent_prefix():
engine = create_engine()
engine.state = {"source": "user-defined-source"}
rendered = engine._format_template(
"{{ agent.source }}|{{ agent_source }}|{{ source.count }}"
)
assert rendered.startswith("user-defined-source|user-defined-source|")
def test_workflow_template_gracefully_handles_invalid_template_syntax():
engine = create_engine()
engine.state = {"query": "Hello"}
invalid_template = "{{ agent.query "
rendered = engine._format_template(invalid_template)
assert rendered == invalid_template