mirror of
https://github.com/arc53/DocsGPT.git
synced 2026-03-04 21:03:41 +00:00
fix: mini workflow fixes
This commit is contained in:
@@ -7,6 +7,10 @@ from bson.objectid import ObjectId
|
||||
|
||||
from application.agents.tools.tool_action_parser import ToolActionParser
|
||||
from application.agents.tools.tool_manager import ToolManager
|
||||
from application.core.json_schema_utils import (
|
||||
JsonSchemaValidationError,
|
||||
normalize_json_schema_payload,
|
||||
)
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.core.settings import settings
|
||||
from application.llm.handlers.handler_creator import LLMHandlerCreator
|
||||
@@ -63,7 +67,12 @@ class BaseAgent(ABC):
|
||||
llm_name if llm_name else "default"
|
||||
)
|
||||
self.attachments = attachments or []
|
||||
self.json_schema = json_schema
|
||||
self.json_schema = None
|
||||
if json_schema is not None:
|
||||
try:
|
||||
self.json_schema = normalize_json_schema_payload(json_schema)
|
||||
except JsonSchemaValidationError as exc:
|
||||
logger.warning("Ignoring invalid JSON schema payload: %s", exc)
|
||||
self.limited_token_mode = limited_token_mode
|
||||
self.token_limit = token_limit
|
||||
self.limited_request_mode = limited_request_mode
|
||||
|
||||
@@ -211,8 +211,21 @@ class WorkflowAgent(BaseAgent):
|
||||
def _serialize_state(self, state: Dict[str, Any]) -> Dict[str, Any]:
|
||||
serialized: Dict[str, Any] = {}
|
||||
for key, value in state.items():
|
||||
if isinstance(value, (str, int, float, bool, type(None))):
|
||||
serialized[key] = value
|
||||
else:
|
||||
serialized[key] = str(value)
|
||||
serialized[key] = self._serialize_state_value(value)
|
||||
return serialized
|
||||
|
||||
def _serialize_state_value(self, value: Any) -> Any:
|
||||
if isinstance(value, dict):
|
||||
return {
|
||||
str(dict_key): self._serialize_state_value(dict_value)
|
||||
for dict_key, dict_value in value.items()
|
||||
}
|
||||
if isinstance(value, list):
|
||||
return [self._serialize_state_value(item) for item in value]
|
||||
if isinstance(value, tuple):
|
||||
return [self._serialize_state_value(item) for item in value]
|
||||
if isinstance(value, datetime):
|
||||
return value.isoformat()
|
||||
if isinstance(value, (str, int, float, bool, type(None))):
|
||||
return value
|
||||
return str(value)
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Dict, Generator, List, Optional, TYPE_CHECKING
|
||||
@@ -13,6 +14,17 @@ from application.agents.workflows.schemas import (
|
||||
WorkflowGraph,
|
||||
WorkflowNode,
|
||||
)
|
||||
from application.core.json_schema_utils import (
|
||||
JsonSchemaValidationError,
|
||||
normalize_json_schema_payload,
|
||||
)
|
||||
from application.templates.namespaces import NamespaceManager
|
||||
from application.templates.template_engine import TemplateEngine, TemplateRenderError
|
||||
|
||||
try:
|
||||
import jsonschema
|
||||
except ImportError: # pragma: no cover - optional dependency in some deployments.
|
||||
jsonschema = None
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from application.agents.base import BaseAgent
|
||||
@@ -20,6 +32,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
StateValue = Any
|
||||
WorkflowState = Dict[str, StateValue]
|
||||
TEMPLATE_RESERVED_NAMESPACES = {"agent", "system", "source", "tools", "passthrough"}
|
||||
|
||||
|
||||
class WorkflowEngine:
|
||||
@@ -31,6 +44,8 @@ class WorkflowEngine:
|
||||
self.state: WorkflowState = {}
|
||||
self.execution_log: List[Dict[str, Any]] = []
|
||||
self._condition_result: Optional[str] = None
|
||||
self._template_engine = TemplateEngine()
|
||||
self._namespace_manager = NamespaceManager()
|
||||
|
||||
def execute(
|
||||
self, initial_inputs: WorkflowState, query: str
|
||||
@@ -174,7 +189,11 @@ class WorkflowEngine:
|
||||
def _execute_agent_node(
|
||||
self, node: WorkflowNode
|
||||
) -> Generator[Dict[str, str], None, None]:
|
||||
from application.core.model_utils import get_api_key_for_provider
|
||||
from application.core.model_utils import (
|
||||
get_api_key_for_provider,
|
||||
get_model_capabilities,
|
||||
get_provider_from_model_id,
|
||||
)
|
||||
|
||||
node_config = AgentNodeConfig(**node.config.get("config", node.config))
|
||||
|
||||
@@ -182,27 +201,50 @@ class WorkflowEngine:
|
||||
formatted_prompt = self._format_template(node_config.prompt_template)
|
||||
else:
|
||||
formatted_prompt = self.state.get("query", "")
|
||||
node_llm_name = node_config.llm_name or self.agent.llm_name
|
||||
node_json_schema = self._normalize_node_json_schema(
|
||||
node_config.json_schema, node.title
|
||||
)
|
||||
node_model_id = node_config.model_id or self.agent.model_id
|
||||
node_llm_name = (
|
||||
node_config.llm_name
|
||||
or get_provider_from_model_id(node_model_id or "")
|
||||
or self.agent.llm_name
|
||||
)
|
||||
node_api_key = get_api_key_for_provider(node_llm_name) or self.agent.api_key
|
||||
|
||||
if node_json_schema and node_model_id:
|
||||
model_capabilities = get_model_capabilities(node_model_id)
|
||||
if model_capabilities and not model_capabilities.get(
|
||||
"supports_structured_output", False
|
||||
):
|
||||
raise ValueError(
|
||||
f'Model "{node_model_id}" does not support structured output for node "{node.title}"'
|
||||
)
|
||||
|
||||
node_agent = WorkflowNodeAgentFactory.create(
|
||||
agent_type=node_config.agent_type,
|
||||
endpoint=self.agent.endpoint,
|
||||
llm_name=node_llm_name,
|
||||
model_id=node_config.model_id or self.agent.model_id,
|
||||
model_id=node_model_id,
|
||||
api_key=node_api_key,
|
||||
tool_ids=node_config.tools,
|
||||
prompt=node_config.system_prompt,
|
||||
chat_history=self.agent.chat_history,
|
||||
decoded_token=self.agent.decoded_token,
|
||||
json_schema=node_config.json_schema,
|
||||
json_schema=node_json_schema,
|
||||
)
|
||||
|
||||
full_response = ""
|
||||
full_response_parts: List[str] = []
|
||||
structured_response_parts: List[str] = []
|
||||
has_structured_response = False
|
||||
first_chunk = True
|
||||
for event in node_agent.gen(formatted_prompt):
|
||||
if "answer" in event:
|
||||
full_response += event["answer"]
|
||||
chunk = str(event["answer"])
|
||||
full_response_parts.append(chunk)
|
||||
if event.get("structured"):
|
||||
has_structured_response = True
|
||||
structured_response_parts.append(chunk)
|
||||
if node_config.stream_to_user:
|
||||
if first_chunk and hasattr(self, "_has_streamed"):
|
||||
yield {"answer": "\n\n"}
|
||||
@@ -212,8 +254,33 @@ class WorkflowEngine:
|
||||
if node_config.stream_to_user:
|
||||
self._has_streamed = True
|
||||
|
||||
output_key = node_config.output_variable or f"node_{node.id}_output"
|
||||
self.state[output_key] = full_response.strip()
|
||||
full_response = "".join(full_response_parts).strip()
|
||||
output_value: Any = full_response
|
||||
if has_structured_response:
|
||||
structured_response = "".join(structured_response_parts).strip()
|
||||
response_to_parse = structured_response or full_response
|
||||
parsed_success, parsed_structured = self._parse_structured_output(
|
||||
response_to_parse
|
||||
)
|
||||
output_value = parsed_structured if parsed_success else response_to_parse
|
||||
if node_json_schema:
|
||||
self._validate_structured_output(node_json_schema, output_value)
|
||||
elif node_json_schema:
|
||||
parsed_success, parsed_structured = self._parse_structured_output(
|
||||
full_response
|
||||
)
|
||||
if not parsed_success:
|
||||
raise ValueError(
|
||||
"Structured output was expected but response was not valid JSON"
|
||||
)
|
||||
output_value = parsed_structured
|
||||
self._validate_structured_output(node_json_schema, output_value)
|
||||
|
||||
default_output_key = f"node_{node.id}_output"
|
||||
self.state[default_output_key] = output_value
|
||||
|
||||
if node_config.output_variable:
|
||||
self.state[node_config.output_variable] = output_value
|
||||
|
||||
def _execute_state_node(
|
||||
self, node: WorkflowNode
|
||||
@@ -254,13 +321,122 @@ class WorkflowEngine:
|
||||
formatted_output = self._format_template(output_template)
|
||||
yield {"answer": formatted_output}
|
||||
|
||||
def _parse_structured_output(self, raw_response: str) -> tuple[bool, Optional[Any]]:
|
||||
normalized_response = raw_response.strip()
|
||||
if not normalized_response:
|
||||
return False, None
|
||||
|
||||
try:
|
||||
return True, json.loads(normalized_response)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(
|
||||
"Workflow agent returned structured output that was not valid JSON"
|
||||
)
|
||||
return False, None
|
||||
|
||||
def _normalize_node_json_schema(
|
||||
self, schema: Optional[Dict[str, Any]], node_title: str
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
if schema is None:
|
||||
return None
|
||||
try:
|
||||
return normalize_json_schema_payload(schema)
|
||||
except JsonSchemaValidationError as exc:
|
||||
raise ValueError(
|
||||
f'Invalid JSON schema for node "{node_title}": {exc}'
|
||||
) from exc
|
||||
|
||||
def _validate_structured_output(self, schema: Dict[str, Any], output_value: Any) -> None:
|
||||
if jsonschema is None:
|
||||
logger.warning(
|
||||
"jsonschema package is not available, skipping structured output validation"
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
normalized_schema = normalize_json_schema_payload(schema)
|
||||
except JsonSchemaValidationError as exc:
|
||||
raise ValueError(f"Invalid JSON schema: {exc}") from exc
|
||||
|
||||
try:
|
||||
jsonschema.validate(instance=output_value, schema=normalized_schema)
|
||||
except jsonschema.exceptions.ValidationError as exc:
|
||||
raise ValueError(f"Structured output did not match schema: {exc.message}") from exc
|
||||
except jsonschema.exceptions.SchemaError as exc:
|
||||
raise ValueError(f"Invalid JSON schema: {exc.message}") from exc
|
||||
|
||||
def _format_template(self, template: str) -> str:
|
||||
formatted = template
|
||||
context = self._build_template_context()
|
||||
try:
|
||||
return self._template_engine.render(template, context)
|
||||
except TemplateRenderError as e:
|
||||
logger.warning(
|
||||
"Workflow template rendering failed, using raw template: %s", str(e)
|
||||
)
|
||||
return template
|
||||
|
||||
def _build_template_context(self) -> Dict[str, Any]:
|
||||
docs, docs_together = self._get_source_template_data()
|
||||
passthrough_data = (
|
||||
self.state.get("passthrough")
|
||||
if isinstance(self.state.get("passthrough"), dict)
|
||||
else None
|
||||
)
|
||||
tools_data = (
|
||||
self.state.get("tools") if isinstance(self.state.get("tools"), dict) else None
|
||||
)
|
||||
|
||||
context = self._namespace_manager.build_context(
|
||||
user_id=getattr(self.agent, "user", None),
|
||||
request_id=getattr(self.agent, "request_id", None),
|
||||
passthrough_data=passthrough_data,
|
||||
docs=docs,
|
||||
docs_together=docs_together,
|
||||
tools_data=tools_data,
|
||||
)
|
||||
|
||||
agent_context: Dict[str, Any] = {}
|
||||
for key, value in self.state.items():
|
||||
placeholder = f"{{{{{key}}}}}"
|
||||
if placeholder in formatted and value is not None:
|
||||
formatted = formatted.replace(placeholder, str(value))
|
||||
return formatted
|
||||
if not isinstance(key, str):
|
||||
continue
|
||||
normalized_key = key.strip()
|
||||
if not normalized_key:
|
||||
continue
|
||||
agent_context[normalized_key] = value
|
||||
|
||||
context["agent"] = agent_context
|
||||
|
||||
# Keep legacy top-level variables working while namespaced variables are adopted.
|
||||
for key, value in agent_context.items():
|
||||
if key in TEMPLATE_RESERVED_NAMESPACES:
|
||||
context[f"agent_{key}"] = value
|
||||
continue
|
||||
if key not in context:
|
||||
context[key] = value
|
||||
|
||||
return context
|
||||
|
||||
def _get_source_template_data(self) -> tuple[Optional[List[Dict[str, Any]]], Optional[str]]:
|
||||
docs = getattr(self.agent, "retrieved_docs", None)
|
||||
if not isinstance(docs, list) or len(docs) == 0:
|
||||
return None, None
|
||||
|
||||
docs_together_parts: List[str] = []
|
||||
for doc in docs:
|
||||
if not isinstance(doc, dict):
|
||||
continue
|
||||
text = doc.get("text")
|
||||
if not isinstance(text, str):
|
||||
continue
|
||||
|
||||
filename = doc.get("filename") or doc.get("title") or doc.get("source")
|
||||
if isinstance(filename, str) and filename.strip():
|
||||
docs_together_parts.append(f"{filename}\n{text}")
|
||||
else:
|
||||
docs_together_parts.append(text)
|
||||
|
||||
docs_together = "\n\n".join(docs_together_parts) if docs_together_parts else None
|
||||
return docs, docs_together
|
||||
|
||||
def get_execution_summary(self) -> List[NodeExecutionLog]:
|
||||
return [
|
||||
|
||||
@@ -23,6 +23,10 @@ from application.api.user.base import (
|
||||
workflow_nodes_collection,
|
||||
workflows_collection,
|
||||
)
|
||||
from application.core.json_schema_utils import (
|
||||
JsonSchemaValidationError,
|
||||
normalize_json_schema_payload,
|
||||
)
|
||||
from application.core.settings import settings
|
||||
from application.utils import (
|
||||
check_required_fields,
|
||||
@@ -479,41 +483,15 @@ class CreateAgent(Resource):
|
||||
data["models"] = []
|
||||
print(f"Received data: {data}")
|
||||
|
||||
# Validate JSON schema if provided
|
||||
|
||||
if data.get("json_schema"):
|
||||
# Validate and normalize JSON schema if provided
|
||||
if "json_schema" in data:
|
||||
try:
|
||||
# Basic validation - ensure it's a valid JSON structure
|
||||
|
||||
json_schema = data.get("json_schema")
|
||||
if not isinstance(json_schema, dict):
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": "JSON schema must be a valid JSON object",
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
# Validate that it has either a 'schema' property or is itself a schema
|
||||
|
||||
if "schema" not in json_schema and "type" not in json_schema:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": "JSON schema must contain either a 'schema' property or be a valid JSON schema with 'type' property",
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Invalid JSON schema: {e}")
|
||||
data["json_schema"] = normalize_json_schema_payload(
|
||||
data.get("json_schema")
|
||||
)
|
||||
except JsonSchemaValidationError as exc:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{"success": False, "message": "Invalid JSON schema format"}
|
||||
),
|
||||
jsonify({"success": False, "message": f"JSON schema {exc}"}),
|
||||
400,
|
||||
)
|
||||
if data.get("status") not in ["draft", "published"]:
|
||||
@@ -732,6 +710,8 @@ class UpdateAgent(Resource):
|
||||
),
|
||||
400,
|
||||
)
|
||||
if data.get("json_schema") == "":
|
||||
data["json_schema"] = None
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error parsing request data: {err}", exc_info=True
|
||||
@@ -892,17 +872,15 @@ class UpdateAgent(Resource):
|
||||
elif field == "json_schema":
|
||||
json_schema = data.get("json_schema")
|
||||
if json_schema is not None:
|
||||
if not isinstance(json_schema, dict):
|
||||
try:
|
||||
update_fields[field] = normalize_json_schema_payload(
|
||||
json_schema
|
||||
)
|
||||
except JsonSchemaValidationError as exc:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": "JSON schema must be a valid object",
|
||||
}
|
||||
),
|
||||
jsonify({"success": False, "message": f"JSON schema {exc}"}),
|
||||
400,
|
||||
)
|
||||
update_fields[field] = json_schema
|
||||
else:
|
||||
update_fields[field] = None
|
||||
elif field == "limited_token_mode":
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Workflow management routes."""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import Dict, List, Set
|
||||
from typing import Any, Dict, List, Optional, Set
|
||||
|
||||
from flask import current_app, request
|
||||
from flask_restx import Namespace, Resource
|
||||
@@ -11,6 +11,11 @@ from application.api.user.base import (
|
||||
workflow_nodes_collection,
|
||||
workflows_collection,
|
||||
)
|
||||
from application.core.json_schema_utils import (
|
||||
JsonSchemaValidationError,
|
||||
normalize_json_schema_payload,
|
||||
)
|
||||
from application.core.model_utils import get_model_capabilities
|
||||
from application.api.user.utils import (
|
||||
check_resource_ownership,
|
||||
error_response,
|
||||
@@ -85,6 +90,50 @@ def fetch_graph_documents(collection, workflow_id: str, graph_version: int) -> L
|
||||
return docs
|
||||
|
||||
|
||||
def validate_json_schema_payload(
|
||||
json_schema: Any,
|
||||
) -> tuple[Optional[Dict[str, Any]], Optional[str]]:
|
||||
"""Validate and normalize optional JSON schema payload for structured output."""
|
||||
if json_schema is None:
|
||||
return None, None
|
||||
try:
|
||||
return normalize_json_schema_payload(json_schema), None
|
||||
except JsonSchemaValidationError as exc:
|
||||
return None, str(exc)
|
||||
|
||||
|
||||
def normalize_agent_node_json_schemas(nodes: List[Dict]) -> List[Dict]:
|
||||
"""Normalize agent-node JSON schema payloads before persistence."""
|
||||
normalized_nodes: List[Dict] = []
|
||||
for node in nodes:
|
||||
if not isinstance(node, dict):
|
||||
normalized_nodes.append(node)
|
||||
continue
|
||||
|
||||
normalized_node = dict(node)
|
||||
if normalized_node.get("type") != "agent":
|
||||
normalized_nodes.append(normalized_node)
|
||||
continue
|
||||
|
||||
raw_config = normalized_node.get("data")
|
||||
if not isinstance(raw_config, dict) or "json_schema" not in raw_config:
|
||||
normalized_nodes.append(normalized_node)
|
||||
continue
|
||||
|
||||
normalized_config = dict(raw_config)
|
||||
try:
|
||||
normalized_config["json_schema"] = normalize_json_schema_payload(
|
||||
raw_config.get("json_schema")
|
||||
)
|
||||
except JsonSchemaValidationError:
|
||||
# Validation runs before normalization; keep original on unexpected shape.
|
||||
normalized_config["json_schema"] = raw_config.get("json_schema")
|
||||
normalized_node["data"] = normalized_config
|
||||
normalized_nodes.append(normalized_node)
|
||||
|
||||
return normalized_nodes
|
||||
|
||||
|
||||
def validate_workflow_structure(nodes: List[Dict], edges: List[Dict]) -> List[str]:
|
||||
"""Validate workflow graph structure."""
|
||||
errors = []
|
||||
@@ -216,6 +265,28 @@ def validate_workflow_structure(nodes: List[Dict], edges: List[Dict]) -> List[st
|
||||
f"must eventually reach an end node"
|
||||
)
|
||||
|
||||
agent_nodes = [n for n in nodes if n.get("type") == "agent"]
|
||||
for agent_node in agent_nodes:
|
||||
agent_title = agent_node.get("title", agent_node.get("id", "unknown"))
|
||||
raw_config = agent_node.get("data", {}) or {}
|
||||
if not isinstance(raw_config, dict):
|
||||
errors.append(f"Agent node '{agent_title}' has invalid configuration")
|
||||
continue
|
||||
normalized_schema, schema_error = validate_json_schema_payload(
|
||||
raw_config.get("json_schema")
|
||||
)
|
||||
has_json_schema = normalized_schema is not None
|
||||
|
||||
model_id = raw_config.get("model_id")
|
||||
if has_json_schema and isinstance(model_id, str) and model_id.strip():
|
||||
capabilities = get_model_capabilities(model_id.strip())
|
||||
if capabilities and not capabilities.get("supports_structured_output", False):
|
||||
errors.append(
|
||||
f"Agent node '{agent_title}' selected model does not support structured output"
|
||||
)
|
||||
if schema_error:
|
||||
errors.append(f"Agent node '{agent_title}' JSON schema {schema_error}")
|
||||
|
||||
for node in nodes:
|
||||
if not node.get("id"):
|
||||
errors.append("All nodes must have an id")
|
||||
@@ -301,6 +372,7 @@ class WorkflowList(Resource):
|
||||
return error_response(
|
||||
"Workflow validation failed", errors=validation_errors
|
||||
)
|
||||
nodes_data = normalize_agent_node_json_schemas(nodes_data)
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
workflow_doc = {
|
||||
@@ -391,6 +463,7 @@ class WorkflowDetail(Resource):
|
||||
return error_response(
|
||||
"Workflow validation failed", errors=validation_errors
|
||||
)
|
||||
nodes_data = normalize_agent_node_json_schemas(nodes_data)
|
||||
|
||||
current_graph_version = get_workflow_graph_version(workflow)
|
||||
next_graph_version = current_graph_version + 1
|
||||
|
||||
34
application/core/json_schema_utils.py
Normal file
34
application/core/json_schema_utils.py
Normal file
@@ -0,0 +1,34 @@
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
|
||||
class JsonSchemaValidationError(ValueError):
|
||||
"""Raised when a JSON schema payload is invalid."""
|
||||
|
||||
|
||||
def normalize_json_schema_payload(json_schema: Any) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Normalize accepted JSON schema payload shapes to a plain schema object.
|
||||
|
||||
Accepted inputs:
|
||||
- None
|
||||
- A raw schema object with a top-level "type"
|
||||
- A wrapped payload with a top-level "schema" object
|
||||
"""
|
||||
if json_schema is None:
|
||||
return None
|
||||
|
||||
if not isinstance(json_schema, dict):
|
||||
raise JsonSchemaValidationError("must be a valid JSON object")
|
||||
|
||||
wrapped_schema = json_schema.get("schema")
|
||||
if wrapped_schema is not None:
|
||||
if not isinstance(wrapped_schema, dict):
|
||||
raise JsonSchemaValidationError('field "schema" must be a valid JSON object')
|
||||
return wrapped_schema
|
||||
|
||||
if "type" not in json_schema:
|
||||
raise JsonSchemaValidationError(
|
||||
'must include either a "type" or "schema" field'
|
||||
)
|
||||
|
||||
return json_schema
|
||||
@@ -439,10 +439,24 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
|
||||
const data = await response.json();
|
||||
const transformed = modelService.transformModels(data.models || []);
|
||||
setAvailableModels(transformed);
|
||||
|
||||
if (mode === 'new' && transformed.length > 0) {
|
||||
const preferredDefaultModelId =
|
||||
transformed.find((model) => model.id === data.default_model_id)?.id ||
|
||||
transformed[0].id;
|
||||
|
||||
if (preferredDefaultModelId) {
|
||||
setSelectedModelIds((prevSelectedModelIds) =>
|
||||
prevSelectedModelIds.size > 0
|
||||
? prevSelectedModelIds
|
||||
: new Set([preferredDefaultModelId]),
|
||||
);
|
||||
}
|
||||
}
|
||||
};
|
||||
getTools();
|
||||
getModels();
|
||||
}, [token]);
|
||||
}, [token, mode]);
|
||||
|
||||
// Validate folder_id from URL against user's folders
|
||||
useEffect(() => {
|
||||
|
||||
@@ -94,6 +94,20 @@ interface UserTool {
|
||||
displayName: string;
|
||||
}
|
||||
|
||||
function validateJsonSchemaConfig(schema: unknown): string | null {
|
||||
if (schema === undefined || schema === null) return null;
|
||||
if (typeof schema !== 'object' || Array.isArray(schema)) {
|
||||
return 'must be a valid JSON object';
|
||||
}
|
||||
|
||||
const schemaObject = schema as Record<string, unknown>;
|
||||
if (!('schema' in schemaObject) && !('type' in schemaObject)) {
|
||||
return 'must include either a "type" or "schema" field';
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
function createEmptyWorkflowAgent(): Agent {
|
||||
return {
|
||||
id: '',
|
||||
@@ -130,17 +144,46 @@ function parseSimpleCel(expression: string): {
|
||||
operator: string;
|
||||
value: string;
|
||||
} {
|
||||
let match = expression.match(
|
||||
const trimmedExpression = expression.trim();
|
||||
|
||||
let match = trimmedExpression.match(
|
||||
/^(\w+)\.(contains|startsWith)\(["'](.*)["']\)$/,
|
||||
);
|
||||
if (match) return { variable: match[1], operator: match[2], value: match[3] };
|
||||
|
||||
match = expression.match(/^(\w+)\s*(==|!=|>=|<=|>|<)\s*["'](.*)["']$/);
|
||||
match = trimmedExpression.match(/^(\w+)\.(contains|startsWith)\((.*)\)$/);
|
||||
if (match) {
|
||||
const rawValue = match[3].trim();
|
||||
const unquotedValue = rawValue.replace(/^["'](.*)["']$/, '$1');
|
||||
return {
|
||||
variable: match[1],
|
||||
operator: match[2],
|
||||
value: unquotedValue,
|
||||
};
|
||||
}
|
||||
|
||||
match = trimmedExpression.match(/^(contains|startsWith)\(["'](.*)["']\)$/);
|
||||
if (match) return { variable: '', operator: match[1], value: match[2] };
|
||||
|
||||
match = trimmedExpression.match(/^(contains|startsWith)\((.*)\)$/);
|
||||
if (match) {
|
||||
const rawValue = match[2].trim();
|
||||
const unquotedValue = rawValue.replace(/^["'](.*)["']$/, '$1');
|
||||
return { variable: '', operator: match[1], value: unquotedValue };
|
||||
}
|
||||
|
||||
match = trimmedExpression.match(/^(\w+)\s*(==|!=|>=|<=|>|<)\s*["'](.*)["']$/);
|
||||
if (match) return { variable: match[1], operator: match[2], value: match[3] };
|
||||
|
||||
match = expression.match(/^(\w+)\s*(==|!=|>=|<=|>|<)\s*(.*)$/);
|
||||
match = trimmedExpression.match(/^(==|!=|>=|<=|>|<)\s*["'](.*)["']$/);
|
||||
if (match) return { variable: '', operator: match[1], value: match[2] };
|
||||
|
||||
match = trimmedExpression.match(/^(\w+)\s*(==|!=|>=|<=|>|<)\s*(.*)$/);
|
||||
if (match) return { variable: match[1], operator: match[2], value: match[3] };
|
||||
|
||||
match = trimmedExpression.match(/^(==|!=|>=|<=|>|<)\s*(.*)$/);
|
||||
if (match) return { variable: '', operator: match[1], value: match[2] };
|
||||
|
||||
return { variable: '', operator: '==', value: '' };
|
||||
}
|
||||
|
||||
@@ -149,13 +192,24 @@ function buildSimpleCel(
|
||||
operator: string,
|
||||
value: string,
|
||||
): string {
|
||||
if (!variable) return '';
|
||||
const isNumeric = value !== '' && !isNaN(Number(value));
|
||||
const isBool = value === 'true' || value === 'false';
|
||||
const quoted = isNumeric || isBool ? value : `"${value}"`;
|
||||
if (operator === 'contains') return `${variable}.contains(${quoted})`;
|
||||
if (operator === 'startsWith') return `${variable}.startsWith(${quoted})`;
|
||||
return `${variable} ${operator} ${quoted}`;
|
||||
const trimmedValue = value.trim();
|
||||
const isNumeric = trimmedValue !== '' && !isNaN(Number(trimmedValue));
|
||||
const isBool = trimmedValue === 'true' || trimmedValue === 'false';
|
||||
const literalValue =
|
||||
isNumeric || isBool ? trimmedValue : JSON.stringify(value);
|
||||
const stringValue = JSON.stringify(value);
|
||||
if (operator === 'contains') {
|
||||
return variable
|
||||
? `${variable}.contains(${stringValue})`
|
||||
: `contains(${stringValue})`;
|
||||
}
|
||||
if (operator === 'startsWith') {
|
||||
return variable
|
||||
? `${variable}.startsWith(${stringValue})`
|
||||
: `startsWith(${stringValue})`;
|
||||
}
|
||||
if (!variable) return `${operator} ${literalValue}`;
|
||||
return `${variable} ${operator} ${literalValue}`;
|
||||
}
|
||||
|
||||
function normalizeConditionCases(cases: ConditionCase[]): ConditionCase[] {
|
||||
@@ -283,7 +337,14 @@ function WorkflowBuilderInner() {
|
||||
>(null);
|
||||
const workflowSettingsRef = useRef<HTMLDivElement>(null);
|
||||
const [availableModels, setAvailableModels] = useState<Model[]>([]);
|
||||
const [defaultAgentModelId, setDefaultAgentModelId] = useState('');
|
||||
const [availableTools, setAvailableTools] = useState<UserTool[]>([]);
|
||||
const [agentJsonSchemaDrafts, setAgentJsonSchemaDrafts] = useState<
|
||||
Record<string, string>
|
||||
>({});
|
||||
const [agentJsonSchemaErrors, setAgentJsonSchemaErrors] = useState<
|
||||
Record<string, string | null>
|
||||
>({});
|
||||
|
||||
const nodeTypes = useMemo<NodeTypes>(
|
||||
() => ({
|
||||
@@ -404,8 +465,14 @@ function WorkflowBuilderInner() {
|
||||
};
|
||||
|
||||
if (type === 'agent') {
|
||||
const defaultModelId = defaultAgentModelId || availableModels[0]?.id;
|
||||
const defaultModelProvider = availableModels.find(
|
||||
(model) => model.id === defaultModelId,
|
||||
)?.provider;
|
||||
baseNode.data.config = {
|
||||
agent_type: 'classic',
|
||||
model_id: defaultModelId,
|
||||
llm_name: defaultModelProvider || '',
|
||||
system_prompt: 'You are a helpful assistant.',
|
||||
prompt_template: '',
|
||||
stream_to_user: true,
|
||||
@@ -430,7 +497,7 @@ function WorkflowBuilderInner() {
|
||||
|
||||
setNodes((nds) => nds.concat(baseNode));
|
||||
},
|
||||
[reactFlowInstance],
|
||||
[reactFlowInstance, availableModels, defaultAgentModelId],
|
||||
);
|
||||
|
||||
const handleNodeClick = useCallback(
|
||||
@@ -449,6 +516,18 @@ function WorkflowBuilderInner() {
|
||||
(e) => e.source !== selectedNode.id && e.target !== selectedNode.id,
|
||||
),
|
||||
);
|
||||
setAgentJsonSchemaDrafts((prev) => {
|
||||
if (!(selectedNode.id in prev)) return prev;
|
||||
const next = { ...prev };
|
||||
delete next[selectedNode.id];
|
||||
return next;
|
||||
});
|
||||
setAgentJsonSchemaErrors((prev) => {
|
||||
if (!(selectedNode.id in prev)) return prev;
|
||||
const next = { ...prev };
|
||||
delete next[selectedNode.id];
|
||||
return next;
|
||||
});
|
||||
setSelectedNode(null);
|
||||
setShowNodeConfig(false);
|
||||
}, [selectedNode]);
|
||||
@@ -468,6 +547,49 @@ function WorkflowBuilderInner() {
|
||||
[selectedNode],
|
||||
);
|
||||
|
||||
const handleAgentJsonSchemaChange = useCallback(
|
||||
(text: string) => {
|
||||
if (!selectedNode || selectedNode.type !== 'agent') return;
|
||||
|
||||
const nodeId = selectedNode.id;
|
||||
setAgentJsonSchemaDrafts((prev) => ({ ...prev, [nodeId]: text }));
|
||||
|
||||
if (text.trim() === '') {
|
||||
setAgentJsonSchemaErrors((prev) => ({ ...prev, [nodeId]: null }));
|
||||
handleUpdateNodeData({
|
||||
config: {
|
||||
...(selectedNode.data.config || {}),
|
||||
json_schema: undefined,
|
||||
},
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
const parsed = JSON.parse(text);
|
||||
const validationError = validateJsonSchemaConfig(parsed);
|
||||
setAgentJsonSchemaErrors((prev) => ({
|
||||
...prev,
|
||||
[nodeId]: validationError,
|
||||
}));
|
||||
if (!validationError) {
|
||||
handleUpdateNodeData({
|
||||
config: {
|
||||
...(selectedNode.data.config || {}),
|
||||
json_schema: parsed,
|
||||
},
|
||||
});
|
||||
}
|
||||
} catch {
|
||||
setAgentJsonSchemaErrors((prev) => ({
|
||||
...prev,
|
||||
[nodeId]: 'must be valid JSON',
|
||||
}));
|
||||
}
|
||||
},
|
||||
[handleUpdateNodeData, selectedNode],
|
||||
);
|
||||
|
||||
const handleUpload = useCallback((files: File[]) => {
|
||||
if (files && files.length > 0) {
|
||||
setImageFile(files[0]);
|
||||
@@ -564,7 +686,17 @@ function WorkflowBuilderInner() {
|
||||
const modelsResponse = await modelService.getModels(null);
|
||||
if (modelsResponse.ok) {
|
||||
const modelsData = await modelsResponse.json();
|
||||
setAvailableModels(modelService.transformModels(modelsData.models));
|
||||
const transformedModels = modelService.transformModels(
|
||||
modelsData.models || [],
|
||||
);
|
||||
setAvailableModels(transformedModels);
|
||||
const preferredDefaultModel =
|
||||
transformedModels.find(
|
||||
(model) => model.id === modelsData.default_model_id,
|
||||
)?.id ||
|
||||
transformedModels[0]?.id ||
|
||||
'';
|
||||
setDefaultAgentModelId(preferredDefaultModel);
|
||||
}
|
||||
|
||||
const toolsResponse = await userService.getUserTools(null);
|
||||
@@ -579,6 +711,51 @@ function WorkflowBuilderInner() {
|
||||
loadModelsAndTools();
|
||||
}, []);
|
||||
|
||||
useEffect(() => {
|
||||
if (!selectedNode || selectedNode.type !== 'agent') return;
|
||||
if (!defaultAgentModelId) return;
|
||||
if (selectedNode.data.config?.model_id) return;
|
||||
|
||||
handleUpdateNodeData({
|
||||
config: {
|
||||
...(selectedNode.data.config || {}),
|
||||
model_id: defaultAgentModelId,
|
||||
llm_name:
|
||||
availableModels.find((model) => model.id === defaultAgentModelId)
|
||||
?.provider || '',
|
||||
},
|
||||
});
|
||||
}, [
|
||||
selectedNode,
|
||||
defaultAgentModelId,
|
||||
availableModels,
|
||||
handleUpdateNodeData,
|
||||
]);
|
||||
|
||||
useEffect(() => {
|
||||
if (!selectedNode || selectedNode.type !== 'agent') return;
|
||||
const nodeId = selectedNode.id;
|
||||
const rawSchema = selectedNode.data.config?.json_schema;
|
||||
|
||||
setAgentJsonSchemaDrafts((prev) => {
|
||||
if (prev[nodeId] !== undefined) return prev;
|
||||
if (rawSchema === undefined || rawSchema === null) {
|
||||
return { ...prev, [nodeId]: '' };
|
||||
}
|
||||
|
||||
try {
|
||||
return { ...prev, [nodeId]: JSON.stringify(rawSchema, null, 2) };
|
||||
} catch {
|
||||
return { ...prev, [nodeId]: String(rawSchema) };
|
||||
}
|
||||
});
|
||||
|
||||
setAgentJsonSchemaErrors((prev) => {
|
||||
if (prev[nodeId] !== undefined) return prev;
|
||||
return { ...prev, [nodeId]: validateJsonSchemaConfig(rawSchema) };
|
||||
});
|
||||
}, [selectedNode]);
|
||||
|
||||
useEffect(() => {
|
||||
const loadAgentDetails = async () => {
|
||||
if (!agentId) return;
|
||||
@@ -655,6 +832,8 @@ function WorkflowBuilderInner() {
|
||||
);
|
||||
setWorkflowName(nextWorkflowName);
|
||||
setWorkflowDescription(nextWorkflowDescription);
|
||||
setAgentJsonSchemaDrafts({});
|
||||
setAgentJsonSchemaErrors({});
|
||||
setNodes(mappedNodes);
|
||||
setEdges(mappedEdges);
|
||||
setSavedWorkflowSignature(
|
||||
@@ -711,6 +890,33 @@ function WorkflowBuilderInner() {
|
||||
`Agent "${node.data?.title || node.id}" must have a model selected`,
|
||||
);
|
||||
}
|
||||
|
||||
const hasSchema =
|
||||
config?.json_schema !== undefined && config?.json_schema !== null;
|
||||
if (hasSchema && config?.model_id) {
|
||||
const selectedModel = availableModels.find(
|
||||
(model) => model.id === config.model_id,
|
||||
);
|
||||
if (selectedModel && !selectedModel.supports_structured_output) {
|
||||
errors.push(
|
||||
`Agent "${node.data?.title || node.id}" selected model does not support structured output`,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
const schemaValidationError = validateJsonSchemaConfig(
|
||||
config?.json_schema,
|
||||
);
|
||||
const draftSchemaError = agentJsonSchemaErrors[node.id];
|
||||
const effectiveSchemaError =
|
||||
draftSchemaError !== undefined
|
||||
? draftSchemaError
|
||||
: schemaValidationError;
|
||||
if (effectiveSchemaError) {
|
||||
errors.push(
|
||||
`Agent "${node.data?.title || node.id}" JSON schema ${effectiveSchemaError}`,
|
||||
);
|
||||
}
|
||||
});
|
||||
|
||||
if (startNodes.length === 1) {
|
||||
@@ -743,6 +949,7 @@ function WorkflowBuilderInner() {
|
||||
const conditionNodes = nodes.filter((n) => n.type === 'condition');
|
||||
conditionNodes.forEach((node) => {
|
||||
const conditionTitle = node.data?.title || node.id;
|
||||
const conditionMode = node.data?.config?.mode || 'simple';
|
||||
const cases = (node.data?.config?.cases || []) as ConditionCase[];
|
||||
if (
|
||||
!cases.length ||
|
||||
@@ -831,6 +1038,16 @@ function WorkflowBuilderInner() {
|
||||
`Condition "${conditionTitle}" case "${handle}" has a branch connection but no expression`,
|
||||
);
|
||||
}
|
||||
if (conditionMode === 'simple' && hasExpression) {
|
||||
const parsedCondition = parseSimpleCel(
|
||||
conditionCase.expression || '',
|
||||
);
|
||||
if (!parsedCondition.variable.trim()) {
|
||||
errors.push(
|
||||
`Condition "${conditionTitle}" case "${handle}" must specify a variable in Simple mode`,
|
||||
);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
outgoing.forEach((edge) => {
|
||||
@@ -844,7 +1061,7 @@ function WorkflowBuilderInner() {
|
||||
});
|
||||
|
||||
return errors;
|
||||
}, [workflowName, nodes, edges]);
|
||||
}, [workflowName, nodes, edges, agentJsonSchemaErrors, availableModels]);
|
||||
|
||||
const canManageAgent = Boolean(currentAgentId || currentAgent.id);
|
||||
const effectiveAgentId = currentAgentId || currentAgent.id || '';
|
||||
@@ -1077,6 +1294,42 @@ function WorkflowBuilderInner() {
|
||||
],
|
||||
);
|
||||
|
||||
const selectedAgentJsonSchemaText = useMemo(() => {
|
||||
if (!selectedNode || selectedNode.type !== 'agent') return '';
|
||||
|
||||
const draft = agentJsonSchemaDrafts[selectedNode.id];
|
||||
if (draft !== undefined) return draft;
|
||||
|
||||
const schema = selectedNode.data.config?.json_schema;
|
||||
if (schema === undefined || schema === null) return '';
|
||||
|
||||
try {
|
||||
return JSON.stringify(schema, null, 2);
|
||||
} catch {
|
||||
return String(schema);
|
||||
}
|
||||
}, [selectedNode, agentJsonSchemaDrafts]);
|
||||
|
||||
const selectedAgentJsonSchemaError = useMemo(() => {
|
||||
if (!selectedNode || selectedNode.type !== 'agent') return null;
|
||||
|
||||
const cachedError = agentJsonSchemaErrors[selectedNode.id];
|
||||
if (cachedError !== undefined) return cachedError;
|
||||
|
||||
return validateJsonSchemaConfig(selectedNode.data.config?.json_schema);
|
||||
}, [selectedNode, agentJsonSchemaErrors]);
|
||||
|
||||
const selectedAgentModelSupportsStructuredOutput = useMemo(() => {
|
||||
if (!selectedNode || selectedNode.type !== 'agent') return true;
|
||||
const modelId = selectedNode.data.config?.model_id;
|
||||
if (!modelId) return true;
|
||||
|
||||
const selectedModel = availableModels.find((model) => model.id === modelId);
|
||||
if (!selectedModel) return true;
|
||||
|
||||
return selectedModel.supports_structured_output;
|
||||
}, [selectedNode, availableModels]);
|
||||
|
||||
return (
|
||||
<>
|
||||
<MobileBlocker />
|
||||
@@ -1577,7 +1830,7 @@ function WorkflowBuilderInner() {
|
||||
nodes={nodes}
|
||||
edges={edges}
|
||||
selectedNodeId={selectedNode.id}
|
||||
placeholder="Use {{variable}} for dynamic content"
|
||||
placeholder="Use {{ agent.variable }} for dynamic content"
|
||||
/>
|
||||
<div>
|
||||
<label className="mb-2 block text-sm font-medium text-gray-700 dark:text-gray-300">
|
||||
@@ -1589,14 +1842,15 @@ function WorkflowBuilderInner() {
|
||||
selectedNode.data.config
|
||||
?.output_variable || ''
|
||||
}
|
||||
onChange={(e) =>
|
||||
onChange={(e) => {
|
||||
const nextOutputVariable = e.target.value;
|
||||
handleUpdateNodeData({
|
||||
config: {
|
||||
...(selectedNode.data.config || {}),
|
||||
output_variable: e.target.value,
|
||||
output_variable: nextOutputVariable,
|
||||
},
|
||||
})
|
||||
}
|
||||
});
|
||||
}}
|
||||
className="border-light-silver focus:ring-purple-30 w-full rounded-xl border bg-white px-3 py-2 text-sm transition-all outline-none focus:ring-2 dark:border-[#3A3A3A] dark:bg-[#2C2C2C] dark:text-white"
|
||||
placeholder="Variable name for output"
|
||||
/>
|
||||
@@ -1651,6 +1905,48 @@ function WorkflowBuilderInner() {
|
||||
emptyText="No tools available"
|
||||
/>
|
||||
</div>
|
||||
<div>
|
||||
<label className="mb-2 block text-sm font-medium text-gray-700 dark:text-gray-300">
|
||||
Structured Output (JSON Schema)
|
||||
</label>
|
||||
{!selectedAgentModelSupportsStructuredOutput && (
|
||||
<p className="mb-2 text-xs text-red-600 dark:text-red-400">
|
||||
Selected model does not support structured
|
||||
output.
|
||||
</p>
|
||||
)}
|
||||
<textarea
|
||||
value={selectedAgentJsonSchemaText}
|
||||
onChange={(e) =>
|
||||
handleAgentJsonSchemaChange(
|
||||
e.target.value,
|
||||
)
|
||||
}
|
||||
className="border-light-silver focus:ring-purple-30 w-full rounded-xl border bg-white px-3 py-2 font-mono text-xs transition-all outline-none focus:ring-2 dark:border-[#3A3A3A] dark:bg-[#2C2C2C] dark:text-white"
|
||||
rows={8}
|
||||
placeholder={`{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"summary": { "type": "string" }
|
||||
},
|
||||
"required": ["summary"]
|
||||
}`}
|
||||
/>
|
||||
{selectedAgentJsonSchemaText.trim() !==
|
||||
'' && (
|
||||
<p
|
||||
className={`mt-2 text-xs ${
|
||||
selectedAgentJsonSchemaError
|
||||
? 'text-red-600 dark:text-red-400'
|
||||
: 'text-green-600 dark:text-green-400'
|
||||
}`}
|
||||
>
|
||||
{selectedAgentJsonSchemaError
|
||||
? `Invalid JSON schema: ${selectedAgentJsonSchemaError}`
|
||||
: 'Valid JSON schema'}
|
||||
</p>
|
||||
)}
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
|
||||
|
||||
@@ -87,7 +87,9 @@ function ExecutionDetails({
|
||||
|
||||
const formatValue = (value: unknown): string => {
|
||||
if (typeof value === 'string') return value;
|
||||
return JSON.stringify(value, null, 2);
|
||||
if (value === undefined) return '';
|
||||
const formatted = JSON.stringify(value, null, 2);
|
||||
return formatted ?? String(value);
|
||||
};
|
||||
|
||||
return (
|
||||
@@ -136,6 +138,11 @@ function ExecutionDetails({
|
||||
if (text.length <= maxLength) return text;
|
||||
return text.slice(0, maxLength) + '...';
|
||||
};
|
||||
const hasOutput =
|
||||
step.output !== undefined &&
|
||||
step.output !== null &&
|
||||
formatValue(step.output) !== '';
|
||||
const formattedOutput = hasOutput ? formatValue(step.output) : '';
|
||||
|
||||
return (
|
||||
<div
|
||||
@@ -171,15 +178,15 @@ function ExecutionDetails({
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
{(step.output || step.error || stateVars.length > 0) && (
|
||||
{(hasOutput || step.error || stateVars.length > 0) && (
|
||||
<div className="mt-3 space-y-2 text-sm">
|
||||
{step.output && (
|
||||
{hasOutput && (
|
||||
<div className="rounded-lg bg-white p-2 dark:bg-[#2A2A2A]">
|
||||
<span className="font-medium text-gray-600 dark:text-gray-400">
|
||||
Output:{' '}
|
||||
</span>
|
||||
<span className="wrap-break-word whitespace-pre-wrap text-gray-900 dark:text-gray-100">
|
||||
{truncateText(step.output, 300)}
|
||||
{truncateText(formattedOutput, 300)}
|
||||
</span>
|
||||
</div>
|
||||
)}
|
||||
@@ -254,7 +261,7 @@ function WorkflowMiniMap({
|
||||
return step?.status || 'pending';
|
||||
};
|
||||
|
||||
const getStatusColor = (nodeId: string, nodeType: string) => {
|
||||
const getStatusColor = (nodeId: string) => {
|
||||
const status = getNodeStatus(nodeId);
|
||||
const isActive = nodeId === activeNodeId;
|
||||
|
||||
@@ -275,18 +282,28 @@ function WorkflowMiniMap({
|
||||
};
|
||||
|
||||
const executedOrder = new Map(executionSteps.map((s, i) => [s.nodeId, i]));
|
||||
const sortedNodes = [...nodes].sort((a, b) => {
|
||||
const aIdx = executedOrder.get(a.id);
|
||||
const bIdx = executedOrder.get(b.id);
|
||||
if (aIdx !== undefined && bIdx !== undefined) return aIdx - bIdx;
|
||||
if (aIdx !== undefined) return -1;
|
||||
if (bIdx !== undefined) return 1;
|
||||
if (a.type === 'start') return -1;
|
||||
if (b.type === 'start') return 1;
|
||||
if (a.type === 'end') return 1;
|
||||
if (b.type === 'end') return -1;
|
||||
return (a.position?.y || 0) - (b.position?.y || 0);
|
||||
});
|
||||
const startNode = nodes.find((node) => node.type === 'start');
|
||||
const visibleNodeIds = new Set(executionSteps.map((step) => step.nodeId));
|
||||
if (activeNodeId) {
|
||||
visibleNodeIds.add(activeNodeId);
|
||||
}
|
||||
if (startNode) {
|
||||
visibleNodeIds.add(startNode.id);
|
||||
}
|
||||
|
||||
const sortedNodes = nodes
|
||||
.filter((node) => visibleNodeIds.has(node.id))
|
||||
.sort((a, b) => {
|
||||
if (a.type === 'start') return -1;
|
||||
if (b.type === 'start') return 1;
|
||||
|
||||
const aIdx = executedOrder.get(a.id);
|
||||
const bIdx = executedOrder.get(b.id);
|
||||
if (aIdx !== undefined && bIdx !== undefined) return aIdx - bIdx;
|
||||
if (aIdx !== undefined) return -1;
|
||||
if (bIdx !== undefined) return 1;
|
||||
return (a.position?.y || 0) - (b.position?.y || 0);
|
||||
});
|
||||
|
||||
const hasStepData = (nodeId: string) => {
|
||||
const step = executionSteps.find((s) => s.nodeId === nodeId);
|
||||
@@ -306,7 +323,7 @@ function WorkflowMiniMap({
|
||||
disabled={!hasStepData(node.id)}
|
||||
className={cn(
|
||||
'flex h-12 w-full items-center gap-2 rounded-lg border px-3 text-xs transition-all',
|
||||
getStatusColor(node.id, node.type),
|
||||
getStatusColor(node.id),
|
||||
hasStepData(node.id) && 'cursor-pointer hover:opacity-80',
|
||||
)}
|
||||
>
|
||||
@@ -533,6 +550,10 @@ export default function WorkflowPreview({
|
||||
const querySteps = query.executionSteps || [];
|
||||
const hasResponse = !!(query.response || query.error);
|
||||
const isLastQuery = index === queries.length - 1;
|
||||
const isStreamingLastQuery =
|
||||
status === 'loading' && isLastQuery;
|
||||
const shouldShowThought =
|
||||
!isStreamingLastQuery && Boolean(query.thought);
|
||||
const isOpen =
|
||||
openDetailsIndex === index ||
|
||||
(!hasResponse && isLastQuery && querySteps.length > 0);
|
||||
@@ -567,17 +588,19 @@ export default function WorkflowPreview({
|
||||
|
||||
{/* Response bubble */}
|
||||
{(query.response ||
|
||||
query.thought ||
|
||||
shouldShowThought ||
|
||||
query.tool_calls) && (
|
||||
<ConversationBubble
|
||||
className={isLastQuery ? 'mb-32' : 'mb-7'}
|
||||
message={query.response}
|
||||
type="ANSWER"
|
||||
thought={query.thought}
|
||||
thought={
|
||||
shouldShowThought ? query.thought : undefined
|
||||
}
|
||||
sources={query.sources}
|
||||
toolCalls={query.tool_calls}
|
||||
feedback={query.feedback}
|
||||
isStreaming={status === 'loading' && isLastQuery}
|
||||
isStreaming={isStreamingLastQuery}
|
||||
/>
|
||||
)}
|
||||
|
||||
|
||||
@@ -9,10 +9,71 @@ import {
|
||||
} from '@/components/ui/popover';
|
||||
|
||||
interface WorkflowVariable {
|
||||
name: string;
|
||||
label: string;
|
||||
templatePath: string;
|
||||
section: string;
|
||||
}
|
||||
|
||||
const GLOBAL_CONTEXT_VARIABLES: WorkflowVariable[] = [
|
||||
{
|
||||
label: 'source.content',
|
||||
templatePath: 'source.content',
|
||||
section: 'Global context',
|
||||
},
|
||||
{
|
||||
label: 'source.summaries',
|
||||
templatePath: 'source.summaries',
|
||||
section: 'Global context',
|
||||
},
|
||||
{
|
||||
label: 'source.documents',
|
||||
templatePath: 'source.documents',
|
||||
section: 'Global context',
|
||||
},
|
||||
{
|
||||
label: 'source.count',
|
||||
templatePath: 'source.count',
|
||||
section: 'Global context',
|
||||
},
|
||||
{
|
||||
label: 'system.date',
|
||||
templatePath: 'system.date',
|
||||
section: 'Global context',
|
||||
},
|
||||
{
|
||||
label: 'system.time',
|
||||
templatePath: 'system.time',
|
||||
section: 'Global context',
|
||||
},
|
||||
{
|
||||
label: 'system.timestamp',
|
||||
templatePath: 'system.timestamp',
|
||||
section: 'Global context',
|
||||
},
|
||||
{
|
||||
label: 'system.request_id',
|
||||
templatePath: 'system.request_id',
|
||||
section: 'Global context',
|
||||
},
|
||||
{
|
||||
label: 'system.user_id',
|
||||
templatePath: 'system.user_id',
|
||||
section: 'Global context',
|
||||
},
|
||||
];
|
||||
|
||||
function toAgentTemplatePath(variableName: string): string {
|
||||
const trimmed = variableName.trim();
|
||||
if (!trimmed) return 'agent';
|
||||
|
||||
if (/^[A-Za-z_][A-Za-z0-9_]*$/.test(trimmed)) {
|
||||
return `agent.${trimmed}`;
|
||||
}
|
||||
|
||||
const escaped = trimmed.replace(/\\/g, '\\\\').replace(/'/g, "\\'");
|
||||
return `agent['${escaped}']`;
|
||||
}
|
||||
|
||||
function getUpstreamNodeIds(nodeId: string, edges: Edge[]): Set<string> {
|
||||
const upstream = new Set<string>();
|
||||
const queue = [nodeId];
|
||||
@@ -36,32 +97,69 @@ function extractUpstreamVariables(
|
||||
selectedNodeId: string,
|
||||
): WorkflowVariable[] {
|
||||
const variables: WorkflowVariable[] = [
|
||||
{ name: 'query', section: 'Workflow input' },
|
||||
{ name: 'chat_history', section: 'Workflow input' },
|
||||
{
|
||||
label: 'agent.query',
|
||||
templatePath: 'agent.query',
|
||||
section: 'Workflow input',
|
||||
},
|
||||
{
|
||||
label: 'agent.chat_history',
|
||||
templatePath: 'agent.chat_history',
|
||||
section: 'Workflow input',
|
||||
},
|
||||
...GLOBAL_CONTEXT_VARIABLES,
|
||||
];
|
||||
const seen = new Set(['query', 'chat_history']);
|
||||
const seen = new Set(variables.map((variable) => variable.templatePath));
|
||||
const upstreamIds = getUpstreamNodeIds(selectedNodeId, edges);
|
||||
|
||||
for (const node of nodes) {
|
||||
if (!upstreamIds.has(node.id)) continue;
|
||||
|
||||
if (node.type === 'agent' && node.data?.config?.output_variable) {
|
||||
const name = node.data.config.output_variable;
|
||||
if (!seen.has(name)) {
|
||||
seen.add(name);
|
||||
if (node.type === 'agent') {
|
||||
const defaultOutputTemplatePath = toAgentTemplatePath(
|
||||
`node_${node.id}_output`,
|
||||
);
|
||||
if (!seen.has(defaultOutputTemplatePath)) {
|
||||
seen.add(defaultOutputTemplatePath);
|
||||
variables.push({
|
||||
name,
|
||||
label: defaultOutputTemplatePath,
|
||||
templatePath: defaultOutputTemplatePath,
|
||||
section: node.data.title || node.data.label || 'Agent',
|
||||
});
|
||||
}
|
||||
|
||||
const outputVariable = String(
|
||||
node.data?.config?.output_variable || '',
|
||||
).trim();
|
||||
if (outputVariable) {
|
||||
const templatePath = toAgentTemplatePath(outputVariable);
|
||||
if (!seen.has(templatePath)) {
|
||||
seen.add(templatePath);
|
||||
variables.push({
|
||||
label: templatePath,
|
||||
templatePath,
|
||||
section: node.data.title || node.data.label || 'Agent',
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
if (node.type === 'state' && node.data?.variable) {
|
||||
const name = node.data.variable;
|
||||
if (!seen.has(name)) {
|
||||
seen.add(name);
|
||||
|
||||
if (node.type === 'state') {
|
||||
const operations = node.data?.config?.operations;
|
||||
if (!Array.isArray(operations)) continue;
|
||||
|
||||
for (const operation of operations) {
|
||||
const targetVariable = String(operation?.target_variable || '').trim();
|
||||
if (!targetVariable) continue;
|
||||
|
||||
const templatePath = toAgentTemplatePath(targetVariable);
|
||||
if (seen.has(templatePath)) continue;
|
||||
|
||||
seen.add(templatePath);
|
||||
variables.push({
|
||||
name,
|
||||
section: 'Set State',
|
||||
label: templatePath,
|
||||
templatePath,
|
||||
section: node.data.title || node.data.label || 'Set State',
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -106,14 +204,16 @@ function VariableListWithSearch({
|
||||
onSelect,
|
||||
}: {
|
||||
variables: WorkflowVariable[];
|
||||
onSelect: (name: string) => void;
|
||||
onSelect: (templatePath: string) => void;
|
||||
}) {
|
||||
const [search, setSearch] = useState('');
|
||||
|
||||
const filtered = useMemo(
|
||||
() =>
|
||||
variables.filter((v) =>
|
||||
v.name.toLowerCase().includes(search.toLowerCase()),
|
||||
`${v.label} ${v.templatePath}`
|
||||
.toLowerCase()
|
||||
.includes(search.toLowerCase()),
|
||||
),
|
||||
[variables, search],
|
||||
);
|
||||
@@ -146,17 +246,17 @@ function VariableListWithSearch({
|
||||
</div>
|
||||
{vars.map((v) => (
|
||||
<button
|
||||
key={v.name}
|
||||
key={`${section}-${v.templatePath}`}
|
||||
onMouseDown={(e) => {
|
||||
e.preventDefault();
|
||||
e.stopPropagation();
|
||||
onSelect(v.name);
|
||||
onSelect(v.templatePath);
|
||||
}}
|
||||
className="flex w-full cursor-pointer items-center gap-2 px-3 py-1.5 text-left text-sm transition-colors hover:bg-gray-50 dark:hover:bg-[#383838]"
|
||||
>
|
||||
<Braces className="text-violets-are-blue h-3.5 w-3.5 shrink-0" />
|
||||
<span className="truncate font-medium text-gray-800 dark:text-gray-200">
|
||||
{v.name}
|
||||
{v.label}
|
||||
</span>
|
||||
</button>
|
||||
))}
|
||||
@@ -206,7 +306,9 @@ export default function PromptTextArea({
|
||||
const filtered = useMemo(
|
||||
() =>
|
||||
variables.filter((v) =>
|
||||
v.name.toLowerCase().includes(filterText.toLowerCase()),
|
||||
`${v.label} ${v.templatePath}`
|
||||
.toLowerCase()
|
||||
.includes(filterText.toLowerCase()),
|
||||
),
|
||||
[variables, filterText],
|
||||
);
|
||||
@@ -217,10 +319,12 @@ export default function PromptTextArea({
|
||||
|
||||
const cursorPos = textarea.selectionStart;
|
||||
const textBeforeCursor = value.slice(0, cursorPos);
|
||||
const triggerMatch = textBeforeCursor.match(/\{\{(\w*)$/);
|
||||
const triggerMatch = textBeforeCursor.match(
|
||||
/\{\{\s*([A-Za-z0-9_.[\]'"]*)$/,
|
||||
);
|
||||
|
||||
if (triggerMatch) {
|
||||
setFilterText(triggerMatch[1]);
|
||||
setFilterText(triggerMatch[1].trim());
|
||||
setCursorInsertPos(cursorPos);
|
||||
|
||||
const wrapper = wrapperRef.current;
|
||||
@@ -237,15 +341,17 @@ export default function PromptTextArea({
|
||||
}, [value]);
|
||||
|
||||
const insertVariable = useCallback(
|
||||
(varName: string) => {
|
||||
(templatePath: string) => {
|
||||
if (cursorInsertPos === null) return;
|
||||
|
||||
const textBeforeCursor = value.slice(0, cursorInsertPos);
|
||||
const triggerMatch = textBeforeCursor.match(/\{\{(\w*)$/);
|
||||
const triggerMatch = textBeforeCursor.match(
|
||||
/\{\{\s*([A-Za-z0-9_.[\]'"]*)$/,
|
||||
);
|
||||
if (!triggerMatch) return;
|
||||
|
||||
const startPos = cursorInsertPos - triggerMatch[0].length;
|
||||
const insertion = `{{${varName}}}`;
|
||||
const insertion = `{{ ${templatePath} }}`;
|
||||
const newValue =
|
||||
value.slice(0, startPos) + insertion + value.slice(cursorInsertPos);
|
||||
|
||||
@@ -262,10 +368,10 @@ export default function PromptTextArea({
|
||||
);
|
||||
|
||||
const insertVariableFromButton = useCallback(
|
||||
(varName: string) => {
|
||||
(templatePath: string) => {
|
||||
const textarea = textareaRef.current;
|
||||
const cursorPos = textarea?.selectionStart ?? value.length;
|
||||
const insertion = `{{${varName}}}`;
|
||||
const insertion = `{{ ${templatePath} }}`;
|
||||
const newValue =
|
||||
value.slice(0, cursorPos) + insertion + value.slice(cursorPos);
|
||||
|
||||
|
||||
@@ -85,10 +85,10 @@ export const AgentNode = memo(function AgentNode({
|
||||
)}
|
||||
{config.output_variable && (
|
||||
<div
|
||||
className="truncate text-xs text-green-600 dark:text-green-400"
|
||||
title={`Output ➔ ${config.output_variable}`}
|
||||
className="truncate text-xs text-gray-500 dark:text-gray-400"
|
||||
title={`Output: ${config.output_variable}`}
|
||||
>
|
||||
Output ➔ {config.output_variable}
|
||||
Output: {config.output_variable}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
@@ -13,7 +13,7 @@ export interface WorkflowExecutionStep {
|
||||
startedAt?: number;
|
||||
completedAt?: number;
|
||||
stateSnapshot?: Record<string, unknown>;
|
||||
output?: string;
|
||||
output?: unknown;
|
||||
error?: string;
|
||||
}
|
||||
|
||||
@@ -321,7 +321,9 @@ export const workflowPreviewSlice = createSlice({
|
||||
}
|
||||
|
||||
const querySteps = state.queries[index].executionSteps!;
|
||||
const existingIndex = querySteps.findIndex((s) => s.nodeId === step.nodeId);
|
||||
const existingIndex = querySteps.findIndex(
|
||||
(s) => s.nodeId === step.nodeId,
|
||||
);
|
||||
|
||||
const updatedStep: WorkflowExecutionStep = {
|
||||
nodeId: step.nodeId,
|
||||
@@ -332,7 +334,10 @@ export const workflowPreviewSlice = createSlice({
|
||||
stateSnapshot: step.stateSnapshot,
|
||||
output: step.output,
|
||||
error: step.error,
|
||||
startedAt: existingIndex !== -1 ? querySteps[existingIndex].startedAt : Date.now(),
|
||||
startedAt:
|
||||
existingIndex !== -1
|
||||
? querySteps[existingIndex].startedAt
|
||||
: Date.now(),
|
||||
completedAt:
|
||||
step.status === 'completed' || step.status === 'failed'
|
||||
? Date.now()
|
||||
@@ -342,7 +347,8 @@ export const workflowPreviewSlice = createSlice({
|
||||
};
|
||||
|
||||
if (existingIndex !== -1) {
|
||||
updatedStep.stateSnapshot = step.stateSnapshot ?? querySteps[existingIndex].stateSnapshot;
|
||||
updatedStep.stateSnapshot =
|
||||
step.stateSnapshot ?? querySteps[existingIndex].stateSnapshot;
|
||||
updatedStep.output = step.output ?? querySteps[existingIndex].output;
|
||||
updatedStep.error = step.error ?? querySteps[existingIndex].error;
|
||||
querySteps[existingIndex] = updatedStep;
|
||||
@@ -350,7 +356,9 @@ export const workflowPreviewSlice = createSlice({
|
||||
querySteps.push(updatedStep);
|
||||
}
|
||||
|
||||
const globalIndex = state.executionSteps.findIndex((s) => s.nodeId === step.nodeId);
|
||||
const globalIndex = state.executionSteps.findIndex(
|
||||
(s) => s.nodeId === step.nodeId,
|
||||
);
|
||||
if (globalIndex !== -1) {
|
||||
state.executionSteps[globalIndex] = updatedStep;
|
||||
} else {
|
||||
|
||||
332
tests/agents/test_workflow_engine.py
Normal file
332
tests/agents/test_workflow_engine.py
Normal file
@@ -0,0 +1,332 @@
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import pytest
|
||||
|
||||
from application.api.user.workflows import routes as workflow_routes
|
||||
from application.agents.workflows.node_agent import WorkflowNodeAgentFactory
|
||||
from application.agents.workflows.schemas import (
|
||||
NodeType,
|
||||
Workflow,
|
||||
WorkflowGraph,
|
||||
WorkflowNode,
|
||||
)
|
||||
from application.agents.workflows.workflow_engine import WorkflowEngine
|
||||
from application.api.user.workflows.routes import validate_workflow_structure
|
||||
|
||||
|
||||
class StubNodeAgent:
|
||||
def __init__(self, events):
|
||||
self.events = events
|
||||
|
||||
def gen(self, _prompt):
|
||||
yield from self.events
|
||||
|
||||
|
||||
def create_engine() -> WorkflowEngine:
|
||||
graph = WorkflowGraph(workflow=Workflow(name="Engine Test"), nodes=[], edges=[])
|
||||
agent = SimpleNamespace(
|
||||
endpoint="stream",
|
||||
llm_name="openai",
|
||||
model_id="gpt-4o-mini",
|
||||
api_key="test-key",
|
||||
chat_history=[],
|
||||
decoded_token={"sub": "user-1"},
|
||||
)
|
||||
return WorkflowEngine(graph, agent)
|
||||
|
||||
|
||||
def create_agent_node(
|
||||
node_id: str,
|
||||
output_variable: str = "",
|
||||
json_schema: Optional[Dict[str, Any]] = None,
|
||||
) -> WorkflowNode:
|
||||
config = {
|
||||
"agent_type": "classic",
|
||||
"system_prompt": "You are a helpful assistant.",
|
||||
"prompt_template": "",
|
||||
"stream_to_user": False,
|
||||
"tools": [],
|
||||
}
|
||||
if output_variable:
|
||||
config["output_variable"] = output_variable
|
||||
if json_schema is not None:
|
||||
config["json_schema"] = json_schema
|
||||
|
||||
return WorkflowNode(
|
||||
id=node_id,
|
||||
workflow_id="workflow-1",
|
||||
type=NodeType.AGENT,
|
||||
title="Agent",
|
||||
position={"x": 0, "y": 0},
|
||||
config=config,
|
||||
)
|
||||
|
||||
|
||||
def test_execute_agent_node_saves_structured_output_as_json(monkeypatch):
|
||||
engine = create_engine()
|
||||
node = create_agent_node(
|
||||
node_id="agent_1",
|
||||
output_variable="result",
|
||||
json_schema={"type": "object"},
|
||||
)
|
||||
node_events = [
|
||||
{"answer": '{"summary":"ok",', "structured": True},
|
||||
{"answer": '"score":2}', "structured": True},
|
||||
]
|
||||
|
||||
monkeypatch.setattr(
|
||||
WorkflowNodeAgentFactory,
|
||||
"create",
|
||||
staticmethod(lambda **kwargs: StubNodeAgent(node_events)),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"application.core.model_utils.get_api_key_for_provider",
|
||||
lambda _provider: None,
|
||||
)
|
||||
|
||||
list(engine._execute_agent_node(node))
|
||||
|
||||
expected_output = {"summary": "ok", "score": 2}
|
||||
assert engine.state["node_agent_1_output"] == expected_output
|
||||
assert engine.state["result"] == expected_output
|
||||
|
||||
|
||||
def test_execute_agent_node_normalizes_wrapped_schema_before_agent_create(monkeypatch):
|
||||
engine = create_engine()
|
||||
node = create_agent_node(
|
||||
node_id="agent_wrapped",
|
||||
json_schema={"schema": {"type": "object"}},
|
||||
)
|
||||
node_events = [{"answer": '{"summary":"ok"}', "structured": True}]
|
||||
captured: Dict[str, Any] = {}
|
||||
|
||||
def create_node_agent(**kwargs):
|
||||
captured["json_schema"] = kwargs.get("json_schema")
|
||||
return StubNodeAgent(node_events)
|
||||
|
||||
monkeypatch.setattr(
|
||||
WorkflowNodeAgentFactory,
|
||||
"create",
|
||||
staticmethod(create_node_agent),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"application.core.model_utils.get_api_key_for_provider",
|
||||
lambda _provider: None,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"application.core.model_utils.get_model_capabilities",
|
||||
lambda _model_id: {"supports_structured_output": True},
|
||||
)
|
||||
|
||||
list(engine._execute_agent_node(node))
|
||||
|
||||
assert captured["json_schema"] == {"type": "object"}
|
||||
assert engine.state["node_agent_wrapped_output"] == {"summary": "ok"}
|
||||
|
||||
|
||||
def test_execute_agent_node_falls_back_to_text_when_schema_not_configured(monkeypatch):
|
||||
engine = create_engine()
|
||||
node = create_agent_node(node_id="agent_2", output_variable="result")
|
||||
node_events = [{"answer": "plain text answer"}]
|
||||
|
||||
monkeypatch.setattr(
|
||||
WorkflowNodeAgentFactory,
|
||||
"create",
|
||||
staticmethod(lambda **kwargs: StubNodeAgent(node_events)),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"application.core.model_utils.get_api_key_for_provider",
|
||||
lambda _provider: None,
|
||||
)
|
||||
|
||||
list(engine._execute_agent_node(node))
|
||||
|
||||
assert engine.state["node_agent_2_output"] == "plain text answer"
|
||||
assert engine.state["result"] == "plain text answer"
|
||||
|
||||
|
||||
def test_validate_workflow_structure_rejects_invalid_agent_json_schema():
|
||||
nodes = [
|
||||
{"id": "start", "type": "start", "title": "Start", "data": {}},
|
||||
{
|
||||
"id": "agent",
|
||||
"type": "agent",
|
||||
"title": "Agent",
|
||||
"data": {"json_schema": "invalid"},
|
||||
},
|
||||
{"id": "end", "type": "end", "title": "End", "data": {}},
|
||||
]
|
||||
edges = [
|
||||
{"id": "edge_1", "source": "start", "target": "agent"},
|
||||
{"id": "edge_2", "source": "agent", "target": "end"},
|
||||
]
|
||||
|
||||
errors = validate_workflow_structure(nodes, edges)
|
||||
|
||||
assert any(
|
||||
"Agent node 'Agent' JSON schema must be a valid JSON object" in err
|
||||
for err in errors
|
||||
)
|
||||
|
||||
|
||||
def test_validate_workflow_structure_accepts_valid_agent_json_schema():
|
||||
nodes = [
|
||||
{"id": "start", "type": "start", "title": "Start", "data": {}},
|
||||
{
|
||||
"id": "agent",
|
||||
"type": "agent",
|
||||
"title": "Agent",
|
||||
"data": {"json_schema": {"type": "object"}},
|
||||
},
|
||||
{"id": "end", "type": "end", "title": "End", "data": {}},
|
||||
]
|
||||
edges = [
|
||||
{"id": "edge_1", "source": "start", "target": "agent"},
|
||||
{"id": "edge_2", "source": "agent", "target": "end"},
|
||||
]
|
||||
|
||||
errors = validate_workflow_structure(nodes, edges)
|
||||
|
||||
assert errors == []
|
||||
|
||||
|
||||
def test_validate_workflow_structure_accepts_wrapped_agent_json_schema():
|
||||
nodes = [
|
||||
{"id": "start", "type": "start", "title": "Start", "data": {}},
|
||||
{
|
||||
"id": "agent",
|
||||
"type": "agent",
|
||||
"title": "Agent",
|
||||
"data": {"json_schema": {"schema": {"type": "object"}}},
|
||||
},
|
||||
{"id": "end", "type": "end", "title": "End", "data": {}},
|
||||
]
|
||||
edges = [
|
||||
{"id": "edge_1", "source": "start", "target": "agent"},
|
||||
{"id": "edge_2", "source": "agent", "target": "end"},
|
||||
]
|
||||
|
||||
errors = validate_workflow_structure(nodes, edges)
|
||||
|
||||
assert errors == []
|
||||
|
||||
|
||||
def test_validate_workflow_structure_accepts_output_variable_and_schema_together():
|
||||
nodes = [
|
||||
{"id": "start", "type": "start", "title": "Start", "data": {}},
|
||||
{
|
||||
"id": "agent",
|
||||
"type": "agent",
|
||||
"title": "Agent",
|
||||
"data": {
|
||||
"output_variable": "answer",
|
||||
"json_schema": {"type": "object"},
|
||||
},
|
||||
},
|
||||
{"id": "end", "type": "end", "title": "End", "data": {}},
|
||||
]
|
||||
edges = [
|
||||
{"id": "edge_1", "source": "start", "target": "agent"},
|
||||
{"id": "edge_2", "source": "agent", "target": "end"},
|
||||
]
|
||||
|
||||
errors = validate_workflow_structure(nodes, edges)
|
||||
|
||||
assert errors == []
|
||||
|
||||
|
||||
def test_validate_workflow_structure_rejects_unsupported_structured_output_model(
|
||||
monkeypatch,
|
||||
):
|
||||
monkeypatch.setattr(
|
||||
workflow_routes,
|
||||
"get_model_capabilities",
|
||||
lambda _model_id: {"supports_structured_output": False},
|
||||
)
|
||||
|
||||
nodes = [
|
||||
{"id": "start", "type": "start", "title": "Start", "data": {}},
|
||||
{
|
||||
"id": "agent",
|
||||
"type": "agent",
|
||||
"title": "Agent",
|
||||
"data": {
|
||||
"model_id": "some-model",
|
||||
"json_schema": {"type": "object"},
|
||||
},
|
||||
},
|
||||
{"id": "end", "type": "end", "title": "End", "data": {}},
|
||||
]
|
||||
edges = [
|
||||
{"id": "edge_1", "source": "start", "target": "agent"},
|
||||
{"id": "edge_2", "source": "agent", "target": "end"},
|
||||
]
|
||||
|
||||
errors = validate_workflow_structure(nodes, edges)
|
||||
|
||||
assert any(
|
||||
"Agent node 'Agent' selected model does not support structured output"
|
||||
in err
|
||||
for err in errors
|
||||
)
|
||||
|
||||
|
||||
def test_execute_agent_node_raises_when_structured_output_violates_schema(monkeypatch):
|
||||
engine = create_engine()
|
||||
node = create_agent_node(
|
||||
node_id="agent_3",
|
||||
json_schema={
|
||||
"type": "object",
|
||||
"properties": {"summary": {"type": "string"}},
|
||||
"required": ["summary"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
)
|
||||
node_events = [{"answer": '{"score":2}', "structured": True}]
|
||||
|
||||
monkeypatch.setattr(
|
||||
WorkflowNodeAgentFactory,
|
||||
"create",
|
||||
staticmethod(lambda **kwargs: StubNodeAgent(node_events)),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"application.core.model_utils.get_api_key_for_provider",
|
||||
lambda _provider: None,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"application.core.model_utils.get_model_capabilities",
|
||||
lambda _model_id: {"supports_structured_output": True},
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="Structured output did not match schema"):
|
||||
list(engine._execute_agent_node(node))
|
||||
|
||||
|
||||
def test_execute_agent_node_raises_when_schema_set_and_response_not_json(monkeypatch):
|
||||
engine = create_engine()
|
||||
node = create_agent_node(
|
||||
node_id="agent_4",
|
||||
json_schema={"type": "object"},
|
||||
)
|
||||
node_events = [{"answer": "not-json"}]
|
||||
|
||||
monkeypatch.setattr(
|
||||
WorkflowNodeAgentFactory,
|
||||
"create",
|
||||
staticmethod(lambda **kwargs: StubNodeAgent(node_events)),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"application.core.model_utils.get_api_key_for_provider",
|
||||
lambda _provider: None,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"application.core.model_utils.get_model_capabilities",
|
||||
lambda _model_id: {"supports_structured_output": True},
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="Structured output was expected but response was not valid JSON",
|
||||
):
|
||||
list(engine._execute_agent_node(node))
|
||||
63
tests/agents/test_workflow_template.py
Normal file
63
tests/agents/test_workflow_template.py
Normal file
@@ -0,0 +1,63 @@
|
||||
from types import SimpleNamespace
|
||||
|
||||
from application.agents.workflows.schemas import Workflow, WorkflowGraph
|
||||
from application.agents.workflows.workflow_engine import WorkflowEngine
|
||||
|
||||
|
||||
def create_engine() -> WorkflowEngine:
|
||||
graph = WorkflowGraph(workflow=Workflow(name="Template Test"), nodes=[], edges=[])
|
||||
agent = SimpleNamespace(
|
||||
user="user-1",
|
||||
request_id="req-1",
|
||||
retrieved_docs=[
|
||||
{"title": "Doc A", "text": "Summary A"},
|
||||
{"title": "Doc B", "text": "Summary B"},
|
||||
],
|
||||
)
|
||||
return WorkflowEngine(graph, agent)
|
||||
|
||||
|
||||
def test_workflow_template_supports_agent_namespace_and_legacy_variables():
|
||||
engine = create_engine()
|
||||
engine.state = {"query": "Hello", "chat_history": "[]", "ticket_id": 42}
|
||||
|
||||
rendered = engine._format_template(
|
||||
"{{ agent.query }}|{{ agent.ticket_id }}|{{ query }}|{{ ticket_id }}"
|
||||
)
|
||||
|
||||
assert rendered == "Hello|42|Hello|42"
|
||||
|
||||
|
||||
def test_workflow_template_supports_global_namespaces():
|
||||
engine = create_engine()
|
||||
engine.state = {"query": "Hello"}
|
||||
|
||||
rendered = engine._format_template(
|
||||
"{{ source.count }}|{{ source.summaries }}|{{ system.request_id }}"
|
||||
)
|
||||
|
||||
assert rendered.startswith("2|")
|
||||
assert "Doc A" in rendered
|
||||
assert "Summary A" in rendered
|
||||
assert rendered.endswith("|req-1")
|
||||
|
||||
|
||||
def test_workflow_template_handles_namespace_conflicts_with_agent_prefix():
|
||||
engine = create_engine()
|
||||
engine.state = {"source": "user-defined-source"}
|
||||
|
||||
rendered = engine._format_template(
|
||||
"{{ agent.source }}|{{ agent_source }}|{{ source.count }}"
|
||||
)
|
||||
|
||||
assert rendered.startswith("user-defined-source|user-defined-source|")
|
||||
|
||||
|
||||
def test_workflow_template_gracefully_handles_invalid_template_syntax():
|
||||
engine = create_engine()
|
||||
engine.state = {"query": "Hello"}
|
||||
|
||||
invalid_template = "{{ agent.query "
|
||||
rendered = engine._format_template(invalid_template)
|
||||
|
||||
assert rendered == invalid_template
|
||||
Reference in New Issue
Block a user