mirror of
https://github.com/arc53/DocsGPT.git
synced 2026-05-07 22:44:10 +00:00
Compare commits
10 Commits
sharepoint
...
dependabot
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4c7a6a78aa | ||
|
|
a6625ec5de | ||
|
|
1a2104f474 | ||
|
|
444abb8283 | ||
|
|
ee86537f21 | ||
|
|
17a736a927 | ||
|
|
6b5779054d | ||
|
|
14296632ef | ||
|
|
2a3f0e455a | ||
|
|
8aa44c415b |
2
.github/workflows/ci.yml
vendored
2
.github/workflows/ci.yml
vendored
@@ -25,7 +25,7 @@ jobs:
|
||||
|
||||
- name: Set up QEMU # Only needed for emulation, not for native arm64 builds
|
||||
if: matrix.platform == 'linux/arm64'
|
||||
uses: docker/setup-qemu-action@v3
|
||||
uses: docker/setup-qemu-action@v4
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
2
.github/workflows/cife.yml
vendored
2
.github/workflows/cife.yml
vendored
@@ -25,7 +25,7 @@ jobs:
|
||||
|
||||
- name: Set up QEMU # Only needed for emulation, not for native arm64 builds
|
||||
if: matrix.platform == 'linux/arm64'
|
||||
uses: docker/setup-qemu-action@v3
|
||||
uses: docker/setup-qemu-action@v4
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
@@ -27,7 +27,7 @@ jobs:
|
||||
|
||||
- name: Set up QEMU # Only needed for emulation, not for native arm64 builds
|
||||
if: matrix.platform == 'linux/arm64'
|
||||
uses: docker/setup-qemu-action@v3
|
||||
uses: docker/setup-qemu-action@v4
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -71,6 +71,7 @@ instance/
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
docs/public/_pagefind/
|
||||
|
||||
# PyBuilder
|
||||
target/
|
||||
|
||||
@@ -7,6 +7,10 @@ from bson.objectid import ObjectId
|
||||
|
||||
from application.agents.tools.tool_action_parser import ToolActionParser
|
||||
from application.agents.tools.tool_manager import ToolManager
|
||||
from application.core.json_schema_utils import (
|
||||
JsonSchemaValidationError,
|
||||
normalize_json_schema_payload,
|
||||
)
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.core.settings import settings
|
||||
from application.llm.handlers.handler_creator import LLMHandlerCreator
|
||||
@@ -23,6 +27,7 @@ class BaseAgent(ABC):
|
||||
llm_name: str,
|
||||
model_id: str,
|
||||
api_key: str,
|
||||
agent_id: Optional[str] = None,
|
||||
user_api_key: Optional[str] = None,
|
||||
prompt: str = "",
|
||||
chat_history: Optional[List[Dict]] = None,
|
||||
@@ -40,6 +45,7 @@ class BaseAgent(ABC):
|
||||
self.llm_name = llm_name
|
||||
self.model_id = model_id
|
||||
self.api_key = api_key
|
||||
self.agent_id = agent_id
|
||||
self.user_api_key = user_api_key
|
||||
self.prompt = prompt
|
||||
self.decoded_token = decoded_token or {}
|
||||
@@ -54,13 +60,19 @@ class BaseAgent(ABC):
|
||||
user_api_key=user_api_key,
|
||||
decoded_token=decoded_token,
|
||||
model_id=model_id,
|
||||
agent_id=agent_id,
|
||||
)
|
||||
self.retrieved_docs = retrieved_docs or []
|
||||
self.llm_handler = LLMHandlerCreator.create_handler(
|
||||
llm_name if llm_name else "default"
|
||||
)
|
||||
self.attachments = attachments or []
|
||||
self.json_schema = json_schema
|
||||
self.json_schema = None
|
||||
if json_schema is not None:
|
||||
try:
|
||||
self.json_schema = normalize_json_schema_payload(json_schema)
|
||||
except JsonSchemaValidationError as exc:
|
||||
logger.warning("Ignoring invalid JSON schema payload: %s", exc)
|
||||
self.limited_token_mode = limited_token_mode
|
||||
self.token_limit = token_limit
|
||||
self.limited_request_mode = limited_request_mode
|
||||
@@ -263,6 +275,11 @@ class BaseAgent(ABC):
|
||||
tool_config=tool_config,
|
||||
user_id=self.user,
|
||||
)
|
||||
resolved_arguments = (
|
||||
{"query_params": query_params, "headers": headers, "body": body}
|
||||
if tool_data["name"] == "api_tool"
|
||||
else parameters
|
||||
)
|
||||
if tool_data["name"] == "api_tool":
|
||||
logger.debug(
|
||||
f"Executing api: {action_name} with query_params: {query_params}, headers: {headers}, body: {body}"
|
||||
@@ -292,11 +309,19 @@ class BaseAgent(ABC):
|
||||
artifact_id = str(artifact_id).strip() if artifact_id is not None else ""
|
||||
if artifact_id:
|
||||
tool_call_data["artifact_id"] = artifact_id
|
||||
result_full = str(result)
|
||||
tool_call_data["resolved_arguments"] = resolved_arguments
|
||||
tool_call_data["result_full"] = result_full
|
||||
tool_call_data["result"] = (
|
||||
f"{str(result)[:50]}..." if len(str(result)) > 50 else result
|
||||
f"{result_full[:50]}..." if len(result_full) > 50 else result_full
|
||||
)
|
||||
|
||||
yield {"type": "tool_call", "data": {**tool_call_data, "status": "completed"}}
|
||||
stream_tool_call_data = {
|
||||
key: value
|
||||
for key, value in tool_call_data.items()
|
||||
if key not in {"result_full", "resolved_arguments"}
|
||||
}
|
||||
yield {"type": "tool_call", "data": {**stream_tool_call_data, "status": "completed"}}
|
||||
self.tool_calls.append(tool_call_data)
|
||||
|
||||
return result, call_id
|
||||
@@ -304,7 +329,11 @@ class BaseAgent(ABC):
|
||||
def _get_truncated_tool_calls(self):
|
||||
return [
|
||||
{
|
||||
**tool_call,
|
||||
"tool_name": tool_call.get("tool_name"),
|
||||
"call_id": tool_call.get("call_id"),
|
||||
"action_name": tool_call.get("action_name"),
|
||||
"arguments": tool_call.get("arguments"),
|
||||
"artifact_id": tool_call.get("artifact_id"),
|
||||
"result": (
|
||||
f"{str(tool_call['result'])[:50]}..."
|
||||
if len(str(tool_call["result"])) > 50
|
||||
@@ -576,6 +605,9 @@ class BaseAgent(ABC):
|
||||
self._validate_context_size(messages)
|
||||
|
||||
gen_kwargs = {"model": self.model_id, "messages": messages}
|
||||
if self.attachments:
|
||||
# Usage accounting only; stripped before provider invocation.
|
||||
gen_kwargs["_usage_attachments"] = self.attachments
|
||||
|
||||
if (
|
||||
hasattr(self.llm, "_supports_tools")
|
||||
|
||||
@@ -211,8 +211,21 @@ class WorkflowAgent(BaseAgent):
|
||||
def _serialize_state(self, state: Dict[str, Any]) -> Dict[str, Any]:
|
||||
serialized: Dict[str, Any] = {}
|
||||
for key, value in state.items():
|
||||
if isinstance(value, (str, int, float, bool, type(None))):
|
||||
serialized[key] = value
|
||||
else:
|
||||
serialized[key] = str(value)
|
||||
serialized[key] = self._serialize_state_value(value)
|
||||
return serialized
|
||||
|
||||
def _serialize_state_value(self, value: Any) -> Any:
|
||||
if isinstance(value, dict):
|
||||
return {
|
||||
str(dict_key): self._serialize_state_value(dict_value)
|
||||
for dict_key, dict_value in value.items()
|
||||
}
|
||||
if isinstance(value, list):
|
||||
return [self._serialize_state_value(item) for item in value]
|
||||
if isinstance(value, tuple):
|
||||
return [self._serialize_state_value(item) for item in value]
|
||||
if isinstance(value, datetime):
|
||||
return value.isoformat()
|
||||
if isinstance(value, (str, int, float, bool, type(None))):
|
||||
return value
|
||||
return str(value)
|
||||
|
||||
64
application/agents/workflows/cel_evaluator.py
Normal file
64
application/agents/workflows/cel_evaluator.py
Normal file
@@ -0,0 +1,64 @@
|
||||
from typing import Any, Dict
|
||||
|
||||
import celpy
|
||||
import celpy.celtypes
|
||||
|
||||
|
||||
class CelEvaluationError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def _convert_value(value: Any) -> Any:
|
||||
if isinstance(value, bool):
|
||||
return celpy.celtypes.BoolType(value)
|
||||
if isinstance(value, int):
|
||||
return celpy.celtypes.IntType(value)
|
||||
if isinstance(value, float):
|
||||
return celpy.celtypes.DoubleType(value)
|
||||
if isinstance(value, str):
|
||||
return celpy.celtypes.StringType(value)
|
||||
if isinstance(value, list):
|
||||
return celpy.celtypes.ListType([_convert_value(item) for item in value])
|
||||
if isinstance(value, dict):
|
||||
return celpy.celtypes.MapType(
|
||||
{celpy.celtypes.StringType(k): _convert_value(v) for k, v in value.items()}
|
||||
)
|
||||
if value is None:
|
||||
return celpy.celtypes.BoolType(False)
|
||||
return celpy.celtypes.StringType(str(value))
|
||||
|
||||
|
||||
def build_activation(state: Dict[str, Any]) -> Dict[str, Any]:
|
||||
return {k: _convert_value(v) for k, v in state.items()}
|
||||
|
||||
|
||||
def evaluate_cel(expression: str, state: Dict[str, Any]) -> Any:
|
||||
if not expression or not expression.strip():
|
||||
raise CelEvaluationError("Empty expression")
|
||||
try:
|
||||
env = celpy.Environment()
|
||||
ast = env.compile(expression)
|
||||
program = env.program(ast)
|
||||
activation = build_activation(state)
|
||||
result = program.evaluate(activation)
|
||||
except celpy.CELEvalError as exc:
|
||||
raise CelEvaluationError(f"CEL evaluation error: {exc}") from exc
|
||||
except Exception as exc:
|
||||
raise CelEvaluationError(f"CEL error: {exc}") from exc
|
||||
return cel_to_python(result)
|
||||
|
||||
|
||||
def cel_to_python(value: Any) -> Any:
|
||||
if isinstance(value, celpy.celtypes.BoolType):
|
||||
return bool(value)
|
||||
if isinstance(value, celpy.celtypes.IntType):
|
||||
return int(value)
|
||||
if isinstance(value, celpy.celtypes.DoubleType):
|
||||
return float(value)
|
||||
if isinstance(value, celpy.celtypes.StringType):
|
||||
return str(value)
|
||||
if isinstance(value, celpy.celtypes.ListType):
|
||||
return [cel_to_python(item) for item in value]
|
||||
if isinstance(value, celpy.celtypes.MapType):
|
||||
return {str(k): cel_to_python(v) for k, v in value.items()}
|
||||
return value
|
||||
@@ -1,6 +1,6 @@
|
||||
from datetime import datetime, timezone
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Any, Dict, List, Literal, Optional, Union
|
||||
|
||||
from bson import ObjectId
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
@@ -12,6 +12,7 @@ class NodeType(str, Enum):
|
||||
AGENT = "agent"
|
||||
NOTE = "note"
|
||||
STATE = "state"
|
||||
CONDITION = "condition"
|
||||
|
||||
|
||||
class AgentType(str, Enum):
|
||||
@@ -48,6 +49,25 @@ class AgentNodeConfig(BaseModel):
|
||||
json_schema: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class ConditionCase(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid", populate_by_name=True)
|
||||
name: Optional[str] = None
|
||||
expression: str = ""
|
||||
source_handle: str = Field(..., alias="sourceHandle")
|
||||
|
||||
|
||||
class ConditionNodeConfig(BaseModel):
|
||||
model_config = ConfigDict(extra="allow")
|
||||
mode: Literal["simple", "advanced"] = "simple"
|
||||
cases: List[ConditionCase] = Field(default_factory=list)
|
||||
|
||||
|
||||
class StateOperation(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
expression: str = ""
|
||||
target_variable: str = ""
|
||||
|
||||
|
||||
class WorkflowEdgeCreate(BaseModel):
|
||||
model_config = ConfigDict(populate_by_name=True)
|
||||
id: str
|
||||
|
||||
@@ -1,16 +1,30 @@
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Dict, Generator, List, Optional, TYPE_CHECKING
|
||||
|
||||
from application.agents.workflows.cel_evaluator import CelEvaluationError, evaluate_cel
|
||||
from application.agents.workflows.node_agent import WorkflowNodeAgentFactory
|
||||
from application.agents.workflows.schemas import (
|
||||
AgentNodeConfig,
|
||||
ConditionNodeConfig,
|
||||
ExecutionStatus,
|
||||
NodeExecutionLog,
|
||||
NodeType,
|
||||
WorkflowGraph,
|
||||
WorkflowNode,
|
||||
)
|
||||
from application.core.json_schema_utils import (
|
||||
JsonSchemaValidationError,
|
||||
normalize_json_schema_payload,
|
||||
)
|
||||
from application.templates.namespaces import NamespaceManager
|
||||
from application.templates.template_engine import TemplateEngine, TemplateRenderError
|
||||
|
||||
try:
|
||||
import jsonschema
|
||||
except ImportError: # pragma: no cover - optional dependency in some deployments.
|
||||
jsonschema = None
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from application.agents.base import BaseAgent
|
||||
@@ -18,6 +32,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
StateValue = Any
|
||||
WorkflowState = Dict[str, StateValue]
|
||||
TEMPLATE_RESERVED_NAMESPACES = {"agent", "system", "source", "tools", "passthrough"}
|
||||
|
||||
|
||||
class WorkflowEngine:
|
||||
@@ -28,6 +43,9 @@ class WorkflowEngine:
|
||||
self.agent = agent
|
||||
self.state: WorkflowState = {}
|
||||
self.execution_log: List[Dict[str, Any]] = []
|
||||
self._condition_result: Optional[str] = None
|
||||
self._template_engine = TemplateEngine()
|
||||
self._namespace_manager = NamespaceManager()
|
||||
|
||||
def execute(
|
||||
self, initial_inputs: WorkflowState, query: str
|
||||
@@ -98,6 +116,10 @@ class WorkflowEngine:
|
||||
if node.type == NodeType.END:
|
||||
break
|
||||
current_node_id = self._get_next_node_id(current_node_id)
|
||||
if current_node_id is None and node.type != NodeType.END:
|
||||
logger.warning(
|
||||
f"Branch ended at node '{node.title}' ({node.id}) without reaching an end node"
|
||||
)
|
||||
steps += 1
|
||||
if steps >= self.MAX_EXECUTION_STEPS:
|
||||
logger.warning(
|
||||
@@ -121,10 +143,20 @@ class WorkflowEngine:
|
||||
}
|
||||
|
||||
def _get_next_node_id(self, current_node_id: str) -> Optional[str]:
|
||||
node = self.graph.get_node_by_id(current_node_id)
|
||||
edges = self.graph.get_outgoing_edges(current_node_id)
|
||||
if edges:
|
||||
return edges[0].target_id
|
||||
return None
|
||||
if not edges:
|
||||
return None
|
||||
|
||||
if node and node.type == NodeType.CONDITION and self._condition_result:
|
||||
target_handle = self._condition_result
|
||||
self._condition_result = None
|
||||
for edge in edges:
|
||||
if edge.source_handle == target_handle:
|
||||
return edge.target_id
|
||||
return None
|
||||
|
||||
return edges[0].target_id
|
||||
|
||||
def _execute_node(
|
||||
self, node: WorkflowNode
|
||||
@@ -136,6 +168,7 @@ class WorkflowEngine:
|
||||
NodeType.NOTE: self._execute_note_node,
|
||||
NodeType.AGENT: self._execute_agent_node,
|
||||
NodeType.STATE: self._execute_state_node,
|
||||
NodeType.CONDITION: self._execute_condition_node,
|
||||
NodeType.END: self._execute_end_node,
|
||||
}
|
||||
|
||||
@@ -156,35 +189,62 @@ class WorkflowEngine:
|
||||
def _execute_agent_node(
|
||||
self, node: WorkflowNode
|
||||
) -> Generator[Dict[str, str], None, None]:
|
||||
from application.core.model_utils import get_api_key_for_provider
|
||||
from application.core.model_utils import (
|
||||
get_api_key_for_provider,
|
||||
get_model_capabilities,
|
||||
get_provider_from_model_id,
|
||||
)
|
||||
|
||||
node_config = AgentNodeConfig(**node.config)
|
||||
node_config = AgentNodeConfig(**node.config.get("config", node.config))
|
||||
|
||||
if node_config.prompt_template:
|
||||
formatted_prompt = self._format_template(node_config.prompt_template)
|
||||
else:
|
||||
formatted_prompt = self.state.get("query", "")
|
||||
node_llm_name = node_config.llm_name or self.agent.llm_name
|
||||
node_json_schema = self._normalize_node_json_schema(
|
||||
node_config.json_schema, node.title
|
||||
)
|
||||
node_model_id = node_config.model_id or self.agent.model_id
|
||||
node_llm_name = (
|
||||
node_config.llm_name
|
||||
or get_provider_from_model_id(node_model_id or "")
|
||||
or self.agent.llm_name
|
||||
)
|
||||
node_api_key = get_api_key_for_provider(node_llm_name) or self.agent.api_key
|
||||
|
||||
if node_json_schema and node_model_id:
|
||||
model_capabilities = get_model_capabilities(node_model_id)
|
||||
if model_capabilities and not model_capabilities.get(
|
||||
"supports_structured_output", False
|
||||
):
|
||||
raise ValueError(
|
||||
f'Model "{node_model_id}" does not support structured output for node "{node.title}"'
|
||||
)
|
||||
|
||||
node_agent = WorkflowNodeAgentFactory.create(
|
||||
agent_type=node_config.agent_type,
|
||||
endpoint=self.agent.endpoint,
|
||||
llm_name=node_llm_name,
|
||||
model_id=node_config.model_id or self.agent.model_id,
|
||||
model_id=node_model_id,
|
||||
api_key=node_api_key,
|
||||
tool_ids=node_config.tools,
|
||||
prompt=node_config.system_prompt,
|
||||
chat_history=self.agent.chat_history,
|
||||
decoded_token=self.agent.decoded_token,
|
||||
json_schema=node_config.json_schema,
|
||||
json_schema=node_json_schema,
|
||||
)
|
||||
|
||||
full_response = ""
|
||||
full_response_parts: List[str] = []
|
||||
structured_response_parts: List[str] = []
|
||||
has_structured_response = False
|
||||
first_chunk = True
|
||||
for event in node_agent.gen(formatted_prompt):
|
||||
if "answer" in event:
|
||||
full_response += event["answer"]
|
||||
chunk = str(event["answer"])
|
||||
full_response_parts.append(chunk)
|
||||
if event.get("structured"):
|
||||
has_structured_response = True
|
||||
structured_response_parts.append(chunk)
|
||||
if node_config.stream_to_user:
|
||||
if first_chunk and hasattr(self, "_has_streamed"):
|
||||
yield {"answer": "\n\n"}
|
||||
@@ -194,72 +254,189 @@ class WorkflowEngine:
|
||||
if node_config.stream_to_user:
|
||||
self._has_streamed = True
|
||||
|
||||
output_key = node_config.output_variable or f"node_{node.id}_output"
|
||||
self.state[output_key] = full_response
|
||||
full_response = "".join(full_response_parts).strip()
|
||||
output_value: Any = full_response
|
||||
if has_structured_response:
|
||||
structured_response = "".join(structured_response_parts).strip()
|
||||
response_to_parse = structured_response or full_response
|
||||
parsed_success, parsed_structured = self._parse_structured_output(
|
||||
response_to_parse
|
||||
)
|
||||
output_value = parsed_structured if parsed_success else response_to_parse
|
||||
if node_json_schema:
|
||||
self._validate_structured_output(node_json_schema, output_value)
|
||||
elif node_json_schema:
|
||||
parsed_success, parsed_structured = self._parse_structured_output(
|
||||
full_response
|
||||
)
|
||||
if not parsed_success:
|
||||
raise ValueError(
|
||||
"Structured output was expected but response was not valid JSON"
|
||||
)
|
||||
output_value = parsed_structured
|
||||
self._validate_structured_output(node_json_schema, output_value)
|
||||
|
||||
default_output_key = f"node_{node.id}_output"
|
||||
self.state[default_output_key] = output_value
|
||||
|
||||
if node_config.output_variable:
|
||||
self.state[node_config.output_variable] = output_value
|
||||
|
||||
def _execute_state_node(
|
||||
self, node: WorkflowNode
|
||||
) -> Generator[Dict[str, str], None, None]:
|
||||
config = node.config
|
||||
operations = config.get("operations", [])
|
||||
config = node.config.get("config", node.config)
|
||||
for op in config.get("operations", []):
|
||||
expression = op.get("expression", "")
|
||||
target_variable = op.get("target_variable", "")
|
||||
if expression and target_variable:
|
||||
self.state[target_variable] = evaluate_cel(expression, self.state)
|
||||
yield from ()
|
||||
|
||||
if operations:
|
||||
for op in operations:
|
||||
key = op.get("key")
|
||||
operation = op.get("operation", "set")
|
||||
value = op.get("value")
|
||||
def _execute_condition_node(
|
||||
self, node: WorkflowNode
|
||||
) -> Generator[Dict[str, str], None, None]:
|
||||
config = ConditionNodeConfig(**node.config.get("config", node.config))
|
||||
matched_handle = None
|
||||
|
||||
if not key:
|
||||
continue
|
||||
if operation == "set":
|
||||
formatted_value = (
|
||||
self._format_template(str(value))
|
||||
if isinstance(value, str)
|
||||
else value
|
||||
)
|
||||
self.state[key] = formatted_value
|
||||
elif operation == "increment":
|
||||
current = self.state.get(key, 0)
|
||||
try:
|
||||
self.state[key] = int(current) + int(value or 1)
|
||||
except (ValueError, TypeError):
|
||||
self.state[key] = 1
|
||||
elif operation == "append":
|
||||
if key not in self.state:
|
||||
self.state[key] = []
|
||||
if isinstance(self.state[key], list):
|
||||
self.state[key].append(value)
|
||||
else:
|
||||
updates = config.get("updates", {})
|
||||
if not updates:
|
||||
var_name = config.get("variable")
|
||||
var_value = config.get("value")
|
||||
if var_name and isinstance(var_name, str):
|
||||
updates = {var_name: var_value or ""}
|
||||
if isinstance(updates, dict):
|
||||
for key, value in updates.items():
|
||||
if isinstance(value, str):
|
||||
self.state[key] = self._format_template(value)
|
||||
else:
|
||||
self.state[key] = value
|
||||
for case in config.cases:
|
||||
if not case.expression.strip():
|
||||
continue
|
||||
try:
|
||||
if evaluate_cel(case.expression, self.state):
|
||||
matched_handle = case.source_handle
|
||||
break
|
||||
except CelEvaluationError:
|
||||
continue
|
||||
|
||||
self._condition_result = matched_handle or "else"
|
||||
yield from ()
|
||||
|
||||
def _execute_end_node(
|
||||
self, node: WorkflowNode
|
||||
) -> Generator[Dict[str, str], None, None]:
|
||||
config = node.config
|
||||
config = node.config.get("config", node.config)
|
||||
output_template = str(config.get("output_template", ""))
|
||||
if output_template:
|
||||
formatted_output = self._format_template(output_template)
|
||||
yield {"answer": formatted_output}
|
||||
|
||||
def _parse_structured_output(self, raw_response: str) -> tuple[bool, Optional[Any]]:
|
||||
normalized_response = raw_response.strip()
|
||||
if not normalized_response:
|
||||
return False, None
|
||||
|
||||
try:
|
||||
return True, json.loads(normalized_response)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(
|
||||
"Workflow agent returned structured output that was not valid JSON"
|
||||
)
|
||||
return False, None
|
||||
|
||||
def _normalize_node_json_schema(
|
||||
self, schema: Optional[Dict[str, Any]], node_title: str
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
if schema is None:
|
||||
return None
|
||||
try:
|
||||
return normalize_json_schema_payload(schema)
|
||||
except JsonSchemaValidationError as exc:
|
||||
raise ValueError(
|
||||
f'Invalid JSON schema for node "{node_title}": {exc}'
|
||||
) from exc
|
||||
|
||||
def _validate_structured_output(self, schema: Dict[str, Any], output_value: Any) -> None:
|
||||
if jsonschema is None:
|
||||
logger.warning(
|
||||
"jsonschema package is not available, skipping structured output validation"
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
normalized_schema = normalize_json_schema_payload(schema)
|
||||
except JsonSchemaValidationError as exc:
|
||||
raise ValueError(f"Invalid JSON schema: {exc}") from exc
|
||||
|
||||
try:
|
||||
jsonschema.validate(instance=output_value, schema=normalized_schema)
|
||||
except jsonschema.exceptions.ValidationError as exc:
|
||||
raise ValueError(f"Structured output did not match schema: {exc.message}") from exc
|
||||
except jsonschema.exceptions.SchemaError as exc:
|
||||
raise ValueError(f"Invalid JSON schema: {exc.message}") from exc
|
||||
|
||||
def _format_template(self, template: str) -> str:
|
||||
formatted = template
|
||||
context = self._build_template_context()
|
||||
try:
|
||||
return self._template_engine.render(template, context)
|
||||
except TemplateRenderError as e:
|
||||
logger.warning(
|
||||
"Workflow template rendering failed, using raw template: %s", str(e)
|
||||
)
|
||||
return template
|
||||
|
||||
def _build_template_context(self) -> Dict[str, Any]:
|
||||
docs, docs_together = self._get_source_template_data()
|
||||
passthrough_data = (
|
||||
self.state.get("passthrough")
|
||||
if isinstance(self.state.get("passthrough"), dict)
|
||||
else None
|
||||
)
|
||||
tools_data = (
|
||||
self.state.get("tools") if isinstance(self.state.get("tools"), dict) else None
|
||||
)
|
||||
|
||||
context = self._namespace_manager.build_context(
|
||||
user_id=getattr(self.agent, "user", None),
|
||||
request_id=getattr(self.agent, "request_id", None),
|
||||
passthrough_data=passthrough_data,
|
||||
docs=docs,
|
||||
docs_together=docs_together,
|
||||
tools_data=tools_data,
|
||||
)
|
||||
|
||||
agent_context: Dict[str, Any] = {}
|
||||
for key, value in self.state.items():
|
||||
placeholder = f"{{{{{key}}}}}"
|
||||
if placeholder in formatted and value is not None:
|
||||
formatted = formatted.replace(placeholder, str(value))
|
||||
return formatted
|
||||
if not isinstance(key, str):
|
||||
continue
|
||||
normalized_key = key.strip()
|
||||
if not normalized_key:
|
||||
continue
|
||||
agent_context[normalized_key] = value
|
||||
|
||||
context["agent"] = agent_context
|
||||
|
||||
# Keep legacy top-level variables working while namespaced variables are adopted.
|
||||
for key, value in agent_context.items():
|
||||
if key in TEMPLATE_RESERVED_NAMESPACES:
|
||||
context[f"agent_{key}"] = value
|
||||
continue
|
||||
if key not in context:
|
||||
context[key] = value
|
||||
|
||||
return context
|
||||
|
||||
def _get_source_template_data(self) -> tuple[Optional[List[Dict[str, Any]]], Optional[str]]:
|
||||
docs = getattr(self.agent, "retrieved_docs", None)
|
||||
if not isinstance(docs, list) or len(docs) == 0:
|
||||
return None, None
|
||||
|
||||
docs_together_parts: List[str] = []
|
||||
for doc in docs:
|
||||
if not isinstance(doc, dict):
|
||||
continue
|
||||
text = doc.get("text")
|
||||
if not isinstance(text, str):
|
||||
continue
|
||||
|
||||
filename = doc.get("filename") or doc.get("title") or doc.get("source")
|
||||
if isinstance(filename, str) and filename.strip():
|
||||
docs_together_parts.append(f"{filename}\n{text}")
|
||||
else:
|
||||
docs_together_parts.append(text)
|
||||
|
||||
docs_together = "\n\n".join(docs_together_parts) if docs_together_parts else None
|
||||
return docs, docs_together
|
||||
|
||||
def get_execution_summary(self) -> List[NodeExecutionLog]:
|
||||
return [
|
||||
|
||||
@@ -42,6 +42,7 @@ class AnswerResource(Resource, BaseAnswerResource):
|
||||
),
|
||||
"retriever": fields.String(required=False, description="Retriever type"),
|
||||
"api_key": fields.String(required=False, description="API key"),
|
||||
"agent_id": fields.String(required=False, description="Agent ID"),
|
||||
"active_docs": fields.String(
|
||||
required=False, description="Active documents"
|
||||
),
|
||||
@@ -100,6 +101,9 @@ class AnswerResource(Resource, BaseAnswerResource):
|
||||
isNoneDoc=data.get("isNoneDoc"),
|
||||
index=None,
|
||||
should_save_conversation=data.get("save_conversation", True),
|
||||
agent_id=processor.agent_id,
|
||||
is_shared_usage=processor.is_shared_usage,
|
||||
shared_token=processor.shared_token,
|
||||
model_id=processor.model_id,
|
||||
)
|
||||
stream_result = self.process_response_stream(stream)
|
||||
|
||||
@@ -46,6 +46,27 @@ class BaseAnswerResource:
|
||||
return missing_fields
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _prepare_tool_calls_for_logging(
|
||||
tool_calls: Optional[List[Dict[str, Any]]], max_chars: int = 10000
|
||||
) -> List[Dict[str, Any]]:
|
||||
if not tool_calls:
|
||||
return []
|
||||
|
||||
prepared = []
|
||||
for tool_call in tool_calls:
|
||||
if not isinstance(tool_call, dict):
|
||||
prepared.append({"result": str(tool_call)[:max_chars]})
|
||||
continue
|
||||
|
||||
item = dict(tool_call)
|
||||
for key in ("result", "result_full"):
|
||||
value = item.get(key)
|
||||
if isinstance(value, str) and len(value) > max_chars:
|
||||
item[key] = value[:max_chars]
|
||||
prepared.append(item)
|
||||
return prepared
|
||||
|
||||
def check_usage(self, agent_config: Dict) -> Optional[Response]:
|
||||
"""Check if there is a usage limit and if it is exceeded
|
||||
|
||||
@@ -246,6 +267,7 @@ class BaseAnswerResource:
|
||||
user_api_key=user_api_key,
|
||||
decoded_token=decoded_token,
|
||||
model_id=model_id,
|
||||
agent_id=agent_id,
|
||||
)
|
||||
|
||||
if should_save_conversation:
|
||||
@@ -292,14 +314,20 @@ class BaseAnswerResource:
|
||||
data = json.dumps(id_data)
|
||||
yield f"data: {data}\n\n"
|
||||
|
||||
tool_calls_for_logging = self._prepare_tool_calls_for_logging(
|
||||
getattr(agent, "tool_calls", tool_calls) or tool_calls
|
||||
)
|
||||
|
||||
log_data = {
|
||||
"action": "stream_answer",
|
||||
"level": "info",
|
||||
"user": decoded_token.get("sub"),
|
||||
"api_key": user_api_key,
|
||||
"agent_id": agent_id,
|
||||
"question": question,
|
||||
"response": response_full,
|
||||
"sources": source_log_docs,
|
||||
"tool_calls": tool_calls_for_logging,
|
||||
"attachments": attachment_ids,
|
||||
"timestamp": datetime.datetime.now(datetime.timezone.utc),
|
||||
}
|
||||
@@ -330,6 +358,7 @@ class BaseAnswerResource:
|
||||
api_key=settings.API_KEY,
|
||||
user_api_key=user_api_key,
|
||||
decoded_token=decoded_token,
|
||||
agent_id=agent_id,
|
||||
)
|
||||
self.conversation_service.save_conversation(
|
||||
conversation_id,
|
||||
|
||||
@@ -42,6 +42,7 @@ class StreamResource(Resource, BaseAnswerResource):
|
||||
),
|
||||
"retriever": fields.String(required=False, description="Retriever type"),
|
||||
"api_key": fields.String(required=False, description="API key"),
|
||||
"agent_id": fields.String(required=False, description="Agent ID"),
|
||||
"active_docs": fields.String(
|
||||
required=False, description="Active documents"
|
||||
),
|
||||
@@ -107,7 +108,7 @@ class StreamResource(Resource, BaseAnswerResource):
|
||||
index=data.get("index"),
|
||||
should_save_conversation=data.get("save_conversation", True),
|
||||
attachment_ids=data.get("attachments", []),
|
||||
agent_id=data.get("agent_id"),
|
||||
agent_id=processor.agent_id,
|
||||
is_shared_usage=processor.is_shared_usage,
|
||||
shared_token=processor.shared_token,
|
||||
model_id=processor.model_id,
|
||||
|
||||
@@ -134,6 +134,7 @@ class CompressionOrchestrator:
|
||||
user_api_key=None,
|
||||
decoded_token=decoded_token,
|
||||
model_id=compression_model,
|
||||
agent_id=conversation.get("agent_id"),
|
||||
)
|
||||
|
||||
# Create compression service with DB update capability
|
||||
|
||||
@@ -90,6 +90,7 @@ class StreamProcessor:
|
||||
self.retriever_config = {}
|
||||
self.is_shared_usage = False
|
||||
self.shared_token = None
|
||||
self.agent_id = self.data.get("agent_id")
|
||||
self.model_id: Optional[str] = None
|
||||
self.conversation_service = ConversationService()
|
||||
self.compression_orchestrator = CompressionOrchestrator(
|
||||
@@ -355,10 +356,13 @@ class StreamProcessor:
|
||||
self.agent_key, self.is_shared_usage, self.shared_token = self._get_agent_key(
|
||||
agent_id, self.initial_user_id
|
||||
)
|
||||
self.agent_id = str(agent_id) if agent_id else None
|
||||
|
||||
api_key = self.data.get("api_key")
|
||||
if api_key:
|
||||
data_key = self._get_data_from_api_key(api_key)
|
||||
if data_key.get("_id"):
|
||||
self.agent_id = str(data_key.get("_id"))
|
||||
self.agent_config.update(
|
||||
{
|
||||
"prompt_id": data_key.get("prompt_id", "default"),
|
||||
@@ -387,6 +391,8 @@ class StreamProcessor:
|
||||
self.retriever_config["chunks"] = 2
|
||||
elif self.agent_key:
|
||||
data_key = self._get_data_from_api_key(self.agent_key)
|
||||
if data_key.get("_id"):
|
||||
self.agent_id = str(data_key.get("_id"))
|
||||
self.agent_config.update(
|
||||
{
|
||||
"prompt_id": data_key.get("prompt_id", "default"),
|
||||
@@ -459,6 +465,7 @@ class StreamProcessor:
|
||||
doc_token_limit=self.retriever_config.get("doc_token_limit", 50000),
|
||||
model_id=self.model_id,
|
||||
user_api_key=self.agent_config["user_api_key"],
|
||||
agent_id=self.agent_id,
|
||||
decoded_token=self.decoded_token,
|
||||
)
|
||||
|
||||
@@ -754,6 +761,7 @@ class StreamProcessor:
|
||||
"llm_name": provider or settings.LLM_PROVIDER,
|
||||
"model_id": self.model_id,
|
||||
"api_key": system_api_key,
|
||||
"agent_id": self.agent_id,
|
||||
"user_api_key": self.agent_config["user_api_key"],
|
||||
"prompt": rendered_prompt,
|
||||
"chat_history": self.history,
|
||||
|
||||
@@ -23,6 +23,10 @@ from application.api.user.base import (
|
||||
workflow_nodes_collection,
|
||||
workflows_collection,
|
||||
)
|
||||
from application.core.json_schema_utils import (
|
||||
JsonSchemaValidationError,
|
||||
normalize_json_schema_payload,
|
||||
)
|
||||
from application.core.settings import settings
|
||||
from application.utils import (
|
||||
check_required_fields,
|
||||
@@ -479,41 +483,15 @@ class CreateAgent(Resource):
|
||||
data["models"] = []
|
||||
print(f"Received data: {data}")
|
||||
|
||||
# Validate JSON schema if provided
|
||||
|
||||
if data.get("json_schema"):
|
||||
# Validate and normalize JSON schema if provided
|
||||
if "json_schema" in data:
|
||||
try:
|
||||
# Basic validation - ensure it's a valid JSON structure
|
||||
|
||||
json_schema = data.get("json_schema")
|
||||
if not isinstance(json_schema, dict):
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": "JSON schema must be a valid JSON object",
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
# Validate that it has either a 'schema' property or is itself a schema
|
||||
|
||||
if "schema" not in json_schema and "type" not in json_schema:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": "JSON schema must contain either a 'schema' property or be a valid JSON schema with 'type' property",
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Invalid JSON schema: {e}")
|
||||
data["json_schema"] = normalize_json_schema_payload(
|
||||
data.get("json_schema")
|
||||
)
|
||||
except JsonSchemaValidationError as exc:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{"success": False, "message": "Invalid JSON schema format"}
|
||||
),
|
||||
jsonify({"success": False, "message": f"JSON schema {exc}"}),
|
||||
400,
|
||||
)
|
||||
if data.get("status") not in ["draft", "published"]:
|
||||
@@ -732,6 +710,8 @@ class UpdateAgent(Resource):
|
||||
),
|
||||
400,
|
||||
)
|
||||
if data.get("json_schema") == "":
|
||||
data["json_schema"] = None
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error parsing request data: {err}", exc_info=True
|
||||
@@ -892,17 +872,15 @@ class UpdateAgent(Resource):
|
||||
elif field == "json_schema":
|
||||
json_schema = data.get("json_schema")
|
||||
if json_schema is not None:
|
||||
if not isinstance(json_schema, dict):
|
||||
try:
|
||||
update_fields[field] = normalize_json_schema_payload(
|
||||
json_schema
|
||||
)
|
||||
except JsonSchemaValidationError as exc:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": "JSON schema must be a valid object",
|
||||
}
|
||||
),
|
||||
jsonify({"success": False, "message": f"JSON schema {exc}"}),
|
||||
400,
|
||||
)
|
||||
update_fields[field] = json_schema
|
||||
else:
|
||||
update_fields[field] = None
|
||||
elif field == "limited_token_mode":
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Workflow management routes."""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import Dict, List
|
||||
from typing import Any, Dict, List, Optional, Set
|
||||
|
||||
from flask import current_app, request
|
||||
from flask_restx import Namespace, Resource
|
||||
@@ -11,6 +11,11 @@ from application.api.user.base import (
|
||||
workflow_nodes_collection,
|
||||
workflows_collection,
|
||||
)
|
||||
from application.core.json_schema_utils import (
|
||||
JsonSchemaValidationError,
|
||||
normalize_json_schema_payload,
|
||||
)
|
||||
from application.core.model_utils import get_model_capabilities
|
||||
from application.api.user.utils import (
|
||||
check_resource_ownership,
|
||||
error_response,
|
||||
@@ -85,6 +90,50 @@ def fetch_graph_documents(collection, workflow_id: str, graph_version: int) -> L
|
||||
return docs
|
||||
|
||||
|
||||
def validate_json_schema_payload(
|
||||
json_schema: Any,
|
||||
) -> tuple[Optional[Dict[str, Any]], Optional[str]]:
|
||||
"""Validate and normalize optional JSON schema payload for structured output."""
|
||||
if json_schema is None:
|
||||
return None, None
|
||||
try:
|
||||
return normalize_json_schema_payload(json_schema), None
|
||||
except JsonSchemaValidationError as exc:
|
||||
return None, str(exc)
|
||||
|
||||
|
||||
def normalize_agent_node_json_schemas(nodes: List[Dict]) -> List[Dict]:
|
||||
"""Normalize agent-node JSON schema payloads before persistence."""
|
||||
normalized_nodes: List[Dict] = []
|
||||
for node in nodes:
|
||||
if not isinstance(node, dict):
|
||||
normalized_nodes.append(node)
|
||||
continue
|
||||
|
||||
normalized_node = dict(node)
|
||||
if normalized_node.get("type") != "agent":
|
||||
normalized_nodes.append(normalized_node)
|
||||
continue
|
||||
|
||||
raw_config = normalized_node.get("data")
|
||||
if not isinstance(raw_config, dict) or "json_schema" not in raw_config:
|
||||
normalized_nodes.append(normalized_node)
|
||||
continue
|
||||
|
||||
normalized_config = dict(raw_config)
|
||||
try:
|
||||
normalized_config["json_schema"] = normalize_json_schema_payload(
|
||||
raw_config.get("json_schema")
|
||||
)
|
||||
except JsonSchemaValidationError:
|
||||
# Validation runs before normalization; keep original on unexpected shape.
|
||||
normalized_config["json_schema"] = raw_config.get("json_schema")
|
||||
normalized_node["data"] = normalized_config
|
||||
normalized_nodes.append(normalized_node)
|
||||
|
||||
return normalized_nodes
|
||||
|
||||
|
||||
def validate_workflow_structure(nodes: List[Dict], edges: List[Dict]) -> List[str]:
|
||||
"""Validate workflow graph structure."""
|
||||
errors = []
|
||||
@@ -102,6 +151,9 @@ def validate_workflow_structure(nodes: List[Dict], edges: List[Dict]) -> List[st
|
||||
errors.append("Workflow must have at least one end node")
|
||||
|
||||
node_ids = {n.get("id") for n in nodes}
|
||||
node_map = {n.get("id"): n for n in nodes}
|
||||
end_ids = {n.get("id") for n in end_nodes}
|
||||
|
||||
for edge in edges:
|
||||
source_id = edge.get("source")
|
||||
target_id = edge.get("target")
|
||||
@@ -115,6 +167,126 @@ def validate_workflow_structure(nodes: List[Dict], edges: List[Dict]) -> List[st
|
||||
if not any(e.get("source") == start_id for e in edges):
|
||||
errors.append("Start node must have at least one outgoing edge")
|
||||
|
||||
condition_nodes = [n for n in nodes if n.get("type") == "condition"]
|
||||
for cnode in condition_nodes:
|
||||
cnode_id = cnode.get("id")
|
||||
cnode_title = cnode.get("title", cnode_id)
|
||||
outgoing = [e for e in edges if e.get("source") == cnode_id]
|
||||
if len(outgoing) < 2:
|
||||
errors.append(
|
||||
f"Condition node '{cnode_title}' must have at least 2 outgoing edges"
|
||||
)
|
||||
node_data = cnode.get("data", {}) or {}
|
||||
cases = node_data.get("cases", [])
|
||||
if not isinstance(cases, list):
|
||||
cases = []
|
||||
if not cases or not any(
|
||||
isinstance(c, dict) and str(c.get("expression", "")).strip() for c in cases
|
||||
):
|
||||
errors.append(
|
||||
f"Condition node '{cnode_title}' must have at least one case with an expression"
|
||||
)
|
||||
|
||||
case_handles: Set[str] = set()
|
||||
duplicate_case_handles: Set[str] = set()
|
||||
for case in cases:
|
||||
if not isinstance(case, dict):
|
||||
continue
|
||||
raw_handle = case.get("sourceHandle", "")
|
||||
handle = raw_handle.strip() if isinstance(raw_handle, str) else ""
|
||||
if not handle:
|
||||
errors.append(
|
||||
f"Condition node '{cnode_title}' has a case without a branch handle"
|
||||
)
|
||||
continue
|
||||
if handle in case_handles:
|
||||
duplicate_case_handles.add(handle)
|
||||
case_handles.add(handle)
|
||||
|
||||
for handle in duplicate_case_handles:
|
||||
errors.append(
|
||||
f"Condition node '{cnode_title}' has duplicate case handle '{handle}'"
|
||||
)
|
||||
|
||||
outgoing_by_handle: Dict[str, List[Dict]] = {}
|
||||
for out_edge in outgoing:
|
||||
raw_handle = out_edge.get("sourceHandle", "")
|
||||
handle = raw_handle.strip() if isinstance(raw_handle, str) else ""
|
||||
outgoing_by_handle.setdefault(handle, []).append(out_edge)
|
||||
|
||||
for handle, handle_edges in outgoing_by_handle.items():
|
||||
if not handle:
|
||||
errors.append(
|
||||
f"Condition node '{cnode_title}' has an outgoing edge without sourceHandle"
|
||||
)
|
||||
continue
|
||||
if handle != "else" and handle not in case_handles:
|
||||
errors.append(
|
||||
f"Condition node '{cnode_title}' has a connection from unknown branch '{handle}'"
|
||||
)
|
||||
if len(handle_edges) > 1:
|
||||
errors.append(
|
||||
f"Condition node '{cnode_title}' has multiple outgoing edges from branch '{handle}'"
|
||||
)
|
||||
|
||||
if "else" not in outgoing_by_handle:
|
||||
errors.append(f"Condition node '{cnode_title}' must have an 'else' branch")
|
||||
|
||||
for case in cases:
|
||||
if not isinstance(case, dict):
|
||||
continue
|
||||
raw_handle = case.get("sourceHandle", "")
|
||||
handle = raw_handle.strip() if isinstance(raw_handle, str) else ""
|
||||
if not handle:
|
||||
continue
|
||||
|
||||
raw_expression = case.get("expression", "")
|
||||
has_expression = isinstance(raw_expression, str) and bool(
|
||||
raw_expression.strip()
|
||||
)
|
||||
has_outgoing = bool(outgoing_by_handle.get(handle))
|
||||
if has_expression and not has_outgoing:
|
||||
errors.append(
|
||||
f"Condition node '{cnode_title}' case '{handle}' has an expression but no outgoing edge"
|
||||
)
|
||||
if not has_expression and has_outgoing:
|
||||
errors.append(
|
||||
f"Condition node '{cnode_title}' case '{handle}' has an outgoing edge but no expression"
|
||||
)
|
||||
|
||||
for handle, handle_edges in outgoing_by_handle.items():
|
||||
if not handle:
|
||||
continue
|
||||
for out_edge in handle_edges:
|
||||
target = out_edge.get("target")
|
||||
if target and not _can_reach_end(target, edges, node_map, end_ids):
|
||||
errors.append(
|
||||
f"Branch '{handle}' of condition '{cnode_title}' "
|
||||
f"must eventually reach an end node"
|
||||
)
|
||||
|
||||
agent_nodes = [n for n in nodes if n.get("type") == "agent"]
|
||||
for agent_node in agent_nodes:
|
||||
agent_title = agent_node.get("title", agent_node.get("id", "unknown"))
|
||||
raw_config = agent_node.get("data", {}) or {}
|
||||
if not isinstance(raw_config, dict):
|
||||
errors.append(f"Agent node '{agent_title}' has invalid configuration")
|
||||
continue
|
||||
normalized_schema, schema_error = validate_json_schema_payload(
|
||||
raw_config.get("json_schema")
|
||||
)
|
||||
has_json_schema = normalized_schema is not None
|
||||
|
||||
model_id = raw_config.get("model_id")
|
||||
if has_json_schema and isinstance(model_id, str) and model_id.strip():
|
||||
capabilities = get_model_capabilities(model_id.strip())
|
||||
if capabilities and not capabilities.get("supports_structured_output", False):
|
||||
errors.append(
|
||||
f"Agent node '{agent_title}' selected model does not support structured output"
|
||||
)
|
||||
if schema_error:
|
||||
errors.append(f"Agent node '{agent_title}' JSON schema {schema_error}")
|
||||
|
||||
for node in nodes:
|
||||
if not node.get("id"):
|
||||
errors.append("All nodes must have an id")
|
||||
@@ -124,6 +296,20 @@ def validate_workflow_structure(nodes: List[Dict], edges: List[Dict]) -> List[st
|
||||
return errors
|
||||
|
||||
|
||||
def _can_reach_end(
|
||||
node_id: str, edges: List[Dict], node_map: Dict, end_ids: set, visited: set = None
|
||||
) -> bool:
|
||||
if visited is None:
|
||||
visited = set()
|
||||
if node_id in end_ids:
|
||||
return True
|
||||
if node_id in visited or node_id not in node_map:
|
||||
return False
|
||||
visited.add(node_id)
|
||||
outgoing = [e.get("target") for e in edges if e.get("source") == node_id]
|
||||
return any(_can_reach_end(t, edges, node_map, end_ids, visited) for t in outgoing if t)
|
||||
|
||||
|
||||
def create_workflow_nodes(
|
||||
workflow_id: str, nodes_data: List[Dict], graph_version: int
|
||||
) -> None:
|
||||
@@ -186,6 +372,7 @@ class WorkflowList(Resource):
|
||||
return error_response(
|
||||
"Workflow validation failed", errors=validation_errors
|
||||
)
|
||||
nodes_data = normalize_agent_node_json_schemas(nodes_data)
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
workflow_doc = {
|
||||
@@ -276,6 +463,7 @@ class WorkflowDetail(Resource):
|
||||
return error_response(
|
||||
"Workflow validation failed", errors=validation_errors
|
||||
)
|
||||
nodes_data = normalize_agent_node_json_schemas(nodes_data)
|
||||
|
||||
current_graph_version = get_workflow_graph_version(workflow)
|
||||
next_graph_version = current_graph_version + 1
|
||||
|
||||
34
application/core/json_schema_utils.py
Normal file
34
application/core/json_schema_utils.py
Normal file
@@ -0,0 +1,34 @@
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
|
||||
class JsonSchemaValidationError(ValueError):
|
||||
"""Raised when a JSON schema payload is invalid."""
|
||||
|
||||
|
||||
def normalize_json_schema_payload(json_schema: Any) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Normalize accepted JSON schema payload shapes to a plain schema object.
|
||||
|
||||
Accepted inputs:
|
||||
- None
|
||||
- A raw schema object with a top-level "type"
|
||||
- A wrapped payload with a top-level "schema" object
|
||||
"""
|
||||
if json_schema is None:
|
||||
return None
|
||||
|
||||
if not isinstance(json_schema, dict):
|
||||
raise JsonSchemaValidationError("must be a valid JSON object")
|
||||
|
||||
wrapped_schema = json_schema.get("schema")
|
||||
if wrapped_schema is not None:
|
||||
if not isinstance(wrapped_schema, dict):
|
||||
raise JsonSchemaValidationError('field "schema" must be a valid JSON object')
|
||||
return wrapped_schema
|
||||
|
||||
if "type" not in json_schema:
|
||||
raise JsonSchemaValidationError(
|
||||
'must include either a "type" or "schema" field'
|
||||
)
|
||||
|
||||
return json_schema
|
||||
@@ -13,10 +13,12 @@ class BaseLLM(ABC):
|
||||
def __init__(
|
||||
self,
|
||||
decoded_token=None,
|
||||
agent_id=None,
|
||||
model_id=None,
|
||||
base_url=None,
|
||||
):
|
||||
self.decoded_token = decoded_token
|
||||
self.agent_id = str(agent_id) if agent_id else None
|
||||
self.model_id = model_id
|
||||
self.base_url = base_url
|
||||
self.token_usage = {"prompt_tokens": 0, "generated_tokens": 0}
|
||||
@@ -33,9 +35,10 @@ class BaseLLM(ABC):
|
||||
self._fallback_llm = LLMCreator.create_llm(
|
||||
settings.FALLBACK_LLM_PROVIDER,
|
||||
api_key=settings.FALLBACK_LLM_API_KEY or settings.API_KEY,
|
||||
user_api_key=None,
|
||||
user_api_key=getattr(self, "user_api_key", None),
|
||||
decoded_token=self.decoded_token,
|
||||
model_id=settings.FALLBACK_LLM_NAME,
|
||||
agent_id=self.agent_id,
|
||||
)
|
||||
logger.info(
|
||||
f"Fallback LLM initialized: {settings.FALLBACK_LLM_PROVIDER}/{settings.FALLBACK_LLM_NAME}"
|
||||
|
||||
@@ -13,7 +13,7 @@ class GoogleLLM(BaseLLM):
|
||||
def __init__(
|
||||
self, api_key=None, user_api_key=None, decoded_token=None, *args, **kwargs
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
super().__init__(decoded_token=decoded_token, *args, **kwargs)
|
||||
self.api_key = api_key or settings.GOOGLE_API_KEY or settings.API_KEY
|
||||
self.user_api_key = user_api_key
|
||||
|
||||
|
||||
@@ -567,6 +567,7 @@ class LLMHandler(ABC):
|
||||
getattr(agent, "user_api_key", None),
|
||||
getattr(agent, "decoded_token", None),
|
||||
model_id=compression_model,
|
||||
agent_id=getattr(agent, "agent_id", None),
|
||||
)
|
||||
|
||||
# Create service without DB persistence capability
|
||||
|
||||
@@ -31,7 +31,15 @@ class LLMCreator:
|
||||
|
||||
@classmethod
|
||||
def create_llm(
|
||||
cls, type, api_key, user_api_key, decoded_token, model_id=None, *args, **kwargs
|
||||
cls,
|
||||
type,
|
||||
api_key,
|
||||
user_api_key,
|
||||
decoded_token,
|
||||
model_id=None,
|
||||
agent_id=None,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
from application.core.model_utils import get_base_url_for_model
|
||||
|
||||
@@ -49,6 +57,7 @@ class LLMCreator:
|
||||
user_api_key,
|
||||
decoded_token=decoded_token,
|
||||
model_id=model_id,
|
||||
agent_id=agent_id,
|
||||
base_url=base_url,
|
||||
*args,
|
||||
**kwargs,
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
anthropic==0.75.0
|
||||
boto3==1.42.17
|
||||
beautifulsoup4==4.14.3
|
||||
cel-python==0.5.0
|
||||
celery==5.6.0
|
||||
cryptography==46.0.3
|
||||
dataclasses-json==0.6.7
|
||||
|
||||
@@ -18,6 +18,7 @@ class ClassicRAG(BaseRetriever):
|
||||
doc_token_limit=50000,
|
||||
model_id="docsgpt-local",
|
||||
user_api_key=None,
|
||||
agent_id=None,
|
||||
llm_name=settings.LLM_PROVIDER,
|
||||
api_key=settings.API_KEY,
|
||||
decoded_token=None,
|
||||
@@ -43,6 +44,7 @@ class ClassicRAG(BaseRetriever):
|
||||
self.model_id = model_id
|
||||
self.doc_token_limit = doc_token_limit
|
||||
self.user_api_key = user_api_key
|
||||
self.agent_id = agent_id
|
||||
self.llm_name = llm_name
|
||||
self.api_key = api_key
|
||||
self.llm = LLMCreator.create_llm(
|
||||
@@ -50,6 +52,7 @@ class ClassicRAG(BaseRetriever):
|
||||
api_key=self.api_key,
|
||||
user_api_key=self.user_api_key,
|
||||
decoded_token=decoded_token,
|
||||
agent_id=self.agent_id,
|
||||
)
|
||||
|
||||
if "active_docs" in source and source["active_docs"] is not None:
|
||||
|
||||
@@ -1,22 +1,104 @@
|
||||
import sys
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.core.settings import settings
|
||||
from application.utils import num_tokens_from_object_or_list, num_tokens_from_string
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
mongo = MongoDB.get_client()
|
||||
db = mongo[settings.MONGO_DB_NAME]
|
||||
usage_collection = db["token_usage"]
|
||||
|
||||
|
||||
def update_token_usage(decoded_token, user_api_key, token_usage):
|
||||
def _serialize_for_token_count(value):
|
||||
"""Normalize payloads into token-countable primitives."""
|
||||
if isinstance(value, str):
|
||||
# Avoid counting large binary payloads in data URLs as text tokens.
|
||||
if value.startswith("data:") and ";base64," in value:
|
||||
return ""
|
||||
return value
|
||||
|
||||
if value is None:
|
||||
return ""
|
||||
|
||||
if isinstance(value, list):
|
||||
return [_serialize_for_token_count(item) for item in value]
|
||||
|
||||
if isinstance(value, dict):
|
||||
serialized = {}
|
||||
for key, raw in value.items():
|
||||
key_lower = str(key).lower()
|
||||
|
||||
# Skip raw binary-like fields; keep textual tool-call fields.
|
||||
if key_lower in {"data", "base64", "image_data"} and isinstance(raw, str):
|
||||
continue
|
||||
if key_lower == "url" and isinstance(raw, str) and ";base64," in raw:
|
||||
continue
|
||||
|
||||
serialized[key] = _serialize_for_token_count(raw)
|
||||
return serialized
|
||||
|
||||
if hasattr(value, "model_dump") and callable(getattr(value, "model_dump")):
|
||||
return _serialize_for_token_count(value.model_dump())
|
||||
if hasattr(value, "to_dict") and callable(getattr(value, "to_dict")):
|
||||
return _serialize_for_token_count(value.to_dict())
|
||||
if hasattr(value, "__dict__"):
|
||||
return _serialize_for_token_count(vars(value))
|
||||
|
||||
return str(value)
|
||||
|
||||
|
||||
def _count_tokens(value):
|
||||
serialized = _serialize_for_token_count(value)
|
||||
if isinstance(serialized, str):
|
||||
return num_tokens_from_string(serialized)
|
||||
return num_tokens_from_object_or_list(serialized)
|
||||
|
||||
|
||||
def _count_prompt_tokens(messages, tools=None, usage_attachments=None, **kwargs):
|
||||
prompt_tokens = 0
|
||||
|
||||
for message in messages or []:
|
||||
if not isinstance(message, dict):
|
||||
prompt_tokens += _count_tokens(message)
|
||||
continue
|
||||
|
||||
prompt_tokens += _count_tokens(message.get("content"))
|
||||
|
||||
# Include tool-related message fields for providers that use OpenAI-native format.
|
||||
prompt_tokens += _count_tokens(message.get("tool_calls"))
|
||||
prompt_tokens += _count_tokens(message.get("tool_call_id"))
|
||||
prompt_tokens += _count_tokens(message.get("function_call"))
|
||||
prompt_tokens += _count_tokens(message.get("function_response"))
|
||||
|
||||
# Count tool schema payload passed to the model.
|
||||
prompt_tokens += _count_tokens(tools)
|
||||
|
||||
# Count structured-output/schema payloads when provided.
|
||||
prompt_tokens += _count_tokens(kwargs.get("response_format"))
|
||||
prompt_tokens += _count_tokens(kwargs.get("response_schema"))
|
||||
|
||||
# Optional usage-only attachment context (not forwarded to provider).
|
||||
prompt_tokens += _count_tokens(usage_attachments)
|
||||
|
||||
return prompt_tokens
|
||||
|
||||
|
||||
def update_token_usage(decoded_token, user_api_key, token_usage, agent_id=None):
|
||||
if "pytest" in sys.modules:
|
||||
return
|
||||
if decoded_token:
|
||||
user_id = decoded_token["sub"]
|
||||
else:
|
||||
user_id = None
|
||||
user_id = decoded_token.get("sub") if isinstance(decoded_token, dict) else None
|
||||
normalized_agent_id = str(agent_id) if agent_id else None
|
||||
|
||||
if not user_id and not user_api_key and not normalized_agent_id:
|
||||
logger.warning(
|
||||
"Skipping token usage insert: missing user_id, api_key, and agent_id"
|
||||
)
|
||||
return
|
||||
|
||||
usage_data = {
|
||||
"user_id": user_id,
|
||||
"api_key": user_api_key,
|
||||
@@ -24,24 +106,31 @@ def update_token_usage(decoded_token, user_api_key, token_usage):
|
||||
"generated_tokens": token_usage["generated_tokens"],
|
||||
"timestamp": datetime.now(),
|
||||
}
|
||||
if normalized_agent_id:
|
||||
usage_data["agent_id"] = normalized_agent_id
|
||||
usage_collection.insert_one(usage_data)
|
||||
|
||||
|
||||
def gen_token_usage(func):
|
||||
def wrapper(self, model, messages, stream, tools, **kwargs):
|
||||
for message in messages:
|
||||
if message["content"]:
|
||||
self.token_usage["prompt_tokens"] += num_tokens_from_string(
|
||||
message["content"]
|
||||
)
|
||||
usage_attachments = kwargs.pop("_usage_attachments", None)
|
||||
call_usage = {"prompt_tokens": 0, "generated_tokens": 0}
|
||||
call_usage["prompt_tokens"] += _count_prompt_tokens(
|
||||
messages,
|
||||
tools=tools,
|
||||
usage_attachments=usage_attachments,
|
||||
**kwargs,
|
||||
)
|
||||
result = func(self, model, messages, stream, tools, **kwargs)
|
||||
if isinstance(result, str):
|
||||
self.token_usage["generated_tokens"] += num_tokens_from_string(result)
|
||||
else:
|
||||
self.token_usage["generated_tokens"] += num_tokens_from_object_or_list(
|
||||
result
|
||||
)
|
||||
update_token_usage(self.decoded_token, self.user_api_key, self.token_usage)
|
||||
call_usage["generated_tokens"] += _count_tokens(result)
|
||||
self.token_usage["prompt_tokens"] += call_usage["prompt_tokens"]
|
||||
self.token_usage["generated_tokens"] += call_usage["generated_tokens"]
|
||||
update_token_usage(
|
||||
self.decoded_token,
|
||||
self.user_api_key,
|
||||
call_usage,
|
||||
getattr(self, "agent_id", None),
|
||||
)
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
@@ -49,17 +138,28 @@ def gen_token_usage(func):
|
||||
|
||||
def stream_token_usage(func):
|
||||
def wrapper(self, model, messages, stream, tools, **kwargs):
|
||||
for message in messages:
|
||||
self.token_usage["prompt_tokens"] += num_tokens_from_string(
|
||||
message["content"]
|
||||
)
|
||||
usage_attachments = kwargs.pop("_usage_attachments", None)
|
||||
call_usage = {"prompt_tokens": 0, "generated_tokens": 0}
|
||||
call_usage["prompt_tokens"] += _count_prompt_tokens(
|
||||
messages,
|
||||
tools=tools,
|
||||
usage_attachments=usage_attachments,
|
||||
**kwargs,
|
||||
)
|
||||
batch = []
|
||||
result = func(self, model, messages, stream, tools, **kwargs)
|
||||
for r in result:
|
||||
batch.append(r)
|
||||
yield r
|
||||
for line in batch:
|
||||
self.token_usage["generated_tokens"] += num_tokens_from_string(line)
|
||||
update_token_usage(self.decoded_token, self.user_api_key, self.token_usage)
|
||||
call_usage["generated_tokens"] += _count_tokens(line)
|
||||
self.token_usage["prompt_tokens"] += call_usage["prompt_tokens"]
|
||||
self.token_usage["generated_tokens"] += call_usage["generated_tokens"]
|
||||
update_token_usage(
|
||||
self.decoded_token,
|
||||
self.user_api_key,
|
||||
call_usage,
|
||||
getattr(self, "agent_id", None),
|
||||
)
|
||||
|
||||
return wrapper
|
||||
|
||||
@@ -322,6 +322,7 @@ def run_agent_logic(agent_config, input_data):
|
||||
chunks = int(agent_config.get("chunks", 2))
|
||||
prompt_id = agent_config.get("prompt_id", "default")
|
||||
user_api_key = agent_config["key"]
|
||||
agent_id = str(agent_config.get("_id")) if agent_config.get("_id") else None
|
||||
agent_type = agent_config.get("agent_type", "classic")
|
||||
decoded_token = {"sub": agent_config.get("user")}
|
||||
json_schema = agent_config.get("json_schema")
|
||||
@@ -352,6 +353,7 @@ def run_agent_logic(agent_config, input_data):
|
||||
doc_token_limit=doc_token_limit,
|
||||
model_id=model_id,
|
||||
user_api_key=user_api_key,
|
||||
agent_id=agent_id,
|
||||
decoded_token=decoded_token,
|
||||
)
|
||||
|
||||
@@ -370,6 +372,7 @@ def run_agent_logic(agent_config, input_data):
|
||||
llm_name=provider or settings.LLM_PROVIDER,
|
||||
model_id=model_id,
|
||||
api_key=system_api_key,
|
||||
agent_id=agent_id,
|
||||
user_api_key=user_api_key,
|
||||
prompt=prompt,
|
||||
chat_history=[],
|
||||
|
||||
25
docs/app/[[...mdxPath]]/page.jsx
Normal file
25
docs/app/[[...mdxPath]]/page.jsx
Normal file
@@ -0,0 +1,25 @@
|
||||
import { generateStaticParamsFor, importPage } from 'nextra/pages';
|
||||
|
||||
import { useMDXComponents } from '../../mdx-components';
|
||||
|
||||
export const generateStaticParams = generateStaticParamsFor('mdxPath');
|
||||
|
||||
export async function generateMetadata(props) {
|
||||
const params = await props.params;
|
||||
const { metadata } = await importPage(params?.mdxPath);
|
||||
return metadata;
|
||||
}
|
||||
|
||||
const Wrapper = useMDXComponents().wrapper;
|
||||
|
||||
export default async function Page(props) {
|
||||
const params = await props.params;
|
||||
const result = await importPage(params?.mdxPath);
|
||||
const { default: MDXContent, metadata, sourceCode, toc } = result;
|
||||
|
||||
return (
|
||||
<Wrapper metadata={metadata} sourceCode={sourceCode} toc={toc}>
|
||||
<MDXContent {...props} params={params} />
|
||||
</Wrapper>
|
||||
);
|
||||
}
|
||||
86
docs/app/layout.jsx
Normal file
86
docs/app/layout.jsx
Normal file
@@ -0,0 +1,86 @@
|
||||
import Image from 'next/image';
|
||||
import { Analytics } from '@vercel/analytics/react';
|
||||
import { Banner, Head } from 'nextra/components';
|
||||
import { getPageMap } from 'nextra/page-map';
|
||||
import { Footer, Layout, Navbar } from 'nextra-theme-docs';
|
||||
import 'nextra-theme-docs/style.css';
|
||||
|
||||
import CuteLogo from '../public/cute-docsgpt.png';
|
||||
import themeConfig from '../theme.config';
|
||||
|
||||
const github = 'https://github.com/arc53/DocsGPT';
|
||||
|
||||
export const metadata = {
|
||||
title: {
|
||||
default: 'DocsGPT Documentation',
|
||||
template: '%s - DocsGPT Documentation',
|
||||
},
|
||||
description:
|
||||
'Use DocsGPT to chat with your data. DocsGPT is a GPT-powered chatbot that can answer questions about your data.',
|
||||
};
|
||||
|
||||
const navbar = (
|
||||
<Navbar
|
||||
logo={
|
||||
<div style={{ alignItems: 'center', display: 'flex', gap: '8px' }}>
|
||||
<Image src={CuteLogo} alt="DocsGPT logo" width={28} height={28} />
|
||||
<span style={{ fontWeight: 'bold', fontSize: 18 }}>DocsGPT Docs</span>
|
||||
</div>
|
||||
}
|
||||
projectLink={github}
|
||||
chatLink="https://discord.com/invite/n5BX8dh8rU"
|
||||
/>
|
||||
);
|
||||
|
||||
const footer = (
|
||||
<Footer>
|
||||
<span>MIT {new Date().getFullYear()} © </span>
|
||||
<a href="https://www.docsgpt.cloud/" target="_blank" rel="noreferrer">
|
||||
DocsGPT
|
||||
</a>
|
||||
{' | '}
|
||||
<a href="https://github.com/arc53/DocsGPT" target="_blank" rel="noreferrer">
|
||||
GitHub
|
||||
</a>
|
||||
{' | '}
|
||||
<a href="https://blog.docsgpt.cloud/" target="_blank" rel="noreferrer">
|
||||
Blog
|
||||
</a>
|
||||
</Footer>
|
||||
);
|
||||
|
||||
export default async function RootLayout({ children }) {
|
||||
return (
|
||||
<html lang="en" dir="ltr" suppressHydrationWarning>
|
||||
<Head>
|
||||
<link
|
||||
rel="apple-touch-icon"
|
||||
sizes="180x180"
|
||||
href="/favicons/apple-touch-icon.png"
|
||||
/>
|
||||
<link rel="icon" type="image/png" sizes="32x32" href="/favicons/favicon-32x32.png" />
|
||||
<link rel="icon" type="image/png" sizes="16x16" href="/favicons/favicon-16x16.png" />
|
||||
<link rel="manifest" href="/favicons/site.webmanifest" />
|
||||
<meta httpEquiv="Content-Language" content="en" />
|
||||
</Head>
|
||||
<body>
|
||||
<Layout
|
||||
banner={
|
||||
<Banner storageKey="docs-launch">
|
||||
<div className="flex justify-center items-center gap-2">
|
||||
Welcome to the new DocsGPT docs!
|
||||
</div>
|
||||
</Banner>
|
||||
}
|
||||
navbar={navbar}
|
||||
footer={footer}
|
||||
pageMap={await getPageMap()}
|
||||
{...themeConfig}
|
||||
>
|
||||
{children}
|
||||
</Layout>
|
||||
<Analytics />
|
||||
</body>
|
||||
</html>
|
||||
);
|
||||
}
|
||||
@@ -1,3 +1,5 @@
|
||||
'use client';
|
||||
|
||||
import Image from 'next/image';
|
||||
|
||||
const iconMap = {
|
||||
@@ -117,4 +119,4 @@ export function DeploymentCards({ items }) {
|
||||
`}</style>
|
||||
</>
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
'use client';
|
||||
|
||||
import Image from 'next/image';
|
||||
|
||||
const iconMap = {
|
||||
@@ -114,4 +116,4 @@ export function ToolCards({ items }) {
|
||||
`}</style>
|
||||
</>
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
{
|
||||
export default {
|
||||
"basics": {
|
||||
"title": "🤖 Agent Basics",
|
||||
"href": "/Agents/basics"
|
||||
@@ -10,5 +10,9 @@
|
||||
"webhooks": {
|
||||
"title": "🪝 Agent Webhooks",
|
||||
"href": "/Agents/webhooks"
|
||||
},
|
||||
"nodes": {
|
||||
"title": "🧩 Workflow Nodes",
|
||||
"href": "/Agents/nodes"
|
||||
}
|
||||
}
|
||||
342
docs/content/Agents/api.mdx
Normal file
342
docs/content/Agents/api.mdx
Normal file
@@ -0,0 +1,342 @@
|
||||
---
|
||||
title: Interacting with Agents via API
|
||||
description: Learn how to programmatically interact with DocsGPT Agents using the streaming and non-streaming API endpoints.
|
||||
---
|
||||
|
||||
import { Callout, Tabs } from 'nextra/components';
|
||||
|
||||
# Interacting with Agents via API
|
||||
|
||||
DocsGPT Agents can be accessed programmatically through API endpoints. This page covers:
|
||||
|
||||
- Non-streaming answers (`/api/answer`)
|
||||
- Streaming answers over SSE (`/stream`)
|
||||
- File/image attachments (`/api/store_attachment` + `/api/task_status` + `/stream`)
|
||||
|
||||
When you use an agent `api_key`, DocsGPT loads that agent's configuration automatically (prompt, tools, sources, default model). You usually only need to send `question` and `api_key`.
|
||||
|
||||
## Base URL
|
||||
|
||||
<Callout type="info">
|
||||
For DocsGPT Cloud, use `https://gptcloud.arc53.com` as the base URL.
|
||||
</Callout>
|
||||
|
||||
- Local: `http://localhost:7091`
|
||||
- Cloud: `https://gptcloud.arc53.com`
|
||||
|
||||
## How Request Resolution Works
|
||||
|
||||
DocsGPT resolves your request in this order:
|
||||
|
||||
1. If `api_key` is provided, DocsGPT loads the mapped agent and executes with that config.
|
||||
2. If `agent_id` is provided (typically with JWT auth), DocsGPT loads that agent if allowed.
|
||||
3. If neither is provided, DocsGPT uses request-level fields (`prompt_id`, `active_docs`, `retriever`, etc.).
|
||||
|
||||
Authentication:
|
||||
|
||||
- Agent API-key flow: include `api_key` in JSON/form payload.
|
||||
- JWT flow (if auth enabled): include `Authorization: Bearer <token>`.
|
||||
|
||||
## Endpoints
|
||||
|
||||
- `POST /api/answer` (non-streaming)
|
||||
- `POST /stream` (SSE streaming)
|
||||
- `POST /api/store_attachment` (multipart upload)
|
||||
- `GET /api/task_status?task_id=...` (Celery task polling)
|
||||
|
||||
## Request Parameters
|
||||
|
||||
Common request body fields:
|
||||
|
||||
| Field | Type | Required | Applies to | Notes |
|
||||
| --- | --- | --- | --- | --- |
|
||||
| `question` | `string` | Yes | `/api/answer`, `/stream` | User query. |
|
||||
| `api_key` | `string` | Usually | `/api/answer`, `/stream` | Recommended for agent API use. Loads agent config from key. |
|
||||
| `conversation_id` | `string` | No | `/api/answer`, `/stream` | Continue an existing conversation. |
|
||||
| `history` | `string` (JSON-encoded array) | No | `/api/answer`, `/stream` | Used for new conversations. Format: `[{\"prompt\":\"...\",\"response\":\"...\"}]`. |
|
||||
| `model_id` | `string` | No | `/api/answer`, `/stream` | Override model for this request. |
|
||||
| `save_conversation` | `boolean` | No | `/api/answer`, `/stream` | Default `true`. If `false`, no conversation is persisted. |
|
||||
| `passthrough` | `object` | No | `/api/answer`, `/stream` | Dynamic values injected into prompt templates. |
|
||||
| `prompt_id` | `string` | No | `/api/answer`, `/stream` | Ignored when `api_key` already defines prompt. |
|
||||
| `active_docs` | `string` or `string[]` | No | `/api/answer`, `/stream` | Overrides active docs when not using key-owned source config. |
|
||||
| `retriever` | `string` | No | `/api/answer`, `/stream` | Retriever type (for example `classic`). |
|
||||
| `chunks` | `number` | No | `/api/answer`, `/stream` | Retrieval chunk count, default `2`. |
|
||||
| `isNoneDoc` | `boolean` | No | `/api/answer`, `/stream` | Skip document retrieval. |
|
||||
| `agent_id` | `string` | No | `/api/answer`, `/stream` | Alternative to `api_key` when using authenticated user context. |
|
||||
|
||||
Streaming-only fields:
|
||||
|
||||
| Field | Type | Required | Notes |
|
||||
| --- | --- | --- | --- |
|
||||
| `attachments` | `string[]` | No | List of attachment IDs from `/api/task_status` success result. |
|
||||
| `index` | `number` | No | Update an existing query index. If provided, `conversation_id` is required. |
|
||||
|
||||
## Non-Streaming API (`/api/answer`)
|
||||
|
||||
`/api/answer` waits for completion and returns one JSON response.
|
||||
|
||||
<Callout type="info">
|
||||
`attachments` are currently handled through `/stream`. For file/image-attached queries, use the streaming endpoint.
|
||||
</Callout>
|
||||
|
||||
Response fields:
|
||||
|
||||
- `conversation_id`
|
||||
- `answer`
|
||||
- `sources`
|
||||
- `tool_calls`
|
||||
- `thought`
|
||||
- Optional structured output metadata (`structured`, `schema`) when enabled
|
||||
|
||||
### Examples
|
||||
|
||||
<Tabs items={['cURL', 'Python', 'JavaScript']}>
|
||||
<Tabs.Tab>
|
||||
```bash
|
||||
curl -X POST http://localhost:7091/api/answer \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"question":"your question here","api_key":"your_agent_api_key"}'
|
||||
```
|
||||
</Tabs.Tab>
|
||||
<Tabs.Tab>
|
||||
```python
|
||||
import requests
|
||||
|
||||
API_URL = "http://localhost:7091/api/answer"
|
||||
API_KEY = "your_agent_api_key"
|
||||
QUESTION = "your question here"
|
||||
|
||||
response = requests.post(
|
||||
API_URL,
|
||||
json={"question": QUESTION, "api_key": API_KEY}
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
print(response.json())
|
||||
else:
|
||||
print(f"Error: {response.status_code}")
|
||||
print(response.text)
|
||||
```
|
||||
</Tabs.Tab>
|
||||
<Tabs.Tab>
|
||||
```javascript
|
||||
const apiUrl = 'http://localhost:7091/api/answer';
|
||||
const apiKey = 'your_agent_api_key';
|
||||
const question = 'your question here';
|
||||
|
||||
async function getAnswer() {
|
||||
try {
|
||||
const response = await fetch(apiUrl, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
body: JSON.stringify({ question, api_key: apiKey }),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error(`HTTP error! Status: ${response.status}`);
|
||||
}
|
||||
|
||||
const data = await response.json();
|
||||
console.log(data);
|
||||
} catch (error) {
|
||||
console.error("Failed to fetch answer:", error);
|
||||
}
|
||||
}
|
||||
|
||||
getAnswer();
|
||||
```
|
||||
</Tabs.Tab>
|
||||
</Tabs>
|
||||
|
||||
---
|
||||
|
||||
## Streaming API (`/stream`)
|
||||
|
||||
`/stream` returns a Server-Sent Events (SSE) stream so you can render output token-by-token.
|
||||
|
||||
### SSE Event Types
|
||||
|
||||
Each `data:` frame is JSON with `type`:
|
||||
|
||||
- `answer`: incremental answer chunk
|
||||
- `source`: source list/chunks
|
||||
- `tool_calls`: tool invocation results/metadata
|
||||
- `thought`: reasoning/thought chunk (agent dependent)
|
||||
- `structured_answer`: final structured payload (when schema mode is active)
|
||||
- `id`: final conversation ID
|
||||
- `error`: error message
|
||||
- `end`: stream is complete
|
||||
|
||||
### Examples
|
||||
|
||||
<Tabs items={['cURL', 'Python', 'JavaScript']}>
|
||||
<Tabs.Tab>
|
||||
```bash
|
||||
curl -X POST http://localhost:7091/stream \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "Accept: text/event-stream" \
|
||||
-d '{"question":"your question here","api_key":"your_agent_api_key"}'
|
||||
```
|
||||
</Tabs.Tab>
|
||||
<Tabs.Tab>
|
||||
```python
|
||||
import requests
|
||||
import json
|
||||
|
||||
API_URL = "http://localhost:7091/stream"
|
||||
payload = {
|
||||
"question": "your question here",
|
||||
"api_key": "your_agent_api_key"
|
||||
}
|
||||
|
||||
with requests.post(API_URL, json=payload, stream=True) as r:
|
||||
for line in r.iter_lines():
|
||||
if line:
|
||||
decoded_line = line.decode('utf-8')
|
||||
if decoded_line.startswith('data: '):
|
||||
try:
|
||||
data = json.loads(decoded_line[6:])
|
||||
print(data)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
```
|
||||
</Tabs.Tab>
|
||||
<Tabs.Tab>
|
||||
```javascript
|
||||
const apiUrl = 'http://localhost:7091/stream';
|
||||
const apiKey = 'your_agent_api_key';
|
||||
const question = 'your question here';
|
||||
|
||||
async function getStream() {
|
||||
try {
|
||||
const response = await fetch(apiUrl, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
'Accept': 'text/event-stream'
|
||||
},
|
||||
body: JSON.stringify({ question, api_key: apiKey }),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error(`HTTP error! Status: ${response.status}`);
|
||||
}
|
||||
|
||||
const reader = response.body.getReader();
|
||||
const decoder = new TextDecoder();
|
||||
|
||||
while (true) {
|
||||
const { done, value } = await reader.read();
|
||||
if (done) break;
|
||||
|
||||
const chunk = decoder.decode(value, { stream: true });
|
||||
// Note: This parsing method assumes each chunk contains whole lines.
|
||||
// For a more robust production implementation, buffer the chunks
|
||||
// and process them line by line.
|
||||
const lines = chunk.split('\n');
|
||||
|
||||
for (const line of lines) {
|
||||
if (line.startsWith('data: ')) {
|
||||
try {
|
||||
const data = JSON.parse(line.substring(6));
|
||||
console.log(data);
|
||||
} catch (e) {
|
||||
console.error("Failed to parse JSON from SSE event:", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Failed to fetch stream:", error);
|
||||
}
|
||||
}
|
||||
|
||||
getStream();
|
||||
```
|
||||
</Tabs.Tab>
|
||||
</Tabs>
|
||||
|
||||
---
|
||||
|
||||
## Attachments API (Including Images)
|
||||
|
||||
To attach an image (or other file) to a query:
|
||||
|
||||
1. Upload file(s) to `/api/store_attachment` (multipart/form-data).
|
||||
2. Poll `/api/task_status` until `status=SUCCESS`.
|
||||
3. Read `result.attachment_id` from task result.
|
||||
4. Send that ID in `/stream` as `attachments: ["..."]`.
|
||||
|
||||
<Callout type="warning">
|
||||
Attachments are processed asynchronously. Do not call `/stream` with an attachment until its task has finished with `SUCCESS`.
|
||||
</Callout>
|
||||
|
||||
### Step 1: Upload Attachment
|
||||
|
||||
`POST /api/store_attachment`
|
||||
|
||||
- Content type: `multipart/form-data`
|
||||
- Form fields:
|
||||
- `file` (required, can be repeated for multi-file upload)
|
||||
- `api_key` (optional if JWT is present; useful for API-key-only flows)
|
||||
|
||||
Example upload (single image):
|
||||
|
||||
```bash
|
||||
curl -X POST http://localhost:7091/api/store_attachment \
|
||||
-F "file=@/absolute/path/to/image.png" \
|
||||
-F "api_key=your_agent_api_key"
|
||||
```
|
||||
|
||||
Possible response (single-file upload):
|
||||
|
||||
```json
|
||||
{
|
||||
"success": true,
|
||||
"task_id": "34f1cb56-7c7f-4d5f-a973-4ea7e65f7a10",
|
||||
"message": "File uploaded successfully. Processing started."
|
||||
}
|
||||
```
|
||||
|
||||
### Step 2: Poll Task Status
|
||||
|
||||
```bash
|
||||
curl "http://localhost:7091/api/task_status?task_id=34f1cb56-7c7f-4d5f-a973-4ea7e65f7a10"
|
||||
```
|
||||
|
||||
When complete:
|
||||
|
||||
```json
|
||||
{
|
||||
"status": "SUCCESS",
|
||||
"result": {
|
||||
"attachment_id": "67b4f8f2618dc9f19384a9e1",
|
||||
"filename": "image.png",
|
||||
"mime_type": "image/png"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Step 3: Attach to `/stream` Request
|
||||
|
||||
Use the `attachment_id` in `attachments`.
|
||||
|
||||
```bash
|
||||
curl -X POST http://localhost:7091/stream \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "Accept: text/event-stream" \
|
||||
-d '{
|
||||
"question": "Describe this image",
|
||||
"api_key": "your_agent_api_key",
|
||||
"attachments": ["67b4f8f2618dc9f19384a9e1"]
|
||||
}'
|
||||
```
|
||||
|
||||
### Image/Attachment Behavior Notes
|
||||
|
||||
- Typical image MIME types supported for native vision flows: `image/png`, `image/jpeg`, `image/jpg`, `image/webp`, `image/gif`.
|
||||
- If the selected model/provider does not support a file type natively, DocsGPT falls back to parsed text content.
|
||||
- For providers that support images but not native PDF file attachments, DocsGPT can convert PDF pages to images (synthetic PDF support).
|
||||
- Attachments are user-scoped. Upload and query must be done under the same user context (same API key owner or same JWT user).
|
||||
65
docs/content/Agents/nodes.mdx
Normal file
65
docs/content/Agents/nodes.mdx
Normal file
@@ -0,0 +1,65 @@
|
||||
# Workflow Nodes
|
||||
|
||||
DocsGPT workflows are composed of **Nodes** that are connected to form a processing graph. These nodes interact with a **Shared State**—a global dictionary of variables that persists throughout the execution of the workflow.
|
||||
|
||||
## The Shared State
|
||||
|
||||
Every workflow run maintains a state object (a JSON-like dictionary).
|
||||
- **Initial State**: Contains the user's input query (`{{query}}`) and chat history (`{{chat_history}}`).
|
||||
- **Accessing Variables**: You can access any variable in the state using the double-curly braces syntax: `{{variable_name}}`.
|
||||
- **Modifying State**: Nodes read from this state and write their outputs back to it.
|
||||
|
||||
---
|
||||
|
||||
## AI Agent Node
|
||||
|
||||
The **AI Agent Node** is the core processing unit. It uses a Large Language Model (LLM) to generate text, answer questions, or perform tasks using tools.
|
||||
|
||||
### Inputs (Template Variables)
|
||||
|
||||
The primary input is the **Prompt Template**. This field supports variable substitution.
|
||||
|
||||
- **Prompt Template**: The text sent to the model.
|
||||
- *Example*: `"Summarize the following text: {{user_input_text}}"`
|
||||
- If left empty, it defaults to the initial user query (`{{query}}`).
|
||||
- **System Prompt**: Instructions that define the agent's persona and constraints.
|
||||
- **Tools**: A list of tools the agent can use (e.g., search, calculator).
|
||||
- **LLM Settings**: Specific provider, model name, and parameters.
|
||||
|
||||
### Outputs (Emissions)
|
||||
|
||||
When the agent completes its task, it stores the result in the shared state.
|
||||
|
||||
- **Output Variable**: The name of the variable where the result will be saved.
|
||||
- *Default*: If not specified, it is saved as `node_{node_id}_output`.
|
||||
- *Custom*: You can set this to something meaningful, like `summary` or `translated_text`.
|
||||
- **Streaming**: If "Stream to user" is enabled, the output is sent to the user in real-time as it is generated, in addition to being saved to the state.
|
||||
|
||||
---
|
||||
|
||||
## Set State Node
|
||||
|
||||
The **Set State Node** allows you to manipulate variables within the shared state directly without calling an LLM. This is useful for initialization, formatting, or control flow logic.
|
||||
|
||||
### Operations
|
||||
|
||||
You can define multiple operations in a single node. Each operation targets a specific **Key** (variable name).
|
||||
|
||||
1. **Set**: Assigns a specific value to a variable.
|
||||
- *Value*: Can be a static string or a template using variables.
|
||||
- *Example*: Set `current_step` to `1`.
|
||||
- *Example*: Set `formatted_response` to `Analysis: {{analysis_result}}`.
|
||||
|
||||
2. **Increment**: Increases the value of a numeric variable.
|
||||
- *Value*: The amount to add (default is 1).
|
||||
- *Example*: Increment `retry_count` by `1`.
|
||||
|
||||
3. **Append**: Adds a value to a list variable.
|
||||
- *Value*: The item to add to the list.
|
||||
- *Example*: Append `{{last_result}}` to `history_list`.
|
||||
|
||||
### Usage Examples
|
||||
|
||||
- **Loop Counters**: Use a *Set State* node to initialize a counter (`i = 0`) before a loop, and another to increment it inside the loop.
|
||||
- **Accumulators**: Use *Append* to collect results from multiple parallel branches into a single list.
|
||||
- **Renaming**: Copy the output of a previous node to a more generic name (e.g., set `context` to `{{search_results}}`) so subsequent nodes can use a standard variable name.
|
||||
@@ -1,4 +1,4 @@
|
||||
{
|
||||
export default {
|
||||
"DocsGPT-Settings": {
|
||||
"title": "⚙️ App Configuration",
|
||||
"href": "/Deploying/DocsGPT-Settings"
|
||||
@@ -29,4 +29,4 @@
|
||||
"href": "/Deploying/Railway",
|
||||
"display": "hidden"
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
{
|
||||
export default {
|
||||
"api-key-guide": {
|
||||
"title": "🔑 Getting API key",
|
||||
"href": "/Extensions/api-key-guide"
|
||||
@@ -19,4 +19,4 @@
|
||||
"title": "🗣️ Chatwoot Extension",
|
||||
"href": "/Extensions/Chatwoot-extension"
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
{
|
||||
export default {
|
||||
"google-drive-connector": {
|
||||
"title": "🔗 Google Drive",
|
||||
"href": "/Guides/Integrations/google-drive-connector"
|
||||
@@ -1,4 +1,4 @@
|
||||
{
|
||||
export default {
|
||||
"Customising-prompts": {
|
||||
"title": "️💻 Customising Prompts",
|
||||
"href": "/Guides/Customising-prompts"
|
||||
@@ -1,4 +1,4 @@
|
||||
{
|
||||
export default {
|
||||
"cloud-providers": {
|
||||
"title": "☁️ Cloud Providers",
|
||||
"href": "/Models/cloud-providers"
|
||||
@@ -11,4 +11,4 @@
|
||||
"title": "📝 Embeddings",
|
||||
"href": "/Models/embeddings"
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
{
|
||||
export default {
|
||||
"basics": {
|
||||
"title": "🔧 Tools Basics",
|
||||
"href": "/Tools/basics"
|
||||
@@ -11,4 +11,4 @@
|
||||
"title": "🛠️ Creating a Custom Tool",
|
||||
"href": "/Tools/creating-a-tool"
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,5 +1,4 @@
|
||||
|
||||
{
|
||||
export default {
|
||||
"index": "Home",
|
||||
"quickstart": "Quickstart",
|
||||
"Deploying": "Deploying",
|
||||
@@ -9,12 +8,11 @@
|
||||
"Extensions": "Extensions",
|
||||
"https://gptcloud.arc53.com/": {
|
||||
"title": "API",
|
||||
"href": "https://gptcloud.arc53.com/",
|
||||
"newWindow": true
|
||||
"href": "https://gptcloud.arc53.com/"
|
||||
},
|
||||
"Guides": "Guides",
|
||||
"changelog": {
|
||||
"title": "Changelog",
|
||||
"display": "hidden"
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -2,7 +2,7 @@
|
||||
title: 'Home'
|
||||
description: Documentation of DocsGPT - quickstart, deployment guides, model configuration, and widget integration documentation.
|
||||
---
|
||||
import { Cards, Card } from 'nextra/components'
|
||||
import { Cards } from 'nextra/components'
|
||||
import Image from 'next/image'
|
||||
|
||||
export const allGuides = {
|
||||
@@ -85,7 +85,7 @@ Try it yourself: [https://www.docsgpt.cloud/](https://www.docsgpt.cloud/)
|
||||
<Cards
|
||||
num={3}
|
||||
children={Object.keys(allGuides).map((key, i) => (
|
||||
<Card
|
||||
<Cards.Card
|
||||
key={i}
|
||||
title={allGuides[key].title}
|
||||
href={allGuides[key].href}
|
||||
8
docs/mdx-components.jsx
Normal file
8
docs/mdx-components.jsx
Normal file
@@ -0,0 +1,8 @@
|
||||
import { useMDXComponents as getThemeComponents } from 'nextra-theme-docs';
|
||||
|
||||
export function useMDXComponents(components) {
|
||||
return {
|
||||
...getThemeComponents(),
|
||||
...components,
|
||||
};
|
||||
}
|
||||
@@ -1,9 +1,9 @@
|
||||
const withNextra = require('nextra')({
|
||||
theme: 'nextra-theme-docs',
|
||||
themeConfig: './theme.config.jsx'
|
||||
})
|
||||
const nextra = require('nextra').default;
|
||||
|
||||
module.exports = withNextra()
|
||||
|
||||
// If you have other Next.js configurations, you can pass them as the parameter:
|
||||
// module.exports = withNextra({ /* other next.js config */ })
|
||||
const withNextra = nextra({
|
||||
defaultShowCopyCode: true,
|
||||
});
|
||||
|
||||
module.exports = withNextra({
|
||||
reactStrictMode: true,
|
||||
});
|
||||
|
||||
5580
docs/package-lock.json
generated
5580
docs/package-lock.json
generated
File diff suppressed because it is too large
Load Diff
@@ -2,6 +2,7 @@
|
||||
"scripts": {
|
||||
"dev": "next dev",
|
||||
"build": "next build",
|
||||
"postbuild": "pagefind --site .next/server/app --output-path public/_pagefind",
|
||||
"start": "next start"
|
||||
},
|
||||
"license": "MIT",
|
||||
@@ -9,9 +10,13 @@
|
||||
"@vercel/analytics": "^1.1.1",
|
||||
"docsgpt-react": "^0.5.1",
|
||||
"next": "^15.5.9",
|
||||
"nextra": "^2.13.2",
|
||||
"nextra-theme-docs": "^2.13.2",
|
||||
"nextra": "^4.6.1",
|
||||
"nextra-theme-docs": "^4.6.1",
|
||||
"react": "^18.2.0",
|
||||
"react-dom": "^18.2.0"
|
||||
},
|
||||
"devDependencies": {
|
||||
"pagefind": "^1.3.0",
|
||||
"typescript": "^5.9.3"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,227 +0,0 @@
|
||||
---
|
||||
title: Interacting with Agents via API
|
||||
description: Learn how to programmatically interact with DocsGPT Agents using the streaming and non-streaming API endpoints.
|
||||
---
|
||||
|
||||
import { Callout, Tabs } from 'nextra/components';
|
||||
|
||||
# Interacting with Agents via API
|
||||
|
||||
DocsGPT Agents can be accessed programmatically through a dedicated API, allowing you to integrate their specialized capabilities into your own applications, scripts, and workflows. This guide covers the two primary methods for interacting with an agent: the streaming API for real-time responses and the non-streaming API for a single, consolidated answer.
|
||||
|
||||
When you use an API key generated for a specific agent, you do not need to pass `prompt`, `tools` etc. The agent's configuration (including its prompt, selected tools, and knowledge sources) is already associated with its unique API key.
|
||||
|
||||
### API Endpoints
|
||||
|
||||
- **Non-Streaming:** `http://localhost:7091/api/answer`
|
||||
- **Streaming:** `http://localhost:7091/stream`
|
||||
|
||||
<Callout type="info">
|
||||
For DocsGPT Cloud, use `https://gptcloud.arc53.com/` as the base URL.
|
||||
</Callout>
|
||||
|
||||
For more technical details, you can explore the API swagger documentation available for the cloud version or your local instance.
|
||||
|
||||
---
|
||||
|
||||
## Non-Streaming API (`/api/answer`)
|
||||
|
||||
This is a standard synchronous endpoint. It waits for the agent to fully process the request and returns a single JSON object with the complete answer. This is the simplest method and is ideal for backend processes where a real-time feed is not required.
|
||||
|
||||
### Request
|
||||
|
||||
- **Endpoint:** `/api/answer`
|
||||
- **Method:** `POST`
|
||||
- **Payload:**
|
||||
- `question` (string, required): The user's query or input for the agent.
|
||||
- `api_key` (string, required): The unique API key for the agent you wish to interact with.
|
||||
- `history` (string, optional): A JSON string representing the conversation history, e.g., `[{\"prompt\": \"first question\", \"answer\": \"first answer\"}]`.
|
||||
|
||||
### Response
|
||||
|
||||
A single JSON object containing:
|
||||
- `answer`: The complete, final answer from the agent.
|
||||
- `sources`: A list of sources the agent consulted.
|
||||
- `conversation_id`: The unique ID for the interaction.
|
||||
|
||||
### Examples
|
||||
|
||||
<Tabs items={['cURL', 'Python', 'JavaScript']}>
|
||||
<Tabs.Tab>
|
||||
```bash
|
||||
curl -X POST http://localhost:7091/api/answer \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"question": "your question here",
|
||||
"api_key": "your_agent_api_key"
|
||||
}'
|
||||
```
|
||||
</Tabs.Tab>
|
||||
<Tabs.Tab>
|
||||
```python
|
||||
import requests
|
||||
|
||||
API_URL = "http://localhost:7091/api/answer"
|
||||
API_KEY = "your_agent_api_key"
|
||||
QUESTION = "your question here"
|
||||
|
||||
response = requests.post(
|
||||
API_URL,
|
||||
json={"question": QUESTION, "api_key": API_KEY}
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
print(response.json())
|
||||
else:
|
||||
print(f"Error: {response.status_code}")
|
||||
print(response.text)
|
||||
```
|
||||
</Tabs.Tab>
|
||||
<Tabs.Tab>
|
||||
```javascript
|
||||
const apiUrl = 'http://localhost:7091/api/answer';
|
||||
const apiKey = 'your_agent_api_key';
|
||||
const question = 'your question here';
|
||||
|
||||
async function getAnswer() {
|
||||
try {
|
||||
const response = await fetch(apiUrl, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
body: JSON.stringify({ question, api_key: apiKey }),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error(`HTTP error! Status: ${response.status}`);
|
||||
}
|
||||
|
||||
const data = await response.json();
|
||||
console.log(data);
|
||||
} catch (error) {
|
||||
console.error("Failed to fetch answer:", error);
|
||||
}
|
||||
}
|
||||
|
||||
getAnswer();
|
||||
```
|
||||
</Tabs.Tab>
|
||||
</Tabs>
|
||||
|
||||
---
|
||||
|
||||
## Streaming API (`/stream`)
|
||||
|
||||
The `/stream` endpoint uses Server-Sent Events (SSE) to push data in real-time. This is ideal for applications where you want to display the response as it's being generated, such as in a live chatbot interface.
|
||||
|
||||
### Request
|
||||
|
||||
- **Endpoint:** `/stream`
|
||||
- **Method:** `POST`
|
||||
- **Payload:** Same as the non-streaming API.
|
||||
|
||||
### Response (SSE Stream)
|
||||
|
||||
The stream consists of multiple `data:` events, each containing a JSON object. Your client should listen for these events and process them based on their `type`.
|
||||
|
||||
**Event Types:**
|
||||
- `answer`: A chunk of the agent's final answer.
|
||||
- `source`: A document or source used by the agent.
|
||||
- `thought`: A reasoning step from the agent (for ReAct agents).
|
||||
- `id`: The unique `conversation_id` for the interaction.
|
||||
- `error`: An error message.
|
||||
- `end`: A final message indicating the stream has concluded.
|
||||
|
||||
### Examples
|
||||
|
||||
<Tabs items={['cURL', 'Python', 'JavaScript']}>
|
||||
<Tabs.Tab>
|
||||
```bash
|
||||
curl -X POST http://localhost:7091/stream \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "Accept: text/event-stream" \
|
||||
-d '{
|
||||
"question": "your question here",
|
||||
"api_key": "your_agent_api_key"
|
||||
}'
|
||||
```
|
||||
</Tabs.Tab>
|
||||
<Tabs.Tab>
|
||||
```python
|
||||
import requests
|
||||
import json
|
||||
|
||||
API_URL = "http://localhost:7091/stream"
|
||||
payload = {
|
||||
"question": "your question here",
|
||||
"api_key": "your_agent_api_key"
|
||||
}
|
||||
|
||||
with requests.post(API_URL, json=payload, stream=True) as r:
|
||||
for line in r.iter_lines():
|
||||
if line:
|
||||
decoded_line = line.decode('utf-8')
|
||||
if decoded_line.startswith('data: '):
|
||||
try:
|
||||
data = json.loads(decoded_line[6:])
|
||||
print(data)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
```
|
||||
</Tabs.Tab>
|
||||
<Tabs.Tab>
|
||||
```javascript
|
||||
const apiUrl = 'http://localhost:7091/stream';
|
||||
const apiKey = 'your_agent_api_key';
|
||||
const question = 'your question here';
|
||||
|
||||
async function getStream() {
|
||||
try {
|
||||
const response = await fetch(apiUrl, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
'Accept': 'text/event-stream'
|
||||
},
|
||||
// Corrected line: 'apiKey' is changed to 'api_key'
|
||||
body: JSON.stringify({ question, api_key: apiKey }),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error(`HTTP error! Status: ${response.status}`);
|
||||
}
|
||||
|
||||
const reader = response.body.getReader();
|
||||
const decoder = new TextDecoder();
|
||||
|
||||
while (true) {
|
||||
const { done, value } = await reader.read();
|
||||
if (done) break;
|
||||
|
||||
const chunk = decoder.decode(value, { stream: true });
|
||||
// Note: This parsing method assumes each chunk contains whole lines.
|
||||
// For a more robust production implementation, buffer the chunks
|
||||
// and process them line by line.
|
||||
const lines = chunk.split('\n');
|
||||
|
||||
for (const line of lines) {
|
||||
if (line.startsWith('data: ')) {
|
||||
try {
|
||||
const data = JSON.parse(line.substring(6));
|
||||
console.log(data);
|
||||
} catch (e) {
|
||||
console.error("Failed to parse JSON from SSE event:", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Failed to fetch stream:", error);
|
||||
}
|
||||
}
|
||||
|
||||
getStream();
|
||||
```
|
||||
</Tabs.Tab>
|
||||
</Tabs>
|
||||
@@ -1,10 +0,0 @@
|
||||
import { DocsGPTWidget } from "docsgpt-react";
|
||||
|
||||
export default function MyApp({ Component, pageProps }) {
|
||||
return (
|
||||
<>
|
||||
<Component {...pageProps} />
|
||||
<DocsGPTWidget showSources={true} apiKey="5d8270cb-735f-484e-9dc9-5b407f24e652" theme="dark" size="medium" />
|
||||
</>
|
||||
)
|
||||
}
|
||||
63
docs/public/llms.txt
Normal file
63
docs/public/llms.txt
Normal file
@@ -0,0 +1,63 @@
|
||||
# DocsGPT
|
||||
|
||||
> DocsGPT is an open-source platform for building AI agents and assistants with document retrieval, tools, and multi-model support.
|
||||
|
||||
This file is a curated map of DocsGPT documentation for LLM and agent use.
|
||||
Prioritize Core, Deploying, and Agents for implementation tasks.
|
||||
|
||||
## Core
|
||||
|
||||
- [Docs Home](https://docs.docsgpt.cloud/): Main documentation landing page.
|
||||
- [Quickstart](https://docs.docsgpt.cloud/quickstart): Fastest path to run DocsGPT locally.
|
||||
- [Architecture](https://docs.docsgpt.cloud/Guides/Architecture): High-level system architecture.
|
||||
- [Development Environment](https://docs.docsgpt.cloud/Deploying/Development-Environment): Backend and frontend local setup.
|
||||
- [DocsGPT Settings](https://docs.docsgpt.cloud/Deploying/DocsGPT-Settings): Environment variables and core app configuration.
|
||||
|
||||
## Deploying
|
||||
|
||||
- [Docker Deployment](https://docs.docsgpt.cloud/Deploying/Docker-Deploying): Run DocsGPT with Docker and Docker Compose.
|
||||
- [Kubernetes Deployment](https://docs.docsgpt.cloud/Deploying/Kubernetes-Deploying): Deploy DocsGPT on Kubernetes clusters.
|
||||
- [Hosting DocsGPT](https://docs.docsgpt.cloud/Deploying/Hosting-the-app): Hosting overview with cloud options.
|
||||
|
||||
## Agents
|
||||
|
||||
- [Agent Basics](https://docs.docsgpt.cloud/Agents/basics): Core concepts for building and managing agents.
|
||||
- [Workflow Nodes](https://docs.docsgpt.cloud/Agents/nodes): Node types and behavior in agent workflows.
|
||||
- [Agent API](https://docs.docsgpt.cloud/Agents/api): Programmatic agent interaction (streaming and non-streaming).
|
||||
- [Agent Webhooks](https://docs.docsgpt.cloud/Agents/webhooks): Trigger and automate agents with webhooks.
|
||||
|
||||
## Tools
|
||||
|
||||
- [Tools Basics](https://docs.docsgpt.cloud/Tools/basics): How tools extend agent capabilities.
|
||||
- [Generic API Tool](https://docs.docsgpt.cloud/Tools/api-tool): Configure API calls without custom code.
|
||||
- [Creating a Custom Tool](https://docs.docsgpt.cloud/Tools/creating-a-tool): Build custom Python tools for DocsGPT.
|
||||
|
||||
## Models
|
||||
|
||||
- [Cloud LLM Providers](https://docs.docsgpt.cloud/Models/cloud-providers): Configure hosted model providers.
|
||||
- [Local Inference](https://docs.docsgpt.cloud/Models/local-inference): Connect DocsGPT to local inference backends.
|
||||
- [Embeddings](https://docs.docsgpt.cloud/Models/embeddings): Select and configure embedding models.
|
||||
|
||||
## Extensions
|
||||
|
||||
- [API Keys for Integrations](https://docs.docsgpt.cloud/Extensions/api-key-guide): Generate and use DocsGPT API keys.
|
||||
- [Chat Widget](https://docs.docsgpt.cloud/Extensions/chat-widget): Embed the DocsGPT chat widget.
|
||||
- [Search Widget](https://docs.docsgpt.cloud/Extensions/search-widget): Embed the DocsGPT search widget.
|
||||
- [Chrome Extension](https://docs.docsgpt.cloud/Extensions/Chrome-extension): Install and use the browser extension.
|
||||
- [Chatwoot Extension](https://docs.docsgpt.cloud/Extensions/Chatwoot-extension): Integrate DocsGPT with Chatwoot.
|
||||
|
||||
## Integrations
|
||||
|
||||
- [Google Drive Connector](https://docs.docsgpt.cloud/Guides/Integrations/google-drive-connector): Ingest and sync files from Google Drive.
|
||||
|
||||
## Optional
|
||||
|
||||
- [Customizing Prompts](https://docs.docsgpt.cloud/Guides/Customising-prompts): Template-based prompt customization.
|
||||
- [How to Train on Other Documentation](https://docs.docsgpt.cloud/Guides/How-to-train-on-other-documentation): Add additional documentation sources.
|
||||
- [Context Compression](https://docs.docsgpt.cloud/Guides/compression): Reduce context while preserving key information.
|
||||
- [OCR for Sources and Attachments](https://docs.docsgpt.cloud/Guides/ocr): OCR behavior for ingestion and chat uploads.
|
||||
- [How to Use Different LLMs](https://docs.docsgpt.cloud/Guides/How-to-use-different-LLM): Additional model-selection guidance.
|
||||
- [Avoiding Hallucinations](https://docs.docsgpt.cloud/Guides/My-AI-answers-questions-using-external-knowledge): Improve answer grounding with external knowledge.
|
||||
- [Amazon Lightsail Deployment](https://docs.docsgpt.cloud/Deploying/Amazon-Lightsail): Deploy DocsGPT on AWS Lightsail.
|
||||
- [Railway Deployment](https://docs.docsgpt.cloud/Deploying/Railway): Deploy DocsGPT on Railway.
|
||||
- [Changelog](https://docs.docsgpt.cloud/changelog): Project release history.
|
||||
@@ -1,161 +1,20 @@
|
||||
import Image from 'next/image'
|
||||
import { Analytics } from '@vercel/analytics/react';
|
||||
|
||||
const github = 'https://github.com/arc53/DocsGPT';
|
||||
|
||||
|
||||
|
||||
|
||||
import { useConfig, useTheme } from 'nextra-theme-docs';
|
||||
import CuteLogo from './public/cute-docsgpt.png';
|
||||
const Logo = ({ height, width }) => {
|
||||
const { theme } = useTheme();
|
||||
return (
|
||||
<div style={{ alignItems: 'center', display: 'flex', gap: '8px' }}>
|
||||
<Image src={CuteLogo} alt="DocsGPT logo" width={width} height={height} />
|
||||
|
||||
<span style={{ fontWeight: 'bold', fontSize: 18 }}>DocsGPT Docs</span>
|
||||
|
||||
|
||||
</div>
|
||||
);
|
||||
};
|
||||
const isDevelopment = process.env.NODE_ENV === 'development';
|
||||
|
||||
const config = {
|
||||
docsRepositoryBase: `${github}/blob/main/docs`,
|
||||
chat: {
|
||||
link: 'https://discord.com/invite/n5BX8dh8rU',
|
||||
darkMode: true,
|
||||
search: isDevelopment ? null : undefined,
|
||||
nextThemes: {
|
||||
defaultTheme: 'dark',
|
||||
},
|
||||
banner: {
|
||||
key: 'docs-launch',
|
||||
text: (
|
||||
<div className="flex justify-center items-center gap-2">
|
||||
Welcome to the new DocsGPT 🦖 docs! 👋
|
||||
</div>
|
||||
),
|
||||
sidebar: {
|
||||
defaultMenuCollapseLevel: 1,
|
||||
},
|
||||
toc: {
|
||||
float: true,
|
||||
},
|
||||
project: {
|
||||
link: github,
|
||||
},
|
||||
darkMode: true,
|
||||
nextThemes: {
|
||||
defaultTheme: 'dark',
|
||||
},
|
||||
primaryHue: {
|
||||
dark: 207,
|
||||
light: 212,
|
||||
},
|
||||
footer: {
|
||||
text: (
|
||||
<div>
|
||||
<span>MIT {new Date().getFullYear()} © </span>
|
||||
<a href="https://www.docsgpt.cloud/" target="_blank">
|
||||
DocsGPT
|
||||
</a>
|
||||
{' | '}
|
||||
<a href="https://github.com/arc53/DocsGPT" target="_blank">
|
||||
GitHub
|
||||
</a>
|
||||
{' | '}
|
||||
<a href="https://blog.docsgpt.cloud/" target="_blank">
|
||||
Blog
|
||||
</a>
|
||||
</div>
|
||||
),
|
||||
},
|
||||
editLink: {
|
||||
content: 'Edit this page on GitHub',
|
||||
},
|
||||
logo() {
|
||||
return (
|
||||
<div className="flex items-center gap-2">
|
||||
<Logo width={28} height={28} />
|
||||
</div>
|
||||
);
|
||||
},
|
||||
useNextSeoProps() {
|
||||
return {
|
||||
titleTemplate: `%s - DocsGPT Documentation`,
|
||||
};
|
||||
},
|
||||
|
||||
head() {
|
||||
const { frontMatter } = useConfig();
|
||||
const { theme } = useTheme();
|
||||
const title = frontMatter?.title || 'Chat with your data with DocsGPT';
|
||||
const description =
|
||||
frontMatter?.description ||
|
||||
'Use DocsGPT to chat with your data. DocsGPT is a GPT powered chatbot that can answer questions about your data.'
|
||||
const image = '/cute-docsgpt.png';
|
||||
|
||||
const composedTitle = `${title} – DocsGPT Documentation`;
|
||||
|
||||
return (
|
||||
<>
|
||||
<link
|
||||
rel="apple-touch-icon"
|
||||
sizes="180x180"
|
||||
href={`/favicons/apple-touch-icon.png`}
|
||||
/>
|
||||
<link
|
||||
rel="icon"
|
||||
type="image/png"
|
||||
sizes="32x32"
|
||||
href={`/favicons/favicon-32x32.png`}
|
||||
/>
|
||||
<link
|
||||
rel="icon"
|
||||
type="image/png"
|
||||
sizes="16x16"
|
||||
href={`/favicons/favicon-16x16.png`}
|
||||
/>
|
||||
<meta name="theme-color" content="#ffffff" />
|
||||
<meta name="msapplication-TileColor" content="#00a300" />
|
||||
<link rel="manifest" href={`/favicons/site.webmanifest`} />
|
||||
<meta httpEquiv="Content-Language" content="en" />
|
||||
<meta name="title" content={composedTitle} />
|
||||
<meta name="description" content={description} />
|
||||
|
||||
<meta name="twitter:card" content="summary_large_image" />
|
||||
<meta name="twitter:site" content="@ATushynski" />
|
||||
<meta name="twitter:image" content={image} />
|
||||
|
||||
<meta property="og:description" content={description} />
|
||||
<meta property="og:title" content={composedTitle} />
|
||||
<meta property="og:image" content={image} />
|
||||
<meta property="og:type" content="website" />
|
||||
<meta
|
||||
name="apple-mobile-web-app-title"
|
||||
content="DocsGPT Documentation"
|
||||
/>
|
||||
|
||||
</>
|
||||
);
|
||||
},
|
||||
sidebar: {
|
||||
defaultMenuCollapseLevel: 1,
|
||||
titleComponent: ({ title, type }) =>
|
||||
type === 'separator' ? (
|
||||
<div className="flex items-center gap-2">
|
||||
<Logo height={10} width={10} />
|
||||
{title}
|
||||
<Analytics />
|
||||
</div>
|
||||
|
||||
) : (
|
||||
<>{title}
|
||||
<Analytics />
|
||||
</>
|
||||
|
||||
),
|
||||
},
|
||||
|
||||
gitTimestamp: ({ timestamp }) => (
|
||||
<>Last updated on {timestamp.toLocaleDateString()}</>
|
||||
),
|
||||
editLink: 'Edit this page on GitHub',
|
||||
};
|
||||
|
||||
export default config;
|
||||
export default config;
|
||||
|
||||
@@ -3,4 +3,5 @@ VITE_BASE_URL=http://localhost:5173
|
||||
VITE_API_HOST=http://127.0.0.1:7091
|
||||
VITE_API_STREAMING=true
|
||||
VITE_NOTIFICATION_TEXT="What's new in 0.15.0 — Changelog"
|
||||
VITE_NOTIFICATION_LINK="https://blog.docsgpt.cloud/docsgpt-0-15-masters-long-term-memory-and-tooling/"
|
||||
VITE_NOTIFICATION_LINK="https://blog.docsgpt.cloud/docsgpt-0-15-masters-long-term-memory-and-tooling/"
|
||||
VITE_GOOGLE_CLIENT_ID=896376503572-u46l78n8ctgtdr4dlei4u06jv6rbpqc5.apps.googleusercontent.com
|
||||
|
||||
@@ -439,10 +439,24 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
|
||||
const data = await response.json();
|
||||
const transformed = modelService.transformModels(data.models || []);
|
||||
setAvailableModels(transformed);
|
||||
|
||||
if (mode === 'new' && transformed.length > 0) {
|
||||
const preferredDefaultModelId =
|
||||
transformed.find((model) => model.id === data.default_model_id)?.id ||
|
||||
transformed[0].id;
|
||||
|
||||
if (preferredDefaultModelId) {
|
||||
setSelectedModelIds((prevSelectedModelIds) =>
|
||||
prevSelectedModelIds.size > 0
|
||||
? prevSelectedModelIds
|
||||
: new Set([preferredDefaultModelId]),
|
||||
);
|
||||
}
|
||||
}
|
||||
};
|
||||
getTools();
|
||||
getModels();
|
||||
}, [token]);
|
||||
}, [token, mode]);
|
||||
|
||||
// Validate folder_id from URL against user's folders
|
||||
useEffect(() => {
|
||||
|
||||
@@ -1,4 +1,20 @@
|
||||
export type NodeType = 'start' | 'end' | 'agent' | 'note' | 'state';
|
||||
export type NodeType = 'start' | 'end' | 'agent' | 'note' | 'state' | 'condition';
|
||||
|
||||
export interface ConditionCase {
|
||||
name?: string;
|
||||
expression: string;
|
||||
sourceHandle: string;
|
||||
}
|
||||
|
||||
export interface ConditionNodeConfig {
|
||||
mode: 'simple' | 'advanced';
|
||||
cases: ConditionCase[];
|
||||
}
|
||||
|
||||
export interface StateOperationConfig {
|
||||
expression: string;
|
||||
target_variable: string;
|
||||
}
|
||||
|
||||
export interface WorkflowEdge {
|
||||
id: string;
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -4,6 +4,7 @@ import {
|
||||
Circle,
|
||||
Database,
|
||||
Flag,
|
||||
GitBranch,
|
||||
Loader2,
|
||||
MessageSquare,
|
||||
Play,
|
||||
@@ -53,6 +54,7 @@ const NODE_ICONS: Record<string, React.ReactNode> = {
|
||||
end: <Flag className="h-3 w-3" />,
|
||||
note: <StickyNote className="h-3 w-3" />,
|
||||
state: <Database className="h-3 w-3" />,
|
||||
condition: <GitBranch className="h-3 w-3" />,
|
||||
};
|
||||
|
||||
const NODE_COLORS: Record<string, string> = {
|
||||
@@ -61,6 +63,7 @@ const NODE_COLORS: Record<string, string> = {
|
||||
end: 'text-gray-600 dark:text-gray-400',
|
||||
note: 'text-yellow-600 dark:text-yellow-400',
|
||||
state: 'text-blue-600 dark:text-blue-400',
|
||||
condition: 'text-orange-600 dark:text-orange-400',
|
||||
};
|
||||
|
||||
function ExecutionDetails({
|
||||
@@ -84,7 +87,9 @@ function ExecutionDetails({
|
||||
|
||||
const formatValue = (value: unknown): string => {
|
||||
if (typeof value === 'string') return value;
|
||||
return JSON.stringify(value, null, 2);
|
||||
if (value === undefined) return '';
|
||||
const formatted = JSON.stringify(value, null, 2);
|
||||
return formatted ?? String(value);
|
||||
};
|
||||
|
||||
return (
|
||||
@@ -133,6 +138,11 @@ function ExecutionDetails({
|
||||
if (text.length <= maxLength) return text;
|
||||
return text.slice(0, maxLength) + '...';
|
||||
};
|
||||
const hasOutput =
|
||||
step.output !== undefined &&
|
||||
step.output !== null &&
|
||||
formatValue(step.output) !== '';
|
||||
const formattedOutput = hasOutput ? formatValue(step.output) : '';
|
||||
|
||||
return (
|
||||
<div
|
||||
@@ -168,15 +178,15 @@ function ExecutionDetails({
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
{(step.output || step.error || stateVars.length > 0) && (
|
||||
{(hasOutput || step.error || stateVars.length > 0) && (
|
||||
<div className="mt-3 space-y-2 text-sm">
|
||||
{step.output && (
|
||||
{hasOutput && (
|
||||
<div className="rounded-lg bg-white p-2 dark:bg-[#2A2A2A]">
|
||||
<span className="font-medium text-gray-600 dark:text-gray-400">
|
||||
Output:{' '}
|
||||
</span>
|
||||
<span className="wrap-break-word whitespace-pre-wrap text-gray-900 dark:text-gray-100">
|
||||
{truncateText(step.output, 300)}
|
||||
{truncateText(formattedOutput, 300)}
|
||||
</span>
|
||||
</div>
|
||||
)}
|
||||
@@ -251,7 +261,7 @@ function WorkflowMiniMap({
|
||||
return step?.status || 'pending';
|
||||
};
|
||||
|
||||
const getStatusColor = (nodeId: string, nodeType: string) => {
|
||||
const getStatusColor = (nodeId: string) => {
|
||||
const status = getNodeStatus(nodeId);
|
||||
const isActive = nodeId === activeNodeId;
|
||||
|
||||
@@ -267,26 +277,33 @@ function WorkflowMiniMap({
|
||||
case 'failed':
|
||||
return 'bg-red-100 dark:bg-red-900/30 border-red-300 dark:border-red-700';
|
||||
default:
|
||||
if (nodeType === 'start') {
|
||||
return 'bg-green-50 dark:bg-green-900/20 border-green-200 dark:border-green-800';
|
||||
}
|
||||
if (nodeType === 'agent') {
|
||||
return 'bg-purple-50 dark:bg-purple-900/20 border-purple-200 dark:border-purple-800';
|
||||
}
|
||||
if (nodeType === 'end') {
|
||||
return 'bg-gray-50 dark:bg-gray-800 border-gray-200 dark:border-gray-700';
|
||||
}
|
||||
return 'bg-gray-50 dark:bg-gray-800 border-gray-200 dark:border-gray-700';
|
||||
}
|
||||
};
|
||||
|
||||
const sortedNodes = [...nodes].sort((a, b) => {
|
||||
if (a.type === 'start') return -1;
|
||||
if (b.type === 'start') return 1;
|
||||
if (a.type === 'end') return 1;
|
||||
if (b.type === 'end') return -1;
|
||||
return (a.position?.y || 0) - (b.position?.y || 0);
|
||||
});
|
||||
const executedOrder = new Map(executionSteps.map((s, i) => [s.nodeId, i]));
|
||||
const startNode = nodes.find((node) => node.type === 'start');
|
||||
const visibleNodeIds = new Set(executionSteps.map((step) => step.nodeId));
|
||||
if (activeNodeId) {
|
||||
visibleNodeIds.add(activeNodeId);
|
||||
}
|
||||
if (startNode) {
|
||||
visibleNodeIds.add(startNode.id);
|
||||
}
|
||||
|
||||
const sortedNodes = nodes
|
||||
.filter((node) => visibleNodeIds.has(node.id))
|
||||
.sort((a, b) => {
|
||||
if (a.type === 'start') return -1;
|
||||
if (b.type === 'start') return 1;
|
||||
|
||||
const aIdx = executedOrder.get(a.id);
|
||||
const bIdx = executedOrder.get(b.id);
|
||||
if (aIdx !== undefined && bIdx !== undefined) return aIdx - bIdx;
|
||||
if (aIdx !== undefined) return -1;
|
||||
if (bIdx !== undefined) return 1;
|
||||
return (a.position?.y || 0) - (b.position?.y || 0);
|
||||
});
|
||||
|
||||
const hasStepData = (nodeId: string) => {
|
||||
const step = executionSteps.find((s) => s.nodeId === nodeId);
|
||||
@@ -306,7 +323,7 @@ function WorkflowMiniMap({
|
||||
disabled={!hasStepData(node.id)}
|
||||
className={cn(
|
||||
'flex h-12 w-full items-center gap-2 rounded-lg border px-3 text-xs transition-all',
|
||||
getStatusColor(node.id, node.type),
|
||||
getStatusColor(node.id),
|
||||
hasStepData(node.id) && 'cursor-pointer hover:opacity-80',
|
||||
)}
|
||||
>
|
||||
@@ -533,6 +550,10 @@ export default function WorkflowPreview({
|
||||
const querySteps = query.executionSteps || [];
|
||||
const hasResponse = !!(query.response || query.error);
|
||||
const isLastQuery = index === queries.length - 1;
|
||||
const isStreamingLastQuery =
|
||||
status === 'loading' && isLastQuery;
|
||||
const shouldShowThought =
|
||||
!isStreamingLastQuery && Boolean(query.thought);
|
||||
const isOpen =
|
||||
openDetailsIndex === index ||
|
||||
(!hasResponse && isLastQuery && querySteps.length > 0);
|
||||
@@ -567,17 +588,19 @@ export default function WorkflowPreview({
|
||||
|
||||
{/* Response bubble */}
|
||||
{(query.response ||
|
||||
query.thought ||
|
||||
shouldShowThought ||
|
||||
query.tool_calls) && (
|
||||
<ConversationBubble
|
||||
className={isLastQuery ? 'mb-32' : 'mb-7'}
|
||||
message={query.response}
|
||||
type="ANSWER"
|
||||
thought={query.thought}
|
||||
thought={
|
||||
shouldShowThought ? query.thought : undefined
|
||||
}
|
||||
sources={query.sources}
|
||||
toolCalls={query.tool_calls}
|
||||
feedback={query.feedback}
|
||||
isStreaming={status === 'loading' && isLastQuery}
|
||||
isStreaming={isStreamingLastQuery}
|
||||
/>
|
||||
)}
|
||||
|
||||
|
||||
@@ -9,10 +9,71 @@ import {
|
||||
} from '@/components/ui/popover';
|
||||
|
||||
interface WorkflowVariable {
|
||||
name: string;
|
||||
label: string;
|
||||
templatePath: string;
|
||||
section: string;
|
||||
}
|
||||
|
||||
const GLOBAL_CONTEXT_VARIABLES: WorkflowVariable[] = [
|
||||
{
|
||||
label: 'source.content',
|
||||
templatePath: 'source.content',
|
||||
section: 'Global context',
|
||||
},
|
||||
{
|
||||
label: 'source.summaries',
|
||||
templatePath: 'source.summaries',
|
||||
section: 'Global context',
|
||||
},
|
||||
{
|
||||
label: 'source.documents',
|
||||
templatePath: 'source.documents',
|
||||
section: 'Global context',
|
||||
},
|
||||
{
|
||||
label: 'source.count',
|
||||
templatePath: 'source.count',
|
||||
section: 'Global context',
|
||||
},
|
||||
{
|
||||
label: 'system.date',
|
||||
templatePath: 'system.date',
|
||||
section: 'Global context',
|
||||
},
|
||||
{
|
||||
label: 'system.time',
|
||||
templatePath: 'system.time',
|
||||
section: 'Global context',
|
||||
},
|
||||
{
|
||||
label: 'system.timestamp',
|
||||
templatePath: 'system.timestamp',
|
||||
section: 'Global context',
|
||||
},
|
||||
{
|
||||
label: 'system.request_id',
|
||||
templatePath: 'system.request_id',
|
||||
section: 'Global context',
|
||||
},
|
||||
{
|
||||
label: 'system.user_id',
|
||||
templatePath: 'system.user_id',
|
||||
section: 'Global context',
|
||||
},
|
||||
];
|
||||
|
||||
function toAgentTemplatePath(variableName: string): string {
|
||||
const trimmed = variableName.trim();
|
||||
if (!trimmed) return 'agent';
|
||||
|
||||
if (/^[A-Za-z_][A-Za-z0-9_]*$/.test(trimmed)) {
|
||||
return `agent.${trimmed}`;
|
||||
}
|
||||
|
||||
const escaped = trimmed.replace(/\\/g, '\\\\').replace(/'/g, "\\'");
|
||||
return `agent['${escaped}']`;
|
||||
}
|
||||
|
||||
function getUpstreamNodeIds(nodeId: string, edges: Edge[]): Set<string> {
|
||||
const upstream = new Set<string>();
|
||||
const queue = [nodeId];
|
||||
@@ -36,32 +97,69 @@ function extractUpstreamVariables(
|
||||
selectedNodeId: string,
|
||||
): WorkflowVariable[] {
|
||||
const variables: WorkflowVariable[] = [
|
||||
{ name: 'query', section: 'Workflow input' },
|
||||
{ name: 'chat_history', section: 'Workflow input' },
|
||||
{
|
||||
label: 'agent.query',
|
||||
templatePath: 'agent.query',
|
||||
section: 'Workflow input',
|
||||
},
|
||||
{
|
||||
label: 'agent.chat_history',
|
||||
templatePath: 'agent.chat_history',
|
||||
section: 'Workflow input',
|
||||
},
|
||||
...GLOBAL_CONTEXT_VARIABLES,
|
||||
];
|
||||
const seen = new Set(['query', 'chat_history']);
|
||||
const seen = new Set(variables.map((variable) => variable.templatePath));
|
||||
const upstreamIds = getUpstreamNodeIds(selectedNodeId, edges);
|
||||
|
||||
for (const node of nodes) {
|
||||
if (!upstreamIds.has(node.id)) continue;
|
||||
|
||||
if (node.type === 'agent' && node.data?.config?.output_variable) {
|
||||
const name = node.data.config.output_variable;
|
||||
if (!seen.has(name)) {
|
||||
seen.add(name);
|
||||
if (node.type === 'agent') {
|
||||
const defaultOutputTemplatePath = toAgentTemplatePath(
|
||||
`node_${node.id}_output`,
|
||||
);
|
||||
if (!seen.has(defaultOutputTemplatePath)) {
|
||||
seen.add(defaultOutputTemplatePath);
|
||||
variables.push({
|
||||
name,
|
||||
label: defaultOutputTemplatePath,
|
||||
templatePath: defaultOutputTemplatePath,
|
||||
section: node.data.title || node.data.label || 'Agent',
|
||||
});
|
||||
}
|
||||
|
||||
const outputVariable = String(
|
||||
node.data?.config?.output_variable || '',
|
||||
).trim();
|
||||
if (outputVariable) {
|
||||
const templatePath = toAgentTemplatePath(outputVariable);
|
||||
if (!seen.has(templatePath)) {
|
||||
seen.add(templatePath);
|
||||
variables.push({
|
||||
label: templatePath,
|
||||
templatePath,
|
||||
section: node.data.title || node.data.label || 'Agent',
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
if (node.type === 'state' && node.data?.variable) {
|
||||
const name = node.data.variable;
|
||||
if (!seen.has(name)) {
|
||||
seen.add(name);
|
||||
|
||||
if (node.type === 'state') {
|
||||
const operations = node.data?.config?.operations;
|
||||
if (!Array.isArray(operations)) continue;
|
||||
|
||||
for (const operation of operations) {
|
||||
const targetVariable = String(operation?.target_variable || '').trim();
|
||||
if (!targetVariable) continue;
|
||||
|
||||
const templatePath = toAgentTemplatePath(targetVariable);
|
||||
if (seen.has(templatePath)) continue;
|
||||
|
||||
seen.add(templatePath);
|
||||
variables.push({
|
||||
name,
|
||||
section: 'Set State',
|
||||
label: templatePath,
|
||||
templatePath,
|
||||
section: node.data.title || node.data.label || 'Set State',
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -106,14 +204,16 @@ function VariableListWithSearch({
|
||||
onSelect,
|
||||
}: {
|
||||
variables: WorkflowVariable[];
|
||||
onSelect: (name: string) => void;
|
||||
onSelect: (templatePath: string) => void;
|
||||
}) {
|
||||
const [search, setSearch] = useState('');
|
||||
|
||||
const filtered = useMemo(
|
||||
() =>
|
||||
variables.filter((v) =>
|
||||
v.name.toLowerCase().includes(search.toLowerCase()),
|
||||
`${v.label} ${v.templatePath}`
|
||||
.toLowerCase()
|
||||
.includes(search.toLowerCase()),
|
||||
),
|
||||
[variables, search],
|
||||
);
|
||||
@@ -146,17 +246,17 @@ function VariableListWithSearch({
|
||||
</div>
|
||||
{vars.map((v) => (
|
||||
<button
|
||||
key={v.name}
|
||||
key={`${section}-${v.templatePath}`}
|
||||
onMouseDown={(e) => {
|
||||
e.preventDefault();
|
||||
e.stopPropagation();
|
||||
onSelect(v.name);
|
||||
onSelect(v.templatePath);
|
||||
}}
|
||||
className="flex w-full cursor-pointer items-center gap-2 px-3 py-1.5 text-left text-sm transition-colors hover:bg-gray-50 dark:hover:bg-[#383838]"
|
||||
>
|
||||
<Braces className="text-violets-are-blue h-3.5 w-3.5 shrink-0" />
|
||||
<span className="truncate font-medium text-gray-800 dark:text-gray-200">
|
||||
{v.name}
|
||||
{v.label}
|
||||
</span>
|
||||
</button>
|
||||
))}
|
||||
@@ -206,7 +306,9 @@ export default function PromptTextArea({
|
||||
const filtered = useMemo(
|
||||
() =>
|
||||
variables.filter((v) =>
|
||||
v.name.toLowerCase().includes(filterText.toLowerCase()),
|
||||
`${v.label} ${v.templatePath}`
|
||||
.toLowerCase()
|
||||
.includes(filterText.toLowerCase()),
|
||||
),
|
||||
[variables, filterText],
|
||||
);
|
||||
@@ -217,10 +319,12 @@ export default function PromptTextArea({
|
||||
|
||||
const cursorPos = textarea.selectionStart;
|
||||
const textBeforeCursor = value.slice(0, cursorPos);
|
||||
const triggerMatch = textBeforeCursor.match(/\{\{(\w*)$/);
|
||||
const triggerMatch = textBeforeCursor.match(
|
||||
/\{\{\s*([A-Za-z0-9_.[\]'"]*)$/,
|
||||
);
|
||||
|
||||
if (triggerMatch) {
|
||||
setFilterText(triggerMatch[1]);
|
||||
setFilterText(triggerMatch[1].trim());
|
||||
setCursorInsertPos(cursorPos);
|
||||
|
||||
const wrapper = wrapperRef.current;
|
||||
@@ -237,15 +341,17 @@ export default function PromptTextArea({
|
||||
}, [value]);
|
||||
|
||||
const insertVariable = useCallback(
|
||||
(varName: string) => {
|
||||
(templatePath: string) => {
|
||||
if (cursorInsertPos === null) return;
|
||||
|
||||
const textBeforeCursor = value.slice(0, cursorInsertPos);
|
||||
const triggerMatch = textBeforeCursor.match(/\{\{(\w*)$/);
|
||||
const triggerMatch = textBeforeCursor.match(
|
||||
/\{\{\s*([A-Za-z0-9_.[\]'"]*)$/,
|
||||
);
|
||||
if (!triggerMatch) return;
|
||||
|
||||
const startPos = cursorInsertPos - triggerMatch[0].length;
|
||||
const insertion = `{{${varName}}}`;
|
||||
const insertion = `{{ ${templatePath} }}`;
|
||||
const newValue =
|
||||
value.slice(0, startPos) + insertion + value.slice(cursorInsertPos);
|
||||
|
||||
@@ -262,10 +368,10 @@ export default function PromptTextArea({
|
||||
);
|
||||
|
||||
const insertVariableFromButton = useCallback(
|
||||
(varName: string) => {
|
||||
(templatePath: string) => {
|
||||
const textarea = textareaRef.current;
|
||||
const cursorPos = textarea?.selectionStart ?? value.length;
|
||||
const insertion = `{{${varName}}}`;
|
||||
const insertion = `{{ ${templatePath} }}`;
|
||||
const newValue =
|
||||
value.slice(0, cursorPos) + insertion + value.slice(cursorPos);
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ interface BaseNodeProps {
|
||||
title: string;
|
||||
children?: ReactNode;
|
||||
selected?: boolean;
|
||||
type?: 'start' | 'end' | 'default' | 'state' | 'agent';
|
||||
type?: 'start' | 'end' | 'default' | 'state' | 'agent' | 'condition';
|
||||
icon?: ReactNode;
|
||||
handles?: {
|
||||
source?: boolean;
|
||||
@@ -40,6 +40,9 @@ export const BaseNode: React.FC<BaseNodeProps> = ({
|
||||
} else if (type === 'state') {
|
||||
iconBg = 'bg-gray-100 dark:bg-gray-800';
|
||||
iconColor = 'text-gray-600 dark:text-gray-400';
|
||||
} else if (type === 'condition') {
|
||||
iconBg = 'bg-orange-100 dark:bg-orange-900/30';
|
||||
iconColor = 'text-orange-600 dark:text-orange-400';
|
||||
}
|
||||
|
||||
return (
|
||||
|
||||
118
frontend/src/agents/workflow/nodes/ConditionNode.tsx
Normal file
118
frontend/src/agents/workflow/nodes/ConditionNode.tsx
Normal file
@@ -0,0 +1,118 @@
|
||||
import { GitBranch } from 'lucide-react';
|
||||
import { memo } from 'react';
|
||||
import { Handle, NodeProps, Position } from 'reactflow';
|
||||
|
||||
import { ConditionCase } from '../../types/workflow';
|
||||
|
||||
type ConditionNodeData = {
|
||||
label?: string;
|
||||
title?: string;
|
||||
config?: {
|
||||
mode?: 'simple' | 'advanced';
|
||||
cases?: ConditionCase[];
|
||||
};
|
||||
};
|
||||
|
||||
const ROW_HEIGHT = 18;
|
||||
const HEADER_HEIGHT = 52;
|
||||
const PADDING_BOTTOM = 8;
|
||||
|
||||
function getNodeHeight(caseCount: number): number {
|
||||
return (
|
||||
HEADER_HEIGHT + Math.max(caseCount + 1, 2) * ROW_HEIGHT + PADDING_BOTTOM
|
||||
);
|
||||
}
|
||||
|
||||
function getHandleTop(index: number, total: number): string {
|
||||
const offset = HEADER_HEIGHT;
|
||||
return `${offset + ROW_HEIGHT * index + ROW_HEIGHT / 2}px`;
|
||||
}
|
||||
|
||||
const ConditionNode = ({ data, selected }: NodeProps<ConditionNodeData>) => {
|
||||
const title = data.title || data.label || 'If / Else';
|
||||
const cases = data.config?.cases || [];
|
||||
const totalOutputs = cases.length + 1;
|
||||
const height = getNodeHeight(cases.length);
|
||||
|
||||
return (
|
||||
<div
|
||||
className={`relative rounded-2xl border bg-white shadow-md transition-all dark:bg-[#2C2C2C] ${
|
||||
selected
|
||||
? 'border-violets-are-blue dark:ring-violets-are-blue scale-105 ring-2 ring-purple-300'
|
||||
: 'border-gray-200 hover:shadow-lg dark:border-[#3A3A3A]'
|
||||
}`}
|
||||
style={{ minWidth: 180, maxWidth: 220, height }}
|
||||
>
|
||||
<Handle
|
||||
type="target"
|
||||
position={Position.Left}
|
||||
isConnectable
|
||||
className="hover:bg-violets-are-blue! top-1/2! -left-1! h-3! w-3! rounded-full! border-2! border-white! bg-gray-400! transition-colors dark:border-[#2C2C2C]!"
|
||||
/>
|
||||
|
||||
<div className="flex items-center gap-3 px-3 py-2">
|
||||
<div className="flex h-9 w-9 shrink-0 items-center justify-center rounded-full bg-orange-100 text-orange-600 dark:bg-orange-900/30 dark:text-orange-400">
|
||||
<GitBranch size={14} />
|
||||
</div>
|
||||
<div className="min-w-0 flex-1 pr-2">
|
||||
<div
|
||||
className="truncate text-sm font-semibold text-gray-900 dark:text-white"
|
||||
title={title}
|
||||
>
|
||||
{title}
|
||||
</div>
|
||||
<div className="text-[10px] text-gray-500 uppercase">
|
||||
{data.config?.mode || 'simple'}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="flex flex-col px-3">
|
||||
{cases.map((c, i) => (
|
||||
<div
|
||||
key={c.sourceHandle}
|
||||
className="flex items-center gap-1"
|
||||
style={{ height: ROW_HEIGHT }}
|
||||
>
|
||||
<span className="shrink-0 text-xs font-medium text-orange-600 dark:text-orange-400">
|
||||
{i === 0 ? 'If' : 'Else if'}
|
||||
</span>
|
||||
{c.name && (
|
||||
<span
|
||||
className="truncate text-xs text-gray-600 dark:text-gray-400"
|
||||
title={c.name}
|
||||
>
|
||||
{c.name}
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
))}
|
||||
<div className="flex items-center gap-1" style={{ height: ROW_HEIGHT }}>
|
||||
<span className="text-xs font-medium text-gray-500">Else</span>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{cases.map((c, i) => (
|
||||
<Handle
|
||||
key={c.sourceHandle}
|
||||
type="source"
|
||||
position={Position.Right}
|
||||
id={c.sourceHandle}
|
||||
isConnectable
|
||||
style={{ top: getHandleTop(i, totalOutputs) }}
|
||||
className="hover:bg-violets-are-blue! -right-1! h-3! w-3! rounded-full! border-2! border-white! bg-orange-400! transition-colors dark:border-[#2C2C2C]!"
|
||||
/>
|
||||
))}
|
||||
<Handle
|
||||
type="source"
|
||||
position={Position.Right}
|
||||
id="else"
|
||||
isConnectable
|
||||
style={{ top: getHandleTop(cases.length, totalOutputs) }}
|
||||
className="hover:bg-violets-are-blue! -right-1! h-3! w-3! rounded-full! border-2! border-white! bg-gray-400! transition-colors dark:border-[#2C2C2C]!"
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(ConditionNode);
|
||||
@@ -2,6 +2,7 @@ import { Database } from 'lucide-react';
|
||||
import { memo } from 'react';
|
||||
import { NodeProps } from 'reactflow';
|
||||
|
||||
import { StateOperationConfig } from '../../types/workflow';
|
||||
import { BaseNode } from './BaseNode';
|
||||
|
||||
type SetStateNodeData = {
|
||||
@@ -9,10 +10,16 @@ type SetStateNodeData = {
|
||||
title?: string;
|
||||
variable?: string;
|
||||
value?: string;
|
||||
config?: {
|
||||
operations?: StateOperationConfig[];
|
||||
};
|
||||
};
|
||||
|
||||
const SetStateNode = ({ data, selected }: NodeProps<SetStateNodeData>) => {
|
||||
const title = data.title || data.label || 'Set State';
|
||||
const operations = data.config?.operations || [];
|
||||
const hasLegacy = !operations.length && data.variable;
|
||||
|
||||
return (
|
||||
<BaseNode
|
||||
title={title}
|
||||
@@ -22,22 +29,31 @@ const SetStateNode = ({ data, selected }: NodeProps<SetStateNodeData>) => {
|
||||
handles={{ source: true, target: true }}
|
||||
>
|
||||
<div className="flex flex-col gap-1">
|
||||
{data.variable && (
|
||||
{operations.length > 0 ? (
|
||||
<div
|
||||
className="truncate text-[10px] text-gray-500 uppercase"
|
||||
title={`Variable: ${data.variable}`}
|
||||
className="truncate text-[10px] text-gray-500"
|
||||
title={`${operations.length} operation(s)`}
|
||||
>
|
||||
{data.variable}
|
||||
{operations.length} variable{operations.length !== 1 ? 's' : ''}
|
||||
</div>
|
||||
)}
|
||||
{data.value && (
|
||||
<div
|
||||
className="truncate text-xs text-blue-600 dark:text-blue-400"
|
||||
title={`Value: ${data.value}`}
|
||||
>
|
||||
{data.value}
|
||||
</div>
|
||||
)}
|
||||
) : hasLegacy ? (
|
||||
<>
|
||||
<div
|
||||
className="truncate text-[10px] text-gray-500 uppercase"
|
||||
title={`Variable: ${data.variable}`}
|
||||
>
|
||||
{data.variable}
|
||||
</div>
|
||||
{data.value && (
|
||||
<div
|
||||
className="truncate text-xs text-blue-600 dark:text-blue-400"
|
||||
title={`Value: ${data.value}`}
|
||||
>
|
||||
{data.value}
|
||||
</div>
|
||||
)}
|
||||
</>
|
||||
) : null}
|
||||
</div>
|
||||
</BaseNode>
|
||||
);
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
import React, { memo } from 'react';
|
||||
import { Bot, Flag, Play, StickyNote } from 'lucide-react';
|
||||
import { memo } from 'react';
|
||||
|
||||
import { BaseNode } from './BaseNode';
|
||||
import ConditionNode from './ConditionNode';
|
||||
import SetStateNode from './SetStateNode';
|
||||
import { Play, Bot, StickyNote, Flag } from 'lucide-react';
|
||||
|
||||
export const StartNode = memo(function StartNode({
|
||||
selected,
|
||||
@@ -83,10 +85,10 @@ export const AgentNode = memo(function AgentNode({
|
||||
)}
|
||||
{config.output_variable && (
|
||||
<div
|
||||
className="truncate text-xs text-green-600 dark:text-green-400"
|
||||
title={`Output ➔ ${config.output_variable}`}
|
||||
className="truncate text-xs text-gray-500 dark:text-gray-400"
|
||||
title={`Output: ${config.output_variable}`}
|
||||
>
|
||||
Output ➔ {config.output_variable}
|
||||
Output: {config.output_variable}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
@@ -142,3 +144,4 @@ export const NoteNode = memo(function NoteNode({
|
||||
});
|
||||
|
||||
export { SetStateNode };
|
||||
export { ConditionNode };
|
||||
|
||||
@@ -13,7 +13,7 @@ export interface WorkflowExecutionStep {
|
||||
startedAt?: number;
|
||||
completedAt?: number;
|
||||
stateSnapshot?: Record<string, unknown>;
|
||||
output?: string;
|
||||
output?: unknown;
|
||||
error?: string;
|
||||
}
|
||||
|
||||
@@ -321,7 +321,9 @@ export const workflowPreviewSlice = createSlice({
|
||||
}
|
||||
|
||||
const querySteps = state.queries[index].executionSteps!;
|
||||
const existingIndex = querySteps.findIndex((s) => s.nodeId === step.nodeId);
|
||||
const existingIndex = querySteps.findIndex(
|
||||
(s) => s.nodeId === step.nodeId,
|
||||
);
|
||||
|
||||
const updatedStep: WorkflowExecutionStep = {
|
||||
nodeId: step.nodeId,
|
||||
@@ -332,7 +334,10 @@ export const workflowPreviewSlice = createSlice({
|
||||
stateSnapshot: step.stateSnapshot,
|
||||
output: step.output,
|
||||
error: step.error,
|
||||
startedAt: existingIndex !== -1 ? querySteps[existingIndex].startedAt : Date.now(),
|
||||
startedAt:
|
||||
existingIndex !== -1
|
||||
? querySteps[existingIndex].startedAt
|
||||
: Date.now(),
|
||||
completedAt:
|
||||
step.status === 'completed' || step.status === 'failed'
|
||||
? Date.now()
|
||||
@@ -342,7 +347,8 @@ export const workflowPreviewSlice = createSlice({
|
||||
};
|
||||
|
||||
if (existingIndex !== -1) {
|
||||
updatedStep.stateSnapshot = step.stateSnapshot ?? querySteps[existingIndex].stateSnapshot;
|
||||
updatedStep.stateSnapshot =
|
||||
step.stateSnapshot ?? querySteps[existingIndex].stateSnapshot;
|
||||
updatedStep.output = step.output ?? querySteps[existingIndex].output;
|
||||
updatedStep.error = step.error ?? querySteps[existingIndex].error;
|
||||
querySteps[existingIndex] = updatedStep;
|
||||
@@ -350,7 +356,9 @@ export const workflowPreviewSlice = createSlice({
|
||||
querySteps.push(updatedStep);
|
||||
}
|
||||
|
||||
const globalIndex = state.executionSteps.findIndex((s) => s.nodeId === step.nodeId);
|
||||
const globalIndex = state.executionSteps.findIndex(
|
||||
(s) => s.nodeId === step.nodeId,
|
||||
);
|
||||
if (globalIndex !== -1) {
|
||||
state.executionSteps[globalIndex] = updatedStep;
|
||||
} else {
|
||||
|
||||
@@ -43,7 +43,7 @@ export default function WrapperModal({
|
||||
|
||||
const modalContent = (
|
||||
<div
|
||||
className="fixed top-0 left-0 z-30 flex h-screen w-screen items-center justify-center"
|
||||
className="fixed top-0 left-0 z-[100] flex h-screen w-screen items-center justify-center"
|
||||
onClick={(e: React.MouseEvent) => e.stopPropagation()}
|
||||
onMouseDown={(e: React.MouseEvent) => e.stopPropagation()}
|
||||
>
|
||||
|
||||
343
setup.ps1
343
setup.ps1
@@ -286,6 +286,301 @@ function Prompt-OllamaOptions {
|
||||
$script:ollama_choice = Read-Host "Choose option (1-2, or b)"
|
||||
}
|
||||
|
||||
# ========================
|
||||
# Advanced Settings Functions
|
||||
# ========================
|
||||
|
||||
# Vector Store configuration
|
||||
function Configure-VectorStore {
|
||||
Write-Host ""
|
||||
Write-ColorText "Vector Store Configuration" -ForegroundColor "White" -Bold
|
||||
Write-ColorText "Choose your vector store:" -ForegroundColor "White"
|
||||
Write-ColorText "1) FAISS (default, local)" -ForegroundColor "Yellow"
|
||||
Write-ColorText "2) Elasticsearch" -ForegroundColor "Yellow"
|
||||
Write-ColorText "3) Qdrant" -ForegroundColor "Yellow"
|
||||
Write-ColorText "4) Milvus" -ForegroundColor "Yellow"
|
||||
Write-ColorText "5) LanceDB" -ForegroundColor "Yellow"
|
||||
Write-ColorText "6) PGVector" -ForegroundColor "Yellow"
|
||||
Write-ColorText "b) Back" -ForegroundColor "Yellow"
|
||||
Write-Host ""
|
||||
$vs_choice = Read-Host "Choose option (1-6, or b)"
|
||||
|
||||
switch ($vs_choice) {
|
||||
"1" {
|
||||
"VECTOR_STORE=faiss" | Add-Content -Path $ENV_FILE -Encoding utf8
|
||||
Write-ColorText "Vector store set to FAISS." -ForegroundColor "Green"
|
||||
}
|
||||
"2" {
|
||||
"VECTOR_STORE=elasticsearch" | Add-Content -Path $ENV_FILE -Encoding utf8
|
||||
$elastic_url = Read-Host "Enter Elasticsearch URL (e.g. http://localhost:9200)"
|
||||
if ($elastic_url) { "ELASTIC_URL=$elastic_url" | Add-Content -Path $ENV_FILE -Encoding utf8 }
|
||||
$elastic_cloud_id = Read-Host "Enter Elasticsearch Cloud ID (leave empty if using URL)"
|
||||
if ($elastic_cloud_id) { "ELASTIC_CLOUD_ID=$elastic_cloud_id" | Add-Content -Path $ENV_FILE -Encoding utf8 }
|
||||
$elastic_user = Read-Host "Enter Elasticsearch username (leave empty if none)"
|
||||
if ($elastic_user) { "ELASTIC_USERNAME=$elastic_user" | Add-Content -Path $ENV_FILE -Encoding utf8 }
|
||||
$elastic_pass = Read-Host "Enter Elasticsearch password (leave empty if none)"
|
||||
if ($elastic_pass) { "ELASTIC_PASSWORD=$elastic_pass" | Add-Content -Path $ENV_FILE -Encoding utf8 }
|
||||
$elastic_index = Read-Host "Enter Elasticsearch index name (default: docsgpt)"
|
||||
if ([string]::IsNullOrEmpty($elastic_index)) { $elastic_index = "docsgpt" }
|
||||
"ELASTIC_INDEX=$elastic_index" | Add-Content -Path $ENV_FILE -Encoding utf8
|
||||
Write-ColorText "Vector store set to Elasticsearch." -ForegroundColor "Green"
|
||||
}
|
||||
"3" {
|
||||
"VECTOR_STORE=qdrant" | Add-Content -Path $ENV_FILE -Encoding utf8
|
||||
$qdrant_url = Read-Host "Enter Qdrant URL (e.g. http://localhost:6333)"
|
||||
if ($qdrant_url) { "QDRANT_URL=$qdrant_url" | Add-Content -Path $ENV_FILE -Encoding utf8 }
|
||||
$qdrant_key = Read-Host "Enter Qdrant API key (leave empty if none)"
|
||||
if ($qdrant_key) { "QDRANT_API_KEY=$qdrant_key" | Add-Content -Path $ENV_FILE -Encoding utf8 }
|
||||
$qdrant_collection = Read-Host "Enter Qdrant collection name (default: docsgpt)"
|
||||
if ([string]::IsNullOrEmpty($qdrant_collection)) { $qdrant_collection = "docsgpt" }
|
||||
"QDRANT_COLLECTION_NAME=$qdrant_collection" | Add-Content -Path $ENV_FILE -Encoding utf8
|
||||
Write-ColorText "Vector store set to Qdrant." -ForegroundColor "Green"
|
||||
}
|
||||
"4" {
|
||||
"VECTOR_STORE=milvus" | Add-Content -Path $ENV_FILE -Encoding utf8
|
||||
$milvus_uri = Read-Host "Enter Milvus URI (default: ./milvus_local.db)"
|
||||
if ([string]::IsNullOrEmpty($milvus_uri)) { $milvus_uri = "./milvus_local.db" }
|
||||
"MILVUS_URI=$milvus_uri" | Add-Content -Path $ENV_FILE -Encoding utf8
|
||||
$milvus_token = Read-Host "Enter Milvus token (leave empty if none)"
|
||||
if ($milvus_token) { "MILVUS_TOKEN=$milvus_token" | Add-Content -Path $ENV_FILE -Encoding utf8 }
|
||||
$milvus_collection = Read-Host "Enter Milvus collection name (default: docsgpt)"
|
||||
if ([string]::IsNullOrEmpty($milvus_collection)) { $milvus_collection = "docsgpt" }
|
||||
"MILVUS_COLLECTION_NAME=$milvus_collection" | Add-Content -Path $ENV_FILE -Encoding utf8
|
||||
Write-ColorText "Vector store set to Milvus." -ForegroundColor "Green"
|
||||
}
|
||||
"5" {
|
||||
"VECTOR_STORE=lancedb" | Add-Content -Path $ENV_FILE -Encoding utf8
|
||||
$lancedb_path = Read-Host "Enter LanceDB path (default: ./data/lancedb)"
|
||||
if ([string]::IsNullOrEmpty($lancedb_path)) { $lancedb_path = "./data/lancedb" }
|
||||
"LANCEDB_PATH=$lancedb_path" | Add-Content -Path $ENV_FILE -Encoding utf8
|
||||
$lancedb_table = Read-Host "Enter LanceDB table name (default: docsgpts)"
|
||||
if ([string]::IsNullOrEmpty($lancedb_table)) { $lancedb_table = "docsgpts" }
|
||||
"LANCEDB_TABLE_NAME=$lancedb_table" | Add-Content -Path $ENV_FILE -Encoding utf8
|
||||
Write-ColorText "Vector store set to LanceDB." -ForegroundColor "Green"
|
||||
}
|
||||
"6" {
|
||||
"VECTOR_STORE=pgvector" | Add-Content -Path $ENV_FILE -Encoding utf8
|
||||
$pgvector_conn = Read-Host "Enter PGVector connection string (e.g. postgresql://user:pass@host:5432/db)"
|
||||
if ($pgvector_conn) { "PGVECTOR_CONNECTION_STRING=$pgvector_conn" | Add-Content -Path $ENV_FILE -Encoding utf8 }
|
||||
Write-ColorText "Vector store set to PGVector." -ForegroundColor "Green"
|
||||
}
|
||||
{$_ -eq "b" -or $_ -eq "B"} { return }
|
||||
default {
|
||||
Write-Host ""
|
||||
Write-ColorText "Invalid choice." -ForegroundColor "Red"
|
||||
Start-Sleep -Seconds 1
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# Embeddings configuration
|
||||
function Configure-Embeddings {
|
||||
Write-Host ""
|
||||
Write-ColorText "Embeddings Configuration" -ForegroundColor "White" -Bold
|
||||
Write-ColorText "Choose your embeddings provider:" -ForegroundColor "White"
|
||||
Write-ColorText "1) HuggingFace (default, local)" -ForegroundColor "Yellow"
|
||||
Write-ColorText "2) OpenAI Embeddings" -ForegroundColor "Yellow"
|
||||
Write-ColorText "3) Custom Remote Embeddings (OpenAI-compatible API)" -ForegroundColor "Yellow"
|
||||
Write-ColorText "b) Back" -ForegroundColor "Yellow"
|
||||
Write-Host ""
|
||||
$emb_choice = Read-Host "Choose option (1-3, or b)"
|
||||
|
||||
switch ($emb_choice) {
|
||||
"1" {
|
||||
"EMBEDDINGS_NAME=huggingface_sentence-transformers/all-mpnet-base-v2" | Add-Content -Path $ENV_FILE -Encoding utf8
|
||||
Write-ColorText "Embeddings set to HuggingFace (local)." -ForegroundColor "Green"
|
||||
}
|
||||
"2" {
|
||||
"EMBEDDINGS_NAME=openai_text-embedding-ada-002" | Add-Content -Path $ENV_FILE -Encoding utf8
|
||||
$emb_key = Read-Host "Enter Embeddings API key (leave empty to reuse LLM API_KEY)"
|
||||
if ($emb_key) { "EMBEDDINGS_KEY=$emb_key" | Add-Content -Path $ENV_FILE -Encoding utf8 }
|
||||
Write-ColorText "Embeddings set to OpenAI." -ForegroundColor "Green"
|
||||
}
|
||||
"3" {
|
||||
$emb_name = Read-Host "Enter embeddings model name"
|
||||
if ($emb_name) { "EMBEDDINGS_NAME=$emb_name" | Add-Content -Path $ENV_FILE -Encoding utf8 }
|
||||
$emb_url = Read-Host "Enter remote embeddings API base URL"
|
||||
if ($emb_url) { "EMBEDDINGS_BASE_URL=$emb_url" | Add-Content -Path $ENV_FILE -Encoding utf8 }
|
||||
$emb_key = Read-Host "Enter embeddings API key (leave empty if none)"
|
||||
if ($emb_key) { "EMBEDDINGS_KEY=$emb_key" | Add-Content -Path $ENV_FILE -Encoding utf8 }
|
||||
Write-ColorText "Custom remote embeddings configured." -ForegroundColor "Green"
|
||||
}
|
||||
{$_ -eq "b" -or $_ -eq "B"} { return }
|
||||
default {
|
||||
Write-Host ""
|
||||
Write-ColorText "Invalid choice." -ForegroundColor "Red"
|
||||
Start-Sleep -Seconds 1
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# Authentication configuration
|
||||
function Configure-Auth {
|
||||
Write-Host ""
|
||||
Write-ColorText "Authentication Configuration" -ForegroundColor "White" -Bold
|
||||
Write-ColorText "Choose authentication type:" -ForegroundColor "White"
|
||||
Write-ColorText "1) None (default, no authentication)" -ForegroundColor "Yellow"
|
||||
Write-ColorText "2) Simple JWT" -ForegroundColor "Yellow"
|
||||
Write-ColorText "3) Session JWT" -ForegroundColor "Yellow"
|
||||
Write-ColorText "b) Back" -ForegroundColor "Yellow"
|
||||
Write-Host ""
|
||||
$auth_choice = Read-Host "Choose option (1-3, or b)"
|
||||
|
||||
switch ($auth_choice) {
|
||||
"1" {
|
||||
Write-ColorText "Authentication disabled (default)." -ForegroundColor "Green"
|
||||
}
|
||||
"2" {
|
||||
"AUTH_TYPE=simple_jwt" | Add-Content -Path $ENV_FILE -Encoding utf8
|
||||
$jwt_key = Read-Host "Enter JWT secret key (leave empty to auto-generate)"
|
||||
if ([string]::IsNullOrEmpty($jwt_key)) {
|
||||
$bytes = New-Object byte[] 32
|
||||
[System.Security.Cryptography.RandomNumberGenerator]::Fill($bytes)
|
||||
$jwt_key = [System.BitConverter]::ToString($bytes).Replace("-", "").ToLower()
|
||||
Write-ColorText "Auto-generated JWT secret key." -ForegroundColor "Yellow"
|
||||
}
|
||||
"JWT_SECRET_KEY=$jwt_key" | Add-Content -Path $ENV_FILE -Encoding utf8
|
||||
Write-ColorText "Authentication set to Simple JWT." -ForegroundColor "Green"
|
||||
}
|
||||
"3" {
|
||||
"AUTH_TYPE=session_jwt" | Add-Content -Path $ENV_FILE -Encoding utf8
|
||||
$jwt_key = Read-Host "Enter JWT secret key (leave empty to auto-generate)"
|
||||
if ([string]::IsNullOrEmpty($jwt_key)) {
|
||||
$bytes = New-Object byte[] 32
|
||||
[System.Security.Cryptography.RandomNumberGenerator]::Fill($bytes)
|
||||
$jwt_key = [System.BitConverter]::ToString($bytes).Replace("-", "").ToLower()
|
||||
Write-ColorText "Auto-generated JWT secret key." -ForegroundColor "Yellow"
|
||||
}
|
||||
"JWT_SECRET_KEY=$jwt_key" | Add-Content -Path $ENV_FILE -Encoding utf8
|
||||
Write-ColorText "Authentication set to Session JWT." -ForegroundColor "Green"
|
||||
}
|
||||
{$_ -eq "b" -or $_ -eq "B"} { return }
|
||||
default {
|
||||
Write-Host ""
|
||||
Write-ColorText "Invalid choice." -ForegroundColor "Red"
|
||||
Start-Sleep -Seconds 1
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# Integrations configuration
|
||||
function Configure-Integrations {
|
||||
Write-Host ""
|
||||
Write-ColorText "Integrations Configuration" -ForegroundColor "White" -Bold
|
||||
Write-ColorText "1) Google Drive" -ForegroundColor "Yellow"
|
||||
Write-ColorText "2) GitHub" -ForegroundColor "Yellow"
|
||||
Write-ColorText "b) Back" -ForegroundColor "Yellow"
|
||||
Write-Host ""
|
||||
$int_choice = Read-Host "Choose option (1-2, or b)"
|
||||
|
||||
switch ($int_choice) {
|
||||
"1" {
|
||||
$google_id = Read-Host "Enter Google OAuth Client ID"
|
||||
if ($google_id) { "GOOGLE_CLIENT_ID=$google_id" | Add-Content -Path $ENV_FILE -Encoding utf8 }
|
||||
$google_secret = Read-Host "Enter Google OAuth Client Secret"
|
||||
if ($google_secret) { "GOOGLE_CLIENT_SECRET=$google_secret" | Add-Content -Path $ENV_FILE -Encoding utf8 }
|
||||
Write-ColorText "Google Drive integration configured." -ForegroundColor "Green"
|
||||
}
|
||||
"2" {
|
||||
$github_token = Read-Host "Enter GitHub Personal Access Token (with repo read access)"
|
||||
if ($github_token) { "GITHUB_ACCESS_TOKEN=$github_token" | Add-Content -Path $ENV_FILE -Encoding utf8 }
|
||||
Write-ColorText "GitHub integration configured." -ForegroundColor "Green"
|
||||
}
|
||||
{$_ -eq "b" -or $_ -eq "B"} { return }
|
||||
default {
|
||||
Write-Host ""
|
||||
Write-ColorText "Invalid choice." -ForegroundColor "Red"
|
||||
Start-Sleep -Seconds 1
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# Document Processing configuration
|
||||
function Configure-DocProcessing {
|
||||
Write-Host ""
|
||||
Write-ColorText "Document Processing Configuration" -ForegroundColor "White" -Bold
|
||||
$pdf_image = Read-Host "Parse PDF pages as images for better table/chart extraction? (y/N)"
|
||||
if ($pdf_image -eq "y" -or $pdf_image -eq "Y") {
|
||||
"PARSE_PDF_AS_IMAGE=true" | Add-Content -Path $ENV_FILE -Encoding utf8
|
||||
Write-ColorText "PDF-as-image parsing enabled." -ForegroundColor "Green"
|
||||
}
|
||||
|
||||
$ocr_enabled = Read-Host "Enable OCR for document processing (Docling)? (y/N)"
|
||||
if ($ocr_enabled -eq "y" -or $ocr_enabled -eq "Y") {
|
||||
"DOCLING_OCR_ENABLED=true" | Add-Content -Path $ENV_FILE -Encoding utf8
|
||||
Write-ColorText "Docling OCR enabled." -ForegroundColor "Green"
|
||||
}
|
||||
}
|
||||
|
||||
# Text-to-Speech configuration
|
||||
function Configure-TTS {
|
||||
Write-Host ""
|
||||
Write-ColorText "Text-to-Speech Configuration" -ForegroundColor "White" -Bold
|
||||
Write-ColorText "Choose TTS provider:" -ForegroundColor "White"
|
||||
Write-ColorText "1) Google TTS (default, free)" -ForegroundColor "Yellow"
|
||||
Write-ColorText "2) ElevenLabs" -ForegroundColor "Yellow"
|
||||
Write-ColorText "b) Back" -ForegroundColor "Yellow"
|
||||
Write-Host ""
|
||||
$tts_choice = Read-Host "Choose option (1-2, or b)"
|
||||
|
||||
switch ($tts_choice) {
|
||||
"1" {
|
||||
"TTS_PROVIDER=google_tts" | Add-Content -Path $ENV_FILE -Encoding utf8
|
||||
Write-ColorText "TTS set to Google TTS." -ForegroundColor "Green"
|
||||
}
|
||||
"2" {
|
||||
"TTS_PROVIDER=elevenlabs" | Add-Content -Path $ENV_FILE -Encoding utf8
|
||||
$elevenlabs_key = Read-Host "Enter ElevenLabs API key"
|
||||
if ($elevenlabs_key) { "ELEVENLABS_API_KEY=$elevenlabs_key" | Add-Content -Path $ENV_FILE -Encoding utf8 }
|
||||
Write-ColorText "TTS set to ElevenLabs." -ForegroundColor "Green"
|
||||
}
|
||||
{$_ -eq "b" -or $_ -eq "B"} { return }
|
||||
default {
|
||||
Write-Host ""
|
||||
Write-ColorText "Invalid choice." -ForegroundColor "Red"
|
||||
Start-Sleep -Seconds 1
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# Main advanced settings menu
|
||||
function Prompt-AdvancedSettings {
|
||||
Write-Host ""
|
||||
$configure_advanced = Read-Host "Would you like to configure advanced settings? (y/N)"
|
||||
if ($configure_advanced -ne "y" -and $configure_advanced -ne "Y") {
|
||||
return
|
||||
}
|
||||
|
||||
while ($true) {
|
||||
Write-Host ""
|
||||
Write-ColorText "Advanced Settings" -ForegroundColor "White" -Bold
|
||||
Write-ColorText "1) Vector Store (default: faiss)" -ForegroundColor "Yellow"
|
||||
Write-ColorText "2) Embeddings (default: HuggingFace local)" -ForegroundColor "Yellow"
|
||||
Write-ColorText "3) Authentication (default: none)" -ForegroundColor "Yellow"
|
||||
Write-ColorText "4) Integrations (Google Drive, GitHub)" -ForegroundColor "Yellow"
|
||||
Write-ColorText "5) Document Processing (PDF as image, OCR)" -ForegroundColor "Yellow"
|
||||
Write-ColorText "6) Text-to-Speech (default: Google TTS)" -ForegroundColor "Yellow"
|
||||
Write-ColorText "s) Save and Continue with Docker setup" -ForegroundColor "Yellow"
|
||||
Write-Host ""
|
||||
$adv_choice = Read-Host "Choose option (1-6, or s)"
|
||||
|
||||
switch ($adv_choice) {
|
||||
"1" { Configure-VectorStore }
|
||||
"2" { Configure-Embeddings }
|
||||
"3" { Configure-Auth }
|
||||
"4" { Configure-Integrations }
|
||||
"5" { Configure-DocProcessing }
|
||||
"6" { Configure-TTS }
|
||||
{$_ -eq "s" -or $_ -eq "S"} { break }
|
||||
default {
|
||||
Write-Host ""
|
||||
Write-ColorText "Invalid choice." -ForegroundColor "Red"
|
||||
Start-Sleep -Seconds 1
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# 1) Use DocsGPT Public API Endpoint (simple and free)
|
||||
function Use-DocsPublicAPIEndpoint {
|
||||
Write-Host ""
|
||||
@@ -297,6 +592,8 @@ function Use-DocsPublicAPIEndpoint {
|
||||
|
||||
Write-ColorText ".env file configured for DocsGPT Public API." -ForegroundColor "Green"
|
||||
|
||||
Prompt-AdvancedSettings
|
||||
|
||||
# Start Docker if needed
|
||||
$dockerRunning = Check-AndStartDocker
|
||||
if (-not $dockerRunning) {
|
||||
@@ -369,6 +666,7 @@ function Serve-LocalOllama {
|
||||
break
|
||||
}
|
||||
elseif ($confirm_gpu -eq "b" -or $confirm_gpu -eq "B") {
|
||||
$script:ollama_choice = "b"
|
||||
Clear-Host
|
||||
return
|
||||
}
|
||||
@@ -406,6 +704,8 @@ function Serve-LocalOllama {
|
||||
Write-ColorText ".env file configured for Ollama ($($docker_compose_file_suffix.ToUpper()))." -ForegroundColor "Green"
|
||||
Write-ColorText "Note: MODEL_NAME is set to '$model_name'. You can change it later in the .env file." -ForegroundColor "Yellow"
|
||||
|
||||
Prompt-AdvancedSettings
|
||||
|
||||
# Start Docker if needed
|
||||
$dockerRunning = Check-AndStartDocker
|
||||
if (-not $dockerRunning) {
|
||||
@@ -569,6 +869,8 @@ function Connect-LocalInferenceEngine {
|
||||
Write-ColorText ".env file configured for $engine_name with OpenAI API format." -ForegroundColor "Green"
|
||||
Write-ColorText "Note: MODEL_NAME is set to '$model_name'. You can change it later in the .env file." -ForegroundColor "Yellow"
|
||||
|
||||
Prompt-AdvancedSettings
|
||||
|
||||
# Start Docker if needed
|
||||
$dockerRunning = Check-AndStartDocker
|
||||
if (-not $dockerRunning) {
|
||||
@@ -665,6 +967,12 @@ function Connect-CloudAPIProvider {
|
||||
$script:llm_name = "azure_openai"
|
||||
$script:model_name = "gpt-4o"
|
||||
Get-APIKey
|
||||
Write-Host ""
|
||||
Write-ColorText "Azure OpenAI requires additional configuration:" -ForegroundColor "White" -Bold
|
||||
$script:azure_api_base = Read-Host "Enter Azure OpenAI API base URL (e.g. https://your-resource.openai.azure.com/)"
|
||||
$script:azure_api_version = Read-Host "Enter Azure OpenAI API version (e.g. 2024-02-15-preview)"
|
||||
$script:azure_deployment = Read-Host "Enter Azure deployment name for chat"
|
||||
$script:azure_emb_deployment = Read-Host "Enter Azure deployment name for embeddings (leave empty to skip)"
|
||||
break
|
||||
}
|
||||
"7" { # Novita
|
||||
@@ -696,9 +1004,19 @@ function Connect-CloudAPIProvider {
|
||||
"LLM_PROVIDER=$llm_name" | Add-Content -Path $ENV_FILE -Encoding utf8
|
||||
"LLM_NAME=$model_name" | Add-Content -Path $ENV_FILE -Encoding utf8
|
||||
"VITE_API_STREAMING=true" | Add-Content -Path $ENV_FILE -Encoding utf8
|
||||
|
||||
|
||||
# Azure OpenAI additional settings
|
||||
if ($llm_name -eq "azure_openai") {
|
||||
if ($azure_api_base) { "OPENAI_API_BASE=$azure_api_base" | Add-Content -Path $ENV_FILE -Encoding utf8 }
|
||||
if ($azure_api_version) { "OPENAI_API_VERSION=$azure_api_version" | Add-Content -Path $ENV_FILE -Encoding utf8 }
|
||||
if ($azure_deployment) { "AZURE_DEPLOYMENT_NAME=$azure_deployment" | Add-Content -Path $ENV_FILE -Encoding utf8 }
|
||||
if ($azure_emb_deployment) { "AZURE_EMBEDDINGS_DEPLOYMENT_NAME=$azure_emb_deployment" | Add-Content -Path $ENV_FILE -Encoding utf8 }
|
||||
}
|
||||
|
||||
Write-ColorText ".env file configured for $provider_name." -ForegroundColor "Green"
|
||||
|
||||
Prompt-AdvancedSettings
|
||||
|
||||
# Start Docker if needed
|
||||
$dockerRunning = Check-AndStartDocker
|
||||
if (-not $dockerRunning) {
|
||||
@@ -709,15 +1027,15 @@ function Connect-CloudAPIProvider {
|
||||
try {
|
||||
Write-Host ""
|
||||
Write-ColorText "Starting Docker Compose..." -ForegroundColor "White"
|
||||
|
||||
|
||||
# Run Docker compose commands
|
||||
& docker compose --env-file "$ENV_FILE" -f "$COMPOSE_FILE" pull
|
||||
if ($LASTEXITCODE -ne 0) {
|
||||
throw "Docker compose pull failed with exit code $LASTEXITCODE"
|
||||
}
|
||||
|
||||
|
||||
& docker compose --env-file "$ENV_FILE" -f "$COMPOSE_FILE" up -d
|
||||
|
||||
|
||||
Write-Host ""
|
||||
Write-ColorText "DocsGPT is now configured to use $provider_name on http://localhost:5173" -ForegroundColor "Green"
|
||||
Write-ColorText "You can stop the application by running: docker compose -f `"$COMPOSE_FILE`" down" -ForegroundColor "Yellow"
|
||||
@@ -734,6 +1052,23 @@ function Connect-CloudAPIProvider {
|
||||
# Main script execution
|
||||
Animate-Dino
|
||||
|
||||
# Check if .env file exists and is not empty
|
||||
if ((Test-Path $ENV_FILE) -and ((Get-Item $ENV_FILE).Length -gt 0)) {
|
||||
Write-Host ""
|
||||
Write-ColorText "Warning: An existing .env file was found with the following settings:" -ForegroundColor "Yellow" -Bold
|
||||
$envLines = Get-Content $ENV_FILE
|
||||
$envLines | Select-Object -First 3 | ForEach-Object { Write-Host " $_" }
|
||||
if ($envLines.Count -gt 3) {
|
||||
Write-Host " ... and $($envLines.Count - 3) more lines"
|
||||
}
|
||||
Write-Host ""
|
||||
$confirm_overwrite = Read-Host "Running setup will overwrite this file. Continue? (y/N)"
|
||||
if ($confirm_overwrite -ne "y" -and $confirm_overwrite -ne "Y") {
|
||||
Write-ColorText "Setup cancelled. Your .env file was not modified." -ForegroundColor "Green"
|
||||
exit 0
|
||||
}
|
||||
}
|
||||
|
||||
while ($true) {
|
||||
Clear-Host
|
||||
Prompt-MainMenu
|
||||
|
||||
324
setup.sh
324
setup.sh
@@ -9,9 +9,10 @@ NC='\033[0m'
|
||||
BOLD='\033[1m'
|
||||
|
||||
# Base Compose file (relative to script location)
|
||||
COMPOSE_FILE="$(dirname "$(readlink -f "$0")")/deployment/docker-compose-hub.yaml"
|
||||
COMPOSE_FILE_LOCAL="$(dirname "$(readlink -f "$0")")/deployment/docker-compose.yaml"
|
||||
ENV_FILE="$(dirname "$(readlink -f "$0")")/.env"
|
||||
SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd -P)"
|
||||
COMPOSE_FILE="${SCRIPT_DIR}/deployment/docker-compose-hub.yaml"
|
||||
COMPOSE_FILE_LOCAL="${SCRIPT_DIR}/deployment/docker-compose.yaml"
|
||||
ENV_FILE="${SCRIPT_DIR}/.env"
|
||||
|
||||
# Animation function
|
||||
animate_dino() {
|
||||
@@ -155,7 +156,7 @@ prompt_cloud_api_provider_options() {
|
||||
echo -e "${YELLOW}7) Novita${NC}"
|
||||
echo -e "${YELLOW}b) Back to Main Menu${NC}"
|
||||
echo
|
||||
read -p "$(echo -e "${DEFAULT_FG}Choose option (1-6, or b): ${NC}")" provider_choice
|
||||
read -p "$(echo -e "${DEFAULT_FG}Choose option (1-7, or b): ${NC}")" provider_choice
|
||||
}
|
||||
|
||||
# Function to prompt for Ollama CPU/GPU options
|
||||
@@ -170,6 +171,264 @@ prompt_ollama_options() {
|
||||
read -p "$(echo -e "${DEFAULT_FG}Choose option (1-2, or b): ${NC}")" ollama_choice
|
||||
}
|
||||
|
||||
# ========================
|
||||
# Advanced Settings Functions
|
||||
# ========================
|
||||
|
||||
# Vector Store configuration
|
||||
configure_vector_store() {
|
||||
echo -e "\n${DEFAULT_FG}${BOLD}Vector Store Configuration${NC}"
|
||||
echo -e "${DEFAULT_FG}Choose your vector store:${NC}"
|
||||
echo -e "${YELLOW}1) FAISS (default, local)${NC}"
|
||||
echo -e "${YELLOW}2) Elasticsearch${NC}"
|
||||
echo -e "${YELLOW}3) Qdrant${NC}"
|
||||
echo -e "${YELLOW}4) Milvus${NC}"
|
||||
echo -e "${YELLOW}5) LanceDB${NC}"
|
||||
echo -e "${YELLOW}6) PGVector${NC}"
|
||||
echo -e "${YELLOW}b) Back${NC}"
|
||||
echo
|
||||
read -p "$(echo -e "${DEFAULT_FG}Choose option (1-6, or b): ${NC}")" vs_choice
|
||||
|
||||
case "$vs_choice" in
|
||||
1)
|
||||
echo "VECTOR_STORE=faiss" >> "$ENV_FILE"
|
||||
echo -e "${GREEN}Vector store set to FAISS.${NC}"
|
||||
;;
|
||||
2)
|
||||
echo "VECTOR_STORE=elasticsearch" >> "$ENV_FILE"
|
||||
read -p "$(echo -e "${DEFAULT_FG}Enter Elasticsearch URL (e.g. http://localhost:9200): ${NC}")" elastic_url
|
||||
[ -n "$elastic_url" ] && echo "ELASTIC_URL=$elastic_url" >> "$ENV_FILE"
|
||||
read -p "$(echo -e "${DEFAULT_FG}Enter Elasticsearch Cloud ID (leave empty if using URL): ${NC}")" elastic_cloud_id
|
||||
[ -n "$elastic_cloud_id" ] && echo "ELASTIC_CLOUD_ID=$elastic_cloud_id" >> "$ENV_FILE"
|
||||
read -p "$(echo -e "${DEFAULT_FG}Enter Elasticsearch username (leave empty if none): ${NC}")" elastic_user
|
||||
[ -n "$elastic_user" ] && echo "ELASTIC_USERNAME=$elastic_user" >> "$ENV_FILE"
|
||||
read -p "$(echo -e "${DEFAULT_FG}Enter Elasticsearch password (leave empty if none): ${NC}")" elastic_pass
|
||||
[ -n "$elastic_pass" ] && echo "ELASTIC_PASSWORD=$elastic_pass" >> "$ENV_FILE"
|
||||
read -p "$(echo -e "${DEFAULT_FG}Enter Elasticsearch index name (default: docsgpt): ${NC}")" elastic_index
|
||||
echo "ELASTIC_INDEX=${elastic_index:-docsgpt}" >> "$ENV_FILE"
|
||||
echo -e "${GREEN}Vector store set to Elasticsearch.${NC}"
|
||||
;;
|
||||
3)
|
||||
echo "VECTOR_STORE=qdrant" >> "$ENV_FILE"
|
||||
read -p "$(echo -e "${DEFAULT_FG}Enter Qdrant URL (e.g. http://localhost:6333): ${NC}")" qdrant_url
|
||||
[ -n "$qdrant_url" ] && echo "QDRANT_URL=$qdrant_url" >> "$ENV_FILE"
|
||||
read -p "$(echo -e "${DEFAULT_FG}Enter Qdrant API key (leave empty if none): ${NC}")" qdrant_key
|
||||
[ -n "$qdrant_key" ] && echo "QDRANT_API_KEY=$qdrant_key" >> "$ENV_FILE"
|
||||
read -p "$(echo -e "${DEFAULT_FG}Enter Qdrant collection name (default: docsgpt): ${NC}")" qdrant_collection
|
||||
echo "QDRANT_COLLECTION_NAME=${qdrant_collection:-docsgpt}" >> "$ENV_FILE"
|
||||
echo -e "${GREEN}Vector store set to Qdrant.${NC}"
|
||||
;;
|
||||
4)
|
||||
echo "VECTOR_STORE=milvus" >> "$ENV_FILE"
|
||||
read -p "$(echo -e "${DEFAULT_FG}Enter Milvus URI (default: ./milvus_local.db): ${NC}")" milvus_uri
|
||||
echo "MILVUS_URI=${milvus_uri:-./milvus_local.db}" >> "$ENV_FILE"
|
||||
read -p "$(echo -e "${DEFAULT_FG}Enter Milvus token (leave empty if none): ${NC}")" milvus_token
|
||||
[ -n "$milvus_token" ] && echo "MILVUS_TOKEN=$milvus_token" >> "$ENV_FILE"
|
||||
read -p "$(echo -e "${DEFAULT_FG}Enter Milvus collection name (default: docsgpt): ${NC}")" milvus_collection
|
||||
echo "MILVUS_COLLECTION_NAME=${milvus_collection:-docsgpt}" >> "$ENV_FILE"
|
||||
echo -e "${GREEN}Vector store set to Milvus.${NC}"
|
||||
;;
|
||||
5)
|
||||
echo "VECTOR_STORE=lancedb" >> "$ENV_FILE"
|
||||
read -p "$(echo -e "${DEFAULT_FG}Enter LanceDB path (default: ./data/lancedb): ${NC}")" lancedb_path
|
||||
echo "LANCEDB_PATH=${lancedb_path:-./data/lancedb}" >> "$ENV_FILE"
|
||||
read -p "$(echo -e "${DEFAULT_FG}Enter LanceDB table name (default: docsgpts): ${NC}")" lancedb_table
|
||||
echo "LANCEDB_TABLE_NAME=${lancedb_table:-docsgpts}" >> "$ENV_FILE"
|
||||
echo -e "${GREEN}Vector store set to LanceDB.${NC}"
|
||||
;;
|
||||
6)
|
||||
echo "VECTOR_STORE=pgvector" >> "$ENV_FILE"
|
||||
read -p "$(echo -e "${DEFAULT_FG}Enter PGVector connection string (e.g. postgresql://user:pass@host:5432/db): ${NC}")" pgvector_conn
|
||||
[ -n "$pgvector_conn" ] && echo "PGVECTOR_CONNECTION_STRING=$pgvector_conn" >> "$ENV_FILE"
|
||||
echo -e "${GREEN}Vector store set to PGVector.${NC}"
|
||||
;;
|
||||
b|B) return ;;
|
||||
*) echo -e "\n${RED}Invalid choice.${NC}" ; sleep 1 ;;
|
||||
esac
|
||||
}
|
||||
|
||||
# Embeddings configuration
|
||||
configure_embeddings() {
|
||||
echo -e "\n${DEFAULT_FG}${BOLD}Embeddings Configuration${NC}"
|
||||
echo -e "${DEFAULT_FG}Choose your embeddings provider:${NC}"
|
||||
echo -e "${YELLOW}1) HuggingFace (default, local)${NC}"
|
||||
echo -e "${YELLOW}2) OpenAI Embeddings${NC}"
|
||||
echo -e "${YELLOW}3) Custom Remote Embeddings (OpenAI-compatible API)${NC}"
|
||||
echo -e "${YELLOW}b) Back${NC}"
|
||||
echo
|
||||
read -p "$(echo -e "${DEFAULT_FG}Choose option (1-3, or b): ${NC}")" emb_choice
|
||||
|
||||
case "$emb_choice" in
|
||||
1)
|
||||
echo "EMBEDDINGS_NAME=huggingface_sentence-transformers/all-mpnet-base-v2" >> "$ENV_FILE"
|
||||
echo -e "${GREEN}Embeddings set to HuggingFace (local).${NC}"
|
||||
;;
|
||||
2)
|
||||
echo "EMBEDDINGS_NAME=openai_text-embedding-ada-002" >> "$ENV_FILE"
|
||||
read -p "$(echo -e "${DEFAULT_FG}Enter Embeddings API key (leave empty to reuse LLM API_KEY): ${NC}")" emb_key
|
||||
[ -n "$emb_key" ] && echo "EMBEDDINGS_KEY=$emb_key" >> "$ENV_FILE"
|
||||
echo -e "${GREEN}Embeddings set to OpenAI.${NC}"
|
||||
;;
|
||||
3)
|
||||
read -p "$(echo -e "${DEFAULT_FG}Enter embeddings model name: ${NC}")" emb_name
|
||||
[ -n "$emb_name" ] && echo "EMBEDDINGS_NAME=$emb_name" >> "$ENV_FILE"
|
||||
read -p "$(echo -e "${DEFAULT_FG}Enter remote embeddings API base URL: ${NC}")" emb_url
|
||||
[ -n "$emb_url" ] && echo "EMBEDDINGS_BASE_URL=$emb_url" >> "$ENV_FILE"
|
||||
read -p "$(echo -e "${DEFAULT_FG}Enter embeddings API key (leave empty if none): ${NC}")" emb_key
|
||||
[ -n "$emb_key" ] && echo "EMBEDDINGS_KEY=$emb_key" >> "$ENV_FILE"
|
||||
echo -e "${GREEN}Custom remote embeddings configured.${NC}"
|
||||
;;
|
||||
b|B) return ;;
|
||||
*) echo -e "\n${RED}Invalid choice.${NC}" ; sleep 1 ;;
|
||||
esac
|
||||
}
|
||||
|
||||
# Authentication configuration
|
||||
configure_auth() {
|
||||
echo -e "\n${DEFAULT_FG}${BOLD}Authentication Configuration${NC}"
|
||||
echo -e "${DEFAULT_FG}Choose authentication type:${NC}"
|
||||
echo -e "${YELLOW}1) None (default, no authentication)${NC}"
|
||||
echo -e "${YELLOW}2) Simple JWT${NC}"
|
||||
echo -e "${YELLOW}3) Session JWT${NC}"
|
||||
echo -e "${YELLOW}b) Back${NC}"
|
||||
echo
|
||||
read -p "$(echo -e "${DEFAULT_FG}Choose option (1-3, or b): ${NC}")" auth_choice
|
||||
|
||||
case "$auth_choice" in
|
||||
1)
|
||||
echo -e "${GREEN}Authentication disabled (default).${NC}"
|
||||
;;
|
||||
2)
|
||||
echo "AUTH_TYPE=simple_jwt" >> "$ENV_FILE"
|
||||
read -p "$(echo -e "${DEFAULT_FG}Enter JWT secret key (leave empty to auto-generate): ${NC}")" jwt_key
|
||||
if [ -n "$jwt_key" ]; then
|
||||
echo "JWT_SECRET_KEY=$jwt_key" >> "$ENV_FILE"
|
||||
else
|
||||
generated_key=$(openssl rand -hex 32 2>/dev/null || head -c 64 /dev/urandom | od -An -tx1 | tr -d ' \n')
|
||||
echo "JWT_SECRET_KEY=$generated_key" >> "$ENV_FILE"
|
||||
echo -e "${YELLOW}Auto-generated JWT secret key.${NC}"
|
||||
fi
|
||||
echo -e "${GREEN}Authentication set to Simple JWT.${NC}"
|
||||
;;
|
||||
3)
|
||||
echo "AUTH_TYPE=session_jwt" >> "$ENV_FILE"
|
||||
read -p "$(echo -e "${DEFAULT_FG}Enter JWT secret key (leave empty to auto-generate): ${NC}")" jwt_key
|
||||
if [ -n "$jwt_key" ]; then
|
||||
echo "JWT_SECRET_KEY=$jwt_key" >> "$ENV_FILE"
|
||||
else
|
||||
generated_key=$(openssl rand -hex 32 2>/dev/null || head -c 64 /dev/urandom | od -An -tx1 | tr -d ' \n')
|
||||
echo "JWT_SECRET_KEY=$generated_key" >> "$ENV_FILE"
|
||||
echo -e "${YELLOW}Auto-generated JWT secret key.${NC}"
|
||||
fi
|
||||
echo -e "${GREEN}Authentication set to Session JWT.${NC}"
|
||||
;;
|
||||
b|B) return ;;
|
||||
*) echo -e "\n${RED}Invalid choice.${NC}" ; sleep 1 ;;
|
||||
esac
|
||||
}
|
||||
|
||||
# Integrations configuration
|
||||
configure_integrations() {
|
||||
echo -e "\n${DEFAULT_FG}${BOLD}Integrations Configuration${NC}"
|
||||
echo -e "${YELLOW}1) Google Drive${NC}"
|
||||
echo -e "${YELLOW}2) GitHub${NC}"
|
||||
echo -e "${YELLOW}b) Back${NC}"
|
||||
echo
|
||||
read -p "$(echo -e "${DEFAULT_FG}Choose option (1-2, or b): ${NC}")" int_choice
|
||||
|
||||
case "$int_choice" in
|
||||
1)
|
||||
read -p "$(echo -e "${DEFAULT_FG}Enter Google OAuth Client ID: ${NC}")" google_id
|
||||
[ -n "$google_id" ] && echo "GOOGLE_CLIENT_ID=$google_id" >> "$ENV_FILE"
|
||||
read -p "$(echo -e "${DEFAULT_FG}Enter Google OAuth Client Secret: ${NC}")" google_secret
|
||||
[ -n "$google_secret" ] && echo "GOOGLE_CLIENT_SECRET=$google_secret" >> "$ENV_FILE"
|
||||
echo -e "${GREEN}Google Drive integration configured.${NC}"
|
||||
;;
|
||||
2)
|
||||
read -p "$(echo -e "${DEFAULT_FG}Enter GitHub Personal Access Token (with repo read access): ${NC}")" github_token
|
||||
[ -n "$github_token" ] && echo "GITHUB_ACCESS_TOKEN=$github_token" >> "$ENV_FILE"
|
||||
echo -e "${GREEN}GitHub integration configured.${NC}"
|
||||
;;
|
||||
b|B) return ;;
|
||||
*) echo -e "\n${RED}Invalid choice.${NC}" ; sleep 1 ;;
|
||||
esac
|
||||
}
|
||||
|
||||
# Document Processing configuration
|
||||
configure_doc_processing() {
|
||||
echo -e "\n${DEFAULT_FG}${BOLD}Document Processing Configuration${NC}"
|
||||
read -p "$(echo -e "${DEFAULT_FG}Parse PDF pages as images for better table/chart extraction? (y/N): ${NC}")" pdf_image
|
||||
if [[ "$pdf_image" =~ ^[yY]$ ]]; then
|
||||
echo "PARSE_PDF_AS_IMAGE=true" >> "$ENV_FILE"
|
||||
echo -e "${GREEN}PDF-as-image parsing enabled.${NC}"
|
||||
fi
|
||||
|
||||
read -p "$(echo -e "${DEFAULT_FG}Enable OCR for document processing (Docling)? (y/N): ${NC}")" ocr_enabled
|
||||
if [[ "$ocr_enabled" =~ ^[yY]$ ]]; then
|
||||
echo "DOCLING_OCR_ENABLED=true" >> "$ENV_FILE"
|
||||
echo -e "${GREEN}Docling OCR enabled.${NC}"
|
||||
fi
|
||||
}
|
||||
|
||||
# Text-to-Speech configuration
|
||||
configure_tts() {
|
||||
echo -e "\n${DEFAULT_FG}${BOLD}Text-to-Speech Configuration${NC}"
|
||||
echo -e "${DEFAULT_FG}Choose TTS provider:${NC}"
|
||||
echo -e "${YELLOW}1) Google TTS (default, free)${NC}"
|
||||
echo -e "${YELLOW}2) ElevenLabs${NC}"
|
||||
echo -e "${YELLOW}b) Back${NC}"
|
||||
echo
|
||||
read -p "$(echo -e "${DEFAULT_FG}Choose option (1-2, or b): ${NC}")" tts_choice
|
||||
|
||||
case "$tts_choice" in
|
||||
1)
|
||||
echo "TTS_PROVIDER=google_tts" >> "$ENV_FILE"
|
||||
echo -e "${GREEN}TTS set to Google TTS.${NC}"
|
||||
;;
|
||||
2)
|
||||
echo "TTS_PROVIDER=elevenlabs" >> "$ENV_FILE"
|
||||
read -p "$(echo -e "${DEFAULT_FG}Enter ElevenLabs API key: ${NC}")" elevenlabs_key
|
||||
[ -n "$elevenlabs_key" ] && echo "ELEVENLABS_API_KEY=$elevenlabs_key" >> "$ENV_FILE"
|
||||
echo -e "${GREEN}TTS set to ElevenLabs.${NC}"
|
||||
;;
|
||||
b|B) return ;;
|
||||
*) echo -e "\n${RED}Invalid choice.${NC}" ; sleep 1 ;;
|
||||
esac
|
||||
}
|
||||
|
||||
# Main advanced settings menu
|
||||
prompt_advanced_settings() {
|
||||
echo
|
||||
read -p "$(echo -e "${DEFAULT_FG}Would you like to configure advanced settings? (y/N): ${NC}")" configure_advanced
|
||||
if [[ ! "$configure_advanced" =~ ^[yY]$ ]]; then
|
||||
return
|
||||
fi
|
||||
|
||||
while true; do
|
||||
echo -e "\n${DEFAULT_FG}${BOLD}Advanced Settings${NC}"
|
||||
echo -e "${YELLOW}1) Vector Store ${NC}${DEFAULT_FG}(default: faiss)${NC}"
|
||||
echo -e "${YELLOW}2) Embeddings ${NC}${DEFAULT_FG}(default: HuggingFace local)${NC}"
|
||||
echo -e "${YELLOW}3) Authentication ${NC}${DEFAULT_FG}(default: none)${NC}"
|
||||
echo -e "${YELLOW}4) Integrations ${NC}${DEFAULT_FG}(Google Drive, GitHub)${NC}"
|
||||
echo -e "${YELLOW}5) Document Processing ${NC}${DEFAULT_FG}(PDF as image, OCR)${NC}"
|
||||
echo -e "${YELLOW}6) Text-to-Speech ${NC}${DEFAULT_FG}(default: Google TTS)${NC}"
|
||||
echo -e "${YELLOW}s) Save and Continue with Docker setup${NC}"
|
||||
echo
|
||||
read -p "$(echo -e "${DEFAULT_FG}Choose option (1-6, or s): ${NC}")" adv_choice
|
||||
|
||||
case "$adv_choice" in
|
||||
1) configure_vector_store ;;
|
||||
2) configure_embeddings ;;
|
||||
3) configure_auth ;;
|
||||
4) configure_integrations ;;
|
||||
5) configure_doc_processing ;;
|
||||
6) configure_tts ;;
|
||||
s|S) break ;;
|
||||
*) echo -e "\n${RED}Invalid choice.${NC}" ; sleep 1 ;;
|
||||
esac
|
||||
done
|
||||
}
|
||||
|
||||
# 1) Use DocsGPT Public API Endpoint (simple and free)
|
||||
use_docs_public_api_endpoint() {
|
||||
echo -e "\n${NC}Setting up DocsGPT Public API Endpoint...${NC}"
|
||||
@@ -177,6 +436,8 @@ use_docs_public_api_endpoint() {
|
||||
echo "VITE_API_STREAMING=true" >> "$ENV_FILE"
|
||||
echo -e "${GREEN}.env file configured for DocsGPT Public API.${NC}"
|
||||
|
||||
prompt_advanced_settings
|
||||
|
||||
check_and_start_docker
|
||||
|
||||
echo -e "\n${NC}Starting Docker Compose...${NC}"
|
||||
@@ -229,11 +490,11 @@ serve_local_ollama() {
|
||||
docker_compose_file_suffix="gpu"
|
||||
get_model_name_ollama
|
||||
break ;;
|
||||
b|B) clear; return ;; # Back to Main Menu
|
||||
b|B) clear; return 1 ;; # Back to Main Menu
|
||||
*) echo -e "\n${RED}Invalid choice. Please choose y or b.${NC}" ; sleep 1 ;;
|
||||
esac
|
||||
;;
|
||||
b|B) clear; return ;; # Back to Main Menu
|
||||
b|B) clear; return 1 ;; # Back to Main Menu
|
||||
*) echo -e "\n${RED}Invalid choice. Please choose 1-2, or b.${NC}" ; sleep 1 ;;
|
||||
esac
|
||||
done
|
||||
@@ -248,6 +509,7 @@ serve_local_ollama() {
|
||||
echo "EMBEDDINGS_NAME=huggingface_sentence-transformers/all-mpnet-base-v2" >> "$ENV_FILE"
|
||||
echo -e "${GREEN}.env file configured for Ollama ($(echo "$docker_compose_file_suffix" | tr '[:lower:]' '[:upper:]')${NC}${GREEN}).${NC}"
|
||||
|
||||
prompt_advanced_settings
|
||||
|
||||
check_and_start_docker
|
||||
local compose_files=(
|
||||
@@ -346,7 +608,7 @@ connect_local_inference_engine() {
|
||||
openai_base_url="http://host.docker.internal:23333/v1"
|
||||
get_model_name
|
||||
break ;;
|
||||
b|B) clear; return ;; # Back to Main Menu
|
||||
b|B) clear; return 1 ;; # Back to Main Menu
|
||||
*) echo -e "\n${RED}Invalid choice. Please choose 1-8, or b.${NC}" ; sleep 1 ;;
|
||||
esac
|
||||
done
|
||||
@@ -361,6 +623,8 @@ connect_local_inference_engine() {
|
||||
echo -e "${GREEN}.env file configured for ${BOLD}${engine_name}${NC}${GREEN} with OpenAI API format.${NC}"
|
||||
echo -e "${YELLOW}Note: MODEL_NAME is set to '${BOLD}$model_name${NC}${YELLOW}'. You can change it later in the .env file.${NC}"
|
||||
|
||||
prompt_advanced_settings
|
||||
|
||||
check_and_start_docker
|
||||
|
||||
echo -e "\n${NC}Starting Docker Compose...${NC}"
|
||||
@@ -431,6 +695,11 @@ connect_cloud_api_provider() {
|
||||
llm_provider="azure_openai"
|
||||
model_name="gpt-4o"
|
||||
get_api_key
|
||||
echo -e "\n${DEFAULT_FG}${BOLD}Azure OpenAI requires additional configuration:${NC}"
|
||||
read -p "$(echo -e "${DEFAULT_FG}Enter Azure OpenAI API base URL (e.g. https://your-resource.openai.azure.com/): ${NC}")" azure_api_base
|
||||
read -p "$(echo -e "${DEFAULT_FG}Enter Azure OpenAI API version (e.g. 2024-02-15-preview): ${NC}")" azure_api_version
|
||||
read -p "$(echo -e "${DEFAULT_FG}Enter Azure deployment name for chat: ${NC}")" azure_deployment
|
||||
read -p "$(echo -e "${DEFAULT_FG}Enter Azure deployment name for embeddings (leave empty to skip): ${NC}")" azure_emb_deployment
|
||||
break ;;
|
||||
7) # Novita
|
||||
provider_name="Novita"
|
||||
@@ -438,8 +707,8 @@ connect_cloud_api_provider() {
|
||||
model_name="deepseek/deepseek-r1"
|
||||
get_api_key
|
||||
break ;;
|
||||
b|B) clear; return ;; # Clear screen and Back to Main Menu
|
||||
*) echo -e "\n${RED}Invalid choice. Please choose 1-6, or b.${NC}" ; sleep 1 ;;
|
||||
b|B) clear; return 1 ;; # Clear screen and Back to Main Menu
|
||||
*) echo -e "\n${RED}Invalid choice. Please choose 1-7, or b.${NC}" ; sleep 1 ;;
|
||||
esac
|
||||
done
|
||||
|
||||
@@ -448,8 +717,19 @@ connect_cloud_api_provider() {
|
||||
echo "LLM_PROVIDER=$llm_provider" >> "$ENV_FILE"
|
||||
echo "LLM_NAME=$model_name" >> "$ENV_FILE"
|
||||
echo "VITE_API_STREAMING=true" >> "$ENV_FILE"
|
||||
|
||||
# Azure OpenAI additional settings
|
||||
if [ "$llm_provider" = "azure_openai" ]; then
|
||||
[ -n "$azure_api_base" ] && echo "OPENAI_API_BASE=$azure_api_base" >> "$ENV_FILE"
|
||||
[ -n "$azure_api_version" ] && echo "OPENAI_API_VERSION=$azure_api_version" >> "$ENV_FILE"
|
||||
[ -n "$azure_deployment" ] && echo "AZURE_DEPLOYMENT_NAME=$azure_deployment" >> "$ENV_FILE"
|
||||
[ -n "$azure_emb_deployment" ] && echo "AZURE_EMBEDDINGS_DEPLOYMENT_NAME=$azure_emb_deployment" >> "$ENV_FILE"
|
||||
fi
|
||||
|
||||
echo -e "${GREEN}.env file configured for ${BOLD}${provider_name}${NC}${GREEN}.${NC}"
|
||||
|
||||
prompt_advanced_settings
|
||||
|
||||
check_and_start_docker
|
||||
|
||||
echo -e "\n${NC}Starting Docker Compose...${NC}"
|
||||
@@ -472,6 +752,21 @@ connect_cloud_api_provider() {
|
||||
# Main script execution
|
||||
animate_dino
|
||||
|
||||
# Check if .env file exists and is not empty
|
||||
if [ -f "$ENV_FILE" ] && [ -s "$ENV_FILE" ]; then
|
||||
echo -e "\n${YELLOW}${BOLD}Warning:${NC}${YELLOW} An existing .env file was found with the following settings:${NC}"
|
||||
head -3 "$ENV_FILE" | while IFS= read -r line; do echo -e "${DEFAULT_FG} $line${NC}"; done
|
||||
total_lines=$(wc -l < "$ENV_FILE")
|
||||
if [ "$total_lines" -gt 3 ]; then
|
||||
echo -e "${DEFAULT_FG} ... and $((total_lines - 3)) more lines${NC}"
|
||||
fi
|
||||
echo
|
||||
read -p "$(echo -e "${YELLOW}Running setup will overwrite this file. Continue? (y/N): ${NC}")" confirm_overwrite
|
||||
if [[ ! "$confirm_overwrite" =~ ^[yY]$ ]]; then
|
||||
echo -e "${GREEN}Setup cancelled. Your .env file was not modified.${NC}"
|
||||
exit 0
|
||||
fi
|
||||
fi
|
||||
|
||||
while true; do # Main menu loop
|
||||
clear # Clear screen before showing main menu again
|
||||
@@ -479,18 +774,15 @@ while true; do # Main menu loop
|
||||
|
||||
case $main_choice in
|
||||
1) # Use DocsGPT Public API Endpoint (Docker Hub images)
|
||||
COMPOSE_FILE="$(dirname "$(readlink -f "$0")")/deployment/docker-compose-hub.yaml"
|
||||
COMPOSE_FILE="${SCRIPT_DIR}/deployment/docker-compose-hub.yaml"
|
||||
use_docs_public_api_endpoint
|
||||
break ;;
|
||||
2) # Serve Local (with Ollama)
|
||||
serve_local_ollama
|
||||
break ;;
|
||||
serve_local_ollama && break ;;
|
||||
3) # Connect Local Inference Engine
|
||||
connect_local_inference_engine
|
||||
break ;;
|
||||
connect_local_inference_engine && break ;;
|
||||
4) # Connect Cloud API Provider
|
||||
connect_cloud_api_provider
|
||||
break ;;
|
||||
connect_cloud_api_provider && break ;;
|
||||
5) # Advanced: Build images locally
|
||||
echo -e "\n${YELLOW}You have selected to build images locally. This is recommended for developers or if you want to test local changes.${NC}"
|
||||
COMPOSE_FILE="$COMPOSE_FILE_LOCAL"
|
||||
|
||||
332
tests/agents/test_workflow_engine.py
Normal file
332
tests/agents/test_workflow_engine.py
Normal file
@@ -0,0 +1,332 @@
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import pytest
|
||||
|
||||
from application.api.user.workflows import routes as workflow_routes
|
||||
from application.agents.workflows.node_agent import WorkflowNodeAgentFactory
|
||||
from application.agents.workflows.schemas import (
|
||||
NodeType,
|
||||
Workflow,
|
||||
WorkflowGraph,
|
||||
WorkflowNode,
|
||||
)
|
||||
from application.agents.workflows.workflow_engine import WorkflowEngine
|
||||
from application.api.user.workflows.routes import validate_workflow_structure
|
||||
|
||||
|
||||
class StubNodeAgent:
|
||||
def __init__(self, events):
|
||||
self.events = events
|
||||
|
||||
def gen(self, _prompt):
|
||||
yield from self.events
|
||||
|
||||
|
||||
def create_engine() -> WorkflowEngine:
|
||||
graph = WorkflowGraph(workflow=Workflow(name="Engine Test"), nodes=[], edges=[])
|
||||
agent = SimpleNamespace(
|
||||
endpoint="stream",
|
||||
llm_name="openai",
|
||||
model_id="gpt-4o-mini",
|
||||
api_key="test-key",
|
||||
chat_history=[],
|
||||
decoded_token={"sub": "user-1"},
|
||||
)
|
||||
return WorkflowEngine(graph, agent)
|
||||
|
||||
|
||||
def create_agent_node(
|
||||
node_id: str,
|
||||
output_variable: str = "",
|
||||
json_schema: Optional[Dict[str, Any]] = None,
|
||||
) -> WorkflowNode:
|
||||
config = {
|
||||
"agent_type": "classic",
|
||||
"system_prompt": "You are a helpful assistant.",
|
||||
"prompt_template": "",
|
||||
"stream_to_user": False,
|
||||
"tools": [],
|
||||
}
|
||||
if output_variable:
|
||||
config["output_variable"] = output_variable
|
||||
if json_schema is not None:
|
||||
config["json_schema"] = json_schema
|
||||
|
||||
return WorkflowNode(
|
||||
id=node_id,
|
||||
workflow_id="workflow-1",
|
||||
type=NodeType.AGENT,
|
||||
title="Agent",
|
||||
position={"x": 0, "y": 0},
|
||||
config=config,
|
||||
)
|
||||
|
||||
|
||||
def test_execute_agent_node_saves_structured_output_as_json(monkeypatch):
|
||||
engine = create_engine()
|
||||
node = create_agent_node(
|
||||
node_id="agent_1",
|
||||
output_variable="result",
|
||||
json_schema={"type": "object"},
|
||||
)
|
||||
node_events = [
|
||||
{"answer": '{"summary":"ok",', "structured": True},
|
||||
{"answer": '"score":2}', "structured": True},
|
||||
]
|
||||
|
||||
monkeypatch.setattr(
|
||||
WorkflowNodeAgentFactory,
|
||||
"create",
|
||||
staticmethod(lambda **kwargs: StubNodeAgent(node_events)),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"application.core.model_utils.get_api_key_for_provider",
|
||||
lambda _provider: None,
|
||||
)
|
||||
|
||||
list(engine._execute_agent_node(node))
|
||||
|
||||
expected_output = {"summary": "ok", "score": 2}
|
||||
assert engine.state["node_agent_1_output"] == expected_output
|
||||
assert engine.state["result"] == expected_output
|
||||
|
||||
|
||||
def test_execute_agent_node_normalizes_wrapped_schema_before_agent_create(monkeypatch):
|
||||
engine = create_engine()
|
||||
node = create_agent_node(
|
||||
node_id="agent_wrapped",
|
||||
json_schema={"schema": {"type": "object"}},
|
||||
)
|
||||
node_events = [{"answer": '{"summary":"ok"}', "structured": True}]
|
||||
captured: Dict[str, Any] = {}
|
||||
|
||||
def create_node_agent(**kwargs):
|
||||
captured["json_schema"] = kwargs.get("json_schema")
|
||||
return StubNodeAgent(node_events)
|
||||
|
||||
monkeypatch.setattr(
|
||||
WorkflowNodeAgentFactory,
|
||||
"create",
|
||||
staticmethod(create_node_agent),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"application.core.model_utils.get_api_key_for_provider",
|
||||
lambda _provider: None,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"application.core.model_utils.get_model_capabilities",
|
||||
lambda _model_id: {"supports_structured_output": True},
|
||||
)
|
||||
|
||||
list(engine._execute_agent_node(node))
|
||||
|
||||
assert captured["json_schema"] == {"type": "object"}
|
||||
assert engine.state["node_agent_wrapped_output"] == {"summary": "ok"}
|
||||
|
||||
|
||||
def test_execute_agent_node_falls_back_to_text_when_schema_not_configured(monkeypatch):
|
||||
engine = create_engine()
|
||||
node = create_agent_node(node_id="agent_2", output_variable="result")
|
||||
node_events = [{"answer": "plain text answer"}]
|
||||
|
||||
monkeypatch.setattr(
|
||||
WorkflowNodeAgentFactory,
|
||||
"create",
|
||||
staticmethod(lambda **kwargs: StubNodeAgent(node_events)),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"application.core.model_utils.get_api_key_for_provider",
|
||||
lambda _provider: None,
|
||||
)
|
||||
|
||||
list(engine._execute_agent_node(node))
|
||||
|
||||
assert engine.state["node_agent_2_output"] == "plain text answer"
|
||||
assert engine.state["result"] == "plain text answer"
|
||||
|
||||
|
||||
def test_validate_workflow_structure_rejects_invalid_agent_json_schema():
|
||||
nodes = [
|
||||
{"id": "start", "type": "start", "title": "Start", "data": {}},
|
||||
{
|
||||
"id": "agent",
|
||||
"type": "agent",
|
||||
"title": "Agent",
|
||||
"data": {"json_schema": "invalid"},
|
||||
},
|
||||
{"id": "end", "type": "end", "title": "End", "data": {}},
|
||||
]
|
||||
edges = [
|
||||
{"id": "edge_1", "source": "start", "target": "agent"},
|
||||
{"id": "edge_2", "source": "agent", "target": "end"},
|
||||
]
|
||||
|
||||
errors = validate_workflow_structure(nodes, edges)
|
||||
|
||||
assert any(
|
||||
"Agent node 'Agent' JSON schema must be a valid JSON object" in err
|
||||
for err in errors
|
||||
)
|
||||
|
||||
|
||||
def test_validate_workflow_structure_accepts_valid_agent_json_schema():
|
||||
nodes = [
|
||||
{"id": "start", "type": "start", "title": "Start", "data": {}},
|
||||
{
|
||||
"id": "agent",
|
||||
"type": "agent",
|
||||
"title": "Agent",
|
||||
"data": {"json_schema": {"type": "object"}},
|
||||
},
|
||||
{"id": "end", "type": "end", "title": "End", "data": {}},
|
||||
]
|
||||
edges = [
|
||||
{"id": "edge_1", "source": "start", "target": "agent"},
|
||||
{"id": "edge_2", "source": "agent", "target": "end"},
|
||||
]
|
||||
|
||||
errors = validate_workflow_structure(nodes, edges)
|
||||
|
||||
assert errors == []
|
||||
|
||||
|
||||
def test_validate_workflow_structure_accepts_wrapped_agent_json_schema():
|
||||
nodes = [
|
||||
{"id": "start", "type": "start", "title": "Start", "data": {}},
|
||||
{
|
||||
"id": "agent",
|
||||
"type": "agent",
|
||||
"title": "Agent",
|
||||
"data": {"json_schema": {"schema": {"type": "object"}}},
|
||||
},
|
||||
{"id": "end", "type": "end", "title": "End", "data": {}},
|
||||
]
|
||||
edges = [
|
||||
{"id": "edge_1", "source": "start", "target": "agent"},
|
||||
{"id": "edge_2", "source": "agent", "target": "end"},
|
||||
]
|
||||
|
||||
errors = validate_workflow_structure(nodes, edges)
|
||||
|
||||
assert errors == []
|
||||
|
||||
|
||||
def test_validate_workflow_structure_accepts_output_variable_and_schema_together():
|
||||
nodes = [
|
||||
{"id": "start", "type": "start", "title": "Start", "data": {}},
|
||||
{
|
||||
"id": "agent",
|
||||
"type": "agent",
|
||||
"title": "Agent",
|
||||
"data": {
|
||||
"output_variable": "answer",
|
||||
"json_schema": {"type": "object"},
|
||||
},
|
||||
},
|
||||
{"id": "end", "type": "end", "title": "End", "data": {}},
|
||||
]
|
||||
edges = [
|
||||
{"id": "edge_1", "source": "start", "target": "agent"},
|
||||
{"id": "edge_2", "source": "agent", "target": "end"},
|
||||
]
|
||||
|
||||
errors = validate_workflow_structure(nodes, edges)
|
||||
|
||||
assert errors == []
|
||||
|
||||
|
||||
def test_validate_workflow_structure_rejects_unsupported_structured_output_model(
|
||||
monkeypatch,
|
||||
):
|
||||
monkeypatch.setattr(
|
||||
workflow_routes,
|
||||
"get_model_capabilities",
|
||||
lambda _model_id: {"supports_structured_output": False},
|
||||
)
|
||||
|
||||
nodes = [
|
||||
{"id": "start", "type": "start", "title": "Start", "data": {}},
|
||||
{
|
||||
"id": "agent",
|
||||
"type": "agent",
|
||||
"title": "Agent",
|
||||
"data": {
|
||||
"model_id": "some-model",
|
||||
"json_schema": {"type": "object"},
|
||||
},
|
||||
},
|
||||
{"id": "end", "type": "end", "title": "End", "data": {}},
|
||||
]
|
||||
edges = [
|
||||
{"id": "edge_1", "source": "start", "target": "agent"},
|
||||
{"id": "edge_2", "source": "agent", "target": "end"},
|
||||
]
|
||||
|
||||
errors = validate_workflow_structure(nodes, edges)
|
||||
|
||||
assert any(
|
||||
"Agent node 'Agent' selected model does not support structured output"
|
||||
in err
|
||||
for err in errors
|
||||
)
|
||||
|
||||
|
||||
def test_execute_agent_node_raises_when_structured_output_violates_schema(monkeypatch):
|
||||
engine = create_engine()
|
||||
node = create_agent_node(
|
||||
node_id="agent_3",
|
||||
json_schema={
|
||||
"type": "object",
|
||||
"properties": {"summary": {"type": "string"}},
|
||||
"required": ["summary"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
)
|
||||
node_events = [{"answer": '{"score":2}', "structured": True}]
|
||||
|
||||
monkeypatch.setattr(
|
||||
WorkflowNodeAgentFactory,
|
||||
"create",
|
||||
staticmethod(lambda **kwargs: StubNodeAgent(node_events)),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"application.core.model_utils.get_api_key_for_provider",
|
||||
lambda _provider: None,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"application.core.model_utils.get_model_capabilities",
|
||||
lambda _model_id: {"supports_structured_output": True},
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="Structured output did not match schema"):
|
||||
list(engine._execute_agent_node(node))
|
||||
|
||||
|
||||
def test_execute_agent_node_raises_when_schema_set_and_response_not_json(monkeypatch):
|
||||
engine = create_engine()
|
||||
node = create_agent_node(
|
||||
node_id="agent_4",
|
||||
json_schema={"type": "object"},
|
||||
)
|
||||
node_events = [{"answer": "not-json"}]
|
||||
|
||||
monkeypatch.setattr(
|
||||
WorkflowNodeAgentFactory,
|
||||
"create",
|
||||
staticmethod(lambda **kwargs: StubNodeAgent(node_events)),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"application.core.model_utils.get_api_key_for_provider",
|
||||
lambda _provider: None,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"application.core.model_utils.get_model_capabilities",
|
||||
lambda _model_id: {"supports_structured_output": True},
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="Structured output was expected but response was not valid JSON",
|
||||
):
|
||||
list(engine._execute_agent_node(node))
|
||||
63
tests/agents/test_workflow_template.py
Normal file
63
tests/agents/test_workflow_template.py
Normal file
@@ -0,0 +1,63 @@
|
||||
from types import SimpleNamespace
|
||||
|
||||
from application.agents.workflows.schemas import Workflow, WorkflowGraph
|
||||
from application.agents.workflows.workflow_engine import WorkflowEngine
|
||||
|
||||
|
||||
def create_engine() -> WorkflowEngine:
|
||||
graph = WorkflowGraph(workflow=Workflow(name="Template Test"), nodes=[], edges=[])
|
||||
agent = SimpleNamespace(
|
||||
user="user-1",
|
||||
request_id="req-1",
|
||||
retrieved_docs=[
|
||||
{"title": "Doc A", "text": "Summary A"},
|
||||
{"title": "Doc B", "text": "Summary B"},
|
||||
],
|
||||
)
|
||||
return WorkflowEngine(graph, agent)
|
||||
|
||||
|
||||
def test_workflow_template_supports_agent_namespace_and_legacy_variables():
|
||||
engine = create_engine()
|
||||
engine.state = {"query": "Hello", "chat_history": "[]", "ticket_id": 42}
|
||||
|
||||
rendered = engine._format_template(
|
||||
"{{ agent.query }}|{{ agent.ticket_id }}|{{ query }}|{{ ticket_id }}"
|
||||
)
|
||||
|
||||
assert rendered == "Hello|42|Hello|42"
|
||||
|
||||
|
||||
def test_workflow_template_supports_global_namespaces():
|
||||
engine = create_engine()
|
||||
engine.state = {"query": "Hello"}
|
||||
|
||||
rendered = engine._format_template(
|
||||
"{{ source.count }}|{{ source.summaries }}|{{ system.request_id }}"
|
||||
)
|
||||
|
||||
assert rendered.startswith("2|")
|
||||
assert "Doc A" in rendered
|
||||
assert "Summary A" in rendered
|
||||
assert rendered.endswith("|req-1")
|
||||
|
||||
|
||||
def test_workflow_template_handles_namespace_conflicts_with_agent_prefix():
|
||||
engine = create_engine()
|
||||
engine.state = {"source": "user-defined-source"}
|
||||
|
||||
rendered = engine._format_template(
|
||||
"{{ agent.source }}|{{ agent_source }}|{{ source.count }}"
|
||||
)
|
||||
|
||||
assert rendered.startswith("user-defined-source|user-defined-source|")
|
||||
|
||||
|
||||
def test_workflow_template_gracefully_handles_invalid_template_syntax():
|
||||
engine = create_engine()
|
||||
engine.state = {"query": "Hello"}
|
||||
|
||||
invalid_template = "{{ agent.query "
|
||||
rendered = engine._format_template(invalid_template)
|
||||
|
||||
assert rendered == invalid_template
|
||||
@@ -199,6 +199,7 @@ class TestStreamProcessorAgentConfiguration:
|
||||
try:
|
||||
processor._configure_agent()
|
||||
assert processor.agent_config is not None
|
||||
assert processor.agent_id == str(agent_id)
|
||||
except Exception as e:
|
||||
assert "Invalid API Key" in str(e)
|
||||
|
||||
@@ -211,6 +212,7 @@ class TestStreamProcessorAgentConfiguration:
|
||||
processor._configure_agent()
|
||||
|
||||
assert isinstance(processor.agent_config, dict)
|
||||
assert processor.agent_id is None
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
|
||||
326
tests/test_usage.py
Normal file
326
tests/test_usage.py
Normal file
@@ -0,0 +1,326 @@
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
|
||||
from application.usage import (
|
||||
_count_tokens,
|
||||
gen_token_usage,
|
||||
stream_token_usage,
|
||||
update_token_usage,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_count_tokens_includes_tool_call_payloads():
|
||||
payload = [
|
||||
{
|
||||
"function_call": {
|
||||
"name": "search_docs",
|
||||
"args": {"query": "pricing limits"},
|
||||
"call_id": "call_1",
|
||||
}
|
||||
},
|
||||
{
|
||||
"function_response": {
|
||||
"name": "search_docs",
|
||||
"response": {"result": "Found 3 docs"},
|
||||
"call_id": "call_1",
|
||||
}
|
||||
},
|
||||
]
|
||||
|
||||
assert _count_tokens(payload) > 0
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_gen_token_usage_counts_structured_tool_content(monkeypatch):
|
||||
captured = {}
|
||||
|
||||
def fake_update(decoded_token, user_api_key, token_usage, agent_id=None):
|
||||
captured["decoded_token"] = decoded_token
|
||||
captured["user_api_key"] = user_api_key
|
||||
captured["token_usage"] = token_usage.copy()
|
||||
captured["agent_id"] = agent_id
|
||||
|
||||
monkeypatch.setattr("application.usage.update_token_usage", fake_update)
|
||||
|
||||
class DummyLLM:
|
||||
decoded_token = {"sub": "user_123"}
|
||||
user_api_key = "api_key_123"
|
||||
agent_id = "agent_123"
|
||||
token_usage = {"prompt_tokens": 0, "generated_tokens": 0}
|
||||
|
||||
@gen_token_usage
|
||||
def wrapped(self, model, messages, stream, tools, **kwargs):
|
||||
_ = (model, messages, stream, tools, kwargs)
|
||||
return {
|
||||
"tool_calls": [
|
||||
{"name": "read_webpage", "arguments": {"url": "https://example.com"}}
|
||||
]
|
||||
}
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{
|
||||
"function_call": {
|
||||
"name": "search_docs",
|
||||
"args": {"query": "pricing"},
|
||||
"call_id": "1",
|
||||
}
|
||||
}
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"content": [
|
||||
{
|
||||
"function_response": {
|
||||
"name": "search_docs",
|
||||
"response": {"result": "Found docs"},
|
||||
"call_id": "1",
|
||||
}
|
||||
}
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
llm = DummyLLM()
|
||||
wrapped(llm, "gpt-4o", messages, False, None)
|
||||
|
||||
assert captured["decoded_token"] == {"sub": "user_123"}
|
||||
assert captured["user_api_key"] == "api_key_123"
|
||||
assert captured["agent_id"] == "agent_123"
|
||||
assert captured["token_usage"]["prompt_tokens"] > 0
|
||||
assert captured["token_usage"]["generated_tokens"] > 0
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_stream_token_usage_counts_tool_call_chunks(monkeypatch):
|
||||
captured = {}
|
||||
|
||||
def fake_update(decoded_token, user_api_key, token_usage, agent_id=None):
|
||||
captured["token_usage"] = token_usage.copy()
|
||||
captured["agent_id"] = agent_id
|
||||
|
||||
monkeypatch.setattr("application.usage.update_token_usage", fake_update)
|
||||
|
||||
class ToolChunk:
|
||||
def model_dump(self):
|
||||
return {
|
||||
"delta": {
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_1",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": '{"location":"Seattle"}',
|
||||
},
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
class DummyLLM:
|
||||
decoded_token = {"sub": "user_123"}
|
||||
user_api_key = "api_key_123"
|
||||
agent_id = "agent_123"
|
||||
token_usage = {"prompt_tokens": 0, "generated_tokens": 0}
|
||||
|
||||
@stream_token_usage
|
||||
def wrapped(self, model, messages, stream, tools, **kwargs):
|
||||
_ = (model, messages, stream, tools, kwargs)
|
||||
yield ToolChunk()
|
||||
yield "done"
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{
|
||||
"function_call": {
|
||||
"name": "get_weather",
|
||||
"args": {"location": "Seattle"},
|
||||
"call_id": "1",
|
||||
}
|
||||
}
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
llm = DummyLLM()
|
||||
list(wrapped(llm, "gpt-4o", messages, True, None))
|
||||
|
||||
assert captured["agent_id"] == "agent_123"
|
||||
assert captured["token_usage"]["prompt_tokens"] > 0
|
||||
assert captured["token_usage"]["generated_tokens"] > 0
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_gen_token_usage_counts_tools_and_image_inputs(monkeypatch):
|
||||
captured = []
|
||||
|
||||
def fake_update(decoded_token, user_api_key, token_usage, agent_id=None):
|
||||
_ = (decoded_token, user_api_key, agent_id)
|
||||
captured.append(token_usage.copy())
|
||||
|
||||
monkeypatch.setattr("application.usage.update_token_usage", fake_update)
|
||||
|
||||
class DummyLLM:
|
||||
decoded_token = {"sub": "user_123"}
|
||||
user_api_key = "api_key_123"
|
||||
agent_id = "agent_123"
|
||||
token_usage = {"prompt_tokens": 0, "generated_tokens": 0}
|
||||
|
||||
@gen_token_usage
|
||||
def wrapped(self, model, messages, stream, tools, **kwargs):
|
||||
_ = (model, messages, stream, tools, kwargs)
|
||||
return "ok"
|
||||
|
||||
messages = [{"role": "user", "content": "What is in this image?"}]
|
||||
tools_payload = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "describe_image",
|
||||
"description": "Describe image content",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"detail": {"type": "string"}},
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
usage_attachments = [
|
||||
{
|
||||
"mime_type": "image/png",
|
||||
"path": "attachments/example.png",
|
||||
"data": "abc123",
|
||||
}
|
||||
]
|
||||
|
||||
llm = DummyLLM()
|
||||
wrapped(llm, "gpt-4o", messages, False, None)
|
||||
wrapped(
|
||||
llm,
|
||||
"gpt-4o",
|
||||
messages,
|
||||
False,
|
||||
tools_payload,
|
||||
_usage_attachments=usage_attachments,
|
||||
)
|
||||
|
||||
assert len(captured) == 2
|
||||
assert captured[1]["prompt_tokens"] > captured[0]["prompt_tokens"]
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_stream_token_usage_counts_tools_and_image_inputs(monkeypatch):
|
||||
captured = []
|
||||
|
||||
def fake_update(decoded_token, user_api_key, token_usage, agent_id=None):
|
||||
_ = (decoded_token, user_api_key, agent_id)
|
||||
captured.append(token_usage.copy())
|
||||
|
||||
monkeypatch.setattr("application.usage.update_token_usage", fake_update)
|
||||
|
||||
class DummyLLM:
|
||||
decoded_token = {"sub": "user_123"}
|
||||
user_api_key = "api_key_123"
|
||||
agent_id = "agent_123"
|
||||
token_usage = {"prompt_tokens": 0, "generated_tokens": 0}
|
||||
|
||||
@stream_token_usage
|
||||
def wrapped(self, model, messages, stream, tools, **kwargs):
|
||||
_ = (model, messages, stream, tools, kwargs)
|
||||
yield "ok"
|
||||
|
||||
messages = [{"role": "user", "content": "What is in this image?"}]
|
||||
tools_payload = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "describe_image",
|
||||
"description": "Describe image content",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"detail": {"type": "string"}},
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
usage_attachments = [
|
||||
{
|
||||
"mime_type": "image/png",
|
||||
"path": "attachments/example.png",
|
||||
"data": "abc123",
|
||||
}
|
||||
]
|
||||
|
||||
llm = DummyLLM()
|
||||
list(wrapped(llm, "gpt-4o", messages, True, None))
|
||||
list(
|
||||
wrapped(
|
||||
llm,
|
||||
"gpt-4o",
|
||||
messages,
|
||||
True,
|
||||
tools_payload,
|
||||
_usage_attachments=usage_attachments,
|
||||
)
|
||||
)
|
||||
|
||||
assert len(captured) == 2
|
||||
assert captured[1]["prompt_tokens"] > captured[0]["prompt_tokens"]
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_update_token_usage_inserts_with_agent_id_only(monkeypatch):
|
||||
inserted_docs = []
|
||||
|
||||
class FakeCollection:
|
||||
def insert_one(self, doc):
|
||||
inserted_docs.append(doc)
|
||||
|
||||
modules_without_pytest = dict(sys.modules)
|
||||
modules_without_pytest.pop("pytest", None)
|
||||
|
||||
monkeypatch.setattr("application.usage.sys.modules", modules_without_pytest)
|
||||
monkeypatch.setattr("application.usage.usage_collection", FakeCollection())
|
||||
|
||||
update_token_usage(
|
||||
decoded_token=None,
|
||||
user_api_key=None,
|
||||
token_usage={"prompt_tokens": 10, "generated_tokens": 5},
|
||||
agent_id="agent_123",
|
||||
)
|
||||
|
||||
assert len(inserted_docs) == 1
|
||||
assert inserted_docs[0]["agent_id"] == "agent_123"
|
||||
assert inserted_docs[0]["user_id"] is None
|
||||
assert inserted_docs[0]["api_key"] is None
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_update_token_usage_skips_when_all_ids_missing(monkeypatch):
|
||||
inserted_docs = []
|
||||
|
||||
class FakeCollection:
|
||||
def insert_one(self, doc):
|
||||
inserted_docs.append(doc)
|
||||
|
||||
modules_without_pytest = dict(sys.modules)
|
||||
modules_without_pytest.pop("pytest", None)
|
||||
|
||||
monkeypatch.setattr("application.usage.sys.modules", modules_without_pytest)
|
||||
monkeypatch.setattr("application.usage.usage_collection", FakeCollection())
|
||||
|
||||
update_token_usage(
|
||||
decoded_token=None,
|
||||
user_api_key=None,
|
||||
token_usage={"prompt_tokens": 10, "generated_tokens": 5},
|
||||
agent_id=None,
|
||||
)
|
||||
|
||||
assert inserted_docs == []
|
||||
Reference in New Issue
Block a user