Compare commits

...

10 Commits

Author SHA1 Message Date
dependabot[bot]
4c7a6a78aa chore(deps): bump docker/setup-qemu-action from 3 to 4
Bumps [docker/setup-qemu-action](https://github.com/docker/setup-qemu-action) from 3 to 4.
- [Release notes](https://github.com/docker/setup-qemu-action/releases)
- [Commits](https://github.com/docker/setup-qemu-action/compare/v3...v4)

---
updated-dependencies:
- dependency-name: docker/setup-qemu-action
  dependency-version: '4'
  dependency-type: direct:production
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>
2026-03-04 20:54:17 +00:00
Alex
a6625ec5de fix: mini workflow fixes 2026-02-22 11:10:42 +00:00
Alex
1a2104f474 fix: token calc (#2285) 2026-02-20 17:37:47 +00:00
Alex
444abb8283 fix search nextra 2026-02-18 18:03:03 +00:00
Alex
ee86537f21 docs: add llms.txt and enable copy code button in nextra 2026-02-18 17:54:25 +00:00
Alex
17a736a927 docs: migrate to Nextra 4 and Next.js App Router 2026-02-18 17:13:24 +00:00
Alex
6b5779054d Merge branch 'main' of https://github.com/arc53/DocsGPT 2026-02-17 18:46:35 +00:00
Alex
14296632ef build(docs): upgrade nextra to v3 and update config 2026-02-17 18:46:28 +00:00
Siddhant Rai
2a3f0e455a feat: condition node functionality with CEL evaluation in Workflows (#2280)
* feat: add condition node functionality with CEL evaluation

- Introduced ConditionNode to support conditional branching in workflows.
- Implemented CEL evaluation for state updates and condition expressions.
- Updated WorkflowEngine to handle condition nodes and their execution logic.
- Enhanced validation for workflows to ensure condition nodes have at least two outgoing edges and valid expressions.
- Modified frontend components to support new condition node type and its configuration.
- Added necessary types and interfaces for condition cases and state operations.
- Updated requirements to include cel-python for expression evaluation.

* mini-fixes

* feat(workflow): improve UX

---------

Co-authored-by: Alex <a@tushynski.me>
2026-02-17 17:29:48 +00:00
Pavel
8aa44c415b Advanced settings (#2281)
Add additional settings to setup scripts
2026-02-17 11:54:59 +00:00
96 changed files with 8053 additions and 3036 deletions

View File

@@ -25,7 +25,7 @@ jobs:
- name: Set up QEMU # Only needed for emulation, not for native arm64 builds
if: matrix.platform == 'linux/arm64'
uses: docker/setup-qemu-action@v3
uses: docker/setup-qemu-action@v4
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3

View File

@@ -25,7 +25,7 @@ jobs:
- name: Set up QEMU # Only needed for emulation, not for native arm64 builds
if: matrix.platform == 'linux/arm64'
uses: docker/setup-qemu-action@v3
uses: docker/setup-qemu-action@v4
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3

View File

@@ -27,7 +27,7 @@ jobs:
- name: Set up QEMU # Only needed for emulation, not for native arm64 builds
if: matrix.platform == 'linux/arm64'
uses: docker/setup-qemu-action@v3
uses: docker/setup-qemu-action@v4
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3

1
.gitignore vendored
View File

@@ -71,6 +71,7 @@ instance/
# Sphinx documentation
docs/_build/
docs/public/_pagefind/
# PyBuilder
target/

View File

@@ -7,6 +7,10 @@ from bson.objectid import ObjectId
from application.agents.tools.tool_action_parser import ToolActionParser
from application.agents.tools.tool_manager import ToolManager
from application.core.json_schema_utils import (
JsonSchemaValidationError,
normalize_json_schema_payload,
)
from application.core.mongo_db import MongoDB
from application.core.settings import settings
from application.llm.handlers.handler_creator import LLMHandlerCreator
@@ -23,6 +27,7 @@ class BaseAgent(ABC):
llm_name: str,
model_id: str,
api_key: str,
agent_id: Optional[str] = None,
user_api_key: Optional[str] = None,
prompt: str = "",
chat_history: Optional[List[Dict]] = None,
@@ -40,6 +45,7 @@ class BaseAgent(ABC):
self.llm_name = llm_name
self.model_id = model_id
self.api_key = api_key
self.agent_id = agent_id
self.user_api_key = user_api_key
self.prompt = prompt
self.decoded_token = decoded_token or {}
@@ -54,13 +60,19 @@ class BaseAgent(ABC):
user_api_key=user_api_key,
decoded_token=decoded_token,
model_id=model_id,
agent_id=agent_id,
)
self.retrieved_docs = retrieved_docs or []
self.llm_handler = LLMHandlerCreator.create_handler(
llm_name if llm_name else "default"
)
self.attachments = attachments or []
self.json_schema = json_schema
self.json_schema = None
if json_schema is not None:
try:
self.json_schema = normalize_json_schema_payload(json_schema)
except JsonSchemaValidationError as exc:
logger.warning("Ignoring invalid JSON schema payload: %s", exc)
self.limited_token_mode = limited_token_mode
self.token_limit = token_limit
self.limited_request_mode = limited_request_mode
@@ -263,6 +275,11 @@ class BaseAgent(ABC):
tool_config=tool_config,
user_id=self.user,
)
resolved_arguments = (
{"query_params": query_params, "headers": headers, "body": body}
if tool_data["name"] == "api_tool"
else parameters
)
if tool_data["name"] == "api_tool":
logger.debug(
f"Executing api: {action_name} with query_params: {query_params}, headers: {headers}, body: {body}"
@@ -292,11 +309,19 @@ class BaseAgent(ABC):
artifact_id = str(artifact_id).strip() if artifact_id is not None else ""
if artifact_id:
tool_call_data["artifact_id"] = artifact_id
result_full = str(result)
tool_call_data["resolved_arguments"] = resolved_arguments
tool_call_data["result_full"] = result_full
tool_call_data["result"] = (
f"{str(result)[:50]}..." if len(str(result)) > 50 else result
f"{result_full[:50]}..." if len(result_full) > 50 else result_full
)
yield {"type": "tool_call", "data": {**tool_call_data, "status": "completed"}}
stream_tool_call_data = {
key: value
for key, value in tool_call_data.items()
if key not in {"result_full", "resolved_arguments"}
}
yield {"type": "tool_call", "data": {**stream_tool_call_data, "status": "completed"}}
self.tool_calls.append(tool_call_data)
return result, call_id
@@ -304,7 +329,11 @@ class BaseAgent(ABC):
def _get_truncated_tool_calls(self):
return [
{
**tool_call,
"tool_name": tool_call.get("tool_name"),
"call_id": tool_call.get("call_id"),
"action_name": tool_call.get("action_name"),
"arguments": tool_call.get("arguments"),
"artifact_id": tool_call.get("artifact_id"),
"result": (
f"{str(tool_call['result'])[:50]}..."
if len(str(tool_call["result"])) > 50
@@ -576,6 +605,9 @@ class BaseAgent(ABC):
self._validate_context_size(messages)
gen_kwargs = {"model": self.model_id, "messages": messages}
if self.attachments:
# Usage accounting only; stripped before provider invocation.
gen_kwargs["_usage_attachments"] = self.attachments
if (
hasattr(self.llm, "_supports_tools")

View File

@@ -211,8 +211,21 @@ class WorkflowAgent(BaseAgent):
def _serialize_state(self, state: Dict[str, Any]) -> Dict[str, Any]:
serialized: Dict[str, Any] = {}
for key, value in state.items():
if isinstance(value, (str, int, float, bool, type(None))):
serialized[key] = value
else:
serialized[key] = str(value)
serialized[key] = self._serialize_state_value(value)
return serialized
def _serialize_state_value(self, value: Any) -> Any:
if isinstance(value, dict):
return {
str(dict_key): self._serialize_state_value(dict_value)
for dict_key, dict_value in value.items()
}
if isinstance(value, list):
return [self._serialize_state_value(item) for item in value]
if isinstance(value, tuple):
return [self._serialize_state_value(item) for item in value]
if isinstance(value, datetime):
return value.isoformat()
if isinstance(value, (str, int, float, bool, type(None))):
return value
return str(value)

View File

@@ -0,0 +1,64 @@
from typing import Any, Dict
import celpy
import celpy.celtypes
class CelEvaluationError(Exception):
pass
def _convert_value(value: Any) -> Any:
if isinstance(value, bool):
return celpy.celtypes.BoolType(value)
if isinstance(value, int):
return celpy.celtypes.IntType(value)
if isinstance(value, float):
return celpy.celtypes.DoubleType(value)
if isinstance(value, str):
return celpy.celtypes.StringType(value)
if isinstance(value, list):
return celpy.celtypes.ListType([_convert_value(item) for item in value])
if isinstance(value, dict):
return celpy.celtypes.MapType(
{celpy.celtypes.StringType(k): _convert_value(v) for k, v in value.items()}
)
if value is None:
return celpy.celtypes.BoolType(False)
return celpy.celtypes.StringType(str(value))
def build_activation(state: Dict[str, Any]) -> Dict[str, Any]:
return {k: _convert_value(v) for k, v in state.items()}
def evaluate_cel(expression: str, state: Dict[str, Any]) -> Any:
if not expression or not expression.strip():
raise CelEvaluationError("Empty expression")
try:
env = celpy.Environment()
ast = env.compile(expression)
program = env.program(ast)
activation = build_activation(state)
result = program.evaluate(activation)
except celpy.CELEvalError as exc:
raise CelEvaluationError(f"CEL evaluation error: {exc}") from exc
except Exception as exc:
raise CelEvaluationError(f"CEL error: {exc}") from exc
return cel_to_python(result)
def cel_to_python(value: Any) -> Any:
if isinstance(value, celpy.celtypes.BoolType):
return bool(value)
if isinstance(value, celpy.celtypes.IntType):
return int(value)
if isinstance(value, celpy.celtypes.DoubleType):
return float(value)
if isinstance(value, celpy.celtypes.StringType):
return str(value)
if isinstance(value, celpy.celtypes.ListType):
return [cel_to_python(item) for item in value]
if isinstance(value, celpy.celtypes.MapType):
return {str(k): cel_to_python(v) for k, v in value.items()}
return value

View File

@@ -1,6 +1,6 @@
from datetime import datetime, timezone
from enum import Enum
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Literal, Optional, Union
from bson import ObjectId
from pydantic import BaseModel, ConfigDict, Field, field_validator
@@ -12,6 +12,7 @@ class NodeType(str, Enum):
AGENT = "agent"
NOTE = "note"
STATE = "state"
CONDITION = "condition"
class AgentType(str, Enum):
@@ -48,6 +49,25 @@ class AgentNodeConfig(BaseModel):
json_schema: Optional[Dict[str, Any]] = None
class ConditionCase(BaseModel):
model_config = ConfigDict(extra="forbid", populate_by_name=True)
name: Optional[str] = None
expression: str = ""
source_handle: str = Field(..., alias="sourceHandle")
class ConditionNodeConfig(BaseModel):
model_config = ConfigDict(extra="allow")
mode: Literal["simple", "advanced"] = "simple"
cases: List[ConditionCase] = Field(default_factory=list)
class StateOperation(BaseModel):
model_config = ConfigDict(extra="forbid")
expression: str = ""
target_variable: str = ""
class WorkflowEdgeCreate(BaseModel):
model_config = ConfigDict(populate_by_name=True)
id: str

View File

@@ -1,16 +1,30 @@
import json
import logging
from datetime import datetime, timezone
from typing import Any, Dict, Generator, List, Optional, TYPE_CHECKING
from application.agents.workflows.cel_evaluator import CelEvaluationError, evaluate_cel
from application.agents.workflows.node_agent import WorkflowNodeAgentFactory
from application.agents.workflows.schemas import (
AgentNodeConfig,
ConditionNodeConfig,
ExecutionStatus,
NodeExecutionLog,
NodeType,
WorkflowGraph,
WorkflowNode,
)
from application.core.json_schema_utils import (
JsonSchemaValidationError,
normalize_json_schema_payload,
)
from application.templates.namespaces import NamespaceManager
from application.templates.template_engine import TemplateEngine, TemplateRenderError
try:
import jsonschema
except ImportError: # pragma: no cover - optional dependency in some deployments.
jsonschema = None
if TYPE_CHECKING:
from application.agents.base import BaseAgent
@@ -18,6 +32,7 @@ logger = logging.getLogger(__name__)
StateValue = Any
WorkflowState = Dict[str, StateValue]
TEMPLATE_RESERVED_NAMESPACES = {"agent", "system", "source", "tools", "passthrough"}
class WorkflowEngine:
@@ -28,6 +43,9 @@ class WorkflowEngine:
self.agent = agent
self.state: WorkflowState = {}
self.execution_log: List[Dict[str, Any]] = []
self._condition_result: Optional[str] = None
self._template_engine = TemplateEngine()
self._namespace_manager = NamespaceManager()
def execute(
self, initial_inputs: WorkflowState, query: str
@@ -98,6 +116,10 @@ class WorkflowEngine:
if node.type == NodeType.END:
break
current_node_id = self._get_next_node_id(current_node_id)
if current_node_id is None and node.type != NodeType.END:
logger.warning(
f"Branch ended at node '{node.title}' ({node.id}) without reaching an end node"
)
steps += 1
if steps >= self.MAX_EXECUTION_STEPS:
logger.warning(
@@ -121,10 +143,20 @@ class WorkflowEngine:
}
def _get_next_node_id(self, current_node_id: str) -> Optional[str]:
node = self.graph.get_node_by_id(current_node_id)
edges = self.graph.get_outgoing_edges(current_node_id)
if edges:
return edges[0].target_id
return None
if not edges:
return None
if node and node.type == NodeType.CONDITION and self._condition_result:
target_handle = self._condition_result
self._condition_result = None
for edge in edges:
if edge.source_handle == target_handle:
return edge.target_id
return None
return edges[0].target_id
def _execute_node(
self, node: WorkflowNode
@@ -136,6 +168,7 @@ class WorkflowEngine:
NodeType.NOTE: self._execute_note_node,
NodeType.AGENT: self._execute_agent_node,
NodeType.STATE: self._execute_state_node,
NodeType.CONDITION: self._execute_condition_node,
NodeType.END: self._execute_end_node,
}
@@ -156,35 +189,62 @@ class WorkflowEngine:
def _execute_agent_node(
self, node: WorkflowNode
) -> Generator[Dict[str, str], None, None]:
from application.core.model_utils import get_api_key_for_provider
from application.core.model_utils import (
get_api_key_for_provider,
get_model_capabilities,
get_provider_from_model_id,
)
node_config = AgentNodeConfig(**node.config)
node_config = AgentNodeConfig(**node.config.get("config", node.config))
if node_config.prompt_template:
formatted_prompt = self._format_template(node_config.prompt_template)
else:
formatted_prompt = self.state.get("query", "")
node_llm_name = node_config.llm_name or self.agent.llm_name
node_json_schema = self._normalize_node_json_schema(
node_config.json_schema, node.title
)
node_model_id = node_config.model_id or self.agent.model_id
node_llm_name = (
node_config.llm_name
or get_provider_from_model_id(node_model_id or "")
or self.agent.llm_name
)
node_api_key = get_api_key_for_provider(node_llm_name) or self.agent.api_key
if node_json_schema and node_model_id:
model_capabilities = get_model_capabilities(node_model_id)
if model_capabilities and not model_capabilities.get(
"supports_structured_output", False
):
raise ValueError(
f'Model "{node_model_id}" does not support structured output for node "{node.title}"'
)
node_agent = WorkflowNodeAgentFactory.create(
agent_type=node_config.agent_type,
endpoint=self.agent.endpoint,
llm_name=node_llm_name,
model_id=node_config.model_id or self.agent.model_id,
model_id=node_model_id,
api_key=node_api_key,
tool_ids=node_config.tools,
prompt=node_config.system_prompt,
chat_history=self.agent.chat_history,
decoded_token=self.agent.decoded_token,
json_schema=node_config.json_schema,
json_schema=node_json_schema,
)
full_response = ""
full_response_parts: List[str] = []
structured_response_parts: List[str] = []
has_structured_response = False
first_chunk = True
for event in node_agent.gen(formatted_prompt):
if "answer" in event:
full_response += event["answer"]
chunk = str(event["answer"])
full_response_parts.append(chunk)
if event.get("structured"):
has_structured_response = True
structured_response_parts.append(chunk)
if node_config.stream_to_user:
if first_chunk and hasattr(self, "_has_streamed"):
yield {"answer": "\n\n"}
@@ -194,72 +254,189 @@ class WorkflowEngine:
if node_config.stream_to_user:
self._has_streamed = True
output_key = node_config.output_variable or f"node_{node.id}_output"
self.state[output_key] = full_response
full_response = "".join(full_response_parts).strip()
output_value: Any = full_response
if has_structured_response:
structured_response = "".join(structured_response_parts).strip()
response_to_parse = structured_response or full_response
parsed_success, parsed_structured = self._parse_structured_output(
response_to_parse
)
output_value = parsed_structured if parsed_success else response_to_parse
if node_json_schema:
self._validate_structured_output(node_json_schema, output_value)
elif node_json_schema:
parsed_success, parsed_structured = self._parse_structured_output(
full_response
)
if not parsed_success:
raise ValueError(
"Structured output was expected but response was not valid JSON"
)
output_value = parsed_structured
self._validate_structured_output(node_json_schema, output_value)
default_output_key = f"node_{node.id}_output"
self.state[default_output_key] = output_value
if node_config.output_variable:
self.state[node_config.output_variable] = output_value
def _execute_state_node(
self, node: WorkflowNode
) -> Generator[Dict[str, str], None, None]:
config = node.config
operations = config.get("operations", [])
config = node.config.get("config", node.config)
for op in config.get("operations", []):
expression = op.get("expression", "")
target_variable = op.get("target_variable", "")
if expression and target_variable:
self.state[target_variable] = evaluate_cel(expression, self.state)
yield from ()
if operations:
for op in operations:
key = op.get("key")
operation = op.get("operation", "set")
value = op.get("value")
def _execute_condition_node(
self, node: WorkflowNode
) -> Generator[Dict[str, str], None, None]:
config = ConditionNodeConfig(**node.config.get("config", node.config))
matched_handle = None
if not key:
continue
if operation == "set":
formatted_value = (
self._format_template(str(value))
if isinstance(value, str)
else value
)
self.state[key] = formatted_value
elif operation == "increment":
current = self.state.get(key, 0)
try:
self.state[key] = int(current) + int(value or 1)
except (ValueError, TypeError):
self.state[key] = 1
elif operation == "append":
if key not in self.state:
self.state[key] = []
if isinstance(self.state[key], list):
self.state[key].append(value)
else:
updates = config.get("updates", {})
if not updates:
var_name = config.get("variable")
var_value = config.get("value")
if var_name and isinstance(var_name, str):
updates = {var_name: var_value or ""}
if isinstance(updates, dict):
for key, value in updates.items():
if isinstance(value, str):
self.state[key] = self._format_template(value)
else:
self.state[key] = value
for case in config.cases:
if not case.expression.strip():
continue
try:
if evaluate_cel(case.expression, self.state):
matched_handle = case.source_handle
break
except CelEvaluationError:
continue
self._condition_result = matched_handle or "else"
yield from ()
def _execute_end_node(
self, node: WorkflowNode
) -> Generator[Dict[str, str], None, None]:
config = node.config
config = node.config.get("config", node.config)
output_template = str(config.get("output_template", ""))
if output_template:
formatted_output = self._format_template(output_template)
yield {"answer": formatted_output}
def _parse_structured_output(self, raw_response: str) -> tuple[bool, Optional[Any]]:
normalized_response = raw_response.strip()
if not normalized_response:
return False, None
try:
return True, json.loads(normalized_response)
except json.JSONDecodeError:
logger.warning(
"Workflow agent returned structured output that was not valid JSON"
)
return False, None
def _normalize_node_json_schema(
self, schema: Optional[Dict[str, Any]], node_title: str
) -> Optional[Dict[str, Any]]:
if schema is None:
return None
try:
return normalize_json_schema_payload(schema)
except JsonSchemaValidationError as exc:
raise ValueError(
f'Invalid JSON schema for node "{node_title}": {exc}'
) from exc
def _validate_structured_output(self, schema: Dict[str, Any], output_value: Any) -> None:
if jsonschema is None:
logger.warning(
"jsonschema package is not available, skipping structured output validation"
)
return
try:
normalized_schema = normalize_json_schema_payload(schema)
except JsonSchemaValidationError as exc:
raise ValueError(f"Invalid JSON schema: {exc}") from exc
try:
jsonschema.validate(instance=output_value, schema=normalized_schema)
except jsonschema.exceptions.ValidationError as exc:
raise ValueError(f"Structured output did not match schema: {exc.message}") from exc
except jsonschema.exceptions.SchemaError as exc:
raise ValueError(f"Invalid JSON schema: {exc.message}") from exc
def _format_template(self, template: str) -> str:
formatted = template
context = self._build_template_context()
try:
return self._template_engine.render(template, context)
except TemplateRenderError as e:
logger.warning(
"Workflow template rendering failed, using raw template: %s", str(e)
)
return template
def _build_template_context(self) -> Dict[str, Any]:
docs, docs_together = self._get_source_template_data()
passthrough_data = (
self.state.get("passthrough")
if isinstance(self.state.get("passthrough"), dict)
else None
)
tools_data = (
self.state.get("tools") if isinstance(self.state.get("tools"), dict) else None
)
context = self._namespace_manager.build_context(
user_id=getattr(self.agent, "user", None),
request_id=getattr(self.agent, "request_id", None),
passthrough_data=passthrough_data,
docs=docs,
docs_together=docs_together,
tools_data=tools_data,
)
agent_context: Dict[str, Any] = {}
for key, value in self.state.items():
placeholder = f"{{{{{key}}}}}"
if placeholder in formatted and value is not None:
formatted = formatted.replace(placeholder, str(value))
return formatted
if not isinstance(key, str):
continue
normalized_key = key.strip()
if not normalized_key:
continue
agent_context[normalized_key] = value
context["agent"] = agent_context
# Keep legacy top-level variables working while namespaced variables are adopted.
for key, value in agent_context.items():
if key in TEMPLATE_RESERVED_NAMESPACES:
context[f"agent_{key}"] = value
continue
if key not in context:
context[key] = value
return context
def _get_source_template_data(self) -> tuple[Optional[List[Dict[str, Any]]], Optional[str]]:
docs = getattr(self.agent, "retrieved_docs", None)
if not isinstance(docs, list) or len(docs) == 0:
return None, None
docs_together_parts: List[str] = []
for doc in docs:
if not isinstance(doc, dict):
continue
text = doc.get("text")
if not isinstance(text, str):
continue
filename = doc.get("filename") or doc.get("title") or doc.get("source")
if isinstance(filename, str) and filename.strip():
docs_together_parts.append(f"{filename}\n{text}")
else:
docs_together_parts.append(text)
docs_together = "\n\n".join(docs_together_parts) if docs_together_parts else None
return docs, docs_together
def get_execution_summary(self) -> List[NodeExecutionLog]:
return [

View File

@@ -42,6 +42,7 @@ class AnswerResource(Resource, BaseAnswerResource):
),
"retriever": fields.String(required=False, description="Retriever type"),
"api_key": fields.String(required=False, description="API key"),
"agent_id": fields.String(required=False, description="Agent ID"),
"active_docs": fields.String(
required=False, description="Active documents"
),
@@ -100,6 +101,9 @@ class AnswerResource(Resource, BaseAnswerResource):
isNoneDoc=data.get("isNoneDoc"),
index=None,
should_save_conversation=data.get("save_conversation", True),
agent_id=processor.agent_id,
is_shared_usage=processor.is_shared_usage,
shared_token=processor.shared_token,
model_id=processor.model_id,
)
stream_result = self.process_response_stream(stream)

View File

@@ -46,6 +46,27 @@ class BaseAnswerResource:
return missing_fields
return None
@staticmethod
def _prepare_tool_calls_for_logging(
tool_calls: Optional[List[Dict[str, Any]]], max_chars: int = 10000
) -> List[Dict[str, Any]]:
if not tool_calls:
return []
prepared = []
for tool_call in tool_calls:
if not isinstance(tool_call, dict):
prepared.append({"result": str(tool_call)[:max_chars]})
continue
item = dict(tool_call)
for key in ("result", "result_full"):
value = item.get(key)
if isinstance(value, str) and len(value) > max_chars:
item[key] = value[:max_chars]
prepared.append(item)
return prepared
def check_usage(self, agent_config: Dict) -> Optional[Response]:
"""Check if there is a usage limit and if it is exceeded
@@ -246,6 +267,7 @@ class BaseAnswerResource:
user_api_key=user_api_key,
decoded_token=decoded_token,
model_id=model_id,
agent_id=agent_id,
)
if should_save_conversation:
@@ -292,14 +314,20 @@ class BaseAnswerResource:
data = json.dumps(id_data)
yield f"data: {data}\n\n"
tool_calls_for_logging = self._prepare_tool_calls_for_logging(
getattr(agent, "tool_calls", tool_calls) or tool_calls
)
log_data = {
"action": "stream_answer",
"level": "info",
"user": decoded_token.get("sub"),
"api_key": user_api_key,
"agent_id": agent_id,
"question": question,
"response": response_full,
"sources": source_log_docs,
"tool_calls": tool_calls_for_logging,
"attachments": attachment_ids,
"timestamp": datetime.datetime.now(datetime.timezone.utc),
}
@@ -330,6 +358,7 @@ class BaseAnswerResource:
api_key=settings.API_KEY,
user_api_key=user_api_key,
decoded_token=decoded_token,
agent_id=agent_id,
)
self.conversation_service.save_conversation(
conversation_id,

View File

@@ -42,6 +42,7 @@ class StreamResource(Resource, BaseAnswerResource):
),
"retriever": fields.String(required=False, description="Retriever type"),
"api_key": fields.String(required=False, description="API key"),
"agent_id": fields.String(required=False, description="Agent ID"),
"active_docs": fields.String(
required=False, description="Active documents"
),
@@ -107,7 +108,7 @@ class StreamResource(Resource, BaseAnswerResource):
index=data.get("index"),
should_save_conversation=data.get("save_conversation", True),
attachment_ids=data.get("attachments", []),
agent_id=data.get("agent_id"),
agent_id=processor.agent_id,
is_shared_usage=processor.is_shared_usage,
shared_token=processor.shared_token,
model_id=processor.model_id,

View File

@@ -134,6 +134,7 @@ class CompressionOrchestrator:
user_api_key=None,
decoded_token=decoded_token,
model_id=compression_model,
agent_id=conversation.get("agent_id"),
)
# Create compression service with DB update capability

View File

@@ -90,6 +90,7 @@ class StreamProcessor:
self.retriever_config = {}
self.is_shared_usage = False
self.shared_token = None
self.agent_id = self.data.get("agent_id")
self.model_id: Optional[str] = None
self.conversation_service = ConversationService()
self.compression_orchestrator = CompressionOrchestrator(
@@ -355,10 +356,13 @@ class StreamProcessor:
self.agent_key, self.is_shared_usage, self.shared_token = self._get_agent_key(
agent_id, self.initial_user_id
)
self.agent_id = str(agent_id) if agent_id else None
api_key = self.data.get("api_key")
if api_key:
data_key = self._get_data_from_api_key(api_key)
if data_key.get("_id"):
self.agent_id = str(data_key.get("_id"))
self.agent_config.update(
{
"prompt_id": data_key.get("prompt_id", "default"),
@@ -387,6 +391,8 @@ class StreamProcessor:
self.retriever_config["chunks"] = 2
elif self.agent_key:
data_key = self._get_data_from_api_key(self.agent_key)
if data_key.get("_id"):
self.agent_id = str(data_key.get("_id"))
self.agent_config.update(
{
"prompt_id": data_key.get("prompt_id", "default"),
@@ -459,6 +465,7 @@ class StreamProcessor:
doc_token_limit=self.retriever_config.get("doc_token_limit", 50000),
model_id=self.model_id,
user_api_key=self.agent_config["user_api_key"],
agent_id=self.agent_id,
decoded_token=self.decoded_token,
)
@@ -754,6 +761,7 @@ class StreamProcessor:
"llm_name": provider or settings.LLM_PROVIDER,
"model_id": self.model_id,
"api_key": system_api_key,
"agent_id": self.agent_id,
"user_api_key": self.agent_config["user_api_key"],
"prompt": rendered_prompt,
"chat_history": self.history,

View File

@@ -23,6 +23,10 @@ from application.api.user.base import (
workflow_nodes_collection,
workflows_collection,
)
from application.core.json_schema_utils import (
JsonSchemaValidationError,
normalize_json_schema_payload,
)
from application.core.settings import settings
from application.utils import (
check_required_fields,
@@ -479,41 +483,15 @@ class CreateAgent(Resource):
data["models"] = []
print(f"Received data: {data}")
# Validate JSON schema if provided
if data.get("json_schema"):
# Validate and normalize JSON schema if provided
if "json_schema" in data:
try:
# Basic validation - ensure it's a valid JSON structure
json_schema = data.get("json_schema")
if not isinstance(json_schema, dict):
return make_response(
jsonify(
{
"success": False,
"message": "JSON schema must be a valid JSON object",
}
),
400,
)
# Validate that it has either a 'schema' property or is itself a schema
if "schema" not in json_schema and "type" not in json_schema:
return make_response(
jsonify(
{
"success": False,
"message": "JSON schema must contain either a 'schema' property or be a valid JSON schema with 'type' property",
}
),
400,
)
except Exception as e:
current_app.logger.error(f"Invalid JSON schema: {e}")
data["json_schema"] = normalize_json_schema_payload(
data.get("json_schema")
)
except JsonSchemaValidationError as exc:
return make_response(
jsonify(
{"success": False, "message": "Invalid JSON schema format"}
),
jsonify({"success": False, "message": f"JSON schema {exc}"}),
400,
)
if data.get("status") not in ["draft", "published"]:
@@ -732,6 +710,8 @@ class UpdateAgent(Resource):
),
400,
)
if data.get("json_schema") == "":
data["json_schema"] = None
except Exception as err:
current_app.logger.error(
f"Error parsing request data: {err}", exc_info=True
@@ -892,17 +872,15 @@ class UpdateAgent(Resource):
elif field == "json_schema":
json_schema = data.get("json_schema")
if json_schema is not None:
if not isinstance(json_schema, dict):
try:
update_fields[field] = normalize_json_schema_payload(
json_schema
)
except JsonSchemaValidationError as exc:
return make_response(
jsonify(
{
"success": False,
"message": "JSON schema must be a valid object",
}
),
jsonify({"success": False, "message": f"JSON schema {exc}"}),
400,
)
update_fields[field] = json_schema
else:
update_fields[field] = None
elif field == "limited_token_mode":

View File

@@ -1,7 +1,7 @@
"""Workflow management routes."""
from datetime import datetime, timezone
from typing import Dict, List
from typing import Any, Dict, List, Optional, Set
from flask import current_app, request
from flask_restx import Namespace, Resource
@@ -11,6 +11,11 @@ from application.api.user.base import (
workflow_nodes_collection,
workflows_collection,
)
from application.core.json_schema_utils import (
JsonSchemaValidationError,
normalize_json_schema_payload,
)
from application.core.model_utils import get_model_capabilities
from application.api.user.utils import (
check_resource_ownership,
error_response,
@@ -85,6 +90,50 @@ def fetch_graph_documents(collection, workflow_id: str, graph_version: int) -> L
return docs
def validate_json_schema_payload(
json_schema: Any,
) -> tuple[Optional[Dict[str, Any]], Optional[str]]:
"""Validate and normalize optional JSON schema payload for structured output."""
if json_schema is None:
return None, None
try:
return normalize_json_schema_payload(json_schema), None
except JsonSchemaValidationError as exc:
return None, str(exc)
def normalize_agent_node_json_schemas(nodes: List[Dict]) -> List[Dict]:
"""Normalize agent-node JSON schema payloads before persistence."""
normalized_nodes: List[Dict] = []
for node in nodes:
if not isinstance(node, dict):
normalized_nodes.append(node)
continue
normalized_node = dict(node)
if normalized_node.get("type") != "agent":
normalized_nodes.append(normalized_node)
continue
raw_config = normalized_node.get("data")
if not isinstance(raw_config, dict) or "json_schema" not in raw_config:
normalized_nodes.append(normalized_node)
continue
normalized_config = dict(raw_config)
try:
normalized_config["json_schema"] = normalize_json_schema_payload(
raw_config.get("json_schema")
)
except JsonSchemaValidationError:
# Validation runs before normalization; keep original on unexpected shape.
normalized_config["json_schema"] = raw_config.get("json_schema")
normalized_node["data"] = normalized_config
normalized_nodes.append(normalized_node)
return normalized_nodes
def validate_workflow_structure(nodes: List[Dict], edges: List[Dict]) -> List[str]:
"""Validate workflow graph structure."""
errors = []
@@ -102,6 +151,9 @@ def validate_workflow_structure(nodes: List[Dict], edges: List[Dict]) -> List[st
errors.append("Workflow must have at least one end node")
node_ids = {n.get("id") for n in nodes}
node_map = {n.get("id"): n for n in nodes}
end_ids = {n.get("id") for n in end_nodes}
for edge in edges:
source_id = edge.get("source")
target_id = edge.get("target")
@@ -115,6 +167,126 @@ def validate_workflow_structure(nodes: List[Dict], edges: List[Dict]) -> List[st
if not any(e.get("source") == start_id for e in edges):
errors.append("Start node must have at least one outgoing edge")
condition_nodes = [n for n in nodes if n.get("type") == "condition"]
for cnode in condition_nodes:
cnode_id = cnode.get("id")
cnode_title = cnode.get("title", cnode_id)
outgoing = [e for e in edges if e.get("source") == cnode_id]
if len(outgoing) < 2:
errors.append(
f"Condition node '{cnode_title}' must have at least 2 outgoing edges"
)
node_data = cnode.get("data", {}) or {}
cases = node_data.get("cases", [])
if not isinstance(cases, list):
cases = []
if not cases or not any(
isinstance(c, dict) and str(c.get("expression", "")).strip() for c in cases
):
errors.append(
f"Condition node '{cnode_title}' must have at least one case with an expression"
)
case_handles: Set[str] = set()
duplicate_case_handles: Set[str] = set()
for case in cases:
if not isinstance(case, dict):
continue
raw_handle = case.get("sourceHandle", "")
handle = raw_handle.strip() if isinstance(raw_handle, str) else ""
if not handle:
errors.append(
f"Condition node '{cnode_title}' has a case without a branch handle"
)
continue
if handle in case_handles:
duplicate_case_handles.add(handle)
case_handles.add(handle)
for handle in duplicate_case_handles:
errors.append(
f"Condition node '{cnode_title}' has duplicate case handle '{handle}'"
)
outgoing_by_handle: Dict[str, List[Dict]] = {}
for out_edge in outgoing:
raw_handle = out_edge.get("sourceHandle", "")
handle = raw_handle.strip() if isinstance(raw_handle, str) else ""
outgoing_by_handle.setdefault(handle, []).append(out_edge)
for handle, handle_edges in outgoing_by_handle.items():
if not handle:
errors.append(
f"Condition node '{cnode_title}' has an outgoing edge without sourceHandle"
)
continue
if handle != "else" and handle not in case_handles:
errors.append(
f"Condition node '{cnode_title}' has a connection from unknown branch '{handle}'"
)
if len(handle_edges) > 1:
errors.append(
f"Condition node '{cnode_title}' has multiple outgoing edges from branch '{handle}'"
)
if "else" not in outgoing_by_handle:
errors.append(f"Condition node '{cnode_title}' must have an 'else' branch")
for case in cases:
if not isinstance(case, dict):
continue
raw_handle = case.get("sourceHandle", "")
handle = raw_handle.strip() if isinstance(raw_handle, str) else ""
if not handle:
continue
raw_expression = case.get("expression", "")
has_expression = isinstance(raw_expression, str) and bool(
raw_expression.strip()
)
has_outgoing = bool(outgoing_by_handle.get(handle))
if has_expression and not has_outgoing:
errors.append(
f"Condition node '{cnode_title}' case '{handle}' has an expression but no outgoing edge"
)
if not has_expression and has_outgoing:
errors.append(
f"Condition node '{cnode_title}' case '{handle}' has an outgoing edge but no expression"
)
for handle, handle_edges in outgoing_by_handle.items():
if not handle:
continue
for out_edge in handle_edges:
target = out_edge.get("target")
if target and not _can_reach_end(target, edges, node_map, end_ids):
errors.append(
f"Branch '{handle}' of condition '{cnode_title}' "
f"must eventually reach an end node"
)
agent_nodes = [n for n in nodes if n.get("type") == "agent"]
for agent_node in agent_nodes:
agent_title = agent_node.get("title", agent_node.get("id", "unknown"))
raw_config = agent_node.get("data", {}) or {}
if not isinstance(raw_config, dict):
errors.append(f"Agent node '{agent_title}' has invalid configuration")
continue
normalized_schema, schema_error = validate_json_schema_payload(
raw_config.get("json_schema")
)
has_json_schema = normalized_schema is not None
model_id = raw_config.get("model_id")
if has_json_schema and isinstance(model_id, str) and model_id.strip():
capabilities = get_model_capabilities(model_id.strip())
if capabilities and not capabilities.get("supports_structured_output", False):
errors.append(
f"Agent node '{agent_title}' selected model does not support structured output"
)
if schema_error:
errors.append(f"Agent node '{agent_title}' JSON schema {schema_error}")
for node in nodes:
if not node.get("id"):
errors.append("All nodes must have an id")
@@ -124,6 +296,20 @@ def validate_workflow_structure(nodes: List[Dict], edges: List[Dict]) -> List[st
return errors
def _can_reach_end(
node_id: str, edges: List[Dict], node_map: Dict, end_ids: set, visited: set = None
) -> bool:
if visited is None:
visited = set()
if node_id in end_ids:
return True
if node_id in visited or node_id not in node_map:
return False
visited.add(node_id)
outgoing = [e.get("target") for e in edges if e.get("source") == node_id]
return any(_can_reach_end(t, edges, node_map, end_ids, visited) for t in outgoing if t)
def create_workflow_nodes(
workflow_id: str, nodes_data: List[Dict], graph_version: int
) -> None:
@@ -186,6 +372,7 @@ class WorkflowList(Resource):
return error_response(
"Workflow validation failed", errors=validation_errors
)
nodes_data = normalize_agent_node_json_schemas(nodes_data)
now = datetime.now(timezone.utc)
workflow_doc = {
@@ -276,6 +463,7 @@ class WorkflowDetail(Resource):
return error_response(
"Workflow validation failed", errors=validation_errors
)
nodes_data = normalize_agent_node_json_schemas(nodes_data)
current_graph_version = get_workflow_graph_version(workflow)
next_graph_version = current_graph_version + 1

View File

@@ -0,0 +1,34 @@
from typing import Any, Dict, Optional
class JsonSchemaValidationError(ValueError):
"""Raised when a JSON schema payload is invalid."""
def normalize_json_schema_payload(json_schema: Any) -> Optional[Dict[str, Any]]:
"""
Normalize accepted JSON schema payload shapes to a plain schema object.
Accepted inputs:
- None
- A raw schema object with a top-level "type"
- A wrapped payload with a top-level "schema" object
"""
if json_schema is None:
return None
if not isinstance(json_schema, dict):
raise JsonSchemaValidationError("must be a valid JSON object")
wrapped_schema = json_schema.get("schema")
if wrapped_schema is not None:
if not isinstance(wrapped_schema, dict):
raise JsonSchemaValidationError('field "schema" must be a valid JSON object')
return wrapped_schema
if "type" not in json_schema:
raise JsonSchemaValidationError(
'must include either a "type" or "schema" field'
)
return json_schema

View File

@@ -13,10 +13,12 @@ class BaseLLM(ABC):
def __init__(
self,
decoded_token=None,
agent_id=None,
model_id=None,
base_url=None,
):
self.decoded_token = decoded_token
self.agent_id = str(agent_id) if agent_id else None
self.model_id = model_id
self.base_url = base_url
self.token_usage = {"prompt_tokens": 0, "generated_tokens": 0}
@@ -33,9 +35,10 @@ class BaseLLM(ABC):
self._fallback_llm = LLMCreator.create_llm(
settings.FALLBACK_LLM_PROVIDER,
api_key=settings.FALLBACK_LLM_API_KEY or settings.API_KEY,
user_api_key=None,
user_api_key=getattr(self, "user_api_key", None),
decoded_token=self.decoded_token,
model_id=settings.FALLBACK_LLM_NAME,
agent_id=self.agent_id,
)
logger.info(
f"Fallback LLM initialized: {settings.FALLBACK_LLM_PROVIDER}/{settings.FALLBACK_LLM_NAME}"

View File

@@ -13,7 +13,7 @@ class GoogleLLM(BaseLLM):
def __init__(
self, api_key=None, user_api_key=None, decoded_token=None, *args, **kwargs
):
super().__init__(*args, **kwargs)
super().__init__(decoded_token=decoded_token, *args, **kwargs)
self.api_key = api_key or settings.GOOGLE_API_KEY or settings.API_KEY
self.user_api_key = user_api_key

View File

@@ -567,6 +567,7 @@ class LLMHandler(ABC):
getattr(agent, "user_api_key", None),
getattr(agent, "decoded_token", None),
model_id=compression_model,
agent_id=getattr(agent, "agent_id", None),
)
# Create service without DB persistence capability

View File

@@ -31,7 +31,15 @@ class LLMCreator:
@classmethod
def create_llm(
cls, type, api_key, user_api_key, decoded_token, model_id=None, *args, **kwargs
cls,
type,
api_key,
user_api_key,
decoded_token,
model_id=None,
agent_id=None,
*args,
**kwargs,
):
from application.core.model_utils import get_base_url_for_model
@@ -49,6 +57,7 @@ class LLMCreator:
user_api_key,
decoded_token=decoded_token,
model_id=model_id,
agent_id=agent_id,
base_url=base_url,
*args,
**kwargs,

View File

@@ -1,6 +1,7 @@
anthropic==0.75.0
boto3==1.42.17
beautifulsoup4==4.14.3
cel-python==0.5.0
celery==5.6.0
cryptography==46.0.3
dataclasses-json==0.6.7

View File

@@ -18,6 +18,7 @@ class ClassicRAG(BaseRetriever):
doc_token_limit=50000,
model_id="docsgpt-local",
user_api_key=None,
agent_id=None,
llm_name=settings.LLM_PROVIDER,
api_key=settings.API_KEY,
decoded_token=None,
@@ -43,6 +44,7 @@ class ClassicRAG(BaseRetriever):
self.model_id = model_id
self.doc_token_limit = doc_token_limit
self.user_api_key = user_api_key
self.agent_id = agent_id
self.llm_name = llm_name
self.api_key = api_key
self.llm = LLMCreator.create_llm(
@@ -50,6 +52,7 @@ class ClassicRAG(BaseRetriever):
api_key=self.api_key,
user_api_key=self.user_api_key,
decoded_token=decoded_token,
agent_id=self.agent_id,
)
if "active_docs" in source and source["active_docs"] is not None:

View File

@@ -1,22 +1,104 @@
import sys
import logging
from datetime import datetime
from application.core.mongo_db import MongoDB
from application.core.settings import settings
from application.utils import num_tokens_from_object_or_list, num_tokens_from_string
logger = logging.getLogger(__name__)
mongo = MongoDB.get_client()
db = mongo[settings.MONGO_DB_NAME]
usage_collection = db["token_usage"]
def update_token_usage(decoded_token, user_api_key, token_usage):
def _serialize_for_token_count(value):
"""Normalize payloads into token-countable primitives."""
if isinstance(value, str):
# Avoid counting large binary payloads in data URLs as text tokens.
if value.startswith("data:") and ";base64," in value:
return ""
return value
if value is None:
return ""
if isinstance(value, list):
return [_serialize_for_token_count(item) for item in value]
if isinstance(value, dict):
serialized = {}
for key, raw in value.items():
key_lower = str(key).lower()
# Skip raw binary-like fields; keep textual tool-call fields.
if key_lower in {"data", "base64", "image_data"} and isinstance(raw, str):
continue
if key_lower == "url" and isinstance(raw, str) and ";base64," in raw:
continue
serialized[key] = _serialize_for_token_count(raw)
return serialized
if hasattr(value, "model_dump") and callable(getattr(value, "model_dump")):
return _serialize_for_token_count(value.model_dump())
if hasattr(value, "to_dict") and callable(getattr(value, "to_dict")):
return _serialize_for_token_count(value.to_dict())
if hasattr(value, "__dict__"):
return _serialize_for_token_count(vars(value))
return str(value)
def _count_tokens(value):
serialized = _serialize_for_token_count(value)
if isinstance(serialized, str):
return num_tokens_from_string(serialized)
return num_tokens_from_object_or_list(serialized)
def _count_prompt_tokens(messages, tools=None, usage_attachments=None, **kwargs):
prompt_tokens = 0
for message in messages or []:
if not isinstance(message, dict):
prompt_tokens += _count_tokens(message)
continue
prompt_tokens += _count_tokens(message.get("content"))
# Include tool-related message fields for providers that use OpenAI-native format.
prompt_tokens += _count_tokens(message.get("tool_calls"))
prompt_tokens += _count_tokens(message.get("tool_call_id"))
prompt_tokens += _count_tokens(message.get("function_call"))
prompt_tokens += _count_tokens(message.get("function_response"))
# Count tool schema payload passed to the model.
prompt_tokens += _count_tokens(tools)
# Count structured-output/schema payloads when provided.
prompt_tokens += _count_tokens(kwargs.get("response_format"))
prompt_tokens += _count_tokens(kwargs.get("response_schema"))
# Optional usage-only attachment context (not forwarded to provider).
prompt_tokens += _count_tokens(usage_attachments)
return prompt_tokens
def update_token_usage(decoded_token, user_api_key, token_usage, agent_id=None):
if "pytest" in sys.modules:
return
if decoded_token:
user_id = decoded_token["sub"]
else:
user_id = None
user_id = decoded_token.get("sub") if isinstance(decoded_token, dict) else None
normalized_agent_id = str(agent_id) if agent_id else None
if not user_id and not user_api_key and not normalized_agent_id:
logger.warning(
"Skipping token usage insert: missing user_id, api_key, and agent_id"
)
return
usage_data = {
"user_id": user_id,
"api_key": user_api_key,
@@ -24,24 +106,31 @@ def update_token_usage(decoded_token, user_api_key, token_usage):
"generated_tokens": token_usage["generated_tokens"],
"timestamp": datetime.now(),
}
if normalized_agent_id:
usage_data["agent_id"] = normalized_agent_id
usage_collection.insert_one(usage_data)
def gen_token_usage(func):
def wrapper(self, model, messages, stream, tools, **kwargs):
for message in messages:
if message["content"]:
self.token_usage["prompt_tokens"] += num_tokens_from_string(
message["content"]
)
usage_attachments = kwargs.pop("_usage_attachments", None)
call_usage = {"prompt_tokens": 0, "generated_tokens": 0}
call_usage["prompt_tokens"] += _count_prompt_tokens(
messages,
tools=tools,
usage_attachments=usage_attachments,
**kwargs,
)
result = func(self, model, messages, stream, tools, **kwargs)
if isinstance(result, str):
self.token_usage["generated_tokens"] += num_tokens_from_string(result)
else:
self.token_usage["generated_tokens"] += num_tokens_from_object_or_list(
result
)
update_token_usage(self.decoded_token, self.user_api_key, self.token_usage)
call_usage["generated_tokens"] += _count_tokens(result)
self.token_usage["prompt_tokens"] += call_usage["prompt_tokens"]
self.token_usage["generated_tokens"] += call_usage["generated_tokens"]
update_token_usage(
self.decoded_token,
self.user_api_key,
call_usage,
getattr(self, "agent_id", None),
)
return result
return wrapper
@@ -49,17 +138,28 @@ def gen_token_usage(func):
def stream_token_usage(func):
def wrapper(self, model, messages, stream, tools, **kwargs):
for message in messages:
self.token_usage["prompt_tokens"] += num_tokens_from_string(
message["content"]
)
usage_attachments = kwargs.pop("_usage_attachments", None)
call_usage = {"prompt_tokens": 0, "generated_tokens": 0}
call_usage["prompt_tokens"] += _count_prompt_tokens(
messages,
tools=tools,
usage_attachments=usage_attachments,
**kwargs,
)
batch = []
result = func(self, model, messages, stream, tools, **kwargs)
for r in result:
batch.append(r)
yield r
for line in batch:
self.token_usage["generated_tokens"] += num_tokens_from_string(line)
update_token_usage(self.decoded_token, self.user_api_key, self.token_usage)
call_usage["generated_tokens"] += _count_tokens(line)
self.token_usage["prompt_tokens"] += call_usage["prompt_tokens"]
self.token_usage["generated_tokens"] += call_usage["generated_tokens"]
update_token_usage(
self.decoded_token,
self.user_api_key,
call_usage,
getattr(self, "agent_id", None),
)
return wrapper

View File

@@ -322,6 +322,7 @@ def run_agent_logic(agent_config, input_data):
chunks = int(agent_config.get("chunks", 2))
prompt_id = agent_config.get("prompt_id", "default")
user_api_key = agent_config["key"]
agent_id = str(agent_config.get("_id")) if agent_config.get("_id") else None
agent_type = agent_config.get("agent_type", "classic")
decoded_token = {"sub": agent_config.get("user")}
json_schema = agent_config.get("json_schema")
@@ -352,6 +353,7 @@ def run_agent_logic(agent_config, input_data):
doc_token_limit=doc_token_limit,
model_id=model_id,
user_api_key=user_api_key,
agent_id=agent_id,
decoded_token=decoded_token,
)
@@ -370,6 +372,7 @@ def run_agent_logic(agent_config, input_data):
llm_name=provider or settings.LLM_PROVIDER,
model_id=model_id,
api_key=system_api_key,
agent_id=agent_id,
user_api_key=user_api_key,
prompt=prompt,
chat_history=[],

View File

@@ -0,0 +1,25 @@
import { generateStaticParamsFor, importPage } from 'nextra/pages';
import { useMDXComponents } from '../../mdx-components';
export const generateStaticParams = generateStaticParamsFor('mdxPath');
export async function generateMetadata(props) {
const params = await props.params;
const { metadata } = await importPage(params?.mdxPath);
return metadata;
}
const Wrapper = useMDXComponents().wrapper;
export default async function Page(props) {
const params = await props.params;
const result = await importPage(params?.mdxPath);
const { default: MDXContent, metadata, sourceCode, toc } = result;
return (
<Wrapper metadata={metadata} sourceCode={sourceCode} toc={toc}>
<MDXContent {...props} params={params} />
</Wrapper>
);
}

86
docs/app/layout.jsx Normal file
View File

@@ -0,0 +1,86 @@
import Image from 'next/image';
import { Analytics } from '@vercel/analytics/react';
import { Banner, Head } from 'nextra/components';
import { getPageMap } from 'nextra/page-map';
import { Footer, Layout, Navbar } from 'nextra-theme-docs';
import 'nextra-theme-docs/style.css';
import CuteLogo from '../public/cute-docsgpt.png';
import themeConfig from '../theme.config';
const github = 'https://github.com/arc53/DocsGPT';
export const metadata = {
title: {
default: 'DocsGPT Documentation',
template: '%s - DocsGPT Documentation',
},
description:
'Use DocsGPT to chat with your data. DocsGPT is a GPT-powered chatbot that can answer questions about your data.',
};
const navbar = (
<Navbar
logo={
<div style={{ alignItems: 'center', display: 'flex', gap: '8px' }}>
<Image src={CuteLogo} alt="DocsGPT logo" width={28} height={28} />
<span style={{ fontWeight: 'bold', fontSize: 18 }}>DocsGPT Docs</span>
</div>
}
projectLink={github}
chatLink="https://discord.com/invite/n5BX8dh8rU"
/>
);
const footer = (
<Footer>
<span>MIT {new Date().getFullYear()} © </span>
<a href="https://www.docsgpt.cloud/" target="_blank" rel="noreferrer">
DocsGPT
</a>
{' | '}
<a href="https://github.com/arc53/DocsGPT" target="_blank" rel="noreferrer">
GitHub
</a>
{' | '}
<a href="https://blog.docsgpt.cloud/" target="_blank" rel="noreferrer">
Blog
</a>
</Footer>
);
export default async function RootLayout({ children }) {
return (
<html lang="en" dir="ltr" suppressHydrationWarning>
<Head>
<link
rel="apple-touch-icon"
sizes="180x180"
href="/favicons/apple-touch-icon.png"
/>
<link rel="icon" type="image/png" sizes="32x32" href="/favicons/favicon-32x32.png" />
<link rel="icon" type="image/png" sizes="16x16" href="/favicons/favicon-16x16.png" />
<link rel="manifest" href="/favicons/site.webmanifest" />
<meta httpEquiv="Content-Language" content="en" />
</Head>
<body>
<Layout
banner={
<Banner storageKey="docs-launch">
<div className="flex justify-center items-center gap-2">
Welcome to the new DocsGPT docs!
</div>
</Banner>
}
navbar={navbar}
footer={footer}
pageMap={await getPageMap()}
{...themeConfig}
>
{children}
</Layout>
<Analytics />
</body>
</html>
);
}

View File

@@ -1,3 +1,5 @@
'use client';
import Image from 'next/image';
const iconMap = {
@@ -117,4 +119,4 @@ export function DeploymentCards({ items }) {
`}</style>
</>
);
}
}

View File

@@ -1,3 +1,5 @@
'use client';
import Image from 'next/image';
const iconMap = {
@@ -114,4 +116,4 @@ export function ToolCards({ items }) {
`}</style>
</>
);
}
}

View File

@@ -1,4 +1,4 @@
{
export default {
"basics": {
"title": "🤖 Agent Basics",
"href": "/Agents/basics"
@@ -10,5 +10,9 @@
"webhooks": {
"title": "🪝 Agent Webhooks",
"href": "/Agents/webhooks"
},
"nodes": {
"title": "🧩 Workflow Nodes",
"href": "/Agents/nodes"
}
}

342
docs/content/Agents/api.mdx Normal file
View File

@@ -0,0 +1,342 @@
---
title: Interacting with Agents via API
description: Learn how to programmatically interact with DocsGPT Agents using the streaming and non-streaming API endpoints.
---
import { Callout, Tabs } from 'nextra/components';
# Interacting with Agents via API
DocsGPT Agents can be accessed programmatically through API endpoints. This page covers:
- Non-streaming answers (`/api/answer`)
- Streaming answers over SSE (`/stream`)
- File/image attachments (`/api/store_attachment` + `/api/task_status` + `/stream`)
When you use an agent `api_key`, DocsGPT loads that agent's configuration automatically (prompt, tools, sources, default model). You usually only need to send `question` and `api_key`.
## Base URL
<Callout type="info">
For DocsGPT Cloud, use `https://gptcloud.arc53.com` as the base URL.
</Callout>
- Local: `http://localhost:7091`
- Cloud: `https://gptcloud.arc53.com`
## How Request Resolution Works
DocsGPT resolves your request in this order:
1. If `api_key` is provided, DocsGPT loads the mapped agent and executes with that config.
2. If `agent_id` is provided (typically with JWT auth), DocsGPT loads that agent if allowed.
3. If neither is provided, DocsGPT uses request-level fields (`prompt_id`, `active_docs`, `retriever`, etc.).
Authentication:
- Agent API-key flow: include `api_key` in JSON/form payload.
- JWT flow (if auth enabled): include `Authorization: Bearer <token>`.
## Endpoints
- `POST /api/answer` (non-streaming)
- `POST /stream` (SSE streaming)
- `POST /api/store_attachment` (multipart upload)
- `GET /api/task_status?task_id=...` (Celery task polling)
## Request Parameters
Common request body fields:
| Field | Type | Required | Applies to | Notes |
| --- | --- | --- | --- | --- |
| `question` | `string` | Yes | `/api/answer`, `/stream` | User query. |
| `api_key` | `string` | Usually | `/api/answer`, `/stream` | Recommended for agent API use. Loads agent config from key. |
| `conversation_id` | `string` | No | `/api/answer`, `/stream` | Continue an existing conversation. |
| `history` | `string` (JSON-encoded array) | No | `/api/answer`, `/stream` | Used for new conversations. Format: `[{\"prompt\":\"...\",\"response\":\"...\"}]`. |
| `model_id` | `string` | No | `/api/answer`, `/stream` | Override model for this request. |
| `save_conversation` | `boolean` | No | `/api/answer`, `/stream` | Default `true`. If `false`, no conversation is persisted. |
| `passthrough` | `object` | No | `/api/answer`, `/stream` | Dynamic values injected into prompt templates. |
| `prompt_id` | `string` | No | `/api/answer`, `/stream` | Ignored when `api_key` already defines prompt. |
| `active_docs` | `string` or `string[]` | No | `/api/answer`, `/stream` | Overrides active docs when not using key-owned source config. |
| `retriever` | `string` | No | `/api/answer`, `/stream` | Retriever type (for example `classic`). |
| `chunks` | `number` | No | `/api/answer`, `/stream` | Retrieval chunk count, default `2`. |
| `isNoneDoc` | `boolean` | No | `/api/answer`, `/stream` | Skip document retrieval. |
| `agent_id` | `string` | No | `/api/answer`, `/stream` | Alternative to `api_key` when using authenticated user context. |
Streaming-only fields:
| Field | Type | Required | Notes |
| --- | --- | --- | --- |
| `attachments` | `string[]` | No | List of attachment IDs from `/api/task_status` success result. |
| `index` | `number` | No | Update an existing query index. If provided, `conversation_id` is required. |
## Non-Streaming API (`/api/answer`)
`/api/answer` waits for completion and returns one JSON response.
<Callout type="info">
`attachments` are currently handled through `/stream`. For file/image-attached queries, use the streaming endpoint.
</Callout>
Response fields:
- `conversation_id`
- `answer`
- `sources`
- `tool_calls`
- `thought`
- Optional structured output metadata (`structured`, `schema`) when enabled
### Examples
<Tabs items={['cURL', 'Python', 'JavaScript']}>
<Tabs.Tab>
```bash
curl -X POST http://localhost:7091/api/answer \
-H "Content-Type: application/json" \
-d '{"question":"your question here","api_key":"your_agent_api_key"}'
```
</Tabs.Tab>
<Tabs.Tab>
```python
import requests
API_URL = "http://localhost:7091/api/answer"
API_KEY = "your_agent_api_key"
QUESTION = "your question here"
response = requests.post(
API_URL,
json={"question": QUESTION, "api_key": API_KEY}
)
if response.status_code == 200:
print(response.json())
else:
print(f"Error: {response.status_code}")
print(response.text)
```
</Tabs.Tab>
<Tabs.Tab>
```javascript
const apiUrl = 'http://localhost:7091/api/answer';
const apiKey = 'your_agent_api_key';
const question = 'your question here';
async function getAnswer() {
try {
const response = await fetch(apiUrl, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify({ question, api_key: apiKey }),
});
if (!response.ok) {
throw new Error(`HTTP error! Status: ${response.status}`);
}
const data = await response.json();
console.log(data);
} catch (error) {
console.error("Failed to fetch answer:", error);
}
}
getAnswer();
```
</Tabs.Tab>
</Tabs>
---
## Streaming API (`/stream`)
`/stream` returns a Server-Sent Events (SSE) stream so you can render output token-by-token.
### SSE Event Types
Each `data:` frame is JSON with `type`:
- `answer`: incremental answer chunk
- `source`: source list/chunks
- `tool_calls`: tool invocation results/metadata
- `thought`: reasoning/thought chunk (agent dependent)
- `structured_answer`: final structured payload (when schema mode is active)
- `id`: final conversation ID
- `error`: error message
- `end`: stream is complete
### Examples
<Tabs items={['cURL', 'Python', 'JavaScript']}>
<Tabs.Tab>
```bash
curl -X POST http://localhost:7091/stream \
-H "Content-Type: application/json" \
-H "Accept: text/event-stream" \
-d '{"question":"your question here","api_key":"your_agent_api_key"}'
```
</Tabs.Tab>
<Tabs.Tab>
```python
import requests
import json
API_URL = "http://localhost:7091/stream"
payload = {
"question": "your question here",
"api_key": "your_agent_api_key"
}
with requests.post(API_URL, json=payload, stream=True) as r:
for line in r.iter_lines():
if line:
decoded_line = line.decode('utf-8')
if decoded_line.startswith('data: '):
try:
data = json.loads(decoded_line[6:])
print(data)
except json.JSONDecodeError:
pass
```
</Tabs.Tab>
<Tabs.Tab>
```javascript
const apiUrl = 'http://localhost:7091/stream';
const apiKey = 'your_agent_api_key';
const question = 'your question here';
async function getStream() {
try {
const response = await fetch(apiUrl, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
'Accept': 'text/event-stream'
},
body: JSON.stringify({ question, api_key: apiKey }),
});
if (!response.ok) {
throw new Error(`HTTP error! Status: ${response.status}`);
}
const reader = response.body.getReader();
const decoder = new TextDecoder();
while (true) {
const { done, value } = await reader.read();
if (done) break;
const chunk = decoder.decode(value, { stream: true });
// Note: This parsing method assumes each chunk contains whole lines.
// For a more robust production implementation, buffer the chunks
// and process them line by line.
const lines = chunk.split('\n');
for (const line of lines) {
if (line.startsWith('data: ')) {
try {
const data = JSON.parse(line.substring(6));
console.log(data);
} catch (e) {
console.error("Failed to parse JSON from SSE event:", e);
}
}
}
}
} catch (error) {
console.error("Failed to fetch stream:", error);
}
}
getStream();
```
</Tabs.Tab>
</Tabs>
---
## Attachments API (Including Images)
To attach an image (or other file) to a query:
1. Upload file(s) to `/api/store_attachment` (multipart/form-data).
2. Poll `/api/task_status` until `status=SUCCESS`.
3. Read `result.attachment_id` from task result.
4. Send that ID in `/stream` as `attachments: ["..."]`.
<Callout type="warning">
Attachments are processed asynchronously. Do not call `/stream` with an attachment until its task has finished with `SUCCESS`.
</Callout>
### Step 1: Upload Attachment
`POST /api/store_attachment`
- Content type: `multipart/form-data`
- Form fields:
- `file` (required, can be repeated for multi-file upload)
- `api_key` (optional if JWT is present; useful for API-key-only flows)
Example upload (single image):
```bash
curl -X POST http://localhost:7091/api/store_attachment \
-F "file=@/absolute/path/to/image.png" \
-F "api_key=your_agent_api_key"
```
Possible response (single-file upload):
```json
{
"success": true,
"task_id": "34f1cb56-7c7f-4d5f-a973-4ea7e65f7a10",
"message": "File uploaded successfully. Processing started."
}
```
### Step 2: Poll Task Status
```bash
curl "http://localhost:7091/api/task_status?task_id=34f1cb56-7c7f-4d5f-a973-4ea7e65f7a10"
```
When complete:
```json
{
"status": "SUCCESS",
"result": {
"attachment_id": "67b4f8f2618dc9f19384a9e1",
"filename": "image.png",
"mime_type": "image/png"
}
}
```
### Step 3: Attach to `/stream` Request
Use the `attachment_id` in `attachments`.
```bash
curl -X POST http://localhost:7091/stream \
-H "Content-Type: application/json" \
-H "Accept: text/event-stream" \
-d '{
"question": "Describe this image",
"api_key": "your_agent_api_key",
"attachments": ["67b4f8f2618dc9f19384a9e1"]
}'
```
### Image/Attachment Behavior Notes
- Typical image MIME types supported for native vision flows: `image/png`, `image/jpeg`, `image/jpg`, `image/webp`, `image/gif`.
- If the selected model/provider does not support a file type natively, DocsGPT falls back to parsed text content.
- For providers that support images but not native PDF file attachments, DocsGPT can convert PDF pages to images (synthetic PDF support).
- Attachments are user-scoped. Upload and query must be done under the same user context (same API key owner or same JWT user).

View File

@@ -0,0 +1,65 @@
# Workflow Nodes
DocsGPT workflows are composed of **Nodes** that are connected to form a processing graph. These nodes interact with a **Shared State**—a global dictionary of variables that persists throughout the execution of the workflow.
## The Shared State
Every workflow run maintains a state object (a JSON-like dictionary).
- **Initial State**: Contains the user's input query (`{{query}}`) and chat history (`{{chat_history}}`).
- **Accessing Variables**: You can access any variable in the state using the double-curly braces syntax: `{{variable_name}}`.
- **Modifying State**: Nodes read from this state and write their outputs back to it.
---
## AI Agent Node
The **AI Agent Node** is the core processing unit. It uses a Large Language Model (LLM) to generate text, answer questions, or perform tasks using tools.
### Inputs (Template Variables)
The primary input is the **Prompt Template**. This field supports variable substitution.
- **Prompt Template**: The text sent to the model.
- *Example*: `"Summarize the following text: {{user_input_text}}"`
- If left empty, it defaults to the initial user query (`{{query}}`).
- **System Prompt**: Instructions that define the agent's persona and constraints.
- **Tools**: A list of tools the agent can use (e.g., search, calculator).
- **LLM Settings**: Specific provider, model name, and parameters.
### Outputs (Emissions)
When the agent completes its task, it stores the result in the shared state.
- **Output Variable**: The name of the variable where the result will be saved.
- *Default*: If not specified, it is saved as `node_{node_id}_output`.
- *Custom*: You can set this to something meaningful, like `summary` or `translated_text`.
- **Streaming**: If "Stream to user" is enabled, the output is sent to the user in real-time as it is generated, in addition to being saved to the state.
---
## Set State Node
The **Set State Node** allows you to manipulate variables within the shared state directly without calling an LLM. This is useful for initialization, formatting, or control flow logic.
### Operations
You can define multiple operations in a single node. Each operation targets a specific **Key** (variable name).
1. **Set**: Assigns a specific value to a variable.
- *Value*: Can be a static string or a template using variables.
- *Example*: Set `current_step` to `1`.
- *Example*: Set `formatted_response` to `Analysis: {{analysis_result}}`.
2. **Increment**: Increases the value of a numeric variable.
- *Value*: The amount to add (default is 1).
- *Example*: Increment `retry_count` by `1`.
3. **Append**: Adds a value to a list variable.
- *Value*: The item to add to the list.
- *Example*: Append `{{last_result}}` to `history_list`.
### Usage Examples
- **Loop Counters**: Use a *Set State* node to initialize a counter (`i = 0`) before a loop, and another to increment it inside the loop.
- **Accumulators**: Use *Append* to collect results from multiple parallel branches into a single list.
- **Renaming**: Copy the output of a previous node to a more generic name (e.g., set `context` to `{{search_results}}`) so subsequent nodes can use a standard variable name.

View File

@@ -1,4 +1,4 @@
{
export default {
"DocsGPT-Settings": {
"title": "⚙️ App Configuration",
"href": "/Deploying/DocsGPT-Settings"
@@ -29,4 +29,4 @@
"href": "/Deploying/Railway",
"display": "hidden"
}
}
}

View File

@@ -1,4 +1,4 @@
{
export default {
"api-key-guide": {
"title": "🔑 Getting API key",
"href": "/Extensions/api-key-guide"
@@ -19,4 +19,4 @@
"title": "🗣️ Chatwoot Extension",
"href": "/Extensions/Chatwoot-extension"
}
}
}

View File

@@ -1,4 +1,4 @@
{
export default {
"google-drive-connector": {
"title": "🔗 Google Drive",
"href": "/Guides/Integrations/google-drive-connector"

View File

@@ -1,4 +1,4 @@
{
export default {
"Customising-prompts": {
"title": "️💻 Customising Prompts",
"href": "/Guides/Customising-prompts"

View File

@@ -1,4 +1,4 @@
{
export default {
"cloud-providers": {
"title": "☁️ Cloud Providers",
"href": "/Models/cloud-providers"
@@ -11,4 +11,4 @@
"title": "📝 Embeddings",
"href": "/Models/embeddings"
}
}
}

View File

@@ -1,4 +1,4 @@
{
export default {
"basics": {
"title": "🔧 Tools Basics",
"href": "/Tools/basics"
@@ -11,4 +11,4 @@
"title": "🛠️ Creating a Custom Tool",
"href": "/Tools/creating-a-tool"
}
}
}

View File

@@ -1,5 +1,4 @@
{
export default {
"index": "Home",
"quickstart": "Quickstart",
"Deploying": "Deploying",
@@ -9,12 +8,11 @@
"Extensions": "Extensions",
"https://gptcloud.arc53.com/": {
"title": "API",
"href": "https://gptcloud.arc53.com/",
"newWindow": true
"href": "https://gptcloud.arc53.com/"
},
"Guides": "Guides",
"changelog": {
"title": "Changelog",
"display": "hidden"
}
}
}

View File

@@ -2,7 +2,7 @@
title: 'Home'
description: Documentation of DocsGPT - quickstart, deployment guides, model configuration, and widget integration documentation.
---
import { Cards, Card } from 'nextra/components'
import { Cards } from 'nextra/components'
import Image from 'next/image'
export const allGuides = {
@@ -85,7 +85,7 @@ Try it yourself: [https://www.docsgpt.cloud/](https://www.docsgpt.cloud/)
<Cards
num={3}
children={Object.keys(allGuides).map((key, i) => (
<Card
<Cards.Card
key={i}
title={allGuides[key].title}
href={allGuides[key].href}

8
docs/mdx-components.jsx Normal file
View File

@@ -0,0 +1,8 @@
import { useMDXComponents as getThemeComponents } from 'nextra-theme-docs';
export function useMDXComponents(components) {
return {
...getThemeComponents(),
...components,
};
}

View File

@@ -1,9 +1,9 @@
const withNextra = require('nextra')({
theme: 'nextra-theme-docs',
themeConfig: './theme.config.jsx'
})
const nextra = require('nextra').default;
module.exports = withNextra()
// If you have other Next.js configurations, you can pass them as the parameter:
// module.exports = withNextra({ /* other next.js config */ })
const withNextra = nextra({
defaultShowCopyCode: true,
});
module.exports = withNextra({
reactStrictMode: true,
});

5580
docs/package-lock.json generated

File diff suppressed because it is too large Load Diff

View File

@@ -2,6 +2,7 @@
"scripts": {
"dev": "next dev",
"build": "next build",
"postbuild": "pagefind --site .next/server/app --output-path public/_pagefind",
"start": "next start"
},
"license": "MIT",
@@ -9,9 +10,13 @@
"@vercel/analytics": "^1.1.1",
"docsgpt-react": "^0.5.1",
"next": "^15.5.9",
"nextra": "^2.13.2",
"nextra-theme-docs": "^2.13.2",
"nextra": "^4.6.1",
"nextra-theme-docs": "^4.6.1",
"react": "^18.2.0",
"react-dom": "^18.2.0"
},
"devDependencies": {
"pagefind": "^1.3.0",
"typescript": "^5.9.3"
}
}

View File

@@ -1,227 +0,0 @@
---
title: Interacting with Agents via API
description: Learn how to programmatically interact with DocsGPT Agents using the streaming and non-streaming API endpoints.
---
import { Callout, Tabs } from 'nextra/components';
# Interacting with Agents via API
DocsGPT Agents can be accessed programmatically through a dedicated API, allowing you to integrate their specialized capabilities into your own applications, scripts, and workflows. This guide covers the two primary methods for interacting with an agent: the streaming API for real-time responses and the non-streaming API for a single, consolidated answer.
When you use an API key generated for a specific agent, you do not need to pass `prompt`, `tools` etc. The agent's configuration (including its prompt, selected tools, and knowledge sources) is already associated with its unique API key.
### API Endpoints
- **Non-Streaming:** `http://localhost:7091/api/answer`
- **Streaming:** `http://localhost:7091/stream`
<Callout type="info">
For DocsGPT Cloud, use `https://gptcloud.arc53.com/` as the base URL.
</Callout>
For more technical details, you can explore the API swagger documentation available for the cloud version or your local instance.
---
## Non-Streaming API (`/api/answer`)
This is a standard synchronous endpoint. It waits for the agent to fully process the request and returns a single JSON object with the complete answer. This is the simplest method and is ideal for backend processes where a real-time feed is not required.
### Request
- **Endpoint:** `/api/answer`
- **Method:** `POST`
- **Payload:**
- `question` (string, required): The user's query or input for the agent.
- `api_key` (string, required): The unique API key for the agent you wish to interact with.
- `history` (string, optional): A JSON string representing the conversation history, e.g., `[{\"prompt\": \"first question\", \"answer\": \"first answer\"}]`.
### Response
A single JSON object containing:
- `answer`: The complete, final answer from the agent.
- `sources`: A list of sources the agent consulted.
- `conversation_id`: The unique ID for the interaction.
### Examples
<Tabs items={['cURL', 'Python', 'JavaScript']}>
<Tabs.Tab>
```bash
curl -X POST http://localhost:7091/api/answer \
-H "Content-Type: application/json" \
-d '{
"question": "your question here",
"api_key": "your_agent_api_key"
}'
```
</Tabs.Tab>
<Tabs.Tab>
```python
import requests
API_URL = "http://localhost:7091/api/answer"
API_KEY = "your_agent_api_key"
QUESTION = "your question here"
response = requests.post(
API_URL,
json={"question": QUESTION, "api_key": API_KEY}
)
if response.status_code == 200:
print(response.json())
else:
print(f"Error: {response.status_code}")
print(response.text)
```
</Tabs.Tab>
<Tabs.Tab>
```javascript
const apiUrl = 'http://localhost:7091/api/answer';
const apiKey = 'your_agent_api_key';
const question = 'your question here';
async function getAnswer() {
try {
const response = await fetch(apiUrl, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify({ question, api_key: apiKey }),
});
if (!response.ok) {
throw new Error(`HTTP error! Status: ${response.status}`);
}
const data = await response.json();
console.log(data);
} catch (error) {
console.error("Failed to fetch answer:", error);
}
}
getAnswer();
```
</Tabs.Tab>
</Tabs>
---
## Streaming API (`/stream`)
The `/stream` endpoint uses Server-Sent Events (SSE) to push data in real-time. This is ideal for applications where you want to display the response as it's being generated, such as in a live chatbot interface.
### Request
- **Endpoint:** `/stream`
- **Method:** `POST`
- **Payload:** Same as the non-streaming API.
### Response (SSE Stream)
The stream consists of multiple `data:` events, each containing a JSON object. Your client should listen for these events and process them based on their `type`.
**Event Types:**
- `answer`: A chunk of the agent's final answer.
- `source`: A document or source used by the agent.
- `thought`: A reasoning step from the agent (for ReAct agents).
- `id`: The unique `conversation_id` for the interaction.
- `error`: An error message.
- `end`: A final message indicating the stream has concluded.
### Examples
<Tabs items={['cURL', 'Python', 'JavaScript']}>
<Tabs.Tab>
```bash
curl -X POST http://localhost:7091/stream \
-H "Content-Type: application/json" \
-H "Accept: text/event-stream" \
-d '{
"question": "your question here",
"api_key": "your_agent_api_key"
}'
```
</Tabs.Tab>
<Tabs.Tab>
```python
import requests
import json
API_URL = "http://localhost:7091/stream"
payload = {
"question": "your question here",
"api_key": "your_agent_api_key"
}
with requests.post(API_URL, json=payload, stream=True) as r:
for line in r.iter_lines():
if line:
decoded_line = line.decode('utf-8')
if decoded_line.startswith('data: '):
try:
data = json.loads(decoded_line[6:])
print(data)
except json.JSONDecodeError:
pass
```
</Tabs.Tab>
<Tabs.Tab>
```javascript
const apiUrl = 'http://localhost:7091/stream';
const apiKey = 'your_agent_api_key';
const question = 'your question here';
async function getStream() {
try {
const response = await fetch(apiUrl, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
'Accept': 'text/event-stream'
},
// Corrected line: 'apiKey' is changed to 'api_key'
body: JSON.stringify({ question, api_key: apiKey }),
});
if (!response.ok) {
throw new Error(`HTTP error! Status: ${response.status}`);
}
const reader = response.body.getReader();
const decoder = new TextDecoder();
while (true) {
const { done, value } = await reader.read();
if (done) break;
const chunk = decoder.decode(value, { stream: true });
// Note: This parsing method assumes each chunk contains whole lines.
// For a more robust production implementation, buffer the chunks
// and process them line by line.
const lines = chunk.split('\n');
for (const line of lines) {
if (line.startsWith('data: ')) {
try {
const data = JSON.parse(line.substring(6));
console.log(data);
} catch (e) {
console.error("Failed to parse JSON from SSE event:", e);
}
}
}
}
} catch (error) {
console.error("Failed to fetch stream:", error);
}
}
getStream();
```
</Tabs.Tab>
</Tabs>

View File

@@ -1,10 +0,0 @@
import { DocsGPTWidget } from "docsgpt-react";
export default function MyApp({ Component, pageProps }) {
return (
<>
<Component {...pageProps} />
<DocsGPTWidget showSources={true} apiKey="5d8270cb-735f-484e-9dc9-5b407f24e652" theme="dark" size="medium" />
</>
)
}

63
docs/public/llms.txt Normal file
View File

@@ -0,0 +1,63 @@
# DocsGPT
> DocsGPT is an open-source platform for building AI agents and assistants with document retrieval, tools, and multi-model support.
This file is a curated map of DocsGPT documentation for LLM and agent use.
Prioritize Core, Deploying, and Agents for implementation tasks.
## Core
- [Docs Home](https://docs.docsgpt.cloud/): Main documentation landing page.
- [Quickstart](https://docs.docsgpt.cloud/quickstart): Fastest path to run DocsGPT locally.
- [Architecture](https://docs.docsgpt.cloud/Guides/Architecture): High-level system architecture.
- [Development Environment](https://docs.docsgpt.cloud/Deploying/Development-Environment): Backend and frontend local setup.
- [DocsGPT Settings](https://docs.docsgpt.cloud/Deploying/DocsGPT-Settings): Environment variables and core app configuration.
## Deploying
- [Docker Deployment](https://docs.docsgpt.cloud/Deploying/Docker-Deploying): Run DocsGPT with Docker and Docker Compose.
- [Kubernetes Deployment](https://docs.docsgpt.cloud/Deploying/Kubernetes-Deploying): Deploy DocsGPT on Kubernetes clusters.
- [Hosting DocsGPT](https://docs.docsgpt.cloud/Deploying/Hosting-the-app): Hosting overview with cloud options.
## Agents
- [Agent Basics](https://docs.docsgpt.cloud/Agents/basics): Core concepts for building and managing agents.
- [Workflow Nodes](https://docs.docsgpt.cloud/Agents/nodes): Node types and behavior in agent workflows.
- [Agent API](https://docs.docsgpt.cloud/Agents/api): Programmatic agent interaction (streaming and non-streaming).
- [Agent Webhooks](https://docs.docsgpt.cloud/Agents/webhooks): Trigger and automate agents with webhooks.
## Tools
- [Tools Basics](https://docs.docsgpt.cloud/Tools/basics): How tools extend agent capabilities.
- [Generic API Tool](https://docs.docsgpt.cloud/Tools/api-tool): Configure API calls without custom code.
- [Creating a Custom Tool](https://docs.docsgpt.cloud/Tools/creating-a-tool): Build custom Python tools for DocsGPT.
## Models
- [Cloud LLM Providers](https://docs.docsgpt.cloud/Models/cloud-providers): Configure hosted model providers.
- [Local Inference](https://docs.docsgpt.cloud/Models/local-inference): Connect DocsGPT to local inference backends.
- [Embeddings](https://docs.docsgpt.cloud/Models/embeddings): Select and configure embedding models.
## Extensions
- [API Keys for Integrations](https://docs.docsgpt.cloud/Extensions/api-key-guide): Generate and use DocsGPT API keys.
- [Chat Widget](https://docs.docsgpt.cloud/Extensions/chat-widget): Embed the DocsGPT chat widget.
- [Search Widget](https://docs.docsgpt.cloud/Extensions/search-widget): Embed the DocsGPT search widget.
- [Chrome Extension](https://docs.docsgpt.cloud/Extensions/Chrome-extension): Install and use the browser extension.
- [Chatwoot Extension](https://docs.docsgpt.cloud/Extensions/Chatwoot-extension): Integrate DocsGPT with Chatwoot.
## Integrations
- [Google Drive Connector](https://docs.docsgpt.cloud/Guides/Integrations/google-drive-connector): Ingest and sync files from Google Drive.
## Optional
- [Customizing Prompts](https://docs.docsgpt.cloud/Guides/Customising-prompts): Template-based prompt customization.
- [How to Train on Other Documentation](https://docs.docsgpt.cloud/Guides/How-to-train-on-other-documentation): Add additional documentation sources.
- [Context Compression](https://docs.docsgpt.cloud/Guides/compression): Reduce context while preserving key information.
- [OCR for Sources and Attachments](https://docs.docsgpt.cloud/Guides/ocr): OCR behavior for ingestion and chat uploads.
- [How to Use Different LLMs](https://docs.docsgpt.cloud/Guides/How-to-use-different-LLM): Additional model-selection guidance.
- [Avoiding Hallucinations](https://docs.docsgpt.cloud/Guides/My-AI-answers-questions-using-external-knowledge): Improve answer grounding with external knowledge.
- [Amazon Lightsail Deployment](https://docs.docsgpt.cloud/Deploying/Amazon-Lightsail): Deploy DocsGPT on AWS Lightsail.
- [Railway Deployment](https://docs.docsgpt.cloud/Deploying/Railway): Deploy DocsGPT on Railway.
- [Changelog](https://docs.docsgpt.cloud/changelog): Project release history.

View File

@@ -1,161 +1,20 @@
import Image from 'next/image'
import { Analytics } from '@vercel/analytics/react';
const github = 'https://github.com/arc53/DocsGPT';
import { useConfig, useTheme } from 'nextra-theme-docs';
import CuteLogo from './public/cute-docsgpt.png';
const Logo = ({ height, width }) => {
const { theme } = useTheme();
return (
<div style={{ alignItems: 'center', display: 'flex', gap: '8px' }}>
<Image src={CuteLogo} alt="DocsGPT logo" width={width} height={height} />
<span style={{ fontWeight: 'bold', fontSize: 18 }}>DocsGPT Docs</span>
</div>
);
};
const isDevelopment = process.env.NODE_ENV === 'development';
const config = {
docsRepositoryBase: `${github}/blob/main/docs`,
chat: {
link: 'https://discord.com/invite/n5BX8dh8rU',
darkMode: true,
search: isDevelopment ? null : undefined,
nextThemes: {
defaultTheme: 'dark',
},
banner: {
key: 'docs-launch',
text: (
<div className="flex justify-center items-center gap-2">
Welcome to the new DocsGPT 🦖 docs! 👋
</div>
),
sidebar: {
defaultMenuCollapseLevel: 1,
},
toc: {
float: true,
},
project: {
link: github,
},
darkMode: true,
nextThemes: {
defaultTheme: 'dark',
},
primaryHue: {
dark: 207,
light: 212,
},
footer: {
text: (
<div>
<span>MIT {new Date().getFullYear()} © </span>
<a href="https://www.docsgpt.cloud/" target="_blank">
DocsGPT
</a>
{' | '}
<a href="https://github.com/arc53/DocsGPT" target="_blank">
GitHub
</a>
{' | '}
<a href="https://blog.docsgpt.cloud/" target="_blank">
Blog
</a>
</div>
),
},
editLink: {
content: 'Edit this page on GitHub',
},
logo() {
return (
<div className="flex items-center gap-2">
<Logo width={28} height={28} />
</div>
);
},
useNextSeoProps() {
return {
titleTemplate: `%s - DocsGPT Documentation`,
};
},
head() {
const { frontMatter } = useConfig();
const { theme } = useTheme();
const title = frontMatter?.title || 'Chat with your data with DocsGPT';
const description =
frontMatter?.description ||
'Use DocsGPT to chat with your data. DocsGPT is a GPT powered chatbot that can answer questions about your data.'
const image = '/cute-docsgpt.png';
const composedTitle = `${title} DocsGPT Documentation`;
return (
<>
<link
rel="apple-touch-icon"
sizes="180x180"
href={`/favicons/apple-touch-icon.png`}
/>
<link
rel="icon"
type="image/png"
sizes="32x32"
href={`/favicons/favicon-32x32.png`}
/>
<link
rel="icon"
type="image/png"
sizes="16x16"
href={`/favicons/favicon-16x16.png`}
/>
<meta name="theme-color" content="#ffffff" />
<meta name="msapplication-TileColor" content="#00a300" />
<link rel="manifest" href={`/favicons/site.webmanifest`} />
<meta httpEquiv="Content-Language" content="en" />
<meta name="title" content={composedTitle} />
<meta name="description" content={description} />
<meta name="twitter:card" content="summary_large_image" />
<meta name="twitter:site" content="@ATushynski" />
<meta name="twitter:image" content={image} />
<meta property="og:description" content={description} />
<meta property="og:title" content={composedTitle} />
<meta property="og:image" content={image} />
<meta property="og:type" content="website" />
<meta
name="apple-mobile-web-app-title"
content="DocsGPT Documentation"
/>
</>
);
},
sidebar: {
defaultMenuCollapseLevel: 1,
titleComponent: ({ title, type }) =>
type === 'separator' ? (
<div className="flex items-center gap-2">
<Logo height={10} width={10} />
{title}
<Analytics />
</div>
) : (
<>{title}
<Analytics />
</>
),
},
gitTimestamp: ({ timestamp }) => (
<>Last updated on {timestamp.toLocaleDateString()}</>
),
editLink: 'Edit this page on GitHub',
};
export default config;
export default config;

View File

@@ -3,4 +3,5 @@ VITE_BASE_URL=http://localhost:5173
VITE_API_HOST=http://127.0.0.1:7091
VITE_API_STREAMING=true
VITE_NOTIFICATION_TEXT="What's new in 0.15.0 — Changelog"
VITE_NOTIFICATION_LINK="https://blog.docsgpt.cloud/docsgpt-0-15-masters-long-term-memory-and-tooling/"
VITE_NOTIFICATION_LINK="https://blog.docsgpt.cloud/docsgpt-0-15-masters-long-term-memory-and-tooling/"
VITE_GOOGLE_CLIENT_ID=896376503572-u46l78n8ctgtdr4dlei4u06jv6rbpqc5.apps.googleusercontent.com

View File

@@ -439,10 +439,24 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
const data = await response.json();
const transformed = modelService.transformModels(data.models || []);
setAvailableModels(transformed);
if (mode === 'new' && transformed.length > 0) {
const preferredDefaultModelId =
transformed.find((model) => model.id === data.default_model_id)?.id ||
transformed[0].id;
if (preferredDefaultModelId) {
setSelectedModelIds((prevSelectedModelIds) =>
prevSelectedModelIds.size > 0
? prevSelectedModelIds
: new Set([preferredDefaultModelId]),
);
}
}
};
getTools();
getModels();
}, [token]);
}, [token, mode]);
// Validate folder_id from URL against user's folders
useEffect(() => {

View File

@@ -1,4 +1,20 @@
export type NodeType = 'start' | 'end' | 'agent' | 'note' | 'state';
export type NodeType = 'start' | 'end' | 'agent' | 'note' | 'state' | 'condition';
export interface ConditionCase {
name?: string;
expression: string;
sourceHandle: string;
}
export interface ConditionNodeConfig {
mode: 'simple' | 'advanced';
cases: ConditionCase[];
}
export interface StateOperationConfig {
expression: string;
target_variable: string;
}
export interface WorkflowEdge {
id: string;

File diff suppressed because it is too large Load Diff

View File

@@ -4,6 +4,7 @@ import {
Circle,
Database,
Flag,
GitBranch,
Loader2,
MessageSquare,
Play,
@@ -53,6 +54,7 @@ const NODE_ICONS: Record<string, React.ReactNode> = {
end: <Flag className="h-3 w-3" />,
note: <StickyNote className="h-3 w-3" />,
state: <Database className="h-3 w-3" />,
condition: <GitBranch className="h-3 w-3" />,
};
const NODE_COLORS: Record<string, string> = {
@@ -61,6 +63,7 @@ const NODE_COLORS: Record<string, string> = {
end: 'text-gray-600 dark:text-gray-400',
note: 'text-yellow-600 dark:text-yellow-400',
state: 'text-blue-600 dark:text-blue-400',
condition: 'text-orange-600 dark:text-orange-400',
};
function ExecutionDetails({
@@ -84,7 +87,9 @@ function ExecutionDetails({
const formatValue = (value: unknown): string => {
if (typeof value === 'string') return value;
return JSON.stringify(value, null, 2);
if (value === undefined) return '';
const formatted = JSON.stringify(value, null, 2);
return formatted ?? String(value);
};
return (
@@ -133,6 +138,11 @@ function ExecutionDetails({
if (text.length <= maxLength) return text;
return text.slice(0, maxLength) + '...';
};
const hasOutput =
step.output !== undefined &&
step.output !== null &&
formatValue(step.output) !== '';
const formattedOutput = hasOutput ? formatValue(step.output) : '';
return (
<div
@@ -168,15 +178,15 @@ function ExecutionDetails({
)}
</div>
</div>
{(step.output || step.error || stateVars.length > 0) && (
{(hasOutput || step.error || stateVars.length > 0) && (
<div className="mt-3 space-y-2 text-sm">
{step.output && (
{hasOutput && (
<div className="rounded-lg bg-white p-2 dark:bg-[#2A2A2A]">
<span className="font-medium text-gray-600 dark:text-gray-400">
Output:{' '}
</span>
<span className="wrap-break-word whitespace-pre-wrap text-gray-900 dark:text-gray-100">
{truncateText(step.output, 300)}
{truncateText(formattedOutput, 300)}
</span>
</div>
)}
@@ -251,7 +261,7 @@ function WorkflowMiniMap({
return step?.status || 'pending';
};
const getStatusColor = (nodeId: string, nodeType: string) => {
const getStatusColor = (nodeId: string) => {
const status = getNodeStatus(nodeId);
const isActive = nodeId === activeNodeId;
@@ -267,26 +277,33 @@ function WorkflowMiniMap({
case 'failed':
return 'bg-red-100 dark:bg-red-900/30 border-red-300 dark:border-red-700';
default:
if (nodeType === 'start') {
return 'bg-green-50 dark:bg-green-900/20 border-green-200 dark:border-green-800';
}
if (nodeType === 'agent') {
return 'bg-purple-50 dark:bg-purple-900/20 border-purple-200 dark:border-purple-800';
}
if (nodeType === 'end') {
return 'bg-gray-50 dark:bg-gray-800 border-gray-200 dark:border-gray-700';
}
return 'bg-gray-50 dark:bg-gray-800 border-gray-200 dark:border-gray-700';
}
};
const sortedNodes = [...nodes].sort((a, b) => {
if (a.type === 'start') return -1;
if (b.type === 'start') return 1;
if (a.type === 'end') return 1;
if (b.type === 'end') return -1;
return (a.position?.y || 0) - (b.position?.y || 0);
});
const executedOrder = new Map(executionSteps.map((s, i) => [s.nodeId, i]));
const startNode = nodes.find((node) => node.type === 'start');
const visibleNodeIds = new Set(executionSteps.map((step) => step.nodeId));
if (activeNodeId) {
visibleNodeIds.add(activeNodeId);
}
if (startNode) {
visibleNodeIds.add(startNode.id);
}
const sortedNodes = nodes
.filter((node) => visibleNodeIds.has(node.id))
.sort((a, b) => {
if (a.type === 'start') return -1;
if (b.type === 'start') return 1;
const aIdx = executedOrder.get(a.id);
const bIdx = executedOrder.get(b.id);
if (aIdx !== undefined && bIdx !== undefined) return aIdx - bIdx;
if (aIdx !== undefined) return -1;
if (bIdx !== undefined) return 1;
return (a.position?.y || 0) - (b.position?.y || 0);
});
const hasStepData = (nodeId: string) => {
const step = executionSteps.find((s) => s.nodeId === nodeId);
@@ -306,7 +323,7 @@ function WorkflowMiniMap({
disabled={!hasStepData(node.id)}
className={cn(
'flex h-12 w-full items-center gap-2 rounded-lg border px-3 text-xs transition-all',
getStatusColor(node.id, node.type),
getStatusColor(node.id),
hasStepData(node.id) && 'cursor-pointer hover:opacity-80',
)}
>
@@ -533,6 +550,10 @@ export default function WorkflowPreview({
const querySteps = query.executionSteps || [];
const hasResponse = !!(query.response || query.error);
const isLastQuery = index === queries.length - 1;
const isStreamingLastQuery =
status === 'loading' && isLastQuery;
const shouldShowThought =
!isStreamingLastQuery && Boolean(query.thought);
const isOpen =
openDetailsIndex === index ||
(!hasResponse && isLastQuery && querySteps.length > 0);
@@ -567,17 +588,19 @@ export default function WorkflowPreview({
{/* Response bubble */}
{(query.response ||
query.thought ||
shouldShowThought ||
query.tool_calls) && (
<ConversationBubble
className={isLastQuery ? 'mb-32' : 'mb-7'}
message={query.response}
type="ANSWER"
thought={query.thought}
thought={
shouldShowThought ? query.thought : undefined
}
sources={query.sources}
toolCalls={query.tool_calls}
feedback={query.feedback}
isStreaming={status === 'loading' && isLastQuery}
isStreaming={isStreamingLastQuery}
/>
)}

View File

@@ -9,10 +9,71 @@ import {
} from '@/components/ui/popover';
interface WorkflowVariable {
name: string;
label: string;
templatePath: string;
section: string;
}
const GLOBAL_CONTEXT_VARIABLES: WorkflowVariable[] = [
{
label: 'source.content',
templatePath: 'source.content',
section: 'Global context',
},
{
label: 'source.summaries',
templatePath: 'source.summaries',
section: 'Global context',
},
{
label: 'source.documents',
templatePath: 'source.documents',
section: 'Global context',
},
{
label: 'source.count',
templatePath: 'source.count',
section: 'Global context',
},
{
label: 'system.date',
templatePath: 'system.date',
section: 'Global context',
},
{
label: 'system.time',
templatePath: 'system.time',
section: 'Global context',
},
{
label: 'system.timestamp',
templatePath: 'system.timestamp',
section: 'Global context',
},
{
label: 'system.request_id',
templatePath: 'system.request_id',
section: 'Global context',
},
{
label: 'system.user_id',
templatePath: 'system.user_id',
section: 'Global context',
},
];
function toAgentTemplatePath(variableName: string): string {
const trimmed = variableName.trim();
if (!trimmed) return 'agent';
if (/^[A-Za-z_][A-Za-z0-9_]*$/.test(trimmed)) {
return `agent.${trimmed}`;
}
const escaped = trimmed.replace(/\\/g, '\\\\').replace(/'/g, "\\'");
return `agent['${escaped}']`;
}
function getUpstreamNodeIds(nodeId: string, edges: Edge[]): Set<string> {
const upstream = new Set<string>();
const queue = [nodeId];
@@ -36,32 +97,69 @@ function extractUpstreamVariables(
selectedNodeId: string,
): WorkflowVariable[] {
const variables: WorkflowVariable[] = [
{ name: 'query', section: 'Workflow input' },
{ name: 'chat_history', section: 'Workflow input' },
{
label: 'agent.query',
templatePath: 'agent.query',
section: 'Workflow input',
},
{
label: 'agent.chat_history',
templatePath: 'agent.chat_history',
section: 'Workflow input',
},
...GLOBAL_CONTEXT_VARIABLES,
];
const seen = new Set(['query', 'chat_history']);
const seen = new Set(variables.map((variable) => variable.templatePath));
const upstreamIds = getUpstreamNodeIds(selectedNodeId, edges);
for (const node of nodes) {
if (!upstreamIds.has(node.id)) continue;
if (node.type === 'agent' && node.data?.config?.output_variable) {
const name = node.data.config.output_variable;
if (!seen.has(name)) {
seen.add(name);
if (node.type === 'agent') {
const defaultOutputTemplatePath = toAgentTemplatePath(
`node_${node.id}_output`,
);
if (!seen.has(defaultOutputTemplatePath)) {
seen.add(defaultOutputTemplatePath);
variables.push({
name,
label: defaultOutputTemplatePath,
templatePath: defaultOutputTemplatePath,
section: node.data.title || node.data.label || 'Agent',
});
}
const outputVariable = String(
node.data?.config?.output_variable || '',
).trim();
if (outputVariable) {
const templatePath = toAgentTemplatePath(outputVariable);
if (!seen.has(templatePath)) {
seen.add(templatePath);
variables.push({
label: templatePath,
templatePath,
section: node.data.title || node.data.label || 'Agent',
});
}
}
}
if (node.type === 'state' && node.data?.variable) {
const name = node.data.variable;
if (!seen.has(name)) {
seen.add(name);
if (node.type === 'state') {
const operations = node.data?.config?.operations;
if (!Array.isArray(operations)) continue;
for (const operation of operations) {
const targetVariable = String(operation?.target_variable || '').trim();
if (!targetVariable) continue;
const templatePath = toAgentTemplatePath(targetVariable);
if (seen.has(templatePath)) continue;
seen.add(templatePath);
variables.push({
name,
section: 'Set State',
label: templatePath,
templatePath,
section: node.data.title || node.data.label || 'Set State',
});
}
}
@@ -106,14 +204,16 @@ function VariableListWithSearch({
onSelect,
}: {
variables: WorkflowVariable[];
onSelect: (name: string) => void;
onSelect: (templatePath: string) => void;
}) {
const [search, setSearch] = useState('');
const filtered = useMemo(
() =>
variables.filter((v) =>
v.name.toLowerCase().includes(search.toLowerCase()),
`${v.label} ${v.templatePath}`
.toLowerCase()
.includes(search.toLowerCase()),
),
[variables, search],
);
@@ -146,17 +246,17 @@ function VariableListWithSearch({
</div>
{vars.map((v) => (
<button
key={v.name}
key={`${section}-${v.templatePath}`}
onMouseDown={(e) => {
e.preventDefault();
e.stopPropagation();
onSelect(v.name);
onSelect(v.templatePath);
}}
className="flex w-full cursor-pointer items-center gap-2 px-3 py-1.5 text-left text-sm transition-colors hover:bg-gray-50 dark:hover:bg-[#383838]"
>
<Braces className="text-violets-are-blue h-3.5 w-3.5 shrink-0" />
<span className="truncate font-medium text-gray-800 dark:text-gray-200">
{v.name}
{v.label}
</span>
</button>
))}
@@ -206,7 +306,9 @@ export default function PromptTextArea({
const filtered = useMemo(
() =>
variables.filter((v) =>
v.name.toLowerCase().includes(filterText.toLowerCase()),
`${v.label} ${v.templatePath}`
.toLowerCase()
.includes(filterText.toLowerCase()),
),
[variables, filterText],
);
@@ -217,10 +319,12 @@ export default function PromptTextArea({
const cursorPos = textarea.selectionStart;
const textBeforeCursor = value.slice(0, cursorPos);
const triggerMatch = textBeforeCursor.match(/\{\{(\w*)$/);
const triggerMatch = textBeforeCursor.match(
/\{\{\s*([A-Za-z0-9_.[\]'"]*)$/,
);
if (triggerMatch) {
setFilterText(triggerMatch[1]);
setFilterText(triggerMatch[1].trim());
setCursorInsertPos(cursorPos);
const wrapper = wrapperRef.current;
@@ -237,15 +341,17 @@ export default function PromptTextArea({
}, [value]);
const insertVariable = useCallback(
(varName: string) => {
(templatePath: string) => {
if (cursorInsertPos === null) return;
const textBeforeCursor = value.slice(0, cursorInsertPos);
const triggerMatch = textBeforeCursor.match(/\{\{(\w*)$/);
const triggerMatch = textBeforeCursor.match(
/\{\{\s*([A-Za-z0-9_.[\]'"]*)$/,
);
if (!triggerMatch) return;
const startPos = cursorInsertPos - triggerMatch[0].length;
const insertion = `{{${varName}}}`;
const insertion = `{{ ${templatePath} }}`;
const newValue =
value.slice(0, startPos) + insertion + value.slice(cursorInsertPos);
@@ -262,10 +368,10 @@ export default function PromptTextArea({
);
const insertVariableFromButton = useCallback(
(varName: string) => {
(templatePath: string) => {
const textarea = textareaRef.current;
const cursorPos = textarea?.selectionStart ?? value.length;
const insertion = `{{${varName}}}`;
const insertion = `{{ ${templatePath} }}`;
const newValue =
value.slice(0, cursorPos) + insertion + value.slice(cursorPos);

View File

@@ -5,7 +5,7 @@ interface BaseNodeProps {
title: string;
children?: ReactNode;
selected?: boolean;
type?: 'start' | 'end' | 'default' | 'state' | 'agent';
type?: 'start' | 'end' | 'default' | 'state' | 'agent' | 'condition';
icon?: ReactNode;
handles?: {
source?: boolean;
@@ -40,6 +40,9 @@ export const BaseNode: React.FC<BaseNodeProps> = ({
} else if (type === 'state') {
iconBg = 'bg-gray-100 dark:bg-gray-800';
iconColor = 'text-gray-600 dark:text-gray-400';
} else if (type === 'condition') {
iconBg = 'bg-orange-100 dark:bg-orange-900/30';
iconColor = 'text-orange-600 dark:text-orange-400';
}
return (

View File

@@ -0,0 +1,118 @@
import { GitBranch } from 'lucide-react';
import { memo } from 'react';
import { Handle, NodeProps, Position } from 'reactflow';
import { ConditionCase } from '../../types/workflow';
type ConditionNodeData = {
label?: string;
title?: string;
config?: {
mode?: 'simple' | 'advanced';
cases?: ConditionCase[];
};
};
const ROW_HEIGHT = 18;
const HEADER_HEIGHT = 52;
const PADDING_BOTTOM = 8;
function getNodeHeight(caseCount: number): number {
return (
HEADER_HEIGHT + Math.max(caseCount + 1, 2) * ROW_HEIGHT + PADDING_BOTTOM
);
}
function getHandleTop(index: number, total: number): string {
const offset = HEADER_HEIGHT;
return `${offset + ROW_HEIGHT * index + ROW_HEIGHT / 2}px`;
}
const ConditionNode = ({ data, selected }: NodeProps<ConditionNodeData>) => {
const title = data.title || data.label || 'If / Else';
const cases = data.config?.cases || [];
const totalOutputs = cases.length + 1;
const height = getNodeHeight(cases.length);
return (
<div
className={`relative rounded-2xl border bg-white shadow-md transition-all dark:bg-[#2C2C2C] ${
selected
? 'border-violets-are-blue dark:ring-violets-are-blue scale-105 ring-2 ring-purple-300'
: 'border-gray-200 hover:shadow-lg dark:border-[#3A3A3A]'
}`}
style={{ minWidth: 180, maxWidth: 220, height }}
>
<Handle
type="target"
position={Position.Left}
isConnectable
className="hover:bg-violets-are-blue! top-1/2! -left-1! h-3! w-3! rounded-full! border-2! border-white! bg-gray-400! transition-colors dark:border-[#2C2C2C]!"
/>
<div className="flex items-center gap-3 px-3 py-2">
<div className="flex h-9 w-9 shrink-0 items-center justify-center rounded-full bg-orange-100 text-orange-600 dark:bg-orange-900/30 dark:text-orange-400">
<GitBranch size={14} />
</div>
<div className="min-w-0 flex-1 pr-2">
<div
className="truncate text-sm font-semibold text-gray-900 dark:text-white"
title={title}
>
{title}
</div>
<div className="text-[10px] text-gray-500 uppercase">
{data.config?.mode || 'simple'}
</div>
</div>
</div>
<div className="flex flex-col px-3">
{cases.map((c, i) => (
<div
key={c.sourceHandle}
className="flex items-center gap-1"
style={{ height: ROW_HEIGHT }}
>
<span className="shrink-0 text-xs font-medium text-orange-600 dark:text-orange-400">
{i === 0 ? 'If' : 'Else if'}
</span>
{c.name && (
<span
className="truncate text-xs text-gray-600 dark:text-gray-400"
title={c.name}
>
{c.name}
</span>
)}
</div>
))}
<div className="flex items-center gap-1" style={{ height: ROW_HEIGHT }}>
<span className="text-xs font-medium text-gray-500">Else</span>
</div>
</div>
{cases.map((c, i) => (
<Handle
key={c.sourceHandle}
type="source"
position={Position.Right}
id={c.sourceHandle}
isConnectable
style={{ top: getHandleTop(i, totalOutputs) }}
className="hover:bg-violets-are-blue! -right-1! h-3! w-3! rounded-full! border-2! border-white! bg-orange-400! transition-colors dark:border-[#2C2C2C]!"
/>
))}
<Handle
type="source"
position={Position.Right}
id="else"
isConnectable
style={{ top: getHandleTop(cases.length, totalOutputs) }}
className="hover:bg-violets-are-blue! -right-1! h-3! w-3! rounded-full! border-2! border-white! bg-gray-400! transition-colors dark:border-[#2C2C2C]!"
/>
</div>
);
};
export default memo(ConditionNode);

View File

@@ -2,6 +2,7 @@ import { Database } from 'lucide-react';
import { memo } from 'react';
import { NodeProps } from 'reactflow';
import { StateOperationConfig } from '../../types/workflow';
import { BaseNode } from './BaseNode';
type SetStateNodeData = {
@@ -9,10 +10,16 @@ type SetStateNodeData = {
title?: string;
variable?: string;
value?: string;
config?: {
operations?: StateOperationConfig[];
};
};
const SetStateNode = ({ data, selected }: NodeProps<SetStateNodeData>) => {
const title = data.title || data.label || 'Set State';
const operations = data.config?.operations || [];
const hasLegacy = !operations.length && data.variable;
return (
<BaseNode
title={title}
@@ -22,22 +29,31 @@ const SetStateNode = ({ data, selected }: NodeProps<SetStateNodeData>) => {
handles={{ source: true, target: true }}
>
<div className="flex flex-col gap-1">
{data.variable && (
{operations.length > 0 ? (
<div
className="truncate text-[10px] text-gray-500 uppercase"
title={`Variable: ${data.variable}`}
className="truncate text-[10px] text-gray-500"
title={`${operations.length} operation(s)`}
>
{data.variable}
{operations.length} variable{operations.length !== 1 ? 's' : ''}
</div>
)}
{data.value && (
<div
className="truncate text-xs text-blue-600 dark:text-blue-400"
title={`Value: ${data.value}`}
>
{data.value}
</div>
)}
) : hasLegacy ? (
<>
<div
className="truncate text-[10px] text-gray-500 uppercase"
title={`Variable: ${data.variable}`}
>
{data.variable}
</div>
{data.value && (
<div
className="truncate text-xs text-blue-600 dark:text-blue-400"
title={`Value: ${data.value}`}
>
{data.value}
</div>
)}
</>
) : null}
</div>
</BaseNode>
);

View File

@@ -1,7 +1,9 @@
import React, { memo } from 'react';
import { Bot, Flag, Play, StickyNote } from 'lucide-react';
import { memo } from 'react';
import { BaseNode } from './BaseNode';
import ConditionNode from './ConditionNode';
import SetStateNode from './SetStateNode';
import { Play, Bot, StickyNote, Flag } from 'lucide-react';
export const StartNode = memo(function StartNode({
selected,
@@ -83,10 +85,10 @@ export const AgentNode = memo(function AgentNode({
)}
{config.output_variable && (
<div
className="truncate text-xs text-green-600 dark:text-green-400"
title={`Output ${config.output_variable}`}
className="truncate text-xs text-gray-500 dark:text-gray-400"
title={`Output: ${config.output_variable}`}
>
Output {config.output_variable}
Output: {config.output_variable}
</div>
)}
</div>
@@ -142,3 +144,4 @@ export const NoteNode = memo(function NoteNode({
});
export { SetStateNode };
export { ConditionNode };

View File

@@ -13,7 +13,7 @@ export interface WorkflowExecutionStep {
startedAt?: number;
completedAt?: number;
stateSnapshot?: Record<string, unknown>;
output?: string;
output?: unknown;
error?: string;
}
@@ -321,7 +321,9 @@ export const workflowPreviewSlice = createSlice({
}
const querySteps = state.queries[index].executionSteps!;
const existingIndex = querySteps.findIndex((s) => s.nodeId === step.nodeId);
const existingIndex = querySteps.findIndex(
(s) => s.nodeId === step.nodeId,
);
const updatedStep: WorkflowExecutionStep = {
nodeId: step.nodeId,
@@ -332,7 +334,10 @@ export const workflowPreviewSlice = createSlice({
stateSnapshot: step.stateSnapshot,
output: step.output,
error: step.error,
startedAt: existingIndex !== -1 ? querySteps[existingIndex].startedAt : Date.now(),
startedAt:
existingIndex !== -1
? querySteps[existingIndex].startedAt
: Date.now(),
completedAt:
step.status === 'completed' || step.status === 'failed'
? Date.now()
@@ -342,7 +347,8 @@ export const workflowPreviewSlice = createSlice({
};
if (existingIndex !== -1) {
updatedStep.stateSnapshot = step.stateSnapshot ?? querySteps[existingIndex].stateSnapshot;
updatedStep.stateSnapshot =
step.stateSnapshot ?? querySteps[existingIndex].stateSnapshot;
updatedStep.output = step.output ?? querySteps[existingIndex].output;
updatedStep.error = step.error ?? querySteps[existingIndex].error;
querySteps[existingIndex] = updatedStep;
@@ -350,7 +356,9 @@ export const workflowPreviewSlice = createSlice({
querySteps.push(updatedStep);
}
const globalIndex = state.executionSteps.findIndex((s) => s.nodeId === step.nodeId);
const globalIndex = state.executionSteps.findIndex(
(s) => s.nodeId === step.nodeId,
);
if (globalIndex !== -1) {
state.executionSteps[globalIndex] = updatedStep;
} else {

View File

@@ -43,7 +43,7 @@ export default function WrapperModal({
const modalContent = (
<div
className="fixed top-0 left-0 z-30 flex h-screen w-screen items-center justify-center"
className="fixed top-0 left-0 z-[100] flex h-screen w-screen items-center justify-center"
onClick={(e: React.MouseEvent) => e.stopPropagation()}
onMouseDown={(e: React.MouseEvent) => e.stopPropagation()}
>

343
setup.ps1
View File

@@ -286,6 +286,301 @@ function Prompt-OllamaOptions {
$script:ollama_choice = Read-Host "Choose option (1-2, or b)"
}
# ========================
# Advanced Settings Functions
# ========================
# Vector Store configuration
function Configure-VectorStore {
Write-Host ""
Write-ColorText "Vector Store Configuration" -ForegroundColor "White" -Bold
Write-ColorText "Choose your vector store:" -ForegroundColor "White"
Write-ColorText "1) FAISS (default, local)" -ForegroundColor "Yellow"
Write-ColorText "2) Elasticsearch" -ForegroundColor "Yellow"
Write-ColorText "3) Qdrant" -ForegroundColor "Yellow"
Write-ColorText "4) Milvus" -ForegroundColor "Yellow"
Write-ColorText "5) LanceDB" -ForegroundColor "Yellow"
Write-ColorText "6) PGVector" -ForegroundColor "Yellow"
Write-ColorText "b) Back" -ForegroundColor "Yellow"
Write-Host ""
$vs_choice = Read-Host "Choose option (1-6, or b)"
switch ($vs_choice) {
"1" {
"VECTOR_STORE=faiss" | Add-Content -Path $ENV_FILE -Encoding utf8
Write-ColorText "Vector store set to FAISS." -ForegroundColor "Green"
}
"2" {
"VECTOR_STORE=elasticsearch" | Add-Content -Path $ENV_FILE -Encoding utf8
$elastic_url = Read-Host "Enter Elasticsearch URL (e.g. http://localhost:9200)"
if ($elastic_url) { "ELASTIC_URL=$elastic_url" | Add-Content -Path $ENV_FILE -Encoding utf8 }
$elastic_cloud_id = Read-Host "Enter Elasticsearch Cloud ID (leave empty if using URL)"
if ($elastic_cloud_id) { "ELASTIC_CLOUD_ID=$elastic_cloud_id" | Add-Content -Path $ENV_FILE -Encoding utf8 }
$elastic_user = Read-Host "Enter Elasticsearch username (leave empty if none)"
if ($elastic_user) { "ELASTIC_USERNAME=$elastic_user" | Add-Content -Path $ENV_FILE -Encoding utf8 }
$elastic_pass = Read-Host "Enter Elasticsearch password (leave empty if none)"
if ($elastic_pass) { "ELASTIC_PASSWORD=$elastic_pass" | Add-Content -Path $ENV_FILE -Encoding utf8 }
$elastic_index = Read-Host "Enter Elasticsearch index name (default: docsgpt)"
if ([string]::IsNullOrEmpty($elastic_index)) { $elastic_index = "docsgpt" }
"ELASTIC_INDEX=$elastic_index" | Add-Content -Path $ENV_FILE -Encoding utf8
Write-ColorText "Vector store set to Elasticsearch." -ForegroundColor "Green"
}
"3" {
"VECTOR_STORE=qdrant" | Add-Content -Path $ENV_FILE -Encoding utf8
$qdrant_url = Read-Host "Enter Qdrant URL (e.g. http://localhost:6333)"
if ($qdrant_url) { "QDRANT_URL=$qdrant_url" | Add-Content -Path $ENV_FILE -Encoding utf8 }
$qdrant_key = Read-Host "Enter Qdrant API key (leave empty if none)"
if ($qdrant_key) { "QDRANT_API_KEY=$qdrant_key" | Add-Content -Path $ENV_FILE -Encoding utf8 }
$qdrant_collection = Read-Host "Enter Qdrant collection name (default: docsgpt)"
if ([string]::IsNullOrEmpty($qdrant_collection)) { $qdrant_collection = "docsgpt" }
"QDRANT_COLLECTION_NAME=$qdrant_collection" | Add-Content -Path $ENV_FILE -Encoding utf8
Write-ColorText "Vector store set to Qdrant." -ForegroundColor "Green"
}
"4" {
"VECTOR_STORE=milvus" | Add-Content -Path $ENV_FILE -Encoding utf8
$milvus_uri = Read-Host "Enter Milvus URI (default: ./milvus_local.db)"
if ([string]::IsNullOrEmpty($milvus_uri)) { $milvus_uri = "./milvus_local.db" }
"MILVUS_URI=$milvus_uri" | Add-Content -Path $ENV_FILE -Encoding utf8
$milvus_token = Read-Host "Enter Milvus token (leave empty if none)"
if ($milvus_token) { "MILVUS_TOKEN=$milvus_token" | Add-Content -Path $ENV_FILE -Encoding utf8 }
$milvus_collection = Read-Host "Enter Milvus collection name (default: docsgpt)"
if ([string]::IsNullOrEmpty($milvus_collection)) { $milvus_collection = "docsgpt" }
"MILVUS_COLLECTION_NAME=$milvus_collection" | Add-Content -Path $ENV_FILE -Encoding utf8
Write-ColorText "Vector store set to Milvus." -ForegroundColor "Green"
}
"5" {
"VECTOR_STORE=lancedb" | Add-Content -Path $ENV_FILE -Encoding utf8
$lancedb_path = Read-Host "Enter LanceDB path (default: ./data/lancedb)"
if ([string]::IsNullOrEmpty($lancedb_path)) { $lancedb_path = "./data/lancedb" }
"LANCEDB_PATH=$lancedb_path" | Add-Content -Path $ENV_FILE -Encoding utf8
$lancedb_table = Read-Host "Enter LanceDB table name (default: docsgpts)"
if ([string]::IsNullOrEmpty($lancedb_table)) { $lancedb_table = "docsgpts" }
"LANCEDB_TABLE_NAME=$lancedb_table" | Add-Content -Path $ENV_FILE -Encoding utf8
Write-ColorText "Vector store set to LanceDB." -ForegroundColor "Green"
}
"6" {
"VECTOR_STORE=pgvector" | Add-Content -Path $ENV_FILE -Encoding utf8
$pgvector_conn = Read-Host "Enter PGVector connection string (e.g. postgresql://user:pass@host:5432/db)"
if ($pgvector_conn) { "PGVECTOR_CONNECTION_STRING=$pgvector_conn" | Add-Content -Path $ENV_FILE -Encoding utf8 }
Write-ColorText "Vector store set to PGVector." -ForegroundColor "Green"
}
{$_ -eq "b" -or $_ -eq "B"} { return }
default {
Write-Host ""
Write-ColorText "Invalid choice." -ForegroundColor "Red"
Start-Sleep -Seconds 1
}
}
}
# Embeddings configuration
function Configure-Embeddings {
Write-Host ""
Write-ColorText "Embeddings Configuration" -ForegroundColor "White" -Bold
Write-ColorText "Choose your embeddings provider:" -ForegroundColor "White"
Write-ColorText "1) HuggingFace (default, local)" -ForegroundColor "Yellow"
Write-ColorText "2) OpenAI Embeddings" -ForegroundColor "Yellow"
Write-ColorText "3) Custom Remote Embeddings (OpenAI-compatible API)" -ForegroundColor "Yellow"
Write-ColorText "b) Back" -ForegroundColor "Yellow"
Write-Host ""
$emb_choice = Read-Host "Choose option (1-3, or b)"
switch ($emb_choice) {
"1" {
"EMBEDDINGS_NAME=huggingface_sentence-transformers/all-mpnet-base-v2" | Add-Content -Path $ENV_FILE -Encoding utf8
Write-ColorText "Embeddings set to HuggingFace (local)." -ForegroundColor "Green"
}
"2" {
"EMBEDDINGS_NAME=openai_text-embedding-ada-002" | Add-Content -Path $ENV_FILE -Encoding utf8
$emb_key = Read-Host "Enter Embeddings API key (leave empty to reuse LLM API_KEY)"
if ($emb_key) { "EMBEDDINGS_KEY=$emb_key" | Add-Content -Path $ENV_FILE -Encoding utf8 }
Write-ColorText "Embeddings set to OpenAI." -ForegroundColor "Green"
}
"3" {
$emb_name = Read-Host "Enter embeddings model name"
if ($emb_name) { "EMBEDDINGS_NAME=$emb_name" | Add-Content -Path $ENV_FILE -Encoding utf8 }
$emb_url = Read-Host "Enter remote embeddings API base URL"
if ($emb_url) { "EMBEDDINGS_BASE_URL=$emb_url" | Add-Content -Path $ENV_FILE -Encoding utf8 }
$emb_key = Read-Host "Enter embeddings API key (leave empty if none)"
if ($emb_key) { "EMBEDDINGS_KEY=$emb_key" | Add-Content -Path $ENV_FILE -Encoding utf8 }
Write-ColorText "Custom remote embeddings configured." -ForegroundColor "Green"
}
{$_ -eq "b" -or $_ -eq "B"} { return }
default {
Write-Host ""
Write-ColorText "Invalid choice." -ForegroundColor "Red"
Start-Sleep -Seconds 1
}
}
}
# Authentication configuration
function Configure-Auth {
Write-Host ""
Write-ColorText "Authentication Configuration" -ForegroundColor "White" -Bold
Write-ColorText "Choose authentication type:" -ForegroundColor "White"
Write-ColorText "1) None (default, no authentication)" -ForegroundColor "Yellow"
Write-ColorText "2) Simple JWT" -ForegroundColor "Yellow"
Write-ColorText "3) Session JWT" -ForegroundColor "Yellow"
Write-ColorText "b) Back" -ForegroundColor "Yellow"
Write-Host ""
$auth_choice = Read-Host "Choose option (1-3, or b)"
switch ($auth_choice) {
"1" {
Write-ColorText "Authentication disabled (default)." -ForegroundColor "Green"
}
"2" {
"AUTH_TYPE=simple_jwt" | Add-Content -Path $ENV_FILE -Encoding utf8
$jwt_key = Read-Host "Enter JWT secret key (leave empty to auto-generate)"
if ([string]::IsNullOrEmpty($jwt_key)) {
$bytes = New-Object byte[] 32
[System.Security.Cryptography.RandomNumberGenerator]::Fill($bytes)
$jwt_key = [System.BitConverter]::ToString($bytes).Replace("-", "").ToLower()
Write-ColorText "Auto-generated JWT secret key." -ForegroundColor "Yellow"
}
"JWT_SECRET_KEY=$jwt_key" | Add-Content -Path $ENV_FILE -Encoding utf8
Write-ColorText "Authentication set to Simple JWT." -ForegroundColor "Green"
}
"3" {
"AUTH_TYPE=session_jwt" | Add-Content -Path $ENV_FILE -Encoding utf8
$jwt_key = Read-Host "Enter JWT secret key (leave empty to auto-generate)"
if ([string]::IsNullOrEmpty($jwt_key)) {
$bytes = New-Object byte[] 32
[System.Security.Cryptography.RandomNumberGenerator]::Fill($bytes)
$jwt_key = [System.BitConverter]::ToString($bytes).Replace("-", "").ToLower()
Write-ColorText "Auto-generated JWT secret key." -ForegroundColor "Yellow"
}
"JWT_SECRET_KEY=$jwt_key" | Add-Content -Path $ENV_FILE -Encoding utf8
Write-ColorText "Authentication set to Session JWT." -ForegroundColor "Green"
}
{$_ -eq "b" -or $_ -eq "B"} { return }
default {
Write-Host ""
Write-ColorText "Invalid choice." -ForegroundColor "Red"
Start-Sleep -Seconds 1
}
}
}
# Integrations configuration
function Configure-Integrations {
Write-Host ""
Write-ColorText "Integrations Configuration" -ForegroundColor "White" -Bold
Write-ColorText "1) Google Drive" -ForegroundColor "Yellow"
Write-ColorText "2) GitHub" -ForegroundColor "Yellow"
Write-ColorText "b) Back" -ForegroundColor "Yellow"
Write-Host ""
$int_choice = Read-Host "Choose option (1-2, or b)"
switch ($int_choice) {
"1" {
$google_id = Read-Host "Enter Google OAuth Client ID"
if ($google_id) { "GOOGLE_CLIENT_ID=$google_id" | Add-Content -Path $ENV_FILE -Encoding utf8 }
$google_secret = Read-Host "Enter Google OAuth Client Secret"
if ($google_secret) { "GOOGLE_CLIENT_SECRET=$google_secret" | Add-Content -Path $ENV_FILE -Encoding utf8 }
Write-ColorText "Google Drive integration configured." -ForegroundColor "Green"
}
"2" {
$github_token = Read-Host "Enter GitHub Personal Access Token (with repo read access)"
if ($github_token) { "GITHUB_ACCESS_TOKEN=$github_token" | Add-Content -Path $ENV_FILE -Encoding utf8 }
Write-ColorText "GitHub integration configured." -ForegroundColor "Green"
}
{$_ -eq "b" -or $_ -eq "B"} { return }
default {
Write-Host ""
Write-ColorText "Invalid choice." -ForegroundColor "Red"
Start-Sleep -Seconds 1
}
}
}
# Document Processing configuration
function Configure-DocProcessing {
Write-Host ""
Write-ColorText "Document Processing Configuration" -ForegroundColor "White" -Bold
$pdf_image = Read-Host "Parse PDF pages as images for better table/chart extraction? (y/N)"
if ($pdf_image -eq "y" -or $pdf_image -eq "Y") {
"PARSE_PDF_AS_IMAGE=true" | Add-Content -Path $ENV_FILE -Encoding utf8
Write-ColorText "PDF-as-image parsing enabled." -ForegroundColor "Green"
}
$ocr_enabled = Read-Host "Enable OCR for document processing (Docling)? (y/N)"
if ($ocr_enabled -eq "y" -or $ocr_enabled -eq "Y") {
"DOCLING_OCR_ENABLED=true" | Add-Content -Path $ENV_FILE -Encoding utf8
Write-ColorText "Docling OCR enabled." -ForegroundColor "Green"
}
}
# Text-to-Speech configuration
function Configure-TTS {
Write-Host ""
Write-ColorText "Text-to-Speech Configuration" -ForegroundColor "White" -Bold
Write-ColorText "Choose TTS provider:" -ForegroundColor "White"
Write-ColorText "1) Google TTS (default, free)" -ForegroundColor "Yellow"
Write-ColorText "2) ElevenLabs" -ForegroundColor "Yellow"
Write-ColorText "b) Back" -ForegroundColor "Yellow"
Write-Host ""
$tts_choice = Read-Host "Choose option (1-2, or b)"
switch ($tts_choice) {
"1" {
"TTS_PROVIDER=google_tts" | Add-Content -Path $ENV_FILE -Encoding utf8
Write-ColorText "TTS set to Google TTS." -ForegroundColor "Green"
}
"2" {
"TTS_PROVIDER=elevenlabs" | Add-Content -Path $ENV_FILE -Encoding utf8
$elevenlabs_key = Read-Host "Enter ElevenLabs API key"
if ($elevenlabs_key) { "ELEVENLABS_API_KEY=$elevenlabs_key" | Add-Content -Path $ENV_FILE -Encoding utf8 }
Write-ColorText "TTS set to ElevenLabs." -ForegroundColor "Green"
}
{$_ -eq "b" -or $_ -eq "B"} { return }
default {
Write-Host ""
Write-ColorText "Invalid choice." -ForegroundColor "Red"
Start-Sleep -Seconds 1
}
}
}
# Main advanced settings menu
function Prompt-AdvancedSettings {
Write-Host ""
$configure_advanced = Read-Host "Would you like to configure advanced settings? (y/N)"
if ($configure_advanced -ne "y" -and $configure_advanced -ne "Y") {
return
}
while ($true) {
Write-Host ""
Write-ColorText "Advanced Settings" -ForegroundColor "White" -Bold
Write-ColorText "1) Vector Store (default: faiss)" -ForegroundColor "Yellow"
Write-ColorText "2) Embeddings (default: HuggingFace local)" -ForegroundColor "Yellow"
Write-ColorText "3) Authentication (default: none)" -ForegroundColor "Yellow"
Write-ColorText "4) Integrations (Google Drive, GitHub)" -ForegroundColor "Yellow"
Write-ColorText "5) Document Processing (PDF as image, OCR)" -ForegroundColor "Yellow"
Write-ColorText "6) Text-to-Speech (default: Google TTS)" -ForegroundColor "Yellow"
Write-ColorText "s) Save and Continue with Docker setup" -ForegroundColor "Yellow"
Write-Host ""
$adv_choice = Read-Host "Choose option (1-6, or s)"
switch ($adv_choice) {
"1" { Configure-VectorStore }
"2" { Configure-Embeddings }
"3" { Configure-Auth }
"4" { Configure-Integrations }
"5" { Configure-DocProcessing }
"6" { Configure-TTS }
{$_ -eq "s" -or $_ -eq "S"} { break }
default {
Write-Host ""
Write-ColorText "Invalid choice." -ForegroundColor "Red"
Start-Sleep -Seconds 1
}
}
}
}
# 1) Use DocsGPT Public API Endpoint (simple and free)
function Use-DocsPublicAPIEndpoint {
Write-Host ""
@@ -297,6 +592,8 @@ function Use-DocsPublicAPIEndpoint {
Write-ColorText ".env file configured for DocsGPT Public API." -ForegroundColor "Green"
Prompt-AdvancedSettings
# Start Docker if needed
$dockerRunning = Check-AndStartDocker
if (-not $dockerRunning) {
@@ -369,6 +666,7 @@ function Serve-LocalOllama {
break
}
elseif ($confirm_gpu -eq "b" -or $confirm_gpu -eq "B") {
$script:ollama_choice = "b"
Clear-Host
return
}
@@ -406,6 +704,8 @@ function Serve-LocalOllama {
Write-ColorText ".env file configured for Ollama ($($docker_compose_file_suffix.ToUpper()))." -ForegroundColor "Green"
Write-ColorText "Note: MODEL_NAME is set to '$model_name'. You can change it later in the .env file." -ForegroundColor "Yellow"
Prompt-AdvancedSettings
# Start Docker if needed
$dockerRunning = Check-AndStartDocker
if (-not $dockerRunning) {
@@ -569,6 +869,8 @@ function Connect-LocalInferenceEngine {
Write-ColorText ".env file configured for $engine_name with OpenAI API format." -ForegroundColor "Green"
Write-ColorText "Note: MODEL_NAME is set to '$model_name'. You can change it later in the .env file." -ForegroundColor "Yellow"
Prompt-AdvancedSettings
# Start Docker if needed
$dockerRunning = Check-AndStartDocker
if (-not $dockerRunning) {
@@ -665,6 +967,12 @@ function Connect-CloudAPIProvider {
$script:llm_name = "azure_openai"
$script:model_name = "gpt-4o"
Get-APIKey
Write-Host ""
Write-ColorText "Azure OpenAI requires additional configuration:" -ForegroundColor "White" -Bold
$script:azure_api_base = Read-Host "Enter Azure OpenAI API base URL (e.g. https://your-resource.openai.azure.com/)"
$script:azure_api_version = Read-Host "Enter Azure OpenAI API version (e.g. 2024-02-15-preview)"
$script:azure_deployment = Read-Host "Enter Azure deployment name for chat"
$script:azure_emb_deployment = Read-Host "Enter Azure deployment name for embeddings (leave empty to skip)"
break
}
"7" { # Novita
@@ -696,9 +1004,19 @@ function Connect-CloudAPIProvider {
"LLM_PROVIDER=$llm_name" | Add-Content -Path $ENV_FILE -Encoding utf8
"LLM_NAME=$model_name" | Add-Content -Path $ENV_FILE -Encoding utf8
"VITE_API_STREAMING=true" | Add-Content -Path $ENV_FILE -Encoding utf8
# Azure OpenAI additional settings
if ($llm_name -eq "azure_openai") {
if ($azure_api_base) { "OPENAI_API_BASE=$azure_api_base" | Add-Content -Path $ENV_FILE -Encoding utf8 }
if ($azure_api_version) { "OPENAI_API_VERSION=$azure_api_version" | Add-Content -Path $ENV_FILE -Encoding utf8 }
if ($azure_deployment) { "AZURE_DEPLOYMENT_NAME=$azure_deployment" | Add-Content -Path $ENV_FILE -Encoding utf8 }
if ($azure_emb_deployment) { "AZURE_EMBEDDINGS_DEPLOYMENT_NAME=$azure_emb_deployment" | Add-Content -Path $ENV_FILE -Encoding utf8 }
}
Write-ColorText ".env file configured for $provider_name." -ForegroundColor "Green"
Prompt-AdvancedSettings
# Start Docker if needed
$dockerRunning = Check-AndStartDocker
if (-not $dockerRunning) {
@@ -709,15 +1027,15 @@ function Connect-CloudAPIProvider {
try {
Write-Host ""
Write-ColorText "Starting Docker Compose..." -ForegroundColor "White"
# Run Docker compose commands
& docker compose --env-file "$ENV_FILE" -f "$COMPOSE_FILE" pull
if ($LASTEXITCODE -ne 0) {
throw "Docker compose pull failed with exit code $LASTEXITCODE"
}
& docker compose --env-file "$ENV_FILE" -f "$COMPOSE_FILE" up -d
Write-Host ""
Write-ColorText "DocsGPT is now configured to use $provider_name on http://localhost:5173" -ForegroundColor "Green"
Write-ColorText "You can stop the application by running: docker compose -f `"$COMPOSE_FILE`" down" -ForegroundColor "Yellow"
@@ -734,6 +1052,23 @@ function Connect-CloudAPIProvider {
# Main script execution
Animate-Dino
# Check if .env file exists and is not empty
if ((Test-Path $ENV_FILE) -and ((Get-Item $ENV_FILE).Length -gt 0)) {
Write-Host ""
Write-ColorText "Warning: An existing .env file was found with the following settings:" -ForegroundColor "Yellow" -Bold
$envLines = Get-Content $ENV_FILE
$envLines | Select-Object -First 3 | ForEach-Object { Write-Host " $_" }
if ($envLines.Count -gt 3) {
Write-Host " ... and $($envLines.Count - 3) more lines"
}
Write-Host ""
$confirm_overwrite = Read-Host "Running setup will overwrite this file. Continue? (y/N)"
if ($confirm_overwrite -ne "y" -and $confirm_overwrite -ne "Y") {
Write-ColorText "Setup cancelled. Your .env file was not modified." -ForegroundColor "Green"
exit 0
}
}
while ($true) {
Clear-Host
Prompt-MainMenu

324
setup.sh
View File

@@ -9,9 +9,10 @@ NC='\033[0m'
BOLD='\033[1m'
# Base Compose file (relative to script location)
COMPOSE_FILE="$(dirname "$(readlink -f "$0")")/deployment/docker-compose-hub.yaml"
COMPOSE_FILE_LOCAL="$(dirname "$(readlink -f "$0")")/deployment/docker-compose.yaml"
ENV_FILE="$(dirname "$(readlink -f "$0")")/.env"
SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd -P)"
COMPOSE_FILE="${SCRIPT_DIR}/deployment/docker-compose-hub.yaml"
COMPOSE_FILE_LOCAL="${SCRIPT_DIR}/deployment/docker-compose.yaml"
ENV_FILE="${SCRIPT_DIR}/.env"
# Animation function
animate_dino() {
@@ -155,7 +156,7 @@ prompt_cloud_api_provider_options() {
echo -e "${YELLOW}7) Novita${NC}"
echo -e "${YELLOW}b) Back to Main Menu${NC}"
echo
read -p "$(echo -e "${DEFAULT_FG}Choose option (1-6, or b): ${NC}")" provider_choice
read -p "$(echo -e "${DEFAULT_FG}Choose option (1-7, or b): ${NC}")" provider_choice
}
# Function to prompt for Ollama CPU/GPU options
@@ -170,6 +171,264 @@ prompt_ollama_options() {
read -p "$(echo -e "${DEFAULT_FG}Choose option (1-2, or b): ${NC}")" ollama_choice
}
# ========================
# Advanced Settings Functions
# ========================
# Vector Store configuration
configure_vector_store() {
echo -e "\n${DEFAULT_FG}${BOLD}Vector Store Configuration${NC}"
echo -e "${DEFAULT_FG}Choose your vector store:${NC}"
echo -e "${YELLOW}1) FAISS (default, local)${NC}"
echo -e "${YELLOW}2) Elasticsearch${NC}"
echo -e "${YELLOW}3) Qdrant${NC}"
echo -e "${YELLOW}4) Milvus${NC}"
echo -e "${YELLOW}5) LanceDB${NC}"
echo -e "${YELLOW}6) PGVector${NC}"
echo -e "${YELLOW}b) Back${NC}"
echo
read -p "$(echo -e "${DEFAULT_FG}Choose option (1-6, or b): ${NC}")" vs_choice
case "$vs_choice" in
1)
echo "VECTOR_STORE=faiss" >> "$ENV_FILE"
echo -e "${GREEN}Vector store set to FAISS.${NC}"
;;
2)
echo "VECTOR_STORE=elasticsearch" >> "$ENV_FILE"
read -p "$(echo -e "${DEFAULT_FG}Enter Elasticsearch URL (e.g. http://localhost:9200): ${NC}")" elastic_url
[ -n "$elastic_url" ] && echo "ELASTIC_URL=$elastic_url" >> "$ENV_FILE"
read -p "$(echo -e "${DEFAULT_FG}Enter Elasticsearch Cloud ID (leave empty if using URL): ${NC}")" elastic_cloud_id
[ -n "$elastic_cloud_id" ] && echo "ELASTIC_CLOUD_ID=$elastic_cloud_id" >> "$ENV_FILE"
read -p "$(echo -e "${DEFAULT_FG}Enter Elasticsearch username (leave empty if none): ${NC}")" elastic_user
[ -n "$elastic_user" ] && echo "ELASTIC_USERNAME=$elastic_user" >> "$ENV_FILE"
read -p "$(echo -e "${DEFAULT_FG}Enter Elasticsearch password (leave empty if none): ${NC}")" elastic_pass
[ -n "$elastic_pass" ] && echo "ELASTIC_PASSWORD=$elastic_pass" >> "$ENV_FILE"
read -p "$(echo -e "${DEFAULT_FG}Enter Elasticsearch index name (default: docsgpt): ${NC}")" elastic_index
echo "ELASTIC_INDEX=${elastic_index:-docsgpt}" >> "$ENV_FILE"
echo -e "${GREEN}Vector store set to Elasticsearch.${NC}"
;;
3)
echo "VECTOR_STORE=qdrant" >> "$ENV_FILE"
read -p "$(echo -e "${DEFAULT_FG}Enter Qdrant URL (e.g. http://localhost:6333): ${NC}")" qdrant_url
[ -n "$qdrant_url" ] && echo "QDRANT_URL=$qdrant_url" >> "$ENV_FILE"
read -p "$(echo -e "${DEFAULT_FG}Enter Qdrant API key (leave empty if none): ${NC}")" qdrant_key
[ -n "$qdrant_key" ] && echo "QDRANT_API_KEY=$qdrant_key" >> "$ENV_FILE"
read -p "$(echo -e "${DEFAULT_FG}Enter Qdrant collection name (default: docsgpt): ${NC}")" qdrant_collection
echo "QDRANT_COLLECTION_NAME=${qdrant_collection:-docsgpt}" >> "$ENV_FILE"
echo -e "${GREEN}Vector store set to Qdrant.${NC}"
;;
4)
echo "VECTOR_STORE=milvus" >> "$ENV_FILE"
read -p "$(echo -e "${DEFAULT_FG}Enter Milvus URI (default: ./milvus_local.db): ${NC}")" milvus_uri
echo "MILVUS_URI=${milvus_uri:-./milvus_local.db}" >> "$ENV_FILE"
read -p "$(echo -e "${DEFAULT_FG}Enter Milvus token (leave empty if none): ${NC}")" milvus_token
[ -n "$milvus_token" ] && echo "MILVUS_TOKEN=$milvus_token" >> "$ENV_FILE"
read -p "$(echo -e "${DEFAULT_FG}Enter Milvus collection name (default: docsgpt): ${NC}")" milvus_collection
echo "MILVUS_COLLECTION_NAME=${milvus_collection:-docsgpt}" >> "$ENV_FILE"
echo -e "${GREEN}Vector store set to Milvus.${NC}"
;;
5)
echo "VECTOR_STORE=lancedb" >> "$ENV_FILE"
read -p "$(echo -e "${DEFAULT_FG}Enter LanceDB path (default: ./data/lancedb): ${NC}")" lancedb_path
echo "LANCEDB_PATH=${lancedb_path:-./data/lancedb}" >> "$ENV_FILE"
read -p "$(echo -e "${DEFAULT_FG}Enter LanceDB table name (default: docsgpts): ${NC}")" lancedb_table
echo "LANCEDB_TABLE_NAME=${lancedb_table:-docsgpts}" >> "$ENV_FILE"
echo -e "${GREEN}Vector store set to LanceDB.${NC}"
;;
6)
echo "VECTOR_STORE=pgvector" >> "$ENV_FILE"
read -p "$(echo -e "${DEFAULT_FG}Enter PGVector connection string (e.g. postgresql://user:pass@host:5432/db): ${NC}")" pgvector_conn
[ -n "$pgvector_conn" ] && echo "PGVECTOR_CONNECTION_STRING=$pgvector_conn" >> "$ENV_FILE"
echo -e "${GREEN}Vector store set to PGVector.${NC}"
;;
b|B) return ;;
*) echo -e "\n${RED}Invalid choice.${NC}" ; sleep 1 ;;
esac
}
# Embeddings configuration
configure_embeddings() {
echo -e "\n${DEFAULT_FG}${BOLD}Embeddings Configuration${NC}"
echo -e "${DEFAULT_FG}Choose your embeddings provider:${NC}"
echo -e "${YELLOW}1) HuggingFace (default, local)${NC}"
echo -e "${YELLOW}2) OpenAI Embeddings${NC}"
echo -e "${YELLOW}3) Custom Remote Embeddings (OpenAI-compatible API)${NC}"
echo -e "${YELLOW}b) Back${NC}"
echo
read -p "$(echo -e "${DEFAULT_FG}Choose option (1-3, or b): ${NC}")" emb_choice
case "$emb_choice" in
1)
echo "EMBEDDINGS_NAME=huggingface_sentence-transformers/all-mpnet-base-v2" >> "$ENV_FILE"
echo -e "${GREEN}Embeddings set to HuggingFace (local).${NC}"
;;
2)
echo "EMBEDDINGS_NAME=openai_text-embedding-ada-002" >> "$ENV_FILE"
read -p "$(echo -e "${DEFAULT_FG}Enter Embeddings API key (leave empty to reuse LLM API_KEY): ${NC}")" emb_key
[ -n "$emb_key" ] && echo "EMBEDDINGS_KEY=$emb_key" >> "$ENV_FILE"
echo -e "${GREEN}Embeddings set to OpenAI.${NC}"
;;
3)
read -p "$(echo -e "${DEFAULT_FG}Enter embeddings model name: ${NC}")" emb_name
[ -n "$emb_name" ] && echo "EMBEDDINGS_NAME=$emb_name" >> "$ENV_FILE"
read -p "$(echo -e "${DEFAULT_FG}Enter remote embeddings API base URL: ${NC}")" emb_url
[ -n "$emb_url" ] && echo "EMBEDDINGS_BASE_URL=$emb_url" >> "$ENV_FILE"
read -p "$(echo -e "${DEFAULT_FG}Enter embeddings API key (leave empty if none): ${NC}")" emb_key
[ -n "$emb_key" ] && echo "EMBEDDINGS_KEY=$emb_key" >> "$ENV_FILE"
echo -e "${GREEN}Custom remote embeddings configured.${NC}"
;;
b|B) return ;;
*) echo -e "\n${RED}Invalid choice.${NC}" ; sleep 1 ;;
esac
}
# Authentication configuration
configure_auth() {
echo -e "\n${DEFAULT_FG}${BOLD}Authentication Configuration${NC}"
echo -e "${DEFAULT_FG}Choose authentication type:${NC}"
echo -e "${YELLOW}1) None (default, no authentication)${NC}"
echo -e "${YELLOW}2) Simple JWT${NC}"
echo -e "${YELLOW}3) Session JWT${NC}"
echo -e "${YELLOW}b) Back${NC}"
echo
read -p "$(echo -e "${DEFAULT_FG}Choose option (1-3, or b): ${NC}")" auth_choice
case "$auth_choice" in
1)
echo -e "${GREEN}Authentication disabled (default).${NC}"
;;
2)
echo "AUTH_TYPE=simple_jwt" >> "$ENV_FILE"
read -p "$(echo -e "${DEFAULT_FG}Enter JWT secret key (leave empty to auto-generate): ${NC}")" jwt_key
if [ -n "$jwt_key" ]; then
echo "JWT_SECRET_KEY=$jwt_key" >> "$ENV_FILE"
else
generated_key=$(openssl rand -hex 32 2>/dev/null || head -c 64 /dev/urandom | od -An -tx1 | tr -d ' \n')
echo "JWT_SECRET_KEY=$generated_key" >> "$ENV_FILE"
echo -e "${YELLOW}Auto-generated JWT secret key.${NC}"
fi
echo -e "${GREEN}Authentication set to Simple JWT.${NC}"
;;
3)
echo "AUTH_TYPE=session_jwt" >> "$ENV_FILE"
read -p "$(echo -e "${DEFAULT_FG}Enter JWT secret key (leave empty to auto-generate): ${NC}")" jwt_key
if [ -n "$jwt_key" ]; then
echo "JWT_SECRET_KEY=$jwt_key" >> "$ENV_FILE"
else
generated_key=$(openssl rand -hex 32 2>/dev/null || head -c 64 /dev/urandom | od -An -tx1 | tr -d ' \n')
echo "JWT_SECRET_KEY=$generated_key" >> "$ENV_FILE"
echo -e "${YELLOW}Auto-generated JWT secret key.${NC}"
fi
echo -e "${GREEN}Authentication set to Session JWT.${NC}"
;;
b|B) return ;;
*) echo -e "\n${RED}Invalid choice.${NC}" ; sleep 1 ;;
esac
}
# Integrations configuration
configure_integrations() {
echo -e "\n${DEFAULT_FG}${BOLD}Integrations Configuration${NC}"
echo -e "${YELLOW}1) Google Drive${NC}"
echo -e "${YELLOW}2) GitHub${NC}"
echo -e "${YELLOW}b) Back${NC}"
echo
read -p "$(echo -e "${DEFAULT_FG}Choose option (1-2, or b): ${NC}")" int_choice
case "$int_choice" in
1)
read -p "$(echo -e "${DEFAULT_FG}Enter Google OAuth Client ID: ${NC}")" google_id
[ -n "$google_id" ] && echo "GOOGLE_CLIENT_ID=$google_id" >> "$ENV_FILE"
read -p "$(echo -e "${DEFAULT_FG}Enter Google OAuth Client Secret: ${NC}")" google_secret
[ -n "$google_secret" ] && echo "GOOGLE_CLIENT_SECRET=$google_secret" >> "$ENV_FILE"
echo -e "${GREEN}Google Drive integration configured.${NC}"
;;
2)
read -p "$(echo -e "${DEFAULT_FG}Enter GitHub Personal Access Token (with repo read access): ${NC}")" github_token
[ -n "$github_token" ] && echo "GITHUB_ACCESS_TOKEN=$github_token" >> "$ENV_FILE"
echo -e "${GREEN}GitHub integration configured.${NC}"
;;
b|B) return ;;
*) echo -e "\n${RED}Invalid choice.${NC}" ; sleep 1 ;;
esac
}
# Document Processing configuration
configure_doc_processing() {
echo -e "\n${DEFAULT_FG}${BOLD}Document Processing Configuration${NC}"
read -p "$(echo -e "${DEFAULT_FG}Parse PDF pages as images for better table/chart extraction? (y/N): ${NC}")" pdf_image
if [[ "$pdf_image" =~ ^[yY]$ ]]; then
echo "PARSE_PDF_AS_IMAGE=true" >> "$ENV_FILE"
echo -e "${GREEN}PDF-as-image parsing enabled.${NC}"
fi
read -p "$(echo -e "${DEFAULT_FG}Enable OCR for document processing (Docling)? (y/N): ${NC}")" ocr_enabled
if [[ "$ocr_enabled" =~ ^[yY]$ ]]; then
echo "DOCLING_OCR_ENABLED=true" >> "$ENV_FILE"
echo -e "${GREEN}Docling OCR enabled.${NC}"
fi
}
# Text-to-Speech configuration
configure_tts() {
echo -e "\n${DEFAULT_FG}${BOLD}Text-to-Speech Configuration${NC}"
echo -e "${DEFAULT_FG}Choose TTS provider:${NC}"
echo -e "${YELLOW}1) Google TTS (default, free)${NC}"
echo -e "${YELLOW}2) ElevenLabs${NC}"
echo -e "${YELLOW}b) Back${NC}"
echo
read -p "$(echo -e "${DEFAULT_FG}Choose option (1-2, or b): ${NC}")" tts_choice
case "$tts_choice" in
1)
echo "TTS_PROVIDER=google_tts" >> "$ENV_FILE"
echo -e "${GREEN}TTS set to Google TTS.${NC}"
;;
2)
echo "TTS_PROVIDER=elevenlabs" >> "$ENV_FILE"
read -p "$(echo -e "${DEFAULT_FG}Enter ElevenLabs API key: ${NC}")" elevenlabs_key
[ -n "$elevenlabs_key" ] && echo "ELEVENLABS_API_KEY=$elevenlabs_key" >> "$ENV_FILE"
echo -e "${GREEN}TTS set to ElevenLabs.${NC}"
;;
b|B) return ;;
*) echo -e "\n${RED}Invalid choice.${NC}" ; sleep 1 ;;
esac
}
# Main advanced settings menu
prompt_advanced_settings() {
echo
read -p "$(echo -e "${DEFAULT_FG}Would you like to configure advanced settings? (y/N): ${NC}")" configure_advanced
if [[ ! "$configure_advanced" =~ ^[yY]$ ]]; then
return
fi
while true; do
echo -e "\n${DEFAULT_FG}${BOLD}Advanced Settings${NC}"
echo -e "${YELLOW}1) Vector Store ${NC}${DEFAULT_FG}(default: faiss)${NC}"
echo -e "${YELLOW}2) Embeddings ${NC}${DEFAULT_FG}(default: HuggingFace local)${NC}"
echo -e "${YELLOW}3) Authentication ${NC}${DEFAULT_FG}(default: none)${NC}"
echo -e "${YELLOW}4) Integrations ${NC}${DEFAULT_FG}(Google Drive, GitHub)${NC}"
echo -e "${YELLOW}5) Document Processing ${NC}${DEFAULT_FG}(PDF as image, OCR)${NC}"
echo -e "${YELLOW}6) Text-to-Speech ${NC}${DEFAULT_FG}(default: Google TTS)${NC}"
echo -e "${YELLOW}s) Save and Continue with Docker setup${NC}"
echo
read -p "$(echo -e "${DEFAULT_FG}Choose option (1-6, or s): ${NC}")" adv_choice
case "$adv_choice" in
1) configure_vector_store ;;
2) configure_embeddings ;;
3) configure_auth ;;
4) configure_integrations ;;
5) configure_doc_processing ;;
6) configure_tts ;;
s|S) break ;;
*) echo -e "\n${RED}Invalid choice.${NC}" ; sleep 1 ;;
esac
done
}
# 1) Use DocsGPT Public API Endpoint (simple and free)
use_docs_public_api_endpoint() {
echo -e "\n${NC}Setting up DocsGPT Public API Endpoint...${NC}"
@@ -177,6 +436,8 @@ use_docs_public_api_endpoint() {
echo "VITE_API_STREAMING=true" >> "$ENV_FILE"
echo -e "${GREEN}.env file configured for DocsGPT Public API.${NC}"
prompt_advanced_settings
check_and_start_docker
echo -e "\n${NC}Starting Docker Compose...${NC}"
@@ -229,11 +490,11 @@ serve_local_ollama() {
docker_compose_file_suffix="gpu"
get_model_name_ollama
break ;;
b|B) clear; return ;; # Back to Main Menu
b|B) clear; return 1 ;; # Back to Main Menu
*) echo -e "\n${RED}Invalid choice. Please choose y or b.${NC}" ; sleep 1 ;;
esac
;;
b|B) clear; return ;; # Back to Main Menu
b|B) clear; return 1 ;; # Back to Main Menu
*) echo -e "\n${RED}Invalid choice. Please choose 1-2, or b.${NC}" ; sleep 1 ;;
esac
done
@@ -248,6 +509,7 @@ serve_local_ollama() {
echo "EMBEDDINGS_NAME=huggingface_sentence-transformers/all-mpnet-base-v2" >> "$ENV_FILE"
echo -e "${GREEN}.env file configured for Ollama ($(echo "$docker_compose_file_suffix" | tr '[:lower:]' '[:upper:]')${NC}${GREEN}).${NC}"
prompt_advanced_settings
check_and_start_docker
local compose_files=(
@@ -346,7 +608,7 @@ connect_local_inference_engine() {
openai_base_url="http://host.docker.internal:23333/v1"
get_model_name
break ;;
b|B) clear; return ;; # Back to Main Menu
b|B) clear; return 1 ;; # Back to Main Menu
*) echo -e "\n${RED}Invalid choice. Please choose 1-8, or b.${NC}" ; sleep 1 ;;
esac
done
@@ -361,6 +623,8 @@ connect_local_inference_engine() {
echo -e "${GREEN}.env file configured for ${BOLD}${engine_name}${NC}${GREEN} with OpenAI API format.${NC}"
echo -e "${YELLOW}Note: MODEL_NAME is set to '${BOLD}$model_name${NC}${YELLOW}'. You can change it later in the .env file.${NC}"
prompt_advanced_settings
check_and_start_docker
echo -e "\n${NC}Starting Docker Compose...${NC}"
@@ -431,6 +695,11 @@ connect_cloud_api_provider() {
llm_provider="azure_openai"
model_name="gpt-4o"
get_api_key
echo -e "\n${DEFAULT_FG}${BOLD}Azure OpenAI requires additional configuration:${NC}"
read -p "$(echo -e "${DEFAULT_FG}Enter Azure OpenAI API base URL (e.g. https://your-resource.openai.azure.com/): ${NC}")" azure_api_base
read -p "$(echo -e "${DEFAULT_FG}Enter Azure OpenAI API version (e.g. 2024-02-15-preview): ${NC}")" azure_api_version
read -p "$(echo -e "${DEFAULT_FG}Enter Azure deployment name for chat: ${NC}")" azure_deployment
read -p "$(echo -e "${DEFAULT_FG}Enter Azure deployment name for embeddings (leave empty to skip): ${NC}")" azure_emb_deployment
break ;;
7) # Novita
provider_name="Novita"
@@ -438,8 +707,8 @@ connect_cloud_api_provider() {
model_name="deepseek/deepseek-r1"
get_api_key
break ;;
b|B) clear; return ;; # Clear screen and Back to Main Menu
*) echo -e "\n${RED}Invalid choice. Please choose 1-6, or b.${NC}" ; sleep 1 ;;
b|B) clear; return 1 ;; # Clear screen and Back to Main Menu
*) echo -e "\n${RED}Invalid choice. Please choose 1-7, or b.${NC}" ; sleep 1 ;;
esac
done
@@ -448,8 +717,19 @@ connect_cloud_api_provider() {
echo "LLM_PROVIDER=$llm_provider" >> "$ENV_FILE"
echo "LLM_NAME=$model_name" >> "$ENV_FILE"
echo "VITE_API_STREAMING=true" >> "$ENV_FILE"
# Azure OpenAI additional settings
if [ "$llm_provider" = "azure_openai" ]; then
[ -n "$azure_api_base" ] && echo "OPENAI_API_BASE=$azure_api_base" >> "$ENV_FILE"
[ -n "$azure_api_version" ] && echo "OPENAI_API_VERSION=$azure_api_version" >> "$ENV_FILE"
[ -n "$azure_deployment" ] && echo "AZURE_DEPLOYMENT_NAME=$azure_deployment" >> "$ENV_FILE"
[ -n "$azure_emb_deployment" ] && echo "AZURE_EMBEDDINGS_DEPLOYMENT_NAME=$azure_emb_deployment" >> "$ENV_FILE"
fi
echo -e "${GREEN}.env file configured for ${BOLD}${provider_name}${NC}${GREEN}.${NC}"
prompt_advanced_settings
check_and_start_docker
echo -e "\n${NC}Starting Docker Compose...${NC}"
@@ -472,6 +752,21 @@ connect_cloud_api_provider() {
# Main script execution
animate_dino
# Check if .env file exists and is not empty
if [ -f "$ENV_FILE" ] && [ -s "$ENV_FILE" ]; then
echo -e "\n${YELLOW}${BOLD}Warning:${NC}${YELLOW} An existing .env file was found with the following settings:${NC}"
head -3 "$ENV_FILE" | while IFS= read -r line; do echo -e "${DEFAULT_FG} $line${NC}"; done
total_lines=$(wc -l < "$ENV_FILE")
if [ "$total_lines" -gt 3 ]; then
echo -e "${DEFAULT_FG} ... and $((total_lines - 3)) more lines${NC}"
fi
echo
read -p "$(echo -e "${YELLOW}Running setup will overwrite this file. Continue? (y/N): ${NC}")" confirm_overwrite
if [[ ! "$confirm_overwrite" =~ ^[yY]$ ]]; then
echo -e "${GREEN}Setup cancelled. Your .env file was not modified.${NC}"
exit 0
fi
fi
while true; do # Main menu loop
clear # Clear screen before showing main menu again
@@ -479,18 +774,15 @@ while true; do # Main menu loop
case $main_choice in
1) # Use DocsGPT Public API Endpoint (Docker Hub images)
COMPOSE_FILE="$(dirname "$(readlink -f "$0")")/deployment/docker-compose-hub.yaml"
COMPOSE_FILE="${SCRIPT_DIR}/deployment/docker-compose-hub.yaml"
use_docs_public_api_endpoint
break ;;
2) # Serve Local (with Ollama)
serve_local_ollama
break ;;
serve_local_ollama && break ;;
3) # Connect Local Inference Engine
connect_local_inference_engine
break ;;
connect_local_inference_engine && break ;;
4) # Connect Cloud API Provider
connect_cloud_api_provider
break ;;
connect_cloud_api_provider && break ;;
5) # Advanced: Build images locally
echo -e "\n${YELLOW}You have selected to build images locally. This is recommended for developers or if you want to test local changes.${NC}"
COMPOSE_FILE="$COMPOSE_FILE_LOCAL"

View File

@@ -0,0 +1,332 @@
from types import SimpleNamespace
from typing import Any, Dict, Optional
import pytest
from application.api.user.workflows import routes as workflow_routes
from application.agents.workflows.node_agent import WorkflowNodeAgentFactory
from application.agents.workflows.schemas import (
NodeType,
Workflow,
WorkflowGraph,
WorkflowNode,
)
from application.agents.workflows.workflow_engine import WorkflowEngine
from application.api.user.workflows.routes import validate_workflow_structure
class StubNodeAgent:
def __init__(self, events):
self.events = events
def gen(self, _prompt):
yield from self.events
def create_engine() -> WorkflowEngine:
graph = WorkflowGraph(workflow=Workflow(name="Engine Test"), nodes=[], edges=[])
agent = SimpleNamespace(
endpoint="stream",
llm_name="openai",
model_id="gpt-4o-mini",
api_key="test-key",
chat_history=[],
decoded_token={"sub": "user-1"},
)
return WorkflowEngine(graph, agent)
def create_agent_node(
node_id: str,
output_variable: str = "",
json_schema: Optional[Dict[str, Any]] = None,
) -> WorkflowNode:
config = {
"agent_type": "classic",
"system_prompt": "You are a helpful assistant.",
"prompt_template": "",
"stream_to_user": False,
"tools": [],
}
if output_variable:
config["output_variable"] = output_variable
if json_schema is not None:
config["json_schema"] = json_schema
return WorkflowNode(
id=node_id,
workflow_id="workflow-1",
type=NodeType.AGENT,
title="Agent",
position={"x": 0, "y": 0},
config=config,
)
def test_execute_agent_node_saves_structured_output_as_json(monkeypatch):
engine = create_engine()
node = create_agent_node(
node_id="agent_1",
output_variable="result",
json_schema={"type": "object"},
)
node_events = [
{"answer": '{"summary":"ok",', "structured": True},
{"answer": '"score":2}', "structured": True},
]
monkeypatch.setattr(
WorkflowNodeAgentFactory,
"create",
staticmethod(lambda **kwargs: StubNodeAgent(node_events)),
)
monkeypatch.setattr(
"application.core.model_utils.get_api_key_for_provider",
lambda _provider: None,
)
list(engine._execute_agent_node(node))
expected_output = {"summary": "ok", "score": 2}
assert engine.state["node_agent_1_output"] == expected_output
assert engine.state["result"] == expected_output
def test_execute_agent_node_normalizes_wrapped_schema_before_agent_create(monkeypatch):
engine = create_engine()
node = create_agent_node(
node_id="agent_wrapped",
json_schema={"schema": {"type": "object"}},
)
node_events = [{"answer": '{"summary":"ok"}', "structured": True}]
captured: Dict[str, Any] = {}
def create_node_agent(**kwargs):
captured["json_schema"] = kwargs.get("json_schema")
return StubNodeAgent(node_events)
monkeypatch.setattr(
WorkflowNodeAgentFactory,
"create",
staticmethod(create_node_agent),
)
monkeypatch.setattr(
"application.core.model_utils.get_api_key_for_provider",
lambda _provider: None,
)
monkeypatch.setattr(
"application.core.model_utils.get_model_capabilities",
lambda _model_id: {"supports_structured_output": True},
)
list(engine._execute_agent_node(node))
assert captured["json_schema"] == {"type": "object"}
assert engine.state["node_agent_wrapped_output"] == {"summary": "ok"}
def test_execute_agent_node_falls_back_to_text_when_schema_not_configured(monkeypatch):
engine = create_engine()
node = create_agent_node(node_id="agent_2", output_variable="result")
node_events = [{"answer": "plain text answer"}]
monkeypatch.setattr(
WorkflowNodeAgentFactory,
"create",
staticmethod(lambda **kwargs: StubNodeAgent(node_events)),
)
monkeypatch.setattr(
"application.core.model_utils.get_api_key_for_provider",
lambda _provider: None,
)
list(engine._execute_agent_node(node))
assert engine.state["node_agent_2_output"] == "plain text answer"
assert engine.state["result"] == "plain text answer"
def test_validate_workflow_structure_rejects_invalid_agent_json_schema():
nodes = [
{"id": "start", "type": "start", "title": "Start", "data": {}},
{
"id": "agent",
"type": "agent",
"title": "Agent",
"data": {"json_schema": "invalid"},
},
{"id": "end", "type": "end", "title": "End", "data": {}},
]
edges = [
{"id": "edge_1", "source": "start", "target": "agent"},
{"id": "edge_2", "source": "agent", "target": "end"},
]
errors = validate_workflow_structure(nodes, edges)
assert any(
"Agent node 'Agent' JSON schema must be a valid JSON object" in err
for err in errors
)
def test_validate_workflow_structure_accepts_valid_agent_json_schema():
nodes = [
{"id": "start", "type": "start", "title": "Start", "data": {}},
{
"id": "agent",
"type": "agent",
"title": "Agent",
"data": {"json_schema": {"type": "object"}},
},
{"id": "end", "type": "end", "title": "End", "data": {}},
]
edges = [
{"id": "edge_1", "source": "start", "target": "agent"},
{"id": "edge_2", "source": "agent", "target": "end"},
]
errors = validate_workflow_structure(nodes, edges)
assert errors == []
def test_validate_workflow_structure_accepts_wrapped_agent_json_schema():
nodes = [
{"id": "start", "type": "start", "title": "Start", "data": {}},
{
"id": "agent",
"type": "agent",
"title": "Agent",
"data": {"json_schema": {"schema": {"type": "object"}}},
},
{"id": "end", "type": "end", "title": "End", "data": {}},
]
edges = [
{"id": "edge_1", "source": "start", "target": "agent"},
{"id": "edge_2", "source": "agent", "target": "end"},
]
errors = validate_workflow_structure(nodes, edges)
assert errors == []
def test_validate_workflow_structure_accepts_output_variable_and_schema_together():
nodes = [
{"id": "start", "type": "start", "title": "Start", "data": {}},
{
"id": "agent",
"type": "agent",
"title": "Agent",
"data": {
"output_variable": "answer",
"json_schema": {"type": "object"},
},
},
{"id": "end", "type": "end", "title": "End", "data": {}},
]
edges = [
{"id": "edge_1", "source": "start", "target": "agent"},
{"id": "edge_2", "source": "agent", "target": "end"},
]
errors = validate_workflow_structure(nodes, edges)
assert errors == []
def test_validate_workflow_structure_rejects_unsupported_structured_output_model(
monkeypatch,
):
monkeypatch.setattr(
workflow_routes,
"get_model_capabilities",
lambda _model_id: {"supports_structured_output": False},
)
nodes = [
{"id": "start", "type": "start", "title": "Start", "data": {}},
{
"id": "agent",
"type": "agent",
"title": "Agent",
"data": {
"model_id": "some-model",
"json_schema": {"type": "object"},
},
},
{"id": "end", "type": "end", "title": "End", "data": {}},
]
edges = [
{"id": "edge_1", "source": "start", "target": "agent"},
{"id": "edge_2", "source": "agent", "target": "end"},
]
errors = validate_workflow_structure(nodes, edges)
assert any(
"Agent node 'Agent' selected model does not support structured output"
in err
for err in errors
)
def test_execute_agent_node_raises_when_structured_output_violates_schema(monkeypatch):
engine = create_engine()
node = create_agent_node(
node_id="agent_3",
json_schema={
"type": "object",
"properties": {"summary": {"type": "string"}},
"required": ["summary"],
"additionalProperties": False,
},
)
node_events = [{"answer": '{"score":2}', "structured": True}]
monkeypatch.setattr(
WorkflowNodeAgentFactory,
"create",
staticmethod(lambda **kwargs: StubNodeAgent(node_events)),
)
monkeypatch.setattr(
"application.core.model_utils.get_api_key_for_provider",
lambda _provider: None,
)
monkeypatch.setattr(
"application.core.model_utils.get_model_capabilities",
lambda _model_id: {"supports_structured_output": True},
)
with pytest.raises(ValueError, match="Structured output did not match schema"):
list(engine._execute_agent_node(node))
def test_execute_agent_node_raises_when_schema_set_and_response_not_json(monkeypatch):
engine = create_engine()
node = create_agent_node(
node_id="agent_4",
json_schema={"type": "object"},
)
node_events = [{"answer": "not-json"}]
monkeypatch.setattr(
WorkflowNodeAgentFactory,
"create",
staticmethod(lambda **kwargs: StubNodeAgent(node_events)),
)
monkeypatch.setattr(
"application.core.model_utils.get_api_key_for_provider",
lambda _provider: None,
)
monkeypatch.setattr(
"application.core.model_utils.get_model_capabilities",
lambda _model_id: {"supports_structured_output": True},
)
with pytest.raises(
ValueError,
match="Structured output was expected but response was not valid JSON",
):
list(engine._execute_agent_node(node))

View File

@@ -0,0 +1,63 @@
from types import SimpleNamespace
from application.agents.workflows.schemas import Workflow, WorkflowGraph
from application.agents.workflows.workflow_engine import WorkflowEngine
def create_engine() -> WorkflowEngine:
graph = WorkflowGraph(workflow=Workflow(name="Template Test"), nodes=[], edges=[])
agent = SimpleNamespace(
user="user-1",
request_id="req-1",
retrieved_docs=[
{"title": "Doc A", "text": "Summary A"},
{"title": "Doc B", "text": "Summary B"},
],
)
return WorkflowEngine(graph, agent)
def test_workflow_template_supports_agent_namespace_and_legacy_variables():
engine = create_engine()
engine.state = {"query": "Hello", "chat_history": "[]", "ticket_id": 42}
rendered = engine._format_template(
"{{ agent.query }}|{{ agent.ticket_id }}|{{ query }}|{{ ticket_id }}"
)
assert rendered == "Hello|42|Hello|42"
def test_workflow_template_supports_global_namespaces():
engine = create_engine()
engine.state = {"query": "Hello"}
rendered = engine._format_template(
"{{ source.count }}|{{ source.summaries }}|{{ system.request_id }}"
)
assert rendered.startswith("2|")
assert "Doc A" in rendered
assert "Summary A" in rendered
assert rendered.endswith("|req-1")
def test_workflow_template_handles_namespace_conflicts_with_agent_prefix():
engine = create_engine()
engine.state = {"source": "user-defined-source"}
rendered = engine._format_template(
"{{ agent.source }}|{{ agent_source }}|{{ source.count }}"
)
assert rendered.startswith("user-defined-source|user-defined-source|")
def test_workflow_template_gracefully_handles_invalid_template_syntax():
engine = create_engine()
engine.state = {"query": "Hello"}
invalid_template = "{{ agent.query "
rendered = engine._format_template(invalid_template)
assert rendered == invalid_template

View File

@@ -199,6 +199,7 @@ class TestStreamProcessorAgentConfiguration:
try:
processor._configure_agent()
assert processor.agent_config is not None
assert processor.agent_id == str(agent_id)
except Exception as e:
assert "Invalid API Key" in str(e)
@@ -211,6 +212,7 @@ class TestStreamProcessorAgentConfiguration:
processor._configure_agent()
assert isinstance(processor.agent_config, dict)
assert processor.agent_id is None
@pytest.mark.unit

326
tests/test_usage.py Normal file
View File

@@ -0,0 +1,326 @@
import sys
import pytest
from application.usage import (
_count_tokens,
gen_token_usage,
stream_token_usage,
update_token_usage,
)
@pytest.mark.unit
def test_count_tokens_includes_tool_call_payloads():
payload = [
{
"function_call": {
"name": "search_docs",
"args": {"query": "pricing limits"},
"call_id": "call_1",
}
},
{
"function_response": {
"name": "search_docs",
"response": {"result": "Found 3 docs"},
"call_id": "call_1",
}
},
]
assert _count_tokens(payload) > 0
@pytest.mark.unit
def test_gen_token_usage_counts_structured_tool_content(monkeypatch):
captured = {}
def fake_update(decoded_token, user_api_key, token_usage, agent_id=None):
captured["decoded_token"] = decoded_token
captured["user_api_key"] = user_api_key
captured["token_usage"] = token_usage.copy()
captured["agent_id"] = agent_id
monkeypatch.setattr("application.usage.update_token_usage", fake_update)
class DummyLLM:
decoded_token = {"sub": "user_123"}
user_api_key = "api_key_123"
agent_id = "agent_123"
token_usage = {"prompt_tokens": 0, "generated_tokens": 0}
@gen_token_usage
def wrapped(self, model, messages, stream, tools, **kwargs):
_ = (model, messages, stream, tools, kwargs)
return {
"tool_calls": [
{"name": "read_webpage", "arguments": {"url": "https://example.com"}}
]
}
messages = [
{
"role": "assistant",
"content": [
{
"function_call": {
"name": "search_docs",
"args": {"query": "pricing"},
"call_id": "1",
}
}
],
},
{
"role": "tool",
"content": [
{
"function_response": {
"name": "search_docs",
"response": {"result": "Found docs"},
"call_id": "1",
}
}
],
},
]
llm = DummyLLM()
wrapped(llm, "gpt-4o", messages, False, None)
assert captured["decoded_token"] == {"sub": "user_123"}
assert captured["user_api_key"] == "api_key_123"
assert captured["agent_id"] == "agent_123"
assert captured["token_usage"]["prompt_tokens"] > 0
assert captured["token_usage"]["generated_tokens"] > 0
@pytest.mark.unit
def test_stream_token_usage_counts_tool_call_chunks(monkeypatch):
captured = {}
def fake_update(decoded_token, user_api_key, token_usage, agent_id=None):
captured["token_usage"] = token_usage.copy()
captured["agent_id"] = agent_id
monkeypatch.setattr("application.usage.update_token_usage", fake_update)
class ToolChunk:
def model_dump(self):
return {
"delta": {
"tool_calls": [
{
"id": "call_1",
"function": {
"name": "get_weather",
"arguments": '{"location":"Seattle"}',
},
}
]
}
}
class DummyLLM:
decoded_token = {"sub": "user_123"}
user_api_key = "api_key_123"
agent_id = "agent_123"
token_usage = {"prompt_tokens": 0, "generated_tokens": 0}
@stream_token_usage
def wrapped(self, model, messages, stream, tools, **kwargs):
_ = (model, messages, stream, tools, kwargs)
yield ToolChunk()
yield "done"
messages = [
{
"role": "assistant",
"content": [
{
"function_call": {
"name": "get_weather",
"args": {"location": "Seattle"},
"call_id": "1",
}
}
],
}
]
llm = DummyLLM()
list(wrapped(llm, "gpt-4o", messages, True, None))
assert captured["agent_id"] == "agent_123"
assert captured["token_usage"]["prompt_tokens"] > 0
assert captured["token_usage"]["generated_tokens"] > 0
@pytest.mark.unit
def test_gen_token_usage_counts_tools_and_image_inputs(monkeypatch):
captured = []
def fake_update(decoded_token, user_api_key, token_usage, agent_id=None):
_ = (decoded_token, user_api_key, agent_id)
captured.append(token_usage.copy())
monkeypatch.setattr("application.usage.update_token_usage", fake_update)
class DummyLLM:
decoded_token = {"sub": "user_123"}
user_api_key = "api_key_123"
agent_id = "agent_123"
token_usage = {"prompt_tokens": 0, "generated_tokens": 0}
@gen_token_usage
def wrapped(self, model, messages, stream, tools, **kwargs):
_ = (model, messages, stream, tools, kwargs)
return "ok"
messages = [{"role": "user", "content": "What is in this image?"}]
tools_payload = [
{
"type": "function",
"function": {
"name": "describe_image",
"description": "Describe image content",
"parameters": {
"type": "object",
"properties": {"detail": {"type": "string"}},
},
},
}
]
usage_attachments = [
{
"mime_type": "image/png",
"path": "attachments/example.png",
"data": "abc123",
}
]
llm = DummyLLM()
wrapped(llm, "gpt-4o", messages, False, None)
wrapped(
llm,
"gpt-4o",
messages,
False,
tools_payload,
_usage_attachments=usage_attachments,
)
assert len(captured) == 2
assert captured[1]["prompt_tokens"] > captured[0]["prompt_tokens"]
@pytest.mark.unit
def test_stream_token_usage_counts_tools_and_image_inputs(monkeypatch):
captured = []
def fake_update(decoded_token, user_api_key, token_usage, agent_id=None):
_ = (decoded_token, user_api_key, agent_id)
captured.append(token_usage.copy())
monkeypatch.setattr("application.usage.update_token_usage", fake_update)
class DummyLLM:
decoded_token = {"sub": "user_123"}
user_api_key = "api_key_123"
agent_id = "agent_123"
token_usage = {"prompt_tokens": 0, "generated_tokens": 0}
@stream_token_usage
def wrapped(self, model, messages, stream, tools, **kwargs):
_ = (model, messages, stream, tools, kwargs)
yield "ok"
messages = [{"role": "user", "content": "What is in this image?"}]
tools_payload = [
{
"type": "function",
"function": {
"name": "describe_image",
"description": "Describe image content",
"parameters": {
"type": "object",
"properties": {"detail": {"type": "string"}},
},
},
}
]
usage_attachments = [
{
"mime_type": "image/png",
"path": "attachments/example.png",
"data": "abc123",
}
]
llm = DummyLLM()
list(wrapped(llm, "gpt-4o", messages, True, None))
list(
wrapped(
llm,
"gpt-4o",
messages,
True,
tools_payload,
_usage_attachments=usage_attachments,
)
)
assert len(captured) == 2
assert captured[1]["prompt_tokens"] > captured[0]["prompt_tokens"]
@pytest.mark.unit
def test_update_token_usage_inserts_with_agent_id_only(monkeypatch):
inserted_docs = []
class FakeCollection:
def insert_one(self, doc):
inserted_docs.append(doc)
modules_without_pytest = dict(sys.modules)
modules_without_pytest.pop("pytest", None)
monkeypatch.setattr("application.usage.sys.modules", modules_without_pytest)
monkeypatch.setattr("application.usage.usage_collection", FakeCollection())
update_token_usage(
decoded_token=None,
user_api_key=None,
token_usage={"prompt_tokens": 10, "generated_tokens": 5},
agent_id="agent_123",
)
assert len(inserted_docs) == 1
assert inserted_docs[0]["agent_id"] == "agent_123"
assert inserted_docs[0]["user_id"] is None
assert inserted_docs[0]["api_key"] is None
@pytest.mark.unit
def test_update_token_usage_skips_when_all_ids_missing(monkeypatch):
inserted_docs = []
class FakeCollection:
def insert_one(self, doc):
inserted_docs.append(doc)
modules_without_pytest = dict(sys.modules)
modules_without_pytest.pop("pytest", None)
monkeypatch.setattr("application.usage.sys.modules", modules_without_pytest)
monkeypatch.setattr("application.usage.usage_collection", FakeCollection())
update_token_usage(
decoded_token=None,
user_api_key=None,
token_usage={"prompt_tokens": 10, "generated_tokens": 5},
agent_id=None,
)
assert inserted_docs == []