mirror of
https://github.com/arc53/DocsGPT.git
synced 2026-05-10 12:31:21 +00:00
Compare commits
16 Commits
dependabot
...
sharepoint
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8b19141d02 | ||
|
|
a680c07b2c | ||
|
|
f7413f96dd | ||
|
|
4b95437d78 | ||
|
|
99e18a7d1b | ||
|
|
e87da7e5a9 | ||
|
|
11e59540fb | ||
|
|
6257ca7935 | ||
|
|
96f8c1785d | ||
|
|
ba9afd6033 | ||
|
|
f744537fdd | ||
|
|
4d01a7df97 | ||
|
|
cc06b83d32 | ||
|
|
d9bc248522 | ||
|
|
80d8363541 | ||
|
|
1e0eaefba2 |
@@ -12,4 +12,17 @@ EMBEDDINGS_KEY=
|
||||
OPENAI_API_BASE=
|
||||
OPENAI_API_VERSION=
|
||||
AZURE_DEPLOYMENT_NAME=
|
||||
AZURE_EMBEDDINGS_DEPLOYMENT_NAME=
|
||||
AZURE_EMBEDDINGS_DEPLOYMENT_NAME=
|
||||
|
||||
#Azure AD Application (client) ID
|
||||
MICROSOFT_CLIENT_ID=your-azure-ad-client-id
|
||||
#Azure AD Application client secret
|
||||
MICROSOFT_CLIENT_SECRET=your-azure-ad-client-secret
|
||||
#Azure AD Tenant ID (or 'common' for multi-tenant)
|
||||
MICROSOFT_TENANT_ID=your-azure-ad-tenant-id
|
||||
#If you are using a Microsoft Entra ID tenant,
|
||||
#configure the AUTHORITY variable as
|
||||
#"https://login.microsoftonline.com/TENANT_GUID"
|
||||
#or "https://login.microsoftonline.com/contoso.onmicrosoft.com".
|
||||
#Alternatively, use "https://login.microsoftonline.com/common" for multi-tenant app.
|
||||
MICROSOFT_AUTHORITY=https://{tenantId}.ciamlogin.com/{tenantId}
|
||||
|
||||
8
.github/workflows/ci.yml
vendored
8
.github/workflows/ci.yml
vendored
@@ -34,13 +34,13 @@ jobs:
|
||||
install: true
|
||||
|
||||
- name: Login to DockerHub
|
||||
uses: docker/login-action@v4
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_PASSWORD }}
|
||||
|
||||
- name: Login to ghcr.io
|
||||
uses: docker/login-action@v4
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
registry: ghcr.io
|
||||
username: ${{ github.repository_owner }}
|
||||
@@ -75,13 +75,13 @@ jobs:
|
||||
install: true
|
||||
|
||||
- name: Login to DockerHub
|
||||
uses: docker/login-action@v4
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_PASSWORD }}
|
||||
|
||||
- name: Login to ghcr.io
|
||||
uses: docker/login-action@v4
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
registry: ghcr.io
|
||||
username: ${{ github.repository_owner }}
|
||||
|
||||
8
.github/workflows/cife.yml
vendored
8
.github/workflows/cife.yml
vendored
@@ -34,13 +34,13 @@ jobs:
|
||||
install: true
|
||||
|
||||
- name: Login to DockerHub
|
||||
uses: docker/login-action@v4
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_PASSWORD }}
|
||||
|
||||
- name: Login to ghcr.io
|
||||
uses: docker/login-action@v4
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
registry: ghcr.io
|
||||
username: ${{ github.repository_owner }}
|
||||
@@ -75,13 +75,13 @@ jobs:
|
||||
install: true
|
||||
|
||||
- name: Login to DockerHub
|
||||
uses: docker/login-action@v4
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_PASSWORD }}
|
||||
|
||||
- name: Login to ghcr.io
|
||||
uses: docker/login-action@v4
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
registry: ghcr.io
|
||||
username: ${{ github.repository_owner }}
|
||||
|
||||
8
.github/workflows/docker-develop-build.yml
vendored
8
.github/workflows/docker-develop-build.yml
vendored
@@ -32,13 +32,13 @@ jobs:
|
||||
install: true
|
||||
|
||||
- name: Login to DockerHub
|
||||
uses: docker/login-action@v4
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_PASSWORD }}
|
||||
|
||||
- name: Login to ghcr.io
|
||||
uses: docker/login-action@v4
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
registry: ghcr.io
|
||||
username: ${{ github.repository_owner }}
|
||||
@@ -73,13 +73,13 @@ jobs:
|
||||
install: true
|
||||
|
||||
- name: Login to DockerHub
|
||||
uses: docker/login-action@v4
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_PASSWORD }}
|
||||
|
||||
- name: Login to ghcr.io
|
||||
uses: docker/login-action@v4
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
registry: ghcr.io
|
||||
username: ${{ github.repository_owner }}
|
||||
|
||||
@@ -36,13 +36,13 @@ jobs:
|
||||
install: true
|
||||
|
||||
- name: Login to DockerHub
|
||||
uses: docker/login-action@v4
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_PASSWORD }}
|
||||
|
||||
- name: Login to ghcr.io
|
||||
uses: docker/login-action@v4
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
registry: ghcr.io
|
||||
username: ${{ github.repository_owner }}
|
||||
@@ -77,13 +77,13 @@ jobs:
|
||||
install: true
|
||||
|
||||
- name: Login to DockerHub
|
||||
uses: docker/login-action@v4
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_PASSWORD }}
|
||||
|
||||
- name: Login to ghcr.io
|
||||
uses: docker/login-action@v4
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
registry: ghcr.io
|
||||
username: ${{ github.repository_owner }}
|
||||
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -71,7 +71,6 @@ instance/
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
docs/public/_pagefind/
|
||||
|
||||
# PyBuilder
|
||||
target/
|
||||
|
||||
@@ -7,10 +7,6 @@ from bson.objectid import ObjectId
|
||||
|
||||
from application.agents.tools.tool_action_parser import ToolActionParser
|
||||
from application.agents.tools.tool_manager import ToolManager
|
||||
from application.core.json_schema_utils import (
|
||||
JsonSchemaValidationError,
|
||||
normalize_json_schema_payload,
|
||||
)
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.core.settings import settings
|
||||
from application.llm.handlers.handler_creator import LLMHandlerCreator
|
||||
@@ -27,7 +23,6 @@ class BaseAgent(ABC):
|
||||
llm_name: str,
|
||||
model_id: str,
|
||||
api_key: str,
|
||||
agent_id: Optional[str] = None,
|
||||
user_api_key: Optional[str] = None,
|
||||
prompt: str = "",
|
||||
chat_history: Optional[List[Dict]] = None,
|
||||
@@ -45,7 +40,6 @@ class BaseAgent(ABC):
|
||||
self.llm_name = llm_name
|
||||
self.model_id = model_id
|
||||
self.api_key = api_key
|
||||
self.agent_id = agent_id
|
||||
self.user_api_key = user_api_key
|
||||
self.prompt = prompt
|
||||
self.decoded_token = decoded_token or {}
|
||||
@@ -60,19 +54,13 @@ class BaseAgent(ABC):
|
||||
user_api_key=user_api_key,
|
||||
decoded_token=decoded_token,
|
||||
model_id=model_id,
|
||||
agent_id=agent_id,
|
||||
)
|
||||
self.retrieved_docs = retrieved_docs or []
|
||||
self.llm_handler = LLMHandlerCreator.create_handler(
|
||||
llm_name if llm_name else "default"
|
||||
)
|
||||
self.attachments = attachments or []
|
||||
self.json_schema = None
|
||||
if json_schema is not None:
|
||||
try:
|
||||
self.json_schema = normalize_json_schema_payload(json_schema)
|
||||
except JsonSchemaValidationError as exc:
|
||||
logger.warning("Ignoring invalid JSON schema payload: %s", exc)
|
||||
self.json_schema = json_schema
|
||||
self.limited_token_mode = limited_token_mode
|
||||
self.token_limit = token_limit
|
||||
self.limited_request_mode = limited_request_mode
|
||||
@@ -275,11 +263,6 @@ class BaseAgent(ABC):
|
||||
tool_config=tool_config,
|
||||
user_id=self.user,
|
||||
)
|
||||
resolved_arguments = (
|
||||
{"query_params": query_params, "headers": headers, "body": body}
|
||||
if tool_data["name"] == "api_tool"
|
||||
else parameters
|
||||
)
|
||||
if tool_data["name"] == "api_tool":
|
||||
logger.debug(
|
||||
f"Executing api: {action_name} with query_params: {query_params}, headers: {headers}, body: {body}"
|
||||
@@ -309,19 +292,11 @@ class BaseAgent(ABC):
|
||||
artifact_id = str(artifact_id).strip() if artifact_id is not None else ""
|
||||
if artifact_id:
|
||||
tool_call_data["artifact_id"] = artifact_id
|
||||
result_full = str(result)
|
||||
tool_call_data["resolved_arguments"] = resolved_arguments
|
||||
tool_call_data["result_full"] = result_full
|
||||
tool_call_data["result"] = (
|
||||
f"{result_full[:50]}..." if len(result_full) > 50 else result_full
|
||||
f"{str(result)[:50]}..." if len(str(result)) > 50 else result
|
||||
)
|
||||
|
||||
stream_tool_call_data = {
|
||||
key: value
|
||||
for key, value in tool_call_data.items()
|
||||
if key not in {"result_full", "resolved_arguments"}
|
||||
}
|
||||
yield {"type": "tool_call", "data": {**stream_tool_call_data, "status": "completed"}}
|
||||
yield {"type": "tool_call", "data": {**tool_call_data, "status": "completed"}}
|
||||
self.tool_calls.append(tool_call_data)
|
||||
|
||||
return result, call_id
|
||||
@@ -329,11 +304,7 @@ class BaseAgent(ABC):
|
||||
def _get_truncated_tool_calls(self):
|
||||
return [
|
||||
{
|
||||
"tool_name": tool_call.get("tool_name"),
|
||||
"call_id": tool_call.get("call_id"),
|
||||
"action_name": tool_call.get("action_name"),
|
||||
"arguments": tool_call.get("arguments"),
|
||||
"artifact_id": tool_call.get("artifact_id"),
|
||||
**tool_call,
|
||||
"result": (
|
||||
f"{str(tool_call['result'])[:50]}..."
|
||||
if len(str(tool_call["result"])) > 50
|
||||
@@ -605,9 +576,6 @@ class BaseAgent(ABC):
|
||||
self._validate_context_size(messages)
|
||||
|
||||
gen_kwargs = {"model": self.model_id, "messages": messages}
|
||||
if self.attachments:
|
||||
# Usage accounting only; stripped before provider invocation.
|
||||
gen_kwargs["_usage_attachments"] = self.attachments
|
||||
|
||||
if (
|
||||
hasattr(self.llm, "_supports_tools")
|
||||
|
||||
@@ -235,4 +235,4 @@ class ReActAgent(BaseAgent):
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error extracting content: {e}")
|
||||
return "".join(collected)
|
||||
return "".join(collected)
|
||||
@@ -211,21 +211,8 @@ class WorkflowAgent(BaseAgent):
|
||||
def _serialize_state(self, state: Dict[str, Any]) -> Dict[str, Any]:
|
||||
serialized: Dict[str, Any] = {}
|
||||
for key, value in state.items():
|
||||
serialized[key] = self._serialize_state_value(value)
|
||||
if isinstance(value, (str, int, float, bool, type(None))):
|
||||
serialized[key] = value
|
||||
else:
|
||||
serialized[key] = str(value)
|
||||
return serialized
|
||||
|
||||
def _serialize_state_value(self, value: Any) -> Any:
|
||||
if isinstance(value, dict):
|
||||
return {
|
||||
str(dict_key): self._serialize_state_value(dict_value)
|
||||
for dict_key, dict_value in value.items()
|
||||
}
|
||||
if isinstance(value, list):
|
||||
return [self._serialize_state_value(item) for item in value]
|
||||
if isinstance(value, tuple):
|
||||
return [self._serialize_state_value(item) for item in value]
|
||||
if isinstance(value, datetime):
|
||||
return value.isoformat()
|
||||
if isinstance(value, (str, int, float, bool, type(None))):
|
||||
return value
|
||||
return str(value)
|
||||
|
||||
@@ -1,64 +0,0 @@
|
||||
from typing import Any, Dict
|
||||
|
||||
import celpy
|
||||
import celpy.celtypes
|
||||
|
||||
|
||||
class CelEvaluationError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def _convert_value(value: Any) -> Any:
|
||||
if isinstance(value, bool):
|
||||
return celpy.celtypes.BoolType(value)
|
||||
if isinstance(value, int):
|
||||
return celpy.celtypes.IntType(value)
|
||||
if isinstance(value, float):
|
||||
return celpy.celtypes.DoubleType(value)
|
||||
if isinstance(value, str):
|
||||
return celpy.celtypes.StringType(value)
|
||||
if isinstance(value, list):
|
||||
return celpy.celtypes.ListType([_convert_value(item) for item in value])
|
||||
if isinstance(value, dict):
|
||||
return celpy.celtypes.MapType(
|
||||
{celpy.celtypes.StringType(k): _convert_value(v) for k, v in value.items()}
|
||||
)
|
||||
if value is None:
|
||||
return celpy.celtypes.BoolType(False)
|
||||
return celpy.celtypes.StringType(str(value))
|
||||
|
||||
|
||||
def build_activation(state: Dict[str, Any]) -> Dict[str, Any]:
|
||||
return {k: _convert_value(v) for k, v in state.items()}
|
||||
|
||||
|
||||
def evaluate_cel(expression: str, state: Dict[str, Any]) -> Any:
|
||||
if not expression or not expression.strip():
|
||||
raise CelEvaluationError("Empty expression")
|
||||
try:
|
||||
env = celpy.Environment()
|
||||
ast = env.compile(expression)
|
||||
program = env.program(ast)
|
||||
activation = build_activation(state)
|
||||
result = program.evaluate(activation)
|
||||
except celpy.CELEvalError as exc:
|
||||
raise CelEvaluationError(f"CEL evaluation error: {exc}") from exc
|
||||
except Exception as exc:
|
||||
raise CelEvaluationError(f"CEL error: {exc}") from exc
|
||||
return cel_to_python(result)
|
||||
|
||||
|
||||
def cel_to_python(value: Any) -> Any:
|
||||
if isinstance(value, celpy.celtypes.BoolType):
|
||||
return bool(value)
|
||||
if isinstance(value, celpy.celtypes.IntType):
|
||||
return int(value)
|
||||
if isinstance(value, celpy.celtypes.DoubleType):
|
||||
return float(value)
|
||||
if isinstance(value, celpy.celtypes.StringType):
|
||||
return str(value)
|
||||
if isinstance(value, celpy.celtypes.ListType):
|
||||
return [cel_to_python(item) for item in value]
|
||||
if isinstance(value, celpy.celtypes.MapType):
|
||||
return {str(k): cel_to_python(v) for k, v in value.items()}
|
||||
return value
|
||||
@@ -1,6 +1,6 @@
|
||||
from datetime import datetime, timezone
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Literal, Optional, Union
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from bson import ObjectId
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
@@ -12,7 +12,6 @@ class NodeType(str, Enum):
|
||||
AGENT = "agent"
|
||||
NOTE = "note"
|
||||
STATE = "state"
|
||||
CONDITION = "condition"
|
||||
|
||||
|
||||
class AgentType(str, Enum):
|
||||
@@ -49,25 +48,6 @@ class AgentNodeConfig(BaseModel):
|
||||
json_schema: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class ConditionCase(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid", populate_by_name=True)
|
||||
name: Optional[str] = None
|
||||
expression: str = ""
|
||||
source_handle: str = Field(..., alias="sourceHandle")
|
||||
|
||||
|
||||
class ConditionNodeConfig(BaseModel):
|
||||
model_config = ConfigDict(extra="allow")
|
||||
mode: Literal["simple", "advanced"] = "simple"
|
||||
cases: List[ConditionCase] = Field(default_factory=list)
|
||||
|
||||
|
||||
class StateOperation(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
expression: str = ""
|
||||
target_variable: str = ""
|
||||
|
||||
|
||||
class WorkflowEdgeCreate(BaseModel):
|
||||
model_config = ConfigDict(populate_by_name=True)
|
||||
id: str
|
||||
|
||||
@@ -1,30 +1,16 @@
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Dict, Generator, List, Optional, TYPE_CHECKING
|
||||
|
||||
from application.agents.workflows.cel_evaluator import CelEvaluationError, evaluate_cel
|
||||
from application.agents.workflows.node_agent import WorkflowNodeAgentFactory
|
||||
from application.agents.workflows.schemas import (
|
||||
AgentNodeConfig,
|
||||
ConditionNodeConfig,
|
||||
ExecutionStatus,
|
||||
NodeExecutionLog,
|
||||
NodeType,
|
||||
WorkflowGraph,
|
||||
WorkflowNode,
|
||||
)
|
||||
from application.core.json_schema_utils import (
|
||||
JsonSchemaValidationError,
|
||||
normalize_json_schema_payload,
|
||||
)
|
||||
from application.templates.namespaces import NamespaceManager
|
||||
from application.templates.template_engine import TemplateEngine, TemplateRenderError
|
||||
|
||||
try:
|
||||
import jsonschema
|
||||
except ImportError: # pragma: no cover - optional dependency in some deployments.
|
||||
jsonschema = None
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from application.agents.base import BaseAgent
|
||||
@@ -32,7 +18,6 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
StateValue = Any
|
||||
WorkflowState = Dict[str, StateValue]
|
||||
TEMPLATE_RESERVED_NAMESPACES = {"agent", "system", "source", "tools", "passthrough"}
|
||||
|
||||
|
||||
class WorkflowEngine:
|
||||
@@ -43,9 +28,6 @@ class WorkflowEngine:
|
||||
self.agent = agent
|
||||
self.state: WorkflowState = {}
|
||||
self.execution_log: List[Dict[str, Any]] = []
|
||||
self._condition_result: Optional[str] = None
|
||||
self._template_engine = TemplateEngine()
|
||||
self._namespace_manager = NamespaceManager()
|
||||
|
||||
def execute(
|
||||
self, initial_inputs: WorkflowState, query: str
|
||||
@@ -116,10 +98,6 @@ class WorkflowEngine:
|
||||
if node.type == NodeType.END:
|
||||
break
|
||||
current_node_id = self._get_next_node_id(current_node_id)
|
||||
if current_node_id is None and node.type != NodeType.END:
|
||||
logger.warning(
|
||||
f"Branch ended at node '{node.title}' ({node.id}) without reaching an end node"
|
||||
)
|
||||
steps += 1
|
||||
if steps >= self.MAX_EXECUTION_STEPS:
|
||||
logger.warning(
|
||||
@@ -143,20 +121,10 @@ class WorkflowEngine:
|
||||
}
|
||||
|
||||
def _get_next_node_id(self, current_node_id: str) -> Optional[str]:
|
||||
node = self.graph.get_node_by_id(current_node_id)
|
||||
edges = self.graph.get_outgoing_edges(current_node_id)
|
||||
if not edges:
|
||||
return None
|
||||
|
||||
if node and node.type == NodeType.CONDITION and self._condition_result:
|
||||
target_handle = self._condition_result
|
||||
self._condition_result = None
|
||||
for edge in edges:
|
||||
if edge.source_handle == target_handle:
|
||||
return edge.target_id
|
||||
return None
|
||||
|
||||
return edges[0].target_id
|
||||
if edges:
|
||||
return edges[0].target_id
|
||||
return None
|
||||
|
||||
def _execute_node(
|
||||
self, node: WorkflowNode
|
||||
@@ -168,7 +136,6 @@ class WorkflowEngine:
|
||||
NodeType.NOTE: self._execute_note_node,
|
||||
NodeType.AGENT: self._execute_agent_node,
|
||||
NodeType.STATE: self._execute_state_node,
|
||||
NodeType.CONDITION: self._execute_condition_node,
|
||||
NodeType.END: self._execute_end_node,
|
||||
}
|
||||
|
||||
@@ -189,62 +156,35 @@ class WorkflowEngine:
|
||||
def _execute_agent_node(
|
||||
self, node: WorkflowNode
|
||||
) -> Generator[Dict[str, str], None, None]:
|
||||
from application.core.model_utils import (
|
||||
get_api_key_for_provider,
|
||||
get_model_capabilities,
|
||||
get_provider_from_model_id,
|
||||
)
|
||||
from application.core.model_utils import get_api_key_for_provider
|
||||
|
||||
node_config = AgentNodeConfig(**node.config.get("config", node.config))
|
||||
node_config = AgentNodeConfig(**node.config)
|
||||
|
||||
if node_config.prompt_template:
|
||||
formatted_prompt = self._format_template(node_config.prompt_template)
|
||||
else:
|
||||
formatted_prompt = self.state.get("query", "")
|
||||
node_json_schema = self._normalize_node_json_schema(
|
||||
node_config.json_schema, node.title
|
||||
)
|
||||
node_model_id = node_config.model_id or self.agent.model_id
|
||||
node_llm_name = (
|
||||
node_config.llm_name
|
||||
or get_provider_from_model_id(node_model_id or "")
|
||||
or self.agent.llm_name
|
||||
)
|
||||
node_llm_name = node_config.llm_name or self.agent.llm_name
|
||||
node_api_key = get_api_key_for_provider(node_llm_name) or self.agent.api_key
|
||||
|
||||
if node_json_schema and node_model_id:
|
||||
model_capabilities = get_model_capabilities(node_model_id)
|
||||
if model_capabilities and not model_capabilities.get(
|
||||
"supports_structured_output", False
|
||||
):
|
||||
raise ValueError(
|
||||
f'Model "{node_model_id}" does not support structured output for node "{node.title}"'
|
||||
)
|
||||
|
||||
node_agent = WorkflowNodeAgentFactory.create(
|
||||
agent_type=node_config.agent_type,
|
||||
endpoint=self.agent.endpoint,
|
||||
llm_name=node_llm_name,
|
||||
model_id=node_model_id,
|
||||
model_id=node_config.model_id or self.agent.model_id,
|
||||
api_key=node_api_key,
|
||||
tool_ids=node_config.tools,
|
||||
prompt=node_config.system_prompt,
|
||||
chat_history=self.agent.chat_history,
|
||||
decoded_token=self.agent.decoded_token,
|
||||
json_schema=node_json_schema,
|
||||
json_schema=node_config.json_schema,
|
||||
)
|
||||
|
||||
full_response_parts: List[str] = []
|
||||
structured_response_parts: List[str] = []
|
||||
has_structured_response = False
|
||||
full_response = ""
|
||||
first_chunk = True
|
||||
for event in node_agent.gen(formatted_prompt):
|
||||
if "answer" in event:
|
||||
chunk = str(event["answer"])
|
||||
full_response_parts.append(chunk)
|
||||
if event.get("structured"):
|
||||
has_structured_response = True
|
||||
structured_response_parts.append(chunk)
|
||||
full_response += event["answer"]
|
||||
if node_config.stream_to_user:
|
||||
if first_chunk and hasattr(self, "_has_streamed"):
|
||||
yield {"answer": "\n\n"}
|
||||
@@ -254,189 +194,72 @@ class WorkflowEngine:
|
||||
if node_config.stream_to_user:
|
||||
self._has_streamed = True
|
||||
|
||||
full_response = "".join(full_response_parts).strip()
|
||||
output_value: Any = full_response
|
||||
if has_structured_response:
|
||||
structured_response = "".join(structured_response_parts).strip()
|
||||
response_to_parse = structured_response or full_response
|
||||
parsed_success, parsed_structured = self._parse_structured_output(
|
||||
response_to_parse
|
||||
)
|
||||
output_value = parsed_structured if parsed_success else response_to_parse
|
||||
if node_json_schema:
|
||||
self._validate_structured_output(node_json_schema, output_value)
|
||||
elif node_json_schema:
|
||||
parsed_success, parsed_structured = self._parse_structured_output(
|
||||
full_response
|
||||
)
|
||||
if not parsed_success:
|
||||
raise ValueError(
|
||||
"Structured output was expected but response was not valid JSON"
|
||||
)
|
||||
output_value = parsed_structured
|
||||
self._validate_structured_output(node_json_schema, output_value)
|
||||
|
||||
default_output_key = f"node_{node.id}_output"
|
||||
self.state[default_output_key] = output_value
|
||||
|
||||
if node_config.output_variable:
|
||||
self.state[node_config.output_variable] = output_value
|
||||
output_key = node_config.output_variable or f"node_{node.id}_output"
|
||||
self.state[output_key] = full_response
|
||||
|
||||
def _execute_state_node(
|
||||
self, node: WorkflowNode
|
||||
) -> Generator[Dict[str, str], None, None]:
|
||||
config = node.config.get("config", node.config)
|
||||
for op in config.get("operations", []):
|
||||
expression = op.get("expression", "")
|
||||
target_variable = op.get("target_variable", "")
|
||||
if expression and target_variable:
|
||||
self.state[target_variable] = evaluate_cel(expression, self.state)
|
||||
yield from ()
|
||||
config = node.config
|
||||
operations = config.get("operations", [])
|
||||
|
||||
def _execute_condition_node(
|
||||
self, node: WorkflowNode
|
||||
) -> Generator[Dict[str, str], None, None]:
|
||||
config = ConditionNodeConfig(**node.config.get("config", node.config))
|
||||
matched_handle = None
|
||||
if operations:
|
||||
for op in operations:
|
||||
key = op.get("key")
|
||||
operation = op.get("operation", "set")
|
||||
value = op.get("value")
|
||||
|
||||
for case in config.cases:
|
||||
if not case.expression.strip():
|
||||
continue
|
||||
try:
|
||||
if evaluate_cel(case.expression, self.state):
|
||||
matched_handle = case.source_handle
|
||||
break
|
||||
except CelEvaluationError:
|
||||
continue
|
||||
|
||||
self._condition_result = matched_handle or "else"
|
||||
if not key:
|
||||
continue
|
||||
if operation == "set":
|
||||
formatted_value = (
|
||||
self._format_template(str(value))
|
||||
if isinstance(value, str)
|
||||
else value
|
||||
)
|
||||
self.state[key] = formatted_value
|
||||
elif operation == "increment":
|
||||
current = self.state.get(key, 0)
|
||||
try:
|
||||
self.state[key] = int(current) + int(value or 1)
|
||||
except (ValueError, TypeError):
|
||||
self.state[key] = 1
|
||||
elif operation == "append":
|
||||
if key not in self.state:
|
||||
self.state[key] = []
|
||||
if isinstance(self.state[key], list):
|
||||
self.state[key].append(value)
|
||||
else:
|
||||
updates = config.get("updates", {})
|
||||
if not updates:
|
||||
var_name = config.get("variable")
|
||||
var_value = config.get("value")
|
||||
if var_name and isinstance(var_name, str):
|
||||
updates = {var_name: var_value or ""}
|
||||
if isinstance(updates, dict):
|
||||
for key, value in updates.items():
|
||||
if isinstance(value, str):
|
||||
self.state[key] = self._format_template(value)
|
||||
else:
|
||||
self.state[key] = value
|
||||
yield from ()
|
||||
|
||||
def _execute_end_node(
|
||||
self, node: WorkflowNode
|
||||
) -> Generator[Dict[str, str], None, None]:
|
||||
config = node.config.get("config", node.config)
|
||||
config = node.config
|
||||
output_template = str(config.get("output_template", ""))
|
||||
if output_template:
|
||||
formatted_output = self._format_template(output_template)
|
||||
yield {"answer": formatted_output}
|
||||
|
||||
def _parse_structured_output(self, raw_response: str) -> tuple[bool, Optional[Any]]:
|
||||
normalized_response = raw_response.strip()
|
||||
if not normalized_response:
|
||||
return False, None
|
||||
|
||||
try:
|
||||
return True, json.loads(normalized_response)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(
|
||||
"Workflow agent returned structured output that was not valid JSON"
|
||||
)
|
||||
return False, None
|
||||
|
||||
def _normalize_node_json_schema(
|
||||
self, schema: Optional[Dict[str, Any]], node_title: str
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
if schema is None:
|
||||
return None
|
||||
try:
|
||||
return normalize_json_schema_payload(schema)
|
||||
except JsonSchemaValidationError as exc:
|
||||
raise ValueError(
|
||||
f'Invalid JSON schema for node "{node_title}": {exc}'
|
||||
) from exc
|
||||
|
||||
def _validate_structured_output(self, schema: Dict[str, Any], output_value: Any) -> None:
|
||||
if jsonschema is None:
|
||||
logger.warning(
|
||||
"jsonschema package is not available, skipping structured output validation"
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
normalized_schema = normalize_json_schema_payload(schema)
|
||||
except JsonSchemaValidationError as exc:
|
||||
raise ValueError(f"Invalid JSON schema: {exc}") from exc
|
||||
|
||||
try:
|
||||
jsonschema.validate(instance=output_value, schema=normalized_schema)
|
||||
except jsonschema.exceptions.ValidationError as exc:
|
||||
raise ValueError(f"Structured output did not match schema: {exc.message}") from exc
|
||||
except jsonschema.exceptions.SchemaError as exc:
|
||||
raise ValueError(f"Invalid JSON schema: {exc.message}") from exc
|
||||
|
||||
def _format_template(self, template: str) -> str:
|
||||
context = self._build_template_context()
|
||||
try:
|
||||
return self._template_engine.render(template, context)
|
||||
except TemplateRenderError as e:
|
||||
logger.warning(
|
||||
"Workflow template rendering failed, using raw template: %s", str(e)
|
||||
)
|
||||
return template
|
||||
|
||||
def _build_template_context(self) -> Dict[str, Any]:
|
||||
docs, docs_together = self._get_source_template_data()
|
||||
passthrough_data = (
|
||||
self.state.get("passthrough")
|
||||
if isinstance(self.state.get("passthrough"), dict)
|
||||
else None
|
||||
)
|
||||
tools_data = (
|
||||
self.state.get("tools") if isinstance(self.state.get("tools"), dict) else None
|
||||
)
|
||||
|
||||
context = self._namespace_manager.build_context(
|
||||
user_id=getattr(self.agent, "user", None),
|
||||
request_id=getattr(self.agent, "request_id", None),
|
||||
passthrough_data=passthrough_data,
|
||||
docs=docs,
|
||||
docs_together=docs_together,
|
||||
tools_data=tools_data,
|
||||
)
|
||||
|
||||
agent_context: Dict[str, Any] = {}
|
||||
formatted = template
|
||||
for key, value in self.state.items():
|
||||
if not isinstance(key, str):
|
||||
continue
|
||||
normalized_key = key.strip()
|
||||
if not normalized_key:
|
||||
continue
|
||||
agent_context[normalized_key] = value
|
||||
|
||||
context["agent"] = agent_context
|
||||
|
||||
# Keep legacy top-level variables working while namespaced variables are adopted.
|
||||
for key, value in agent_context.items():
|
||||
if key in TEMPLATE_RESERVED_NAMESPACES:
|
||||
context[f"agent_{key}"] = value
|
||||
continue
|
||||
if key not in context:
|
||||
context[key] = value
|
||||
|
||||
return context
|
||||
|
||||
def _get_source_template_data(self) -> tuple[Optional[List[Dict[str, Any]]], Optional[str]]:
|
||||
docs = getattr(self.agent, "retrieved_docs", None)
|
||||
if not isinstance(docs, list) or len(docs) == 0:
|
||||
return None, None
|
||||
|
||||
docs_together_parts: List[str] = []
|
||||
for doc in docs:
|
||||
if not isinstance(doc, dict):
|
||||
continue
|
||||
text = doc.get("text")
|
||||
if not isinstance(text, str):
|
||||
continue
|
||||
|
||||
filename = doc.get("filename") or doc.get("title") or doc.get("source")
|
||||
if isinstance(filename, str) and filename.strip():
|
||||
docs_together_parts.append(f"{filename}\n{text}")
|
||||
else:
|
||||
docs_together_parts.append(text)
|
||||
|
||||
docs_together = "\n\n".join(docs_together_parts) if docs_together_parts else None
|
||||
return docs, docs_together
|
||||
placeholder = f"{{{{{key}}}}}"
|
||||
if placeholder in formatted and value is not None:
|
||||
formatted = formatted.replace(placeholder, str(value))
|
||||
return formatted
|
||||
|
||||
def get_execution_summary(self) -> List[NodeExecutionLog]:
|
||||
return [
|
||||
|
||||
@@ -42,7 +42,6 @@ class AnswerResource(Resource, BaseAnswerResource):
|
||||
),
|
||||
"retriever": fields.String(required=False, description="Retriever type"),
|
||||
"api_key": fields.String(required=False, description="API key"),
|
||||
"agent_id": fields.String(required=False, description="Agent ID"),
|
||||
"active_docs": fields.String(
|
||||
required=False, description="Active documents"
|
||||
),
|
||||
@@ -101,9 +100,6 @@ class AnswerResource(Resource, BaseAnswerResource):
|
||||
isNoneDoc=data.get("isNoneDoc"),
|
||||
index=None,
|
||||
should_save_conversation=data.get("save_conversation", True),
|
||||
agent_id=processor.agent_id,
|
||||
is_shared_usage=processor.is_shared_usage,
|
||||
shared_token=processor.shared_token,
|
||||
model_id=processor.model_id,
|
||||
)
|
||||
stream_result = self.process_response_stream(stream)
|
||||
|
||||
@@ -46,27 +46,6 @@ class BaseAnswerResource:
|
||||
return missing_fields
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _prepare_tool_calls_for_logging(
|
||||
tool_calls: Optional[List[Dict[str, Any]]], max_chars: int = 10000
|
||||
) -> List[Dict[str, Any]]:
|
||||
if not tool_calls:
|
||||
return []
|
||||
|
||||
prepared = []
|
||||
for tool_call in tool_calls:
|
||||
if not isinstance(tool_call, dict):
|
||||
prepared.append({"result": str(tool_call)[:max_chars]})
|
||||
continue
|
||||
|
||||
item = dict(tool_call)
|
||||
for key in ("result", "result_full"):
|
||||
value = item.get(key)
|
||||
if isinstance(value, str) and len(value) > max_chars:
|
||||
item[key] = value[:max_chars]
|
||||
prepared.append(item)
|
||||
return prepared
|
||||
|
||||
def check_usage(self, agent_config: Dict) -> Optional[Response]:
|
||||
"""Check if there is a usage limit and if it is exceeded
|
||||
|
||||
@@ -267,7 +246,6 @@ class BaseAnswerResource:
|
||||
user_api_key=user_api_key,
|
||||
decoded_token=decoded_token,
|
||||
model_id=model_id,
|
||||
agent_id=agent_id,
|
||||
)
|
||||
|
||||
if should_save_conversation:
|
||||
@@ -314,20 +292,14 @@ class BaseAnswerResource:
|
||||
data = json.dumps(id_data)
|
||||
yield f"data: {data}\n\n"
|
||||
|
||||
tool_calls_for_logging = self._prepare_tool_calls_for_logging(
|
||||
getattr(agent, "tool_calls", tool_calls) or tool_calls
|
||||
)
|
||||
|
||||
log_data = {
|
||||
"action": "stream_answer",
|
||||
"level": "info",
|
||||
"user": decoded_token.get("sub"),
|
||||
"api_key": user_api_key,
|
||||
"agent_id": agent_id,
|
||||
"question": question,
|
||||
"response": response_full,
|
||||
"sources": source_log_docs,
|
||||
"tool_calls": tool_calls_for_logging,
|
||||
"attachments": attachment_ids,
|
||||
"timestamp": datetime.datetime.now(datetime.timezone.utc),
|
||||
}
|
||||
@@ -358,7 +330,6 @@ class BaseAnswerResource:
|
||||
api_key=settings.API_KEY,
|
||||
user_api_key=user_api_key,
|
||||
decoded_token=decoded_token,
|
||||
agent_id=agent_id,
|
||||
)
|
||||
self.conversation_service.save_conversation(
|
||||
conversation_id,
|
||||
|
||||
@@ -42,7 +42,6 @@ class StreamResource(Resource, BaseAnswerResource):
|
||||
),
|
||||
"retriever": fields.String(required=False, description="Retriever type"),
|
||||
"api_key": fields.String(required=False, description="API key"),
|
||||
"agent_id": fields.String(required=False, description="Agent ID"),
|
||||
"active_docs": fields.String(
|
||||
required=False, description="Active documents"
|
||||
),
|
||||
@@ -108,7 +107,7 @@ class StreamResource(Resource, BaseAnswerResource):
|
||||
index=data.get("index"),
|
||||
should_save_conversation=data.get("save_conversation", True),
|
||||
attachment_ids=data.get("attachments", []),
|
||||
agent_id=processor.agent_id,
|
||||
agent_id=data.get("agent_id"),
|
||||
is_shared_usage=processor.is_shared_usage,
|
||||
shared_token=processor.shared_token,
|
||||
model_id=processor.model_id,
|
||||
|
||||
@@ -134,7 +134,6 @@ class CompressionOrchestrator:
|
||||
user_api_key=None,
|
||||
decoded_token=decoded_token,
|
||||
model_id=compression_model,
|
||||
agent_id=conversation.get("agent_id"),
|
||||
)
|
||||
|
||||
# Create compression service with DB update capability
|
||||
|
||||
@@ -90,7 +90,6 @@ class StreamProcessor:
|
||||
self.retriever_config = {}
|
||||
self.is_shared_usage = False
|
||||
self.shared_token = None
|
||||
self.agent_id = self.data.get("agent_id")
|
||||
self.model_id: Optional[str] = None
|
||||
self.conversation_service = ConversationService()
|
||||
self.compression_orchestrator = CompressionOrchestrator(
|
||||
@@ -356,13 +355,10 @@ class StreamProcessor:
|
||||
self.agent_key, self.is_shared_usage, self.shared_token = self._get_agent_key(
|
||||
agent_id, self.initial_user_id
|
||||
)
|
||||
self.agent_id = str(agent_id) if agent_id else None
|
||||
|
||||
api_key = self.data.get("api_key")
|
||||
if api_key:
|
||||
data_key = self._get_data_from_api_key(api_key)
|
||||
if data_key.get("_id"):
|
||||
self.agent_id = str(data_key.get("_id"))
|
||||
self.agent_config.update(
|
||||
{
|
||||
"prompt_id": data_key.get("prompt_id", "default"),
|
||||
@@ -391,8 +387,6 @@ class StreamProcessor:
|
||||
self.retriever_config["chunks"] = 2
|
||||
elif self.agent_key:
|
||||
data_key = self._get_data_from_api_key(self.agent_key)
|
||||
if data_key.get("_id"):
|
||||
self.agent_id = str(data_key.get("_id"))
|
||||
self.agent_config.update(
|
||||
{
|
||||
"prompt_id": data_key.get("prompt_id", "default"),
|
||||
@@ -465,7 +459,6 @@ class StreamProcessor:
|
||||
doc_token_limit=self.retriever_config.get("doc_token_limit", 50000),
|
||||
model_id=self.model_id,
|
||||
user_api_key=self.agent_config["user_api_key"],
|
||||
agent_id=self.agent_id,
|
||||
decoded_token=self.decoded_token,
|
||||
)
|
||||
|
||||
@@ -761,7 +754,6 @@ class StreamProcessor:
|
||||
"llm_name": provider or settings.LLM_PROVIDER,
|
||||
"model_id": self.model_id,
|
||||
"api_key": system_api_key,
|
||||
"agent_id": self.agent_id,
|
||||
"user_api_key": self.agent_config["user_api_key"],
|
||||
"prompt": rendered_prompt,
|
||||
"chat_history": self.history,
|
||||
|
||||
@@ -146,20 +146,19 @@ class ConnectorsCallback(Resource):
|
||||
session_token = str(uuid.uuid4())
|
||||
|
||||
try:
|
||||
credentials = auth.create_credentials_from_token_info(token_info)
|
||||
service = auth.build_drive_service(credentials)
|
||||
user_info = service.about().get(fields="user").execute()
|
||||
user_email = user_info.get('user', {}).get('emailAddress', 'Connected User')
|
||||
if provider == "google_drive":
|
||||
credentials = auth.create_credentials_from_token_info(token_info)
|
||||
service = auth.build_drive_service(credentials)
|
||||
user_info = service.about().get(fields="user").execute()
|
||||
user_email = user_info.get('user', {}).get('emailAddress', 'Connected User')
|
||||
else:
|
||||
user_email = token_info.get('user_info', {}).get('email', 'Connected User')
|
||||
|
||||
except Exception as e:
|
||||
current_app.logger.warning(f"Could not get user info: {e}")
|
||||
user_email = 'Connected User'
|
||||
|
||||
sanitized_token_info = {
|
||||
"access_token": token_info.get("access_token"),
|
||||
"refresh_token": token_info.get("refresh_token"),
|
||||
"token_uri": token_info.get("token_uri"),
|
||||
"expiry": token_info.get("expiry")
|
||||
}
|
||||
sanitized_token_info = auth.sanitize_token_info(token_info)
|
||||
|
||||
sessions_collection.find_one_and_update(
|
||||
{"_id": ObjectId(state_object_id), "provider": provider},
|
||||
@@ -201,12 +200,12 @@ class ConnectorsCallback(Resource):
|
||||
@connectors_ns.route("/api/connectors/files")
|
||||
class ConnectorFiles(Resource):
|
||||
@api.expect(api.model("ConnectorFilesModel", {
|
||||
"provider": fields.String(required=True),
|
||||
"session_token": fields.String(required=True),
|
||||
"folder_id": fields.String(required=False),
|
||||
"limit": fields.Integer(required=False),
|
||||
"provider": fields.String(required=True),
|
||||
"session_token": fields.String(required=True),
|
||||
"folder_id": fields.String(required=False),
|
||||
"limit": fields.Integer(required=False),
|
||||
"page_token": fields.String(required=False),
|
||||
"search_query": fields.String(required=False)
|
||||
"search_query": fields.String(required=False),
|
||||
}))
|
||||
@api.doc(description="List files from a connector provider (supports pagination and search)")
|
||||
def post(self):
|
||||
@@ -214,11 +213,8 @@ class ConnectorFiles(Resource):
|
||||
data = request.get_json()
|
||||
provider = data.get('provider')
|
||||
session_token = data.get('session_token')
|
||||
folder_id = data.get('folder_id')
|
||||
limit = data.get('limit', 10)
|
||||
page_token = data.get('page_token')
|
||||
search_query = data.get('search_query')
|
||||
|
||||
|
||||
if not provider or not session_token:
|
||||
return make_response(jsonify({"success": False, "error": "provider and session_token are required"}), 400)
|
||||
|
||||
@@ -231,15 +227,12 @@ class ConnectorFiles(Resource):
|
||||
return make_response(jsonify({"success": False, "error": "Invalid or unauthorized session"}), 401)
|
||||
|
||||
loader = ConnectorCreator.create_connector(provider, session_token)
|
||||
|
||||
generic_keys = {'provider', 'session_token'}
|
||||
input_config = {
|
||||
'limit': limit,
|
||||
'list_only': True,
|
||||
'session_token': session_token,
|
||||
'folder_id': folder_id,
|
||||
'page_token': page_token
|
||||
k: v for k, v in data.items() if k not in generic_keys
|
||||
}
|
||||
if search_query:
|
||||
input_config['search_query'] = search_query
|
||||
input_config['list_only'] = True
|
||||
|
||||
documents = loader.load_data(input_config)
|
||||
|
||||
@@ -306,12 +299,7 @@ class ConnectorValidateSession(Resource):
|
||||
if is_expired and token_info.get('refresh_token'):
|
||||
try:
|
||||
refreshed_token_info = auth.refresh_access_token(token_info.get('refresh_token'))
|
||||
sanitized_token_info = {
|
||||
"access_token": refreshed_token_info.get("access_token"),
|
||||
"refresh_token": refreshed_token_info.get("refresh_token"),
|
||||
"token_uri": refreshed_token_info.get("token_uri"),
|
||||
"expiry": refreshed_token_info.get("expiry")
|
||||
}
|
||||
sanitized_token_info = auth.sanitize_token_info(refreshed_token_info)
|
||||
sessions_collection.update_one(
|
||||
{"session_token": session_token},
|
||||
{"$set": {"token_info": sanitized_token_info}}
|
||||
@@ -328,12 +316,18 @@ class ConnectorValidateSession(Resource):
|
||||
"error": "Session token has expired. Please reconnect."
|
||||
}), 401)
|
||||
|
||||
return make_response(jsonify({
|
||||
_base_fields = {"access_token", "refresh_token", "token_uri", "expiry"}
|
||||
provider_extras = {k: v for k, v in token_info.items() if k not in _base_fields}
|
||||
|
||||
response_data = {
|
||||
"success": True,
|
||||
"expired": False,
|
||||
"user_email": session.get('user_email', 'Connected User'),
|
||||
"access_token": token_info.get('access_token')
|
||||
}), 200)
|
||||
"access_token": token_info.get('access_token'),
|
||||
**provider_extras,
|
||||
}
|
||||
|
||||
return make_response(jsonify(response_data), 200)
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error validating connector session: {e}", exc_info=True)
|
||||
return make_response(jsonify({"success": False, "error": "Failed to validate session"}), 500)
|
||||
|
||||
@@ -23,10 +23,6 @@ from application.api.user.base import (
|
||||
workflow_nodes_collection,
|
||||
workflows_collection,
|
||||
)
|
||||
from application.core.json_schema_utils import (
|
||||
JsonSchemaValidationError,
|
||||
normalize_json_schema_payload,
|
||||
)
|
||||
from application.core.settings import settings
|
||||
from application.utils import (
|
||||
check_required_fields,
|
||||
@@ -483,15 +479,41 @@ class CreateAgent(Resource):
|
||||
data["models"] = []
|
||||
print(f"Received data: {data}")
|
||||
|
||||
# Validate and normalize JSON schema if provided
|
||||
if "json_schema" in data:
|
||||
# Validate JSON schema if provided
|
||||
|
||||
if data.get("json_schema"):
|
||||
try:
|
||||
data["json_schema"] = normalize_json_schema_payload(
|
||||
data.get("json_schema")
|
||||
)
|
||||
except JsonSchemaValidationError as exc:
|
||||
# Basic validation - ensure it's a valid JSON structure
|
||||
|
||||
json_schema = data.get("json_schema")
|
||||
if not isinstance(json_schema, dict):
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": "JSON schema must be a valid JSON object",
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
# Validate that it has either a 'schema' property or is itself a schema
|
||||
|
||||
if "schema" not in json_schema and "type" not in json_schema:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": "JSON schema must contain either a 'schema' property or be a valid JSON schema with 'type' property",
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Invalid JSON schema: {e}")
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": f"JSON schema {exc}"}),
|
||||
jsonify(
|
||||
{"success": False, "message": "Invalid JSON schema format"}
|
||||
),
|
||||
400,
|
||||
)
|
||||
if data.get("status") not in ["draft", "published"]:
|
||||
@@ -710,8 +732,6 @@ class UpdateAgent(Resource):
|
||||
),
|
||||
400,
|
||||
)
|
||||
if data.get("json_schema") == "":
|
||||
data["json_schema"] = None
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error parsing request data: {err}", exc_info=True
|
||||
@@ -872,15 +892,17 @@ class UpdateAgent(Resource):
|
||||
elif field == "json_schema":
|
||||
json_schema = data.get("json_schema")
|
||||
if json_schema is not None:
|
||||
try:
|
||||
update_fields[field] = normalize_json_schema_payload(
|
||||
json_schema
|
||||
)
|
||||
except JsonSchemaValidationError as exc:
|
||||
if not isinstance(json_schema, dict):
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": f"JSON schema {exc}"}),
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": "JSON schema must be a valid object",
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
update_fields[field] = json_schema
|
||||
else:
|
||||
update_fields[field] = None
|
||||
elif field == "limited_token_mode":
|
||||
@@ -1412,4 +1434,4 @@ class RemoveSharedAgent(Resource):
|
||||
current_app.logger.error(f"Error removing shared agent: {err}")
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Server error"}), 500
|
||||
)
|
||||
)
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Workflow management routes."""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Dict, List, Optional, Set
|
||||
from typing import Dict, List
|
||||
|
||||
from flask import current_app, request
|
||||
from flask_restx import Namespace, Resource
|
||||
@@ -11,11 +11,6 @@ from application.api.user.base import (
|
||||
workflow_nodes_collection,
|
||||
workflows_collection,
|
||||
)
|
||||
from application.core.json_schema_utils import (
|
||||
JsonSchemaValidationError,
|
||||
normalize_json_schema_payload,
|
||||
)
|
||||
from application.core.model_utils import get_model_capabilities
|
||||
from application.api.user.utils import (
|
||||
check_resource_ownership,
|
||||
error_response,
|
||||
@@ -90,50 +85,6 @@ def fetch_graph_documents(collection, workflow_id: str, graph_version: int) -> L
|
||||
return docs
|
||||
|
||||
|
||||
def validate_json_schema_payload(
|
||||
json_schema: Any,
|
||||
) -> tuple[Optional[Dict[str, Any]], Optional[str]]:
|
||||
"""Validate and normalize optional JSON schema payload for structured output."""
|
||||
if json_schema is None:
|
||||
return None, None
|
||||
try:
|
||||
return normalize_json_schema_payload(json_schema), None
|
||||
except JsonSchemaValidationError as exc:
|
||||
return None, str(exc)
|
||||
|
||||
|
||||
def normalize_agent_node_json_schemas(nodes: List[Dict]) -> List[Dict]:
|
||||
"""Normalize agent-node JSON schema payloads before persistence."""
|
||||
normalized_nodes: List[Dict] = []
|
||||
for node in nodes:
|
||||
if not isinstance(node, dict):
|
||||
normalized_nodes.append(node)
|
||||
continue
|
||||
|
||||
normalized_node = dict(node)
|
||||
if normalized_node.get("type") != "agent":
|
||||
normalized_nodes.append(normalized_node)
|
||||
continue
|
||||
|
||||
raw_config = normalized_node.get("data")
|
||||
if not isinstance(raw_config, dict) or "json_schema" not in raw_config:
|
||||
normalized_nodes.append(normalized_node)
|
||||
continue
|
||||
|
||||
normalized_config = dict(raw_config)
|
||||
try:
|
||||
normalized_config["json_schema"] = normalize_json_schema_payload(
|
||||
raw_config.get("json_schema")
|
||||
)
|
||||
except JsonSchemaValidationError:
|
||||
# Validation runs before normalization; keep original on unexpected shape.
|
||||
normalized_config["json_schema"] = raw_config.get("json_schema")
|
||||
normalized_node["data"] = normalized_config
|
||||
normalized_nodes.append(normalized_node)
|
||||
|
||||
return normalized_nodes
|
||||
|
||||
|
||||
def validate_workflow_structure(nodes: List[Dict], edges: List[Dict]) -> List[str]:
|
||||
"""Validate workflow graph structure."""
|
||||
errors = []
|
||||
@@ -151,9 +102,6 @@ def validate_workflow_structure(nodes: List[Dict], edges: List[Dict]) -> List[st
|
||||
errors.append("Workflow must have at least one end node")
|
||||
|
||||
node_ids = {n.get("id") for n in nodes}
|
||||
node_map = {n.get("id"): n for n in nodes}
|
||||
end_ids = {n.get("id") for n in end_nodes}
|
||||
|
||||
for edge in edges:
|
||||
source_id = edge.get("source")
|
||||
target_id = edge.get("target")
|
||||
@@ -167,126 +115,6 @@ def validate_workflow_structure(nodes: List[Dict], edges: List[Dict]) -> List[st
|
||||
if not any(e.get("source") == start_id for e in edges):
|
||||
errors.append("Start node must have at least one outgoing edge")
|
||||
|
||||
condition_nodes = [n for n in nodes if n.get("type") == "condition"]
|
||||
for cnode in condition_nodes:
|
||||
cnode_id = cnode.get("id")
|
||||
cnode_title = cnode.get("title", cnode_id)
|
||||
outgoing = [e for e in edges if e.get("source") == cnode_id]
|
||||
if len(outgoing) < 2:
|
||||
errors.append(
|
||||
f"Condition node '{cnode_title}' must have at least 2 outgoing edges"
|
||||
)
|
||||
node_data = cnode.get("data", {}) or {}
|
||||
cases = node_data.get("cases", [])
|
||||
if not isinstance(cases, list):
|
||||
cases = []
|
||||
if not cases or not any(
|
||||
isinstance(c, dict) and str(c.get("expression", "")).strip() for c in cases
|
||||
):
|
||||
errors.append(
|
||||
f"Condition node '{cnode_title}' must have at least one case with an expression"
|
||||
)
|
||||
|
||||
case_handles: Set[str] = set()
|
||||
duplicate_case_handles: Set[str] = set()
|
||||
for case in cases:
|
||||
if not isinstance(case, dict):
|
||||
continue
|
||||
raw_handle = case.get("sourceHandle", "")
|
||||
handle = raw_handle.strip() if isinstance(raw_handle, str) else ""
|
||||
if not handle:
|
||||
errors.append(
|
||||
f"Condition node '{cnode_title}' has a case without a branch handle"
|
||||
)
|
||||
continue
|
||||
if handle in case_handles:
|
||||
duplicate_case_handles.add(handle)
|
||||
case_handles.add(handle)
|
||||
|
||||
for handle in duplicate_case_handles:
|
||||
errors.append(
|
||||
f"Condition node '{cnode_title}' has duplicate case handle '{handle}'"
|
||||
)
|
||||
|
||||
outgoing_by_handle: Dict[str, List[Dict]] = {}
|
||||
for out_edge in outgoing:
|
||||
raw_handle = out_edge.get("sourceHandle", "")
|
||||
handle = raw_handle.strip() if isinstance(raw_handle, str) else ""
|
||||
outgoing_by_handle.setdefault(handle, []).append(out_edge)
|
||||
|
||||
for handle, handle_edges in outgoing_by_handle.items():
|
||||
if not handle:
|
||||
errors.append(
|
||||
f"Condition node '{cnode_title}' has an outgoing edge without sourceHandle"
|
||||
)
|
||||
continue
|
||||
if handle != "else" and handle not in case_handles:
|
||||
errors.append(
|
||||
f"Condition node '{cnode_title}' has a connection from unknown branch '{handle}'"
|
||||
)
|
||||
if len(handle_edges) > 1:
|
||||
errors.append(
|
||||
f"Condition node '{cnode_title}' has multiple outgoing edges from branch '{handle}'"
|
||||
)
|
||||
|
||||
if "else" not in outgoing_by_handle:
|
||||
errors.append(f"Condition node '{cnode_title}' must have an 'else' branch")
|
||||
|
||||
for case in cases:
|
||||
if not isinstance(case, dict):
|
||||
continue
|
||||
raw_handle = case.get("sourceHandle", "")
|
||||
handle = raw_handle.strip() if isinstance(raw_handle, str) else ""
|
||||
if not handle:
|
||||
continue
|
||||
|
||||
raw_expression = case.get("expression", "")
|
||||
has_expression = isinstance(raw_expression, str) and bool(
|
||||
raw_expression.strip()
|
||||
)
|
||||
has_outgoing = bool(outgoing_by_handle.get(handle))
|
||||
if has_expression and not has_outgoing:
|
||||
errors.append(
|
||||
f"Condition node '{cnode_title}' case '{handle}' has an expression but no outgoing edge"
|
||||
)
|
||||
if not has_expression and has_outgoing:
|
||||
errors.append(
|
||||
f"Condition node '{cnode_title}' case '{handle}' has an outgoing edge but no expression"
|
||||
)
|
||||
|
||||
for handle, handle_edges in outgoing_by_handle.items():
|
||||
if not handle:
|
||||
continue
|
||||
for out_edge in handle_edges:
|
||||
target = out_edge.get("target")
|
||||
if target and not _can_reach_end(target, edges, node_map, end_ids):
|
||||
errors.append(
|
||||
f"Branch '{handle}' of condition '{cnode_title}' "
|
||||
f"must eventually reach an end node"
|
||||
)
|
||||
|
||||
agent_nodes = [n for n in nodes if n.get("type") == "agent"]
|
||||
for agent_node in agent_nodes:
|
||||
agent_title = agent_node.get("title", agent_node.get("id", "unknown"))
|
||||
raw_config = agent_node.get("data", {}) or {}
|
||||
if not isinstance(raw_config, dict):
|
||||
errors.append(f"Agent node '{agent_title}' has invalid configuration")
|
||||
continue
|
||||
normalized_schema, schema_error = validate_json_schema_payload(
|
||||
raw_config.get("json_schema")
|
||||
)
|
||||
has_json_schema = normalized_schema is not None
|
||||
|
||||
model_id = raw_config.get("model_id")
|
||||
if has_json_schema and isinstance(model_id, str) and model_id.strip():
|
||||
capabilities = get_model_capabilities(model_id.strip())
|
||||
if capabilities and not capabilities.get("supports_structured_output", False):
|
||||
errors.append(
|
||||
f"Agent node '{agent_title}' selected model does not support structured output"
|
||||
)
|
||||
if schema_error:
|
||||
errors.append(f"Agent node '{agent_title}' JSON schema {schema_error}")
|
||||
|
||||
for node in nodes:
|
||||
if not node.get("id"):
|
||||
errors.append("All nodes must have an id")
|
||||
@@ -296,20 +124,6 @@ def validate_workflow_structure(nodes: List[Dict], edges: List[Dict]) -> List[st
|
||||
return errors
|
||||
|
||||
|
||||
def _can_reach_end(
|
||||
node_id: str, edges: List[Dict], node_map: Dict, end_ids: set, visited: set = None
|
||||
) -> bool:
|
||||
if visited is None:
|
||||
visited = set()
|
||||
if node_id in end_ids:
|
||||
return True
|
||||
if node_id in visited or node_id not in node_map:
|
||||
return False
|
||||
visited.add(node_id)
|
||||
outgoing = [e.get("target") for e in edges if e.get("source") == node_id]
|
||||
return any(_can_reach_end(t, edges, node_map, end_ids, visited) for t in outgoing if t)
|
||||
|
||||
|
||||
def create_workflow_nodes(
|
||||
workflow_id: str, nodes_data: List[Dict], graph_version: int
|
||||
) -> None:
|
||||
@@ -372,7 +186,6 @@ class WorkflowList(Resource):
|
||||
return error_response(
|
||||
"Workflow validation failed", errors=validation_errors
|
||||
)
|
||||
nodes_data = normalize_agent_node_json_schemas(nodes_data)
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
workflow_doc = {
|
||||
@@ -463,7 +276,6 @@ class WorkflowDetail(Resource):
|
||||
return error_response(
|
||||
"Workflow validation failed", errors=validation_errors
|
||||
)
|
||||
nodes_data = normalize_agent_node_json_schemas(nodes_data)
|
||||
|
||||
current_graph_version = get_workflow_graph_version(workflow)
|
||||
next_graph_version = current_graph_version + 1
|
||||
|
||||
@@ -1,34 +0,0 @@
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
|
||||
class JsonSchemaValidationError(ValueError):
|
||||
"""Raised when a JSON schema payload is invalid."""
|
||||
|
||||
|
||||
def normalize_json_schema_payload(json_schema: Any) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Normalize accepted JSON schema payload shapes to a plain schema object.
|
||||
|
||||
Accepted inputs:
|
||||
- None
|
||||
- A raw schema object with a top-level "type"
|
||||
- A wrapped payload with a top-level "schema" object
|
||||
"""
|
||||
if json_schema is None:
|
||||
return None
|
||||
|
||||
if not isinstance(json_schema, dict):
|
||||
raise JsonSchemaValidationError("must be a valid JSON object")
|
||||
|
||||
wrapped_schema = json_schema.get("schema")
|
||||
if wrapped_schema is not None:
|
||||
if not isinstance(wrapped_schema, dict):
|
||||
raise JsonSchemaValidationError('field "schema" must be a valid JSON object')
|
||||
return wrapped_schema
|
||||
|
||||
if "type" not in json_schema:
|
||||
raise JsonSchemaValidationError(
|
||||
'must include either a "type" or "schema" field'
|
||||
)
|
||||
|
||||
return json_schema
|
||||
@@ -65,8 +65,14 @@ class Settings(BaseSettings):
|
||||
"http://127.0.0.1:7091/api/connectors/callback" ##add redirect url as it is to your provider's console(gcp)
|
||||
)
|
||||
|
||||
# Microsoft Entra ID (Azure AD) integration
|
||||
MICROSOFT_CLIENT_ID: Optional[str] = None # Azure AD Application (client) ID
|
||||
MICROSOFT_CLIENT_SECRET: Optional[str] = None # Azure AD Application client secret
|
||||
MICROSOFT_TENANT_ID: Optional[str] = "common" # Azure AD Tenant ID (or 'common' for multi-tenant)
|
||||
MICROSOFT_AUTHORITY: Optional[str] = None # e.g., "https://login.microsoftonline.com/{tenant_id}"
|
||||
|
||||
# GitHub source
|
||||
GITHUB_ACCESS_TOKEN: Optional[str] = None # PAT token with read repo access
|
||||
GITHUB_ACCESS_TOKEN: Optional[str] = None # PAT token with read repo access
|
||||
|
||||
# LLM Cache
|
||||
CACHE_REDIS_URL: str = "redis://localhost:6379/2"
|
||||
|
||||
@@ -13,12 +13,10 @@ class BaseLLM(ABC):
|
||||
def __init__(
|
||||
self,
|
||||
decoded_token=None,
|
||||
agent_id=None,
|
||||
model_id=None,
|
||||
base_url=None,
|
||||
):
|
||||
self.decoded_token = decoded_token
|
||||
self.agent_id = str(agent_id) if agent_id else None
|
||||
self.model_id = model_id
|
||||
self.base_url = base_url
|
||||
self.token_usage = {"prompt_tokens": 0, "generated_tokens": 0}
|
||||
@@ -35,10 +33,9 @@ class BaseLLM(ABC):
|
||||
self._fallback_llm = LLMCreator.create_llm(
|
||||
settings.FALLBACK_LLM_PROVIDER,
|
||||
api_key=settings.FALLBACK_LLM_API_KEY or settings.API_KEY,
|
||||
user_api_key=getattr(self, "user_api_key", None),
|
||||
user_api_key=None,
|
||||
decoded_token=self.decoded_token,
|
||||
model_id=settings.FALLBACK_LLM_NAME,
|
||||
agent_id=self.agent_id,
|
||||
)
|
||||
logger.info(
|
||||
f"Fallback LLM initialized: {settings.FALLBACK_LLM_PROVIDER}/{settings.FALLBACK_LLM_NAME}"
|
||||
|
||||
@@ -13,7 +13,7 @@ class GoogleLLM(BaseLLM):
|
||||
def __init__(
|
||||
self, api_key=None, user_api_key=None, decoded_token=None, *args, **kwargs
|
||||
):
|
||||
super().__init__(decoded_token=decoded_token, *args, **kwargs)
|
||||
super().__init__(*args, **kwargs)
|
||||
self.api_key = api_key or settings.GOOGLE_API_KEY or settings.API_KEY
|
||||
self.user_api_key = user_api_key
|
||||
|
||||
|
||||
@@ -567,7 +567,6 @@ class LLMHandler(ABC):
|
||||
getattr(agent, "user_api_key", None),
|
||||
getattr(agent, "decoded_token", None),
|
||||
model_id=compression_model,
|
||||
agent_id=getattr(agent, "agent_id", None),
|
||||
)
|
||||
|
||||
# Create service without DB persistence capability
|
||||
|
||||
@@ -31,15 +31,7 @@ class LLMCreator:
|
||||
|
||||
@classmethod
|
||||
def create_llm(
|
||||
cls,
|
||||
type,
|
||||
api_key,
|
||||
user_api_key,
|
||||
decoded_token,
|
||||
model_id=None,
|
||||
agent_id=None,
|
||||
*args,
|
||||
**kwargs,
|
||||
cls, type, api_key, user_api_key, decoded_token, model_id=None, *args, **kwargs
|
||||
):
|
||||
from application.core.model_utils import get_base_url_for_model
|
||||
|
||||
@@ -57,7 +49,6 @@ class LLMCreator:
|
||||
user_api_key,
|
||||
decoded_token=decoded_token,
|
||||
model_id=model_id,
|
||||
agent_id=agent_id,
|
||||
base_url=base_url,
|
||||
*args,
|
||||
**kwargs,
|
||||
|
||||
@@ -62,15 +62,26 @@ class BaseConnectorAuth(ABC):
|
||||
def is_token_expired(self, token_info: Dict[str, Any]) -> bool:
|
||||
"""
|
||||
Check if a token is expired.
|
||||
|
||||
|
||||
Args:
|
||||
token_info: Token information dictionary
|
||||
|
||||
|
||||
Returns:
|
||||
True if token is expired, False otherwise
|
||||
"""
|
||||
pass
|
||||
|
||||
def sanitize_token_info(self, token_info: Dict[str, Any], **extra_fields) -> Dict[str, Any]:
|
||||
"""Extract the fields safe to persist in the session store.
|
||||
"""
|
||||
return {
|
||||
"access_token": token_info.get("access_token"),
|
||||
"refresh_token": token_info.get("refresh_token"),
|
||||
"token_uri": token_info.get("token_uri"),
|
||||
"expiry": token_info.get("expiry"),
|
||||
**extra_fields,
|
||||
}
|
||||
|
||||
|
||||
class BaseConnectorLoader(ABC):
|
||||
"""
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
from application.parser.connectors.google_drive.loader import GoogleDriveLoader
|
||||
from application.parser.connectors.google_drive.auth import GoogleDriveAuth
|
||||
from application.parser.connectors.share_point.auth import SharePointAuth
|
||||
from application.parser.connectors.share_point.loader import SharePointLoader
|
||||
|
||||
|
||||
class ConnectorCreator:
|
||||
@@ -12,10 +14,12 @@ class ConnectorCreator:
|
||||
|
||||
connectors = {
|
||||
"google_drive": GoogleDriveLoader,
|
||||
"share_point": SharePointLoader,
|
||||
}
|
||||
|
||||
auth_providers = {
|
||||
"google_drive": GoogleDriveAuth,
|
||||
"share_point": SharePointAuth,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -232,10 +232,6 @@ class GoogleDriveAuth(BaseConnectorAuth):
|
||||
if missing_fields:
|
||||
raise ValueError(f"Missing required token fields: {missing_fields}")
|
||||
|
||||
if 'client_id' not in token_info:
|
||||
token_info['client_id'] = settings.GOOGLE_CLIENT_ID
|
||||
if 'client_secret' not in token_info:
|
||||
token_info['client_secret'] = settings.GOOGLE_CLIENT_SECRET
|
||||
if 'token_uri' not in token_info:
|
||||
token_info['token_uri'] = 'https://oauth2.googleapis.com/token'
|
||||
|
||||
|
||||
@@ -327,15 +327,10 @@ class GoogleDriveLoader(BaseConnectorLoader):
|
||||
content_bytes = file_io.getvalue()
|
||||
|
||||
try:
|
||||
content = content_bytes.decode('utf-8')
|
||||
return content_bytes.decode('utf-8')
|
||||
except UnicodeDecodeError:
|
||||
try:
|
||||
content = content_bytes.decode('latin-1')
|
||||
except UnicodeDecodeError:
|
||||
logging.error(f"Could not decode file {file_id} as text")
|
||||
return None
|
||||
|
||||
return content
|
||||
logging.error(f"Could not decode file {file_id} as text")
|
||||
return None
|
||||
|
||||
except HttpError as e:
|
||||
logging.error(f"HTTP error downloading file {file_id}: {e.resp.status} - {e.content}")
|
||||
|
||||
10
application/parser/connectors/share_point/__init__.py
Normal file
10
application/parser/connectors/share_point/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
||||
"""
|
||||
Share Point connector package for DocsGPT.
|
||||
|
||||
This module provides authentication and document loading capabilities for Share Point.
|
||||
"""
|
||||
|
||||
from .auth import SharePointAuth
|
||||
from .loader import SharePointLoader
|
||||
|
||||
__all__ = ['SharePointAuth', 'SharePointLoader']
|
||||
152
application/parser/connectors/share_point/auth.py
Normal file
152
application/parser/connectors/share_point/auth.py
Normal file
@@ -0,0 +1,152 @@
|
||||
import datetime
|
||||
import logging
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
from msal import ConfidentialClientApplication
|
||||
|
||||
from application.core.settings import settings
|
||||
from application.parser.connectors.base import BaseConnectorAuth
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SharePointAuth(BaseConnectorAuth):
|
||||
"""
|
||||
Handles Microsoft OAuth 2.0 authentication for SharePoint/OneDrive.
|
||||
|
||||
Note: Files.Read scope allows access to files the user has granted access to,
|
||||
similar to Google Drive's drive.file scope.
|
||||
"""
|
||||
|
||||
SCOPES = [
|
||||
"Files.Read",
|
||||
"Sites.Read.All",
|
||||
"User.Read",
|
||||
]
|
||||
|
||||
def __init__(self):
|
||||
self.client_id = settings.MICROSOFT_CLIENT_ID
|
||||
self.client_secret = settings.MICROSOFT_CLIENT_SECRET
|
||||
|
||||
if not self.client_id:
|
||||
raise ValueError(
|
||||
"Microsoft OAuth credentials not configured. Please set MICROSOFT_CLIENT_ID in settings."
|
||||
)
|
||||
|
||||
if not self.client_secret:
|
||||
raise ValueError(
|
||||
"Microsoft OAuth credentials not configured. Please set MICROSOFT_CLIENT_SECRET in settings."
|
||||
)
|
||||
|
||||
self.redirect_uri = settings.CONNECTOR_REDIRECT_BASE_URI
|
||||
self.tenant_id = settings.MICROSOFT_TENANT_ID
|
||||
self.authority = getattr(settings, "MICROSOFT_AUTHORITY", f"https://login.microsoftonline.com/{self.tenant_id}")
|
||||
|
||||
self.auth_app = ConfidentialClientApplication(
|
||||
client_id=self.client_id,
|
||||
client_credential=self.client_secret,
|
||||
authority=self.authority
|
||||
)
|
||||
|
||||
def get_authorization_url(self, state: Optional[str] = None) -> str:
|
||||
return self.auth_app.get_authorization_request_url(
|
||||
scopes=self.SCOPES, state=state, redirect_uri=self.redirect_uri
|
||||
)
|
||||
|
||||
def exchange_code_for_tokens(self, authorization_code: str) -> Dict[str, Any]:
|
||||
result = self.auth_app.acquire_token_by_authorization_code(
|
||||
code=authorization_code,
|
||||
scopes=self.SCOPES,
|
||||
redirect_uri=self.redirect_uri
|
||||
)
|
||||
|
||||
if "error" in result:
|
||||
logger.error("Token exchange failed: %s", result.get("error_description"))
|
||||
raise ValueError(f"Error acquiring token: {result.get('error_description')}")
|
||||
|
||||
return self.map_token_response(result)
|
||||
|
||||
def refresh_access_token(self, refresh_token: str) -> Dict[str, Any]:
|
||||
result = self.auth_app.acquire_token_by_refresh_token(refresh_token=refresh_token, scopes=self.SCOPES)
|
||||
|
||||
if "error" in result:
|
||||
logger.error("Token refresh failed: %s", result.get("error_description"))
|
||||
raise ValueError(f"Error refreshing token: {result.get('error_description')}")
|
||||
|
||||
return self.map_token_response(result)
|
||||
|
||||
def get_token_info_from_session(self, session_token: str) -> Dict[str, Any]:
|
||||
try:
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.core.settings import settings
|
||||
|
||||
mongo = MongoDB.get_client()
|
||||
db = mongo[settings.MONGO_DB_NAME]
|
||||
|
||||
sessions_collection = db["connector_sessions"]
|
||||
session = sessions_collection.find_one({"session_token": session_token})
|
||||
|
||||
if not session:
|
||||
raise ValueError(f"Invalid session token: {session_token}")
|
||||
|
||||
if "token_info" not in session:
|
||||
raise ValueError("Session missing token information")
|
||||
|
||||
token_info = session["token_info"]
|
||||
if not token_info:
|
||||
raise ValueError("Invalid token information")
|
||||
|
||||
required_fields = ["access_token", "refresh_token"]
|
||||
missing_fields = [field for field in required_fields if field not in token_info or not token_info.get(field)]
|
||||
if missing_fields:
|
||||
raise ValueError(f"Missing required token fields: {missing_fields}")
|
||||
|
||||
if 'token_uri' not in token_info:
|
||||
token_info['token_uri'] = f"https://login.microsoftonline.com/{settings.MICROSOFT_TENANT_ID}/oauth2/v2.0/token"
|
||||
|
||||
return token_info
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to retrieve token from session: %s", e)
|
||||
raise ValueError(f"Failed to retrieve SharePoint token information: {str(e)}")
|
||||
|
||||
def is_token_expired(self, token_info: Dict[str, Any]) -> bool:
|
||||
if not token_info:
|
||||
return True
|
||||
|
||||
expiry_timestamp = token_info.get("expiry")
|
||||
|
||||
if expiry_timestamp is None:
|
||||
return True
|
||||
|
||||
current_timestamp = int(datetime.datetime.now().timestamp())
|
||||
return (expiry_timestamp - current_timestamp) < 60
|
||||
|
||||
def sanitize_token_info(self, token_info: Dict[str, Any], **extra_fields) -> Dict[str, Any]:
|
||||
return super().sanitize_token_info(
|
||||
token_info,
|
||||
allows_shared_content=token_info.get("allows_shared_content", False),
|
||||
**extra_fields,
|
||||
)
|
||||
|
||||
PERSONAL_ACCOUNT_TENANT_ID = "9188040d-6c67-4c5b-b112-36a304b66dad"
|
||||
|
||||
def _allows_shared_content(self, id_token_claims: Dict[str, Any]) -> bool:
|
||||
"""Return True when the account is a work/school tenant that can access SharePoint shared content."""
|
||||
tid = id_token_claims.get("tid", "")
|
||||
return bool(tid) and tid != self.PERSONAL_ACCOUNT_TENANT_ID
|
||||
|
||||
def map_token_response(self, result) -> Dict[str, Any]:
|
||||
claims = result.get("id_token_claims", {})
|
||||
return {
|
||||
"access_token": result.get("access_token"),
|
||||
"refresh_token": result.get("refresh_token"),
|
||||
"token_uri": claims.get("iss"),
|
||||
"scopes": result.get("scope"),
|
||||
"expiry": claims.get("exp"),
|
||||
"allows_shared_content": self._allows_shared_content(claims),
|
||||
"user_info": {
|
||||
"name": claims.get("name"),
|
||||
"email": claims.get("preferred_username"),
|
||||
},
|
||||
}
|
||||
649
application/parser/connectors/share_point/loader.py
Normal file
649
application/parser/connectors/share_point/loader.py
Normal file
@@ -0,0 +1,649 @@
|
||||
"""
|
||||
SharePoint/OneDrive loader for DocsGPT.
|
||||
Loads documents from SharePoint/OneDrive using Microsoft Graph API.
|
||||
"""
|
||||
|
||||
import functools
|
||||
import logging
|
||||
import os
|
||||
from typing import List, Dict, Any, Optional, Tuple
|
||||
from urllib.parse import quote
|
||||
|
||||
import requests
|
||||
|
||||
from application.parser.connectors.base import BaseConnectorLoader
|
||||
from application.parser.connectors.share_point.auth import SharePointAuth
|
||||
from application.parser.schema.base import Document
|
||||
|
||||
|
||||
def _retry_on_auth_failure(func):
|
||||
"""Retry once after refreshing the access token on 401/403 responses."""
|
||||
@functools.wraps(func)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
try:
|
||||
return func(self, *args, **kwargs)
|
||||
except requests.exceptions.HTTPError as e:
|
||||
if e.response is not None and e.response.status_code in (401, 403):
|
||||
logging.info(f"Auth failure in {func.__name__}, refreshing token and retrying")
|
||||
try:
|
||||
new_token_info = self.auth.refresh_access_token(self.refresh_token)
|
||||
self.access_token = new_token_info.get('access_token')
|
||||
except Exception as refresh_error:
|
||||
raise ValueError(
|
||||
f"Authentication failed and could not be refreshed: {refresh_error}"
|
||||
) from e
|
||||
return func(self, *args, **kwargs)
|
||||
raise
|
||||
return wrapper
|
||||
|
||||
|
||||
class SharePointLoader(BaseConnectorLoader):
|
||||
|
||||
SUPPORTED_MIME_TYPES = {
|
||||
'application/pdf': '.pdf',
|
||||
'application/vnd.openxmlformats-officedocument.wordprocessingml.document': '.docx',
|
||||
'application/vnd.openxmlformats-officedocument.presentationml.presentation': '.pptx',
|
||||
'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet': '.xlsx',
|
||||
'application/msword': '.doc',
|
||||
'application/vnd.ms-powerpoint': '.ppt',
|
||||
'application/vnd.ms-excel': '.xls',
|
||||
'text/plain': '.txt',
|
||||
'text/csv': '.csv',
|
||||
'text/html': '.html',
|
||||
'text/markdown': '.md',
|
||||
'text/x-rst': '.rst',
|
||||
'application/json': '.json',
|
||||
'application/epub+zip': '.epub',
|
||||
'application/rtf': '.rtf',
|
||||
'image/jpeg': '.jpg',
|
||||
'image/png': '.png',
|
||||
}
|
||||
|
||||
EXTENSION_TO_MIME = {v: k for k, v in SUPPORTED_MIME_TYPES.items()}
|
||||
|
||||
GRAPH_API_BASE = "https://graph.microsoft.com/v1.0"
|
||||
|
||||
def __init__(self, session_token: str):
|
||||
self.auth = SharePointAuth()
|
||||
self.session_token = session_token
|
||||
|
||||
token_info = self.auth.get_token_info_from_session(session_token)
|
||||
self.access_token = token_info.get('access_token')
|
||||
self.refresh_token = token_info.get('refresh_token')
|
||||
self.allows_shared_content = token_info.get('allows_shared_content', False)
|
||||
|
||||
if not self.access_token:
|
||||
raise ValueError("No access token found in session")
|
||||
|
||||
self.next_page_token = None
|
||||
|
||||
def _get_headers(self) -> Dict[str, str]:
|
||||
return {
|
||||
'Authorization': f'Bearer {self.access_token}',
|
||||
'Accept': 'application/json'
|
||||
}
|
||||
|
||||
def _ensure_valid_token(self):
|
||||
if not self.access_token:
|
||||
raise ValueError("No access token available")
|
||||
|
||||
token_info = {'access_token': self.access_token, 'expiry': None}
|
||||
if self.auth.is_token_expired(token_info):
|
||||
logging.info("Token expired, attempting refresh")
|
||||
try:
|
||||
new_token_info = self.auth.refresh_access_token(self.refresh_token)
|
||||
self.access_token = new_token_info.get('access_token')
|
||||
except Exception:
|
||||
raise ValueError("Failed to refresh access token")
|
||||
|
||||
def _get_item_url(self, item_ref: str) -> str:
|
||||
if ':' in item_ref:
|
||||
drive_id, item_id = item_ref.split(':', 1)
|
||||
return f"{self.GRAPH_API_BASE}/drives/{drive_id}/items/{item_id}"
|
||||
return f"{self.GRAPH_API_BASE}/me/drive/items/{item_ref}"
|
||||
|
||||
def _process_file(self, file_metadata: Dict[str, Any], load_content: bool = True) -> Optional[Document]:
|
||||
try:
|
||||
drive_item_id = file_metadata.get('id')
|
||||
file_name = file_metadata.get('name', 'Unknown')
|
||||
file_data = file_metadata.get('file', {})
|
||||
mime_type = file_data.get('mimeType', 'application/octet-stream')
|
||||
|
||||
if mime_type not in self.SUPPORTED_MIME_TYPES:
|
||||
logging.info(f"Skipping unsupported file type: {mime_type} for file {file_name}")
|
||||
return None
|
||||
|
||||
doc_metadata = {
|
||||
'file_name': file_name,
|
||||
'mime_type': mime_type,
|
||||
'size': file_metadata.get('size'),
|
||||
'created_time': file_metadata.get('createdDateTime'),
|
||||
'modified_time': file_metadata.get('lastModifiedDateTime'),
|
||||
'source': 'share_point'
|
||||
}
|
||||
|
||||
if not load_content:
|
||||
return Document(
|
||||
text="",
|
||||
doc_id=drive_item_id,
|
||||
extra_info=doc_metadata
|
||||
)
|
||||
|
||||
content = self._download_file_content(drive_item_id)
|
||||
if content is None:
|
||||
logging.warning(f"Could not load content for file {file_name} ({drive_item_id})")
|
||||
return None
|
||||
|
||||
return Document(
|
||||
text=content,
|
||||
doc_id=drive_item_id,
|
||||
extra_info=doc_metadata
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error processing file: {e}")
|
||||
return None
|
||||
|
||||
def load_data(self, inputs: Dict[str, Any]) -> List[Document]:
|
||||
try:
|
||||
documents: List[Document] = []
|
||||
|
||||
folder_id = inputs.get('folder_id')
|
||||
file_ids = inputs.get('file_ids', [])
|
||||
limit = inputs.get('limit', 100)
|
||||
list_only = inputs.get('list_only', False)
|
||||
load_content = not list_only
|
||||
page_token = inputs.get('page_token')
|
||||
search_query = inputs.get('search_query')
|
||||
self.next_page_token = None
|
||||
|
||||
shared = inputs.get('shared', False)
|
||||
|
||||
if file_ids:
|
||||
for file_id in file_ids:
|
||||
try:
|
||||
doc = self._load_file_by_id(file_id, load_content=load_content)
|
||||
if doc:
|
||||
if not search_query or (
|
||||
search_query.lower() in doc.extra_info.get('file_name', '').lower()
|
||||
):
|
||||
documents.append(doc)
|
||||
except Exception as e:
|
||||
logging.error(f"Error loading file {file_id}: {e}")
|
||||
continue
|
||||
elif shared:
|
||||
if not self.allows_shared_content:
|
||||
logging.warning("Shared content is only available for work/school Microsoft accounts")
|
||||
return []
|
||||
documents = self._list_shared_items(
|
||||
limit=limit,
|
||||
load_content=load_content,
|
||||
page_token=page_token,
|
||||
search_query=search_query
|
||||
)
|
||||
else:
|
||||
parent_id = folder_id if folder_id else 'root'
|
||||
documents = self._list_items_in_parent(
|
||||
parent_id,
|
||||
limit=limit,
|
||||
load_content=load_content,
|
||||
page_token=page_token,
|
||||
search_query=search_query
|
||||
)
|
||||
|
||||
logging.info(f"Loaded {len(documents)} documents from SharePoint/OneDrive")
|
||||
return documents
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error loading data from SharePoint/OneDrive: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
@_retry_on_auth_failure
|
||||
def _load_file_by_id(self, file_id: str, load_content: bool = True) -> Optional[Document]:
|
||||
self._ensure_valid_token()
|
||||
|
||||
try:
|
||||
url = self._get_item_url(file_id)
|
||||
params = {'$select': 'id,name,file,createdDateTime,lastModifiedDateTime,size'}
|
||||
response = requests.get(url, headers=self._get_headers(), params=params)
|
||||
response.raise_for_status()
|
||||
|
||||
file_metadata = response.json()
|
||||
return self._process_file(file_metadata, load_content=load_content)
|
||||
|
||||
except requests.exceptions.HTTPError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logging.error(f"Error loading file {file_id}: {e}")
|
||||
return None
|
||||
|
||||
@_retry_on_auth_failure
|
||||
def _list_items_in_parent(self, parent_id: str, limit: int = 100, load_content: bool = False, page_token: Optional[str] = None, search_query: Optional[str] = None) -> List[Document]:
|
||||
self._ensure_valid_token()
|
||||
|
||||
documents: List[Document] = []
|
||||
|
||||
try:
|
||||
url = f"{self._get_item_url(parent_id)}/children"
|
||||
params = {'$top': min(100, limit) if limit else 100, '$select': 'id,name,file,folder,createdDateTime,lastModifiedDateTime,size'}
|
||||
if page_token:
|
||||
params['$skipToken'] = page_token
|
||||
|
||||
if search_query:
|
||||
encoded_query = quote(search_query, safe='')
|
||||
if ':' in parent_id:
|
||||
drive_id = parent_id.split(':', 1)[0]
|
||||
search_url = f"{self.GRAPH_API_BASE}/drives/{drive_id}/root/search(q='{encoded_query}')"
|
||||
else:
|
||||
search_url = f"{self.GRAPH_API_BASE}/me/drive/search(q='{encoded_query}')"
|
||||
response = requests.get(search_url, headers=self._get_headers(), params=params)
|
||||
else:
|
||||
response = requests.get(url, headers=self._get_headers(), params=params)
|
||||
|
||||
response.raise_for_status()
|
||||
|
||||
results = response.json()
|
||||
|
||||
items = results.get('value', [])
|
||||
for item in items:
|
||||
if 'folder' in item:
|
||||
doc_metadata = {
|
||||
'file_name': item.get('name', 'Unknown'),
|
||||
'mime_type': 'folder',
|
||||
'size': item.get('size'),
|
||||
'created_time': item.get('createdDateTime'),
|
||||
'modified_time': item.get('lastModifiedDateTime'),
|
||||
'source': 'share_point',
|
||||
'is_folder': True
|
||||
}
|
||||
documents.append(Document(text="", doc_id=item.get('id'), extra_info=doc_metadata))
|
||||
else:
|
||||
doc = self._process_file(item, load_content=load_content)
|
||||
if doc:
|
||||
documents.append(doc)
|
||||
|
||||
if limit and len(documents) >= limit:
|
||||
break
|
||||
|
||||
next_link = results.get('@odata.nextLink')
|
||||
if next_link:
|
||||
from urllib.parse import urlparse, parse_qs
|
||||
parsed = urlparse(next_link)
|
||||
query_params = parse_qs(parsed.query)
|
||||
skiptoken_list = query_params.get('$skiptoken')
|
||||
if skiptoken_list:
|
||||
self.next_page_token = skiptoken_list[0]
|
||||
else:
|
||||
self.next_page_token = None
|
||||
else:
|
||||
self.next_page_token = None
|
||||
return documents
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error listing items under parent {parent_id}: {e}")
|
||||
return documents
|
||||
|
||||
|
||||
|
||||
|
||||
def _resolve_mime_type(self, resource: Dict[str, Any]) -> Tuple[str, bool]:
|
||||
"""Resolve mime type from resource, falling back to file extension."""
|
||||
file_data = resource.get('file', {})
|
||||
mime_type = file_data.get('mimeType') if file_data else None
|
||||
|
||||
if mime_type and mime_type in self.SUPPORTED_MIME_TYPES:
|
||||
return mime_type, True
|
||||
|
||||
name = resource.get('name', '')
|
||||
ext = os.path.splitext(name)[1].lower()
|
||||
if ext in self.EXTENSION_TO_MIME:
|
||||
return self.EXTENSION_TO_MIME[ext], True
|
||||
|
||||
return mime_type or 'application/octet-stream', False
|
||||
|
||||
def _get_user_drive_web_url(self) -> Optional[str]:
|
||||
"""Fetch the current user's OneDrive web URL for KQL path exclusion."""
|
||||
try:
|
||||
response = requests.get(
|
||||
f"{self.GRAPH_API_BASE}/me/drive",
|
||||
headers=self._get_headers(),
|
||||
params={'$select': 'webUrl'}
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json().get('webUrl')
|
||||
except Exception as e:
|
||||
logging.warning(f"Could not fetch user drive web URL: {e}")
|
||||
return None
|
||||
|
||||
def _build_shared_kql_query(self, search_query: Optional[str], user_drive_url: Optional[str]) -> str:
|
||||
"""Build KQL query string that excludes the user's own drive items."""
|
||||
base_query = search_query if search_query else "*"
|
||||
if user_drive_url:
|
||||
return f'{base_query} AND -path:"{user_drive_url}"'
|
||||
return base_query
|
||||
|
||||
def _list_shared_items(self, limit: int = 100, load_content: bool = False, page_token: Optional[str] = None, search_query: Optional[str] = None) -> List[Document]:
|
||||
"""Fetch shared drive items using Microsoft Graph Search API with local offset paging.
|
||||
|
||||
We always fetch up to a fixed maximum number of hits from Graph (single request),
|
||||
then page through that array locally using `page_token` as a simple integer offset.
|
||||
This avoids relying on buggy or inconsistent remote `from`/`size` semantics.
|
||||
"""
|
||||
self._ensure_valid_token()
|
||||
documents: List[Document] = []
|
||||
|
||||
try:
|
||||
user_drive_url = self._get_user_drive_web_url()
|
||||
query_text = self._build_shared_kql_query(search_query, user_drive_url)
|
||||
|
||||
url = f"{self.GRAPH_API_BASE}/search/query"
|
||||
page_size = 500 # maximum number of hits we care about for selection
|
||||
|
||||
body = {
|
||||
"requests": [
|
||||
{
|
||||
"entityTypes": ["driveItem"],
|
||||
"query": {"queryString": query_text},
|
||||
"from": 0,
|
||||
"size": page_size,
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
headers = self._get_headers()
|
||||
headers["Content-Type"] = "application/json"
|
||||
response = requests.post(url, headers=headers, json=body)
|
||||
response.raise_for_status()
|
||||
results = response.json()
|
||||
|
||||
search_response = results.get("value", [])
|
||||
if not search_response:
|
||||
logging.warning("Search API returned empty value array")
|
||||
self.next_page_token = None
|
||||
return documents
|
||||
|
||||
hits_containers = search_response[0].get("hitsContainers", [])
|
||||
if not hits_containers:
|
||||
logging.warning("Search API returned no hitsContainers")
|
||||
self.next_page_token = None
|
||||
return documents
|
||||
|
||||
container = hits_containers[0]
|
||||
total = container.get("total", 0)
|
||||
raw_hits = container.get("hits", [])
|
||||
|
||||
# Deduplicate by effective item ID (driveId:itemId) to avoid the same
|
||||
# resource appearing multiple times across the result set.
|
||||
deduped_hits = []
|
||||
seen_ids = set()
|
||||
for hit in raw_hits:
|
||||
resource = hit.get("resource", {})
|
||||
item_id = resource.get("id")
|
||||
drive_id = resource.get("parentReference", {}).get("driveId")
|
||||
effective_id = f"{drive_id}:{item_id}" if drive_id and item_id else item_id
|
||||
if not effective_id or effective_id in seen_ids:
|
||||
continue
|
||||
seen_ids.add(effective_id)
|
||||
deduped_hits.append(hit)
|
||||
|
||||
hits = deduped_hits
|
||||
logging.info(
|
||||
f"Search API returned {total} total results, {len(raw_hits)} raw hits, {len(hits)} unique hits in this batch"
|
||||
)
|
||||
try:
|
||||
offset = int(page_token) if page_token is not None else 0
|
||||
except (TypeError, ValueError):
|
||||
logging.warning(
|
||||
f"Invalid page_token '{page_token}' for shared items search, defaulting to 0"
|
||||
)
|
||||
offset = 0
|
||||
|
||||
if offset < 0:
|
||||
offset = 0
|
||||
if offset >= len(hits):
|
||||
self.next_page_token = None
|
||||
return documents
|
||||
|
||||
end_index = offset + limit if limit else len(hits)
|
||||
end_index = min(end_index, len(hits))
|
||||
|
||||
for hit in hits[offset:end_index]:
|
||||
resource = hit.get("resource", {})
|
||||
item_name = resource.get("name", "Unknown")
|
||||
item_id = resource.get("id")
|
||||
drive_id = resource.get("parentReference", {}).get("driveId")
|
||||
|
||||
effective_id = f"{drive_id}:{item_id}" if drive_id and item_id else item_id
|
||||
|
||||
is_folder = "folder" in resource
|
||||
|
||||
if is_folder:
|
||||
doc_metadata = {
|
||||
"file_name": item_name,
|
||||
"mime_type": "folder",
|
||||
"size": resource.get("size"),
|
||||
"created_time": resource.get("createdDateTime"),
|
||||
"modified_time": resource.get("lastModifiedDateTime"),
|
||||
"source": "share_point",
|
||||
"is_folder": True,
|
||||
}
|
||||
documents.append(
|
||||
Document(text="", doc_id=effective_id, extra_info=doc_metadata)
|
||||
)
|
||||
else:
|
||||
mime_type, supported = self._resolve_mime_type(resource)
|
||||
if not supported:
|
||||
logging.info(
|
||||
f"Skipping unsupported shared file: {item_name} (mime: {mime_type})"
|
||||
)
|
||||
continue
|
||||
|
||||
doc_metadata = {
|
||||
"file_name": item_name,
|
||||
"mime_type": mime_type,
|
||||
"size": resource.get("size"),
|
||||
"created_time": resource.get("createdDateTime"),
|
||||
"modified_time": resource.get("lastModifiedDateTime"),
|
||||
"source": "share_point",
|
||||
}
|
||||
|
||||
content = ""
|
||||
if load_content:
|
||||
content = self._download_file_content(effective_id) or ""
|
||||
|
||||
documents.append(
|
||||
Document(text=content, doc_id=effective_id, extra_info=doc_metadata)
|
||||
)
|
||||
|
||||
if limit and end_index < len(hits):
|
||||
self.next_page_token = str(end_index)
|
||||
else:
|
||||
self.next_page_token = None
|
||||
|
||||
return documents
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error listing shared items via search API: {e}", exc_info=True)
|
||||
return documents
|
||||
|
||||
@_retry_on_auth_failure
|
||||
def _download_file_content(self, file_id: str) -> Optional[str]:
|
||||
self._ensure_valid_token()
|
||||
|
||||
try:
|
||||
url = f"{self._get_item_url(file_id)}/content"
|
||||
response = requests.get(url, headers=self._get_headers())
|
||||
response.raise_for_status()
|
||||
|
||||
try:
|
||||
return response.content.decode('utf-8')
|
||||
except UnicodeDecodeError:
|
||||
logging.error(f"Could not decode file {file_id} as text")
|
||||
return None
|
||||
|
||||
except requests.exceptions.HTTPError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logging.error(f"Error downloading file {file_id}: {e}")
|
||||
return None
|
||||
|
||||
def _download_single_file(self, file_id: str, local_dir: str) -> bool:
|
||||
try:
|
||||
url = self._get_item_url(file_id)
|
||||
params = {'$select': 'id,name,file'}
|
||||
response = requests.get(url, headers=self._get_headers(), params=params)
|
||||
response.raise_for_status()
|
||||
|
||||
metadata = response.json()
|
||||
file_name = metadata.get('name', 'unknown')
|
||||
file_data = metadata.get('file', {})
|
||||
mime_type = file_data.get('mimeType', 'application/octet-stream')
|
||||
|
||||
if mime_type not in self.SUPPORTED_MIME_TYPES:
|
||||
logging.info(f"Skipping unsupported file type: {mime_type}")
|
||||
return False
|
||||
|
||||
os.makedirs(local_dir, exist_ok=True)
|
||||
full_path = os.path.join(local_dir, file_name)
|
||||
|
||||
download_url = f"{self._get_item_url(file_id)}/content"
|
||||
download_response = requests.get(download_url, headers=self._get_headers())
|
||||
download_response.raise_for_status()
|
||||
|
||||
with open(full_path, 'wb') as f:
|
||||
f.write(download_response.content)
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
logging.error(f"Error in _download_single_file: {e}")
|
||||
return False
|
||||
|
||||
def _download_folder_recursive(self, folder_id: str, local_dir: str, recursive: bool = True) -> int:
|
||||
files_downloaded = 0
|
||||
try:
|
||||
os.makedirs(local_dir, exist_ok=True)
|
||||
|
||||
url = f"{self._get_item_url(folder_id)}/children"
|
||||
params = {'$top': 1000}
|
||||
|
||||
while url:
|
||||
response = requests.get(url, headers=self._get_headers(), params=params)
|
||||
response.raise_for_status()
|
||||
|
||||
results = response.json()
|
||||
items = results.get('value', [])
|
||||
logging.info(f"Found {len(items)} items in folder {folder_id}")
|
||||
|
||||
for item in items:
|
||||
item_name = item.get('name', 'unknown')
|
||||
item_id = item.get('id')
|
||||
|
||||
if 'folder' in item:
|
||||
if recursive:
|
||||
subfolder_path = os.path.join(local_dir, item_name)
|
||||
os.makedirs(subfolder_path, exist_ok=True)
|
||||
subfolder_files = self._download_folder_recursive(
|
||||
item_id,
|
||||
subfolder_path,
|
||||
recursive
|
||||
)
|
||||
files_downloaded += subfolder_files
|
||||
logging.info(f"Downloaded {subfolder_files} files from subfolder {item_name}")
|
||||
else:
|
||||
success = self._download_single_file(item_id, local_dir)
|
||||
if success:
|
||||
files_downloaded += 1
|
||||
logging.info(f"Downloaded file: {item_name}")
|
||||
else:
|
||||
logging.warning(f"Failed to download file: {item_name}")
|
||||
|
||||
url = results.get('@odata.nextLink')
|
||||
|
||||
return files_downloaded
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error in _download_folder_recursive for folder {folder_id}: {e}", exc_info=True)
|
||||
return files_downloaded
|
||||
|
||||
def _download_folder_contents(self, folder_id: str, local_dir: str, recursive: bool = True) -> int:
|
||||
try:
|
||||
self._ensure_valid_token()
|
||||
return self._download_folder_recursive(folder_id, local_dir, recursive)
|
||||
except Exception as e:
|
||||
logging.error(f"Error downloading folder {folder_id}: {e}", exc_info=True)
|
||||
return 0
|
||||
|
||||
def _download_file_to_directory(self, file_id: str, local_dir: str) -> bool:
|
||||
try:
|
||||
self._ensure_valid_token()
|
||||
return self._download_single_file(file_id, local_dir)
|
||||
except Exception as e:
|
||||
logging.error(f"Error downloading file {file_id}: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
def download_to_directory(self, local_dir: str, source_config: Dict[str, Any] = None) -> Dict[str, Any]:
|
||||
if source_config is None:
|
||||
source_config = {}
|
||||
|
||||
config = source_config if source_config else getattr(self, 'config', {})
|
||||
files_downloaded = 0
|
||||
|
||||
try:
|
||||
folder_ids = config.get('folder_ids', [])
|
||||
file_ids = config.get('file_ids', [])
|
||||
recursive = config.get('recursive', True)
|
||||
|
||||
if file_ids:
|
||||
if isinstance(file_ids, str):
|
||||
file_ids = [file_ids]
|
||||
|
||||
for file_id in file_ids:
|
||||
if self._download_file_to_directory(file_id, local_dir):
|
||||
files_downloaded += 1
|
||||
|
||||
if folder_ids:
|
||||
if isinstance(folder_ids, str):
|
||||
folder_ids = [folder_ids]
|
||||
|
||||
for folder_id in folder_ids:
|
||||
try:
|
||||
url = self._get_item_url(folder_id)
|
||||
params = {'$select': 'id,name'}
|
||||
response = requests.get(url, headers=self._get_headers(), params=params)
|
||||
response.raise_for_status()
|
||||
|
||||
folder_metadata = response.json()
|
||||
folder_name = folder_metadata.get('name', '')
|
||||
folder_path = os.path.join(local_dir, folder_name)
|
||||
os.makedirs(folder_path, exist_ok=True)
|
||||
|
||||
folder_files = self._download_folder_recursive(
|
||||
folder_id,
|
||||
folder_path,
|
||||
recursive
|
||||
)
|
||||
files_downloaded += folder_files
|
||||
logging.info(f"Downloaded {folder_files} files from folder {folder_name}")
|
||||
except Exception as e:
|
||||
logging.error(f"Error downloading folder {folder_id}: {e}", exc_info=True)
|
||||
|
||||
if not file_ids and not folder_ids:
|
||||
raise ValueError("No folder_ids or file_ids provided for download")
|
||||
|
||||
return {
|
||||
"files_downloaded": files_downloaded,
|
||||
"directory_path": local_dir,
|
||||
"empty_result": files_downloaded == 0,
|
||||
"source_type": "share_point",
|
||||
"config_used": config
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"files_downloaded": files_downloaded,
|
||||
"directory_path": local_dir,
|
||||
"empty_result": True,
|
||||
"source_type": "share_point",
|
||||
"config_used": config,
|
||||
"error": str(e)
|
||||
}
|
||||
@@ -1,7 +1,6 @@
|
||||
anthropic==0.75.0
|
||||
boto3==1.42.17
|
||||
beautifulsoup4==4.14.3
|
||||
cel-python==0.5.0
|
||||
celery==5.6.0
|
||||
cryptography==46.0.3
|
||||
dataclasses-json==0.6.7
|
||||
@@ -47,6 +46,7 @@ markupsafe==3.0.3
|
||||
marshmallow>=3.18.0,<5.0.0
|
||||
mpmath==1.3.0
|
||||
multidict==6.7.0
|
||||
msal==1.34.0
|
||||
mypy-extensions==1.1.0
|
||||
networkx==3.6.1
|
||||
numpy==2.4.0
|
||||
@@ -95,4 +95,4 @@ werkzeug>=3.1.0
|
||||
yarl==1.22.0
|
||||
markdownify==1.2.2
|
||||
tldextract==5.3.0
|
||||
websockets==15.0.1
|
||||
websockets==15.0.1
|
||||
@@ -18,7 +18,6 @@ class ClassicRAG(BaseRetriever):
|
||||
doc_token_limit=50000,
|
||||
model_id="docsgpt-local",
|
||||
user_api_key=None,
|
||||
agent_id=None,
|
||||
llm_name=settings.LLM_PROVIDER,
|
||||
api_key=settings.API_KEY,
|
||||
decoded_token=None,
|
||||
@@ -44,7 +43,6 @@ class ClassicRAG(BaseRetriever):
|
||||
self.model_id = model_id
|
||||
self.doc_token_limit = doc_token_limit
|
||||
self.user_api_key = user_api_key
|
||||
self.agent_id = agent_id
|
||||
self.llm_name = llm_name
|
||||
self.api_key = api_key
|
||||
self.llm = LLMCreator.create_llm(
|
||||
@@ -52,7 +50,6 @@ class ClassicRAG(BaseRetriever):
|
||||
api_key=self.api_key,
|
||||
user_api_key=self.user_api_key,
|
||||
decoded_token=decoded_token,
|
||||
agent_id=self.agent_id,
|
||||
)
|
||||
|
||||
if "active_docs" in source and source["active_docs"] is not None:
|
||||
|
||||
@@ -1,104 +1,22 @@
|
||||
import sys
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.core.settings import settings
|
||||
from application.utils import num_tokens_from_object_or_list, num_tokens_from_string
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
mongo = MongoDB.get_client()
|
||||
db = mongo[settings.MONGO_DB_NAME]
|
||||
usage_collection = db["token_usage"]
|
||||
|
||||
|
||||
def _serialize_for_token_count(value):
|
||||
"""Normalize payloads into token-countable primitives."""
|
||||
if isinstance(value, str):
|
||||
# Avoid counting large binary payloads in data URLs as text tokens.
|
||||
if value.startswith("data:") and ";base64," in value:
|
||||
return ""
|
||||
return value
|
||||
|
||||
if value is None:
|
||||
return ""
|
||||
|
||||
if isinstance(value, list):
|
||||
return [_serialize_for_token_count(item) for item in value]
|
||||
|
||||
if isinstance(value, dict):
|
||||
serialized = {}
|
||||
for key, raw in value.items():
|
||||
key_lower = str(key).lower()
|
||||
|
||||
# Skip raw binary-like fields; keep textual tool-call fields.
|
||||
if key_lower in {"data", "base64", "image_data"} and isinstance(raw, str):
|
||||
continue
|
||||
if key_lower == "url" and isinstance(raw, str) and ";base64," in raw:
|
||||
continue
|
||||
|
||||
serialized[key] = _serialize_for_token_count(raw)
|
||||
return serialized
|
||||
|
||||
if hasattr(value, "model_dump") and callable(getattr(value, "model_dump")):
|
||||
return _serialize_for_token_count(value.model_dump())
|
||||
if hasattr(value, "to_dict") and callable(getattr(value, "to_dict")):
|
||||
return _serialize_for_token_count(value.to_dict())
|
||||
if hasattr(value, "__dict__"):
|
||||
return _serialize_for_token_count(vars(value))
|
||||
|
||||
return str(value)
|
||||
|
||||
|
||||
def _count_tokens(value):
|
||||
serialized = _serialize_for_token_count(value)
|
||||
if isinstance(serialized, str):
|
||||
return num_tokens_from_string(serialized)
|
||||
return num_tokens_from_object_or_list(serialized)
|
||||
|
||||
|
||||
def _count_prompt_tokens(messages, tools=None, usage_attachments=None, **kwargs):
|
||||
prompt_tokens = 0
|
||||
|
||||
for message in messages or []:
|
||||
if not isinstance(message, dict):
|
||||
prompt_tokens += _count_tokens(message)
|
||||
continue
|
||||
|
||||
prompt_tokens += _count_tokens(message.get("content"))
|
||||
|
||||
# Include tool-related message fields for providers that use OpenAI-native format.
|
||||
prompt_tokens += _count_tokens(message.get("tool_calls"))
|
||||
prompt_tokens += _count_tokens(message.get("tool_call_id"))
|
||||
prompt_tokens += _count_tokens(message.get("function_call"))
|
||||
prompt_tokens += _count_tokens(message.get("function_response"))
|
||||
|
||||
# Count tool schema payload passed to the model.
|
||||
prompt_tokens += _count_tokens(tools)
|
||||
|
||||
# Count structured-output/schema payloads when provided.
|
||||
prompt_tokens += _count_tokens(kwargs.get("response_format"))
|
||||
prompt_tokens += _count_tokens(kwargs.get("response_schema"))
|
||||
|
||||
# Optional usage-only attachment context (not forwarded to provider).
|
||||
prompt_tokens += _count_tokens(usage_attachments)
|
||||
|
||||
return prompt_tokens
|
||||
|
||||
|
||||
def update_token_usage(decoded_token, user_api_key, token_usage, agent_id=None):
|
||||
def update_token_usage(decoded_token, user_api_key, token_usage):
|
||||
if "pytest" in sys.modules:
|
||||
return
|
||||
user_id = decoded_token.get("sub") if isinstance(decoded_token, dict) else None
|
||||
normalized_agent_id = str(agent_id) if agent_id else None
|
||||
|
||||
if not user_id and not user_api_key and not normalized_agent_id:
|
||||
logger.warning(
|
||||
"Skipping token usage insert: missing user_id, api_key, and agent_id"
|
||||
)
|
||||
return
|
||||
|
||||
if decoded_token:
|
||||
user_id = decoded_token["sub"]
|
||||
else:
|
||||
user_id = None
|
||||
usage_data = {
|
||||
"user_id": user_id,
|
||||
"api_key": user_api_key,
|
||||
@@ -106,31 +24,24 @@ def update_token_usage(decoded_token, user_api_key, token_usage, agent_id=None):
|
||||
"generated_tokens": token_usage["generated_tokens"],
|
||||
"timestamp": datetime.now(),
|
||||
}
|
||||
if normalized_agent_id:
|
||||
usage_data["agent_id"] = normalized_agent_id
|
||||
usage_collection.insert_one(usage_data)
|
||||
|
||||
|
||||
def gen_token_usage(func):
|
||||
def wrapper(self, model, messages, stream, tools, **kwargs):
|
||||
usage_attachments = kwargs.pop("_usage_attachments", None)
|
||||
call_usage = {"prompt_tokens": 0, "generated_tokens": 0}
|
||||
call_usage["prompt_tokens"] += _count_prompt_tokens(
|
||||
messages,
|
||||
tools=tools,
|
||||
usage_attachments=usage_attachments,
|
||||
**kwargs,
|
||||
)
|
||||
for message in messages:
|
||||
if message["content"]:
|
||||
self.token_usage["prompt_tokens"] += num_tokens_from_string(
|
||||
message["content"]
|
||||
)
|
||||
result = func(self, model, messages, stream, tools, **kwargs)
|
||||
call_usage["generated_tokens"] += _count_tokens(result)
|
||||
self.token_usage["prompt_tokens"] += call_usage["prompt_tokens"]
|
||||
self.token_usage["generated_tokens"] += call_usage["generated_tokens"]
|
||||
update_token_usage(
|
||||
self.decoded_token,
|
||||
self.user_api_key,
|
||||
call_usage,
|
||||
getattr(self, "agent_id", None),
|
||||
)
|
||||
if isinstance(result, str):
|
||||
self.token_usage["generated_tokens"] += num_tokens_from_string(result)
|
||||
else:
|
||||
self.token_usage["generated_tokens"] += num_tokens_from_object_or_list(
|
||||
result
|
||||
)
|
||||
update_token_usage(self.decoded_token, self.user_api_key, self.token_usage)
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
@@ -138,28 +49,17 @@ def gen_token_usage(func):
|
||||
|
||||
def stream_token_usage(func):
|
||||
def wrapper(self, model, messages, stream, tools, **kwargs):
|
||||
usage_attachments = kwargs.pop("_usage_attachments", None)
|
||||
call_usage = {"prompt_tokens": 0, "generated_tokens": 0}
|
||||
call_usage["prompt_tokens"] += _count_prompt_tokens(
|
||||
messages,
|
||||
tools=tools,
|
||||
usage_attachments=usage_attachments,
|
||||
**kwargs,
|
||||
)
|
||||
for message in messages:
|
||||
self.token_usage["prompt_tokens"] += num_tokens_from_string(
|
||||
message["content"]
|
||||
)
|
||||
batch = []
|
||||
result = func(self, model, messages, stream, tools, **kwargs)
|
||||
for r in result:
|
||||
batch.append(r)
|
||||
yield r
|
||||
for line in batch:
|
||||
call_usage["generated_tokens"] += _count_tokens(line)
|
||||
self.token_usage["prompt_tokens"] += call_usage["prompt_tokens"]
|
||||
self.token_usage["generated_tokens"] += call_usage["generated_tokens"]
|
||||
update_token_usage(
|
||||
self.decoded_token,
|
||||
self.user_api_key,
|
||||
call_usage,
|
||||
getattr(self, "agent_id", None),
|
||||
)
|
||||
self.token_usage["generated_tokens"] += num_tokens_from_string(line)
|
||||
update_token_usage(self.decoded_token, self.user_api_key, self.token_usage)
|
||||
|
||||
return wrapper
|
||||
|
||||
@@ -322,7 +322,6 @@ def run_agent_logic(agent_config, input_data):
|
||||
chunks = int(agent_config.get("chunks", 2))
|
||||
prompt_id = agent_config.get("prompt_id", "default")
|
||||
user_api_key = agent_config["key"]
|
||||
agent_id = str(agent_config.get("_id")) if agent_config.get("_id") else None
|
||||
agent_type = agent_config.get("agent_type", "classic")
|
||||
decoded_token = {"sub": agent_config.get("user")}
|
||||
json_schema = agent_config.get("json_schema")
|
||||
@@ -353,7 +352,6 @@ def run_agent_logic(agent_config, input_data):
|
||||
doc_token_limit=doc_token_limit,
|
||||
model_id=model_id,
|
||||
user_api_key=user_api_key,
|
||||
agent_id=agent_id,
|
||||
decoded_token=decoded_token,
|
||||
)
|
||||
|
||||
@@ -372,7 +370,6 @@ def run_agent_logic(agent_config, input_data):
|
||||
llm_name=provider or settings.LLM_PROVIDER,
|
||||
model_id=model_id,
|
||||
api_key=system_api_key,
|
||||
agent_id=agent_id,
|
||||
user_api_key=user_api_key,
|
||||
prompt=prompt,
|
||||
chat_history=[],
|
||||
|
||||
@@ -1,25 +0,0 @@
|
||||
import { generateStaticParamsFor, importPage } from 'nextra/pages';
|
||||
|
||||
import { useMDXComponents } from '../../mdx-components';
|
||||
|
||||
export const generateStaticParams = generateStaticParamsFor('mdxPath');
|
||||
|
||||
export async function generateMetadata(props) {
|
||||
const params = await props.params;
|
||||
const { metadata } = await importPage(params?.mdxPath);
|
||||
return metadata;
|
||||
}
|
||||
|
||||
const Wrapper = useMDXComponents().wrapper;
|
||||
|
||||
export default async function Page(props) {
|
||||
const params = await props.params;
|
||||
const result = await importPage(params?.mdxPath);
|
||||
const { default: MDXContent, metadata, sourceCode, toc } = result;
|
||||
|
||||
return (
|
||||
<Wrapper metadata={metadata} sourceCode={sourceCode} toc={toc}>
|
||||
<MDXContent {...props} params={params} />
|
||||
</Wrapper>
|
||||
);
|
||||
}
|
||||
@@ -1,86 +0,0 @@
|
||||
import Image from 'next/image';
|
||||
import { Analytics } from '@vercel/analytics/react';
|
||||
import { Banner, Head } from 'nextra/components';
|
||||
import { getPageMap } from 'nextra/page-map';
|
||||
import { Footer, Layout, Navbar } from 'nextra-theme-docs';
|
||||
import 'nextra-theme-docs/style.css';
|
||||
|
||||
import CuteLogo from '../public/cute-docsgpt.png';
|
||||
import themeConfig from '../theme.config';
|
||||
|
||||
const github = 'https://github.com/arc53/DocsGPT';
|
||||
|
||||
export const metadata = {
|
||||
title: {
|
||||
default: 'DocsGPT Documentation',
|
||||
template: '%s - DocsGPT Documentation',
|
||||
},
|
||||
description:
|
||||
'Use DocsGPT to chat with your data. DocsGPT is a GPT-powered chatbot that can answer questions about your data.',
|
||||
};
|
||||
|
||||
const navbar = (
|
||||
<Navbar
|
||||
logo={
|
||||
<div style={{ alignItems: 'center', display: 'flex', gap: '8px' }}>
|
||||
<Image src={CuteLogo} alt="DocsGPT logo" width={28} height={28} />
|
||||
<span style={{ fontWeight: 'bold', fontSize: 18 }}>DocsGPT Docs</span>
|
||||
</div>
|
||||
}
|
||||
projectLink={github}
|
||||
chatLink="https://discord.com/invite/n5BX8dh8rU"
|
||||
/>
|
||||
);
|
||||
|
||||
const footer = (
|
||||
<Footer>
|
||||
<span>MIT {new Date().getFullYear()} © </span>
|
||||
<a href="https://www.docsgpt.cloud/" target="_blank" rel="noreferrer">
|
||||
DocsGPT
|
||||
</a>
|
||||
{' | '}
|
||||
<a href="https://github.com/arc53/DocsGPT" target="_blank" rel="noreferrer">
|
||||
GitHub
|
||||
</a>
|
||||
{' | '}
|
||||
<a href="https://blog.docsgpt.cloud/" target="_blank" rel="noreferrer">
|
||||
Blog
|
||||
</a>
|
||||
</Footer>
|
||||
);
|
||||
|
||||
export default async function RootLayout({ children }) {
|
||||
return (
|
||||
<html lang="en" dir="ltr" suppressHydrationWarning>
|
||||
<Head>
|
||||
<link
|
||||
rel="apple-touch-icon"
|
||||
sizes="180x180"
|
||||
href="/favicons/apple-touch-icon.png"
|
||||
/>
|
||||
<link rel="icon" type="image/png" sizes="32x32" href="/favicons/favicon-32x32.png" />
|
||||
<link rel="icon" type="image/png" sizes="16x16" href="/favicons/favicon-16x16.png" />
|
||||
<link rel="manifest" href="/favicons/site.webmanifest" />
|
||||
<meta httpEquiv="Content-Language" content="en" />
|
||||
</Head>
|
||||
<body>
|
||||
<Layout
|
||||
banner={
|
||||
<Banner storageKey="docs-launch">
|
||||
<div className="flex justify-center items-center gap-2">
|
||||
Welcome to the new DocsGPT docs!
|
||||
</div>
|
||||
</Banner>
|
||||
}
|
||||
navbar={navbar}
|
||||
footer={footer}
|
||||
pageMap={await getPageMap()}
|
||||
{...themeConfig}
|
||||
>
|
||||
{children}
|
||||
</Layout>
|
||||
<Analytics />
|
||||
</body>
|
||||
</html>
|
||||
);
|
||||
}
|
||||
@@ -1,5 +1,3 @@
|
||||
'use client';
|
||||
|
||||
import Image from 'next/image';
|
||||
|
||||
const iconMap = {
|
||||
@@ -119,4 +117,4 @@ export function DeploymentCards({ items }) {
|
||||
`}</style>
|
||||
</>
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -1,5 +1,3 @@
|
||||
'use client';
|
||||
|
||||
import Image from 'next/image';
|
||||
|
||||
const iconMap = {
|
||||
@@ -116,4 +114,4 @@ export function ToolCards({ items }) {
|
||||
`}</style>
|
||||
</>
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -1,342 +0,0 @@
|
||||
---
|
||||
title: Interacting with Agents via API
|
||||
description: Learn how to programmatically interact with DocsGPT Agents using the streaming and non-streaming API endpoints.
|
||||
---
|
||||
|
||||
import { Callout, Tabs } from 'nextra/components';
|
||||
|
||||
# Interacting with Agents via API
|
||||
|
||||
DocsGPT Agents can be accessed programmatically through API endpoints. This page covers:
|
||||
|
||||
- Non-streaming answers (`/api/answer`)
|
||||
- Streaming answers over SSE (`/stream`)
|
||||
- File/image attachments (`/api/store_attachment` + `/api/task_status` + `/stream`)
|
||||
|
||||
When you use an agent `api_key`, DocsGPT loads that agent's configuration automatically (prompt, tools, sources, default model). You usually only need to send `question` and `api_key`.
|
||||
|
||||
## Base URL
|
||||
|
||||
<Callout type="info">
|
||||
For DocsGPT Cloud, use `https://gptcloud.arc53.com` as the base URL.
|
||||
</Callout>
|
||||
|
||||
- Local: `http://localhost:7091`
|
||||
- Cloud: `https://gptcloud.arc53.com`
|
||||
|
||||
## How Request Resolution Works
|
||||
|
||||
DocsGPT resolves your request in this order:
|
||||
|
||||
1. If `api_key` is provided, DocsGPT loads the mapped agent and executes with that config.
|
||||
2. If `agent_id` is provided (typically with JWT auth), DocsGPT loads that agent if allowed.
|
||||
3. If neither is provided, DocsGPT uses request-level fields (`prompt_id`, `active_docs`, `retriever`, etc.).
|
||||
|
||||
Authentication:
|
||||
|
||||
- Agent API-key flow: include `api_key` in JSON/form payload.
|
||||
- JWT flow (if auth enabled): include `Authorization: Bearer <token>`.
|
||||
|
||||
## Endpoints
|
||||
|
||||
- `POST /api/answer` (non-streaming)
|
||||
- `POST /stream` (SSE streaming)
|
||||
- `POST /api/store_attachment` (multipart upload)
|
||||
- `GET /api/task_status?task_id=...` (Celery task polling)
|
||||
|
||||
## Request Parameters
|
||||
|
||||
Common request body fields:
|
||||
|
||||
| Field | Type | Required | Applies to | Notes |
|
||||
| --- | --- | --- | --- | --- |
|
||||
| `question` | `string` | Yes | `/api/answer`, `/stream` | User query. |
|
||||
| `api_key` | `string` | Usually | `/api/answer`, `/stream` | Recommended for agent API use. Loads agent config from key. |
|
||||
| `conversation_id` | `string` | No | `/api/answer`, `/stream` | Continue an existing conversation. |
|
||||
| `history` | `string` (JSON-encoded array) | No | `/api/answer`, `/stream` | Used for new conversations. Format: `[{\"prompt\":\"...\",\"response\":\"...\"}]`. |
|
||||
| `model_id` | `string` | No | `/api/answer`, `/stream` | Override model for this request. |
|
||||
| `save_conversation` | `boolean` | No | `/api/answer`, `/stream` | Default `true`. If `false`, no conversation is persisted. |
|
||||
| `passthrough` | `object` | No | `/api/answer`, `/stream` | Dynamic values injected into prompt templates. |
|
||||
| `prompt_id` | `string` | No | `/api/answer`, `/stream` | Ignored when `api_key` already defines prompt. |
|
||||
| `active_docs` | `string` or `string[]` | No | `/api/answer`, `/stream` | Overrides active docs when not using key-owned source config. |
|
||||
| `retriever` | `string` | No | `/api/answer`, `/stream` | Retriever type (for example `classic`). |
|
||||
| `chunks` | `number` | No | `/api/answer`, `/stream` | Retrieval chunk count, default `2`. |
|
||||
| `isNoneDoc` | `boolean` | No | `/api/answer`, `/stream` | Skip document retrieval. |
|
||||
| `agent_id` | `string` | No | `/api/answer`, `/stream` | Alternative to `api_key` when using authenticated user context. |
|
||||
|
||||
Streaming-only fields:
|
||||
|
||||
| Field | Type | Required | Notes |
|
||||
| --- | --- | --- | --- |
|
||||
| `attachments` | `string[]` | No | List of attachment IDs from `/api/task_status` success result. |
|
||||
| `index` | `number` | No | Update an existing query index. If provided, `conversation_id` is required. |
|
||||
|
||||
## Non-Streaming API (`/api/answer`)
|
||||
|
||||
`/api/answer` waits for completion and returns one JSON response.
|
||||
|
||||
<Callout type="info">
|
||||
`attachments` are currently handled through `/stream`. For file/image-attached queries, use the streaming endpoint.
|
||||
</Callout>
|
||||
|
||||
Response fields:
|
||||
|
||||
- `conversation_id`
|
||||
- `answer`
|
||||
- `sources`
|
||||
- `tool_calls`
|
||||
- `thought`
|
||||
- Optional structured output metadata (`structured`, `schema`) when enabled
|
||||
|
||||
### Examples
|
||||
|
||||
<Tabs items={['cURL', 'Python', 'JavaScript']}>
|
||||
<Tabs.Tab>
|
||||
```bash
|
||||
curl -X POST http://localhost:7091/api/answer \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"question":"your question here","api_key":"your_agent_api_key"}'
|
||||
```
|
||||
</Tabs.Tab>
|
||||
<Tabs.Tab>
|
||||
```python
|
||||
import requests
|
||||
|
||||
API_URL = "http://localhost:7091/api/answer"
|
||||
API_KEY = "your_agent_api_key"
|
||||
QUESTION = "your question here"
|
||||
|
||||
response = requests.post(
|
||||
API_URL,
|
||||
json={"question": QUESTION, "api_key": API_KEY}
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
print(response.json())
|
||||
else:
|
||||
print(f"Error: {response.status_code}")
|
||||
print(response.text)
|
||||
```
|
||||
</Tabs.Tab>
|
||||
<Tabs.Tab>
|
||||
```javascript
|
||||
const apiUrl = 'http://localhost:7091/api/answer';
|
||||
const apiKey = 'your_agent_api_key';
|
||||
const question = 'your question here';
|
||||
|
||||
async function getAnswer() {
|
||||
try {
|
||||
const response = await fetch(apiUrl, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
body: JSON.stringify({ question, api_key: apiKey }),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error(`HTTP error! Status: ${response.status}`);
|
||||
}
|
||||
|
||||
const data = await response.json();
|
||||
console.log(data);
|
||||
} catch (error) {
|
||||
console.error("Failed to fetch answer:", error);
|
||||
}
|
||||
}
|
||||
|
||||
getAnswer();
|
||||
```
|
||||
</Tabs.Tab>
|
||||
</Tabs>
|
||||
|
||||
---
|
||||
|
||||
## Streaming API (`/stream`)
|
||||
|
||||
`/stream` returns a Server-Sent Events (SSE) stream so you can render output token-by-token.
|
||||
|
||||
### SSE Event Types
|
||||
|
||||
Each `data:` frame is JSON with `type`:
|
||||
|
||||
- `answer`: incremental answer chunk
|
||||
- `source`: source list/chunks
|
||||
- `tool_calls`: tool invocation results/metadata
|
||||
- `thought`: reasoning/thought chunk (agent dependent)
|
||||
- `structured_answer`: final structured payload (when schema mode is active)
|
||||
- `id`: final conversation ID
|
||||
- `error`: error message
|
||||
- `end`: stream is complete
|
||||
|
||||
### Examples
|
||||
|
||||
<Tabs items={['cURL', 'Python', 'JavaScript']}>
|
||||
<Tabs.Tab>
|
||||
```bash
|
||||
curl -X POST http://localhost:7091/stream \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "Accept: text/event-stream" \
|
||||
-d '{"question":"your question here","api_key":"your_agent_api_key"}'
|
||||
```
|
||||
</Tabs.Tab>
|
||||
<Tabs.Tab>
|
||||
```python
|
||||
import requests
|
||||
import json
|
||||
|
||||
API_URL = "http://localhost:7091/stream"
|
||||
payload = {
|
||||
"question": "your question here",
|
||||
"api_key": "your_agent_api_key"
|
||||
}
|
||||
|
||||
with requests.post(API_URL, json=payload, stream=True) as r:
|
||||
for line in r.iter_lines():
|
||||
if line:
|
||||
decoded_line = line.decode('utf-8')
|
||||
if decoded_line.startswith('data: '):
|
||||
try:
|
||||
data = json.loads(decoded_line[6:])
|
||||
print(data)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
```
|
||||
</Tabs.Tab>
|
||||
<Tabs.Tab>
|
||||
```javascript
|
||||
const apiUrl = 'http://localhost:7091/stream';
|
||||
const apiKey = 'your_agent_api_key';
|
||||
const question = 'your question here';
|
||||
|
||||
async function getStream() {
|
||||
try {
|
||||
const response = await fetch(apiUrl, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
'Accept': 'text/event-stream'
|
||||
},
|
||||
body: JSON.stringify({ question, api_key: apiKey }),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error(`HTTP error! Status: ${response.status}`);
|
||||
}
|
||||
|
||||
const reader = response.body.getReader();
|
||||
const decoder = new TextDecoder();
|
||||
|
||||
while (true) {
|
||||
const { done, value } = await reader.read();
|
||||
if (done) break;
|
||||
|
||||
const chunk = decoder.decode(value, { stream: true });
|
||||
// Note: This parsing method assumes each chunk contains whole lines.
|
||||
// For a more robust production implementation, buffer the chunks
|
||||
// and process them line by line.
|
||||
const lines = chunk.split('\n');
|
||||
|
||||
for (const line of lines) {
|
||||
if (line.startsWith('data: ')) {
|
||||
try {
|
||||
const data = JSON.parse(line.substring(6));
|
||||
console.log(data);
|
||||
} catch (e) {
|
||||
console.error("Failed to parse JSON from SSE event:", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Failed to fetch stream:", error);
|
||||
}
|
||||
}
|
||||
|
||||
getStream();
|
||||
```
|
||||
</Tabs.Tab>
|
||||
</Tabs>
|
||||
|
||||
---
|
||||
|
||||
## Attachments API (Including Images)
|
||||
|
||||
To attach an image (or other file) to a query:
|
||||
|
||||
1. Upload file(s) to `/api/store_attachment` (multipart/form-data).
|
||||
2. Poll `/api/task_status` until `status=SUCCESS`.
|
||||
3. Read `result.attachment_id` from task result.
|
||||
4. Send that ID in `/stream` as `attachments: ["..."]`.
|
||||
|
||||
<Callout type="warning">
|
||||
Attachments are processed asynchronously. Do not call `/stream` with an attachment until its task has finished with `SUCCESS`.
|
||||
</Callout>
|
||||
|
||||
### Step 1: Upload Attachment
|
||||
|
||||
`POST /api/store_attachment`
|
||||
|
||||
- Content type: `multipart/form-data`
|
||||
- Form fields:
|
||||
- `file` (required, can be repeated for multi-file upload)
|
||||
- `api_key` (optional if JWT is present; useful for API-key-only flows)
|
||||
|
||||
Example upload (single image):
|
||||
|
||||
```bash
|
||||
curl -X POST http://localhost:7091/api/store_attachment \
|
||||
-F "file=@/absolute/path/to/image.png" \
|
||||
-F "api_key=your_agent_api_key"
|
||||
```
|
||||
|
||||
Possible response (single-file upload):
|
||||
|
||||
```json
|
||||
{
|
||||
"success": true,
|
||||
"task_id": "34f1cb56-7c7f-4d5f-a973-4ea7e65f7a10",
|
||||
"message": "File uploaded successfully. Processing started."
|
||||
}
|
||||
```
|
||||
|
||||
### Step 2: Poll Task Status
|
||||
|
||||
```bash
|
||||
curl "http://localhost:7091/api/task_status?task_id=34f1cb56-7c7f-4d5f-a973-4ea7e65f7a10"
|
||||
```
|
||||
|
||||
When complete:
|
||||
|
||||
```json
|
||||
{
|
||||
"status": "SUCCESS",
|
||||
"result": {
|
||||
"attachment_id": "67b4f8f2618dc9f19384a9e1",
|
||||
"filename": "image.png",
|
||||
"mime_type": "image/png"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Step 3: Attach to `/stream` Request
|
||||
|
||||
Use the `attachment_id` in `attachments`.
|
||||
|
||||
```bash
|
||||
curl -X POST http://localhost:7091/stream \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "Accept: text/event-stream" \
|
||||
-d '{
|
||||
"question": "Describe this image",
|
||||
"api_key": "your_agent_api_key",
|
||||
"attachments": ["67b4f8f2618dc9f19384a9e1"]
|
||||
}'
|
||||
```
|
||||
|
||||
### Image/Attachment Behavior Notes
|
||||
|
||||
- Typical image MIME types supported for native vision flows: `image/png`, `image/jpeg`, `image/jpg`, `image/webp`, `image/gif`.
|
||||
- If the selected model/provider does not support a file type natively, DocsGPT falls back to parsed text content.
|
||||
- For providers that support images but not native PDF file attachments, DocsGPT can convert PDF pages to images (synthetic PDF support).
|
||||
- Attachments are user-scoped. Upload and query must be done under the same user context (same API key owner or same JWT user).
|
||||
@@ -1,65 +0,0 @@
|
||||
# Workflow Nodes
|
||||
|
||||
DocsGPT workflows are composed of **Nodes** that are connected to form a processing graph. These nodes interact with a **Shared State**—a global dictionary of variables that persists throughout the execution of the workflow.
|
||||
|
||||
## The Shared State
|
||||
|
||||
Every workflow run maintains a state object (a JSON-like dictionary).
|
||||
- **Initial State**: Contains the user's input query (`{{query}}`) and chat history (`{{chat_history}}`).
|
||||
- **Accessing Variables**: You can access any variable in the state using the double-curly braces syntax: `{{variable_name}}`.
|
||||
- **Modifying State**: Nodes read from this state and write their outputs back to it.
|
||||
|
||||
---
|
||||
|
||||
## AI Agent Node
|
||||
|
||||
The **AI Agent Node** is the core processing unit. It uses a Large Language Model (LLM) to generate text, answer questions, or perform tasks using tools.
|
||||
|
||||
### Inputs (Template Variables)
|
||||
|
||||
The primary input is the **Prompt Template**. This field supports variable substitution.
|
||||
|
||||
- **Prompt Template**: The text sent to the model.
|
||||
- *Example*: `"Summarize the following text: {{user_input_text}}"`
|
||||
- If left empty, it defaults to the initial user query (`{{query}}`).
|
||||
- **System Prompt**: Instructions that define the agent's persona and constraints.
|
||||
- **Tools**: A list of tools the agent can use (e.g., search, calculator).
|
||||
- **LLM Settings**: Specific provider, model name, and parameters.
|
||||
|
||||
### Outputs (Emissions)
|
||||
|
||||
When the agent completes its task, it stores the result in the shared state.
|
||||
|
||||
- **Output Variable**: The name of the variable where the result will be saved.
|
||||
- *Default*: If not specified, it is saved as `node_{node_id}_output`.
|
||||
- *Custom*: You can set this to something meaningful, like `summary` or `translated_text`.
|
||||
- **Streaming**: If "Stream to user" is enabled, the output is sent to the user in real-time as it is generated, in addition to being saved to the state.
|
||||
|
||||
---
|
||||
|
||||
## Set State Node
|
||||
|
||||
The **Set State Node** allows you to manipulate variables within the shared state directly without calling an LLM. This is useful for initialization, formatting, or control flow logic.
|
||||
|
||||
### Operations
|
||||
|
||||
You can define multiple operations in a single node. Each operation targets a specific **Key** (variable name).
|
||||
|
||||
1. **Set**: Assigns a specific value to a variable.
|
||||
- *Value*: Can be a static string or a template using variables.
|
||||
- *Example*: Set `current_step` to `1`.
|
||||
- *Example*: Set `formatted_response` to `Analysis: {{analysis_result}}`.
|
||||
|
||||
2. **Increment**: Increases the value of a numeric variable.
|
||||
- *Value*: The amount to add (default is 1).
|
||||
- *Example*: Increment `retry_count` by `1`.
|
||||
|
||||
3. **Append**: Adds a value to a list variable.
|
||||
- *Value*: The item to add to the list.
|
||||
- *Example*: Append `{{last_result}}` to `history_list`.
|
||||
|
||||
### Usage Examples
|
||||
|
||||
- **Loop Counters**: Use a *Set State* node to initialize a counter (`i = 0`) before a loop, and another to increment it inside the loop.
|
||||
- **Accumulators**: Use *Append* to collect results from multiple parallel branches into a single list.
|
||||
- **Renaming**: Copy the output of a previous node to a more generic name (e.g., set `context` to `{{search_results}}`) so subsequent nodes can use a standard variable name.
|
||||
@@ -1,8 +0,0 @@
|
||||
import { useMDXComponents as getThemeComponents } from 'nextra-theme-docs';
|
||||
|
||||
export function useMDXComponents(components) {
|
||||
return {
|
||||
...getThemeComponents(),
|
||||
...components,
|
||||
};
|
||||
}
|
||||
@@ -1,9 +1,9 @@
|
||||
const nextra = require('nextra').default;
|
||||
const withNextra = require('nextra')({
|
||||
theme: 'nextra-theme-docs',
|
||||
themeConfig: './theme.config.jsx'
|
||||
})
|
||||
|
||||
const withNextra = nextra({
|
||||
defaultShowCopyCode: true,
|
||||
});
|
||||
|
||||
module.exports = withNextra({
|
||||
reactStrictMode: true,
|
||||
});
|
||||
module.exports = withNextra()
|
||||
|
||||
// If you have other Next.js configurations, you can pass them as the parameter:
|
||||
// module.exports = withNextra({ /* other next.js config */ })
|
||||
|
||||
5286
docs/package-lock.json
generated
5286
docs/package-lock.json
generated
File diff suppressed because it is too large
Load Diff
@@ -2,7 +2,6 @@
|
||||
"scripts": {
|
||||
"dev": "next dev",
|
||||
"build": "next build",
|
||||
"postbuild": "pagefind --site .next/server/app --output-path public/_pagefind",
|
||||
"start": "next start"
|
||||
},
|
||||
"license": "MIT",
|
||||
@@ -10,13 +9,9 @@
|
||||
"@vercel/analytics": "^1.1.1",
|
||||
"docsgpt-react": "^0.5.1",
|
||||
"next": "^15.5.9",
|
||||
"nextra": "^4.6.1",
|
||||
"nextra-theme-docs": "^4.6.1",
|
||||
"nextra": "^2.13.2",
|
||||
"nextra-theme-docs": "^2.13.2",
|
||||
"react": "^18.2.0",
|
||||
"react-dom": "^18.2.0"
|
||||
},
|
||||
"devDependencies": {
|
||||
"pagefind": "^1.3.0",
|
||||
"typescript": "^5.9.3"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
export default {
|
||||
{
|
||||
"basics": {
|
||||
"title": "🤖 Agent Basics",
|
||||
"href": "/Agents/basics"
|
||||
@@ -10,9 +10,5 @@ export default {
|
||||
"webhooks": {
|
||||
"title": "🪝 Agent Webhooks",
|
||||
"href": "/Agents/webhooks"
|
||||
},
|
||||
"nodes": {
|
||||
"title": "🧩 Workflow Nodes",
|
||||
"href": "/Agents/nodes"
|
||||
}
|
||||
}
|
||||
227
docs/pages/Agents/api.mdx
Normal file
227
docs/pages/Agents/api.mdx
Normal file
@@ -0,0 +1,227 @@
|
||||
---
|
||||
title: Interacting with Agents via API
|
||||
description: Learn how to programmatically interact with DocsGPT Agents using the streaming and non-streaming API endpoints.
|
||||
---
|
||||
|
||||
import { Callout, Tabs } from 'nextra/components';
|
||||
|
||||
# Interacting with Agents via API
|
||||
|
||||
DocsGPT Agents can be accessed programmatically through a dedicated API, allowing you to integrate their specialized capabilities into your own applications, scripts, and workflows. This guide covers the two primary methods for interacting with an agent: the streaming API for real-time responses and the non-streaming API for a single, consolidated answer.
|
||||
|
||||
When you use an API key generated for a specific agent, you do not need to pass `prompt`, `tools` etc. The agent's configuration (including its prompt, selected tools, and knowledge sources) is already associated with its unique API key.
|
||||
|
||||
### API Endpoints
|
||||
|
||||
- **Non-Streaming:** `http://localhost:7091/api/answer`
|
||||
- **Streaming:** `http://localhost:7091/stream`
|
||||
|
||||
<Callout type="info">
|
||||
For DocsGPT Cloud, use `https://gptcloud.arc53.com/` as the base URL.
|
||||
</Callout>
|
||||
|
||||
For more technical details, you can explore the API swagger documentation available for the cloud version or your local instance.
|
||||
|
||||
---
|
||||
|
||||
## Non-Streaming API (`/api/answer`)
|
||||
|
||||
This is a standard synchronous endpoint. It waits for the agent to fully process the request and returns a single JSON object with the complete answer. This is the simplest method and is ideal for backend processes where a real-time feed is not required.
|
||||
|
||||
### Request
|
||||
|
||||
- **Endpoint:** `/api/answer`
|
||||
- **Method:** `POST`
|
||||
- **Payload:**
|
||||
- `question` (string, required): The user's query or input for the agent.
|
||||
- `api_key` (string, required): The unique API key for the agent you wish to interact with.
|
||||
- `history` (string, optional): A JSON string representing the conversation history, e.g., `[{\"prompt\": \"first question\", \"answer\": \"first answer\"}]`.
|
||||
|
||||
### Response
|
||||
|
||||
A single JSON object containing:
|
||||
- `answer`: The complete, final answer from the agent.
|
||||
- `sources`: A list of sources the agent consulted.
|
||||
- `conversation_id`: The unique ID for the interaction.
|
||||
|
||||
### Examples
|
||||
|
||||
<Tabs items={['cURL', 'Python', 'JavaScript']}>
|
||||
<Tabs.Tab>
|
||||
```bash
|
||||
curl -X POST http://localhost:7091/api/answer \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"question": "your question here",
|
||||
"api_key": "your_agent_api_key"
|
||||
}'
|
||||
```
|
||||
</Tabs.Tab>
|
||||
<Tabs.Tab>
|
||||
```python
|
||||
import requests
|
||||
|
||||
API_URL = "http://localhost:7091/api/answer"
|
||||
API_KEY = "your_agent_api_key"
|
||||
QUESTION = "your question here"
|
||||
|
||||
response = requests.post(
|
||||
API_URL,
|
||||
json={"question": QUESTION, "api_key": API_KEY}
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
print(response.json())
|
||||
else:
|
||||
print(f"Error: {response.status_code}")
|
||||
print(response.text)
|
||||
```
|
||||
</Tabs.Tab>
|
||||
<Tabs.Tab>
|
||||
```javascript
|
||||
const apiUrl = 'http://localhost:7091/api/answer';
|
||||
const apiKey = 'your_agent_api_key';
|
||||
const question = 'your question here';
|
||||
|
||||
async function getAnswer() {
|
||||
try {
|
||||
const response = await fetch(apiUrl, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
body: JSON.stringify({ question, api_key: apiKey }),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error(`HTTP error! Status: ${response.status}`);
|
||||
}
|
||||
|
||||
const data = await response.json();
|
||||
console.log(data);
|
||||
} catch (error) {
|
||||
console.error("Failed to fetch answer:", error);
|
||||
}
|
||||
}
|
||||
|
||||
getAnswer();
|
||||
```
|
||||
</Tabs.Tab>
|
||||
</Tabs>
|
||||
|
||||
---
|
||||
|
||||
## Streaming API (`/stream`)
|
||||
|
||||
The `/stream` endpoint uses Server-Sent Events (SSE) to push data in real-time. This is ideal for applications where you want to display the response as it's being generated, such as in a live chatbot interface.
|
||||
|
||||
### Request
|
||||
|
||||
- **Endpoint:** `/stream`
|
||||
- **Method:** `POST`
|
||||
- **Payload:** Same as the non-streaming API.
|
||||
|
||||
### Response (SSE Stream)
|
||||
|
||||
The stream consists of multiple `data:` events, each containing a JSON object. Your client should listen for these events and process them based on their `type`.
|
||||
|
||||
**Event Types:**
|
||||
- `answer`: A chunk of the agent's final answer.
|
||||
- `source`: A document or source used by the agent.
|
||||
- `thought`: A reasoning step from the agent (for ReAct agents).
|
||||
- `id`: The unique `conversation_id` for the interaction.
|
||||
- `error`: An error message.
|
||||
- `end`: A final message indicating the stream has concluded.
|
||||
|
||||
### Examples
|
||||
|
||||
<Tabs items={['cURL', 'Python', 'JavaScript']}>
|
||||
<Tabs.Tab>
|
||||
```bash
|
||||
curl -X POST http://localhost:7091/stream \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "Accept: text/event-stream" \
|
||||
-d '{
|
||||
"question": "your question here",
|
||||
"api_key": "your_agent_api_key"
|
||||
}'
|
||||
```
|
||||
</Tabs.Tab>
|
||||
<Tabs.Tab>
|
||||
```python
|
||||
import requests
|
||||
import json
|
||||
|
||||
API_URL = "http://localhost:7091/stream"
|
||||
payload = {
|
||||
"question": "your question here",
|
||||
"api_key": "your_agent_api_key"
|
||||
}
|
||||
|
||||
with requests.post(API_URL, json=payload, stream=True) as r:
|
||||
for line in r.iter_lines():
|
||||
if line:
|
||||
decoded_line = line.decode('utf-8')
|
||||
if decoded_line.startswith('data: '):
|
||||
try:
|
||||
data = json.loads(decoded_line[6:])
|
||||
print(data)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
```
|
||||
</Tabs.Tab>
|
||||
<Tabs.Tab>
|
||||
```javascript
|
||||
const apiUrl = 'http://localhost:7091/stream';
|
||||
const apiKey = 'your_agent_api_key';
|
||||
const question = 'your question here';
|
||||
|
||||
async function getStream() {
|
||||
try {
|
||||
const response = await fetch(apiUrl, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
'Accept': 'text/event-stream'
|
||||
},
|
||||
// Corrected line: 'apiKey' is changed to 'api_key'
|
||||
body: JSON.stringify({ question, api_key: apiKey }),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error(`HTTP error! Status: ${response.status}`);
|
||||
}
|
||||
|
||||
const reader = response.body.getReader();
|
||||
const decoder = new TextDecoder();
|
||||
|
||||
while (true) {
|
||||
const { done, value } = await reader.read();
|
||||
if (done) break;
|
||||
|
||||
const chunk = decoder.decode(value, { stream: true });
|
||||
// Note: This parsing method assumes each chunk contains whole lines.
|
||||
// For a more robust production implementation, buffer the chunks
|
||||
// and process them line by line.
|
||||
const lines = chunk.split('\n');
|
||||
|
||||
for (const line of lines) {
|
||||
if (line.startsWith('data: ')) {
|
||||
try {
|
||||
const data = JSON.parse(line.substring(6));
|
||||
console.log(data);
|
||||
} catch (e) {
|
||||
console.error("Failed to parse JSON from SSE event:", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Failed to fetch stream:", error);
|
||||
}
|
||||
}
|
||||
|
||||
getStream();
|
||||
```
|
||||
</Tabs.Tab>
|
||||
</Tabs>
|
||||
@@ -1,4 +1,4 @@
|
||||
export default {
|
||||
{
|
||||
"DocsGPT-Settings": {
|
||||
"title": "⚙️ App Configuration",
|
||||
"href": "/Deploying/DocsGPT-Settings"
|
||||
@@ -29,4 +29,4 @@ export default {
|
||||
"href": "/Deploying/Railway",
|
||||
"display": "hidden"
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
export default {
|
||||
{
|
||||
"api-key-guide": {
|
||||
"title": "🔑 Getting API key",
|
||||
"href": "/Extensions/api-key-guide"
|
||||
@@ -19,4 +19,4 @@ export default {
|
||||
"title": "🗣️ Chatwoot Extension",
|
||||
"href": "/Extensions/Chatwoot-extension"
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
export default {
|
||||
{
|
||||
"google-drive-connector": {
|
||||
"title": "🔗 Google Drive",
|
||||
"href": "/Guides/Integrations/google-drive-connector"
|
||||
@@ -1,4 +1,4 @@
|
||||
export default {
|
||||
{
|
||||
"Customising-prompts": {
|
||||
"title": "️💻 Customising Prompts",
|
||||
"href": "/Guides/Customising-prompts"
|
||||
@@ -1,4 +1,4 @@
|
||||
export default {
|
||||
{
|
||||
"cloud-providers": {
|
||||
"title": "☁️ Cloud Providers",
|
||||
"href": "/Models/cloud-providers"
|
||||
@@ -11,4 +11,4 @@ export default {
|
||||
"title": "📝 Embeddings",
|
||||
"href": "/Models/embeddings"
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
export default {
|
||||
{
|
||||
"basics": {
|
||||
"title": "🔧 Tools Basics",
|
||||
"href": "/Tools/basics"
|
||||
@@ -11,4 +11,4 @@ export default {
|
||||
"title": "🛠️ Creating a Custom Tool",
|
||||
"href": "/Tools/creating-a-tool"
|
||||
}
|
||||
}
|
||||
}
|
||||
10
docs/pages/_app.mdx
Normal file
10
docs/pages/_app.mdx
Normal file
@@ -0,0 +1,10 @@
|
||||
import { DocsGPTWidget } from "docsgpt-react";
|
||||
|
||||
export default function MyApp({ Component, pageProps }) {
|
||||
return (
|
||||
<>
|
||||
<Component {...pageProps} />
|
||||
<DocsGPTWidget showSources={true} apiKey="5d8270cb-735f-484e-9dc9-5b407f24e652" theme="dark" size="medium" />
|
||||
</>
|
||||
)
|
||||
}
|
||||
@@ -1,4 +1,5 @@
|
||||
export default {
|
||||
|
||||
{
|
||||
"index": "Home",
|
||||
"quickstart": "Quickstart",
|
||||
"Deploying": "Deploying",
|
||||
@@ -8,11 +9,12 @@ export default {
|
||||
"Extensions": "Extensions",
|
||||
"https://gptcloud.arc53.com/": {
|
||||
"title": "API",
|
||||
"href": "https://gptcloud.arc53.com/"
|
||||
"href": "https://gptcloud.arc53.com/",
|
||||
"newWindow": true
|
||||
},
|
||||
"Guides": "Guides",
|
||||
"changelog": {
|
||||
"title": "Changelog",
|
||||
"display": "hidden"
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -2,7 +2,7 @@
|
||||
title: 'Home'
|
||||
description: Documentation of DocsGPT - quickstart, deployment guides, model configuration, and widget integration documentation.
|
||||
---
|
||||
import { Cards } from 'nextra/components'
|
||||
import { Cards, Card } from 'nextra/components'
|
||||
import Image from 'next/image'
|
||||
|
||||
export const allGuides = {
|
||||
@@ -85,7 +85,7 @@ Try it yourself: [https://www.docsgpt.cloud/](https://www.docsgpt.cloud/)
|
||||
<Cards
|
||||
num={3}
|
||||
children={Object.keys(allGuides).map((key, i) => (
|
||||
<Cards.Card
|
||||
<Card
|
||||
key={i}
|
||||
title={allGuides[key].title}
|
||||
href={allGuides[key].href}
|
||||
@@ -1,63 +0,0 @@
|
||||
# DocsGPT
|
||||
|
||||
> DocsGPT is an open-source platform for building AI agents and assistants with document retrieval, tools, and multi-model support.
|
||||
|
||||
This file is a curated map of DocsGPT documentation for LLM and agent use.
|
||||
Prioritize Core, Deploying, and Agents for implementation tasks.
|
||||
|
||||
## Core
|
||||
|
||||
- [Docs Home](https://docs.docsgpt.cloud/): Main documentation landing page.
|
||||
- [Quickstart](https://docs.docsgpt.cloud/quickstart): Fastest path to run DocsGPT locally.
|
||||
- [Architecture](https://docs.docsgpt.cloud/Guides/Architecture): High-level system architecture.
|
||||
- [Development Environment](https://docs.docsgpt.cloud/Deploying/Development-Environment): Backend and frontend local setup.
|
||||
- [DocsGPT Settings](https://docs.docsgpt.cloud/Deploying/DocsGPT-Settings): Environment variables and core app configuration.
|
||||
|
||||
## Deploying
|
||||
|
||||
- [Docker Deployment](https://docs.docsgpt.cloud/Deploying/Docker-Deploying): Run DocsGPT with Docker and Docker Compose.
|
||||
- [Kubernetes Deployment](https://docs.docsgpt.cloud/Deploying/Kubernetes-Deploying): Deploy DocsGPT on Kubernetes clusters.
|
||||
- [Hosting DocsGPT](https://docs.docsgpt.cloud/Deploying/Hosting-the-app): Hosting overview with cloud options.
|
||||
|
||||
## Agents
|
||||
|
||||
- [Agent Basics](https://docs.docsgpt.cloud/Agents/basics): Core concepts for building and managing agents.
|
||||
- [Workflow Nodes](https://docs.docsgpt.cloud/Agents/nodes): Node types and behavior in agent workflows.
|
||||
- [Agent API](https://docs.docsgpt.cloud/Agents/api): Programmatic agent interaction (streaming and non-streaming).
|
||||
- [Agent Webhooks](https://docs.docsgpt.cloud/Agents/webhooks): Trigger and automate agents with webhooks.
|
||||
|
||||
## Tools
|
||||
|
||||
- [Tools Basics](https://docs.docsgpt.cloud/Tools/basics): How tools extend agent capabilities.
|
||||
- [Generic API Tool](https://docs.docsgpt.cloud/Tools/api-tool): Configure API calls without custom code.
|
||||
- [Creating a Custom Tool](https://docs.docsgpt.cloud/Tools/creating-a-tool): Build custom Python tools for DocsGPT.
|
||||
|
||||
## Models
|
||||
|
||||
- [Cloud LLM Providers](https://docs.docsgpt.cloud/Models/cloud-providers): Configure hosted model providers.
|
||||
- [Local Inference](https://docs.docsgpt.cloud/Models/local-inference): Connect DocsGPT to local inference backends.
|
||||
- [Embeddings](https://docs.docsgpt.cloud/Models/embeddings): Select and configure embedding models.
|
||||
|
||||
## Extensions
|
||||
|
||||
- [API Keys for Integrations](https://docs.docsgpt.cloud/Extensions/api-key-guide): Generate and use DocsGPT API keys.
|
||||
- [Chat Widget](https://docs.docsgpt.cloud/Extensions/chat-widget): Embed the DocsGPT chat widget.
|
||||
- [Search Widget](https://docs.docsgpt.cloud/Extensions/search-widget): Embed the DocsGPT search widget.
|
||||
- [Chrome Extension](https://docs.docsgpt.cloud/Extensions/Chrome-extension): Install and use the browser extension.
|
||||
- [Chatwoot Extension](https://docs.docsgpt.cloud/Extensions/Chatwoot-extension): Integrate DocsGPT with Chatwoot.
|
||||
|
||||
## Integrations
|
||||
|
||||
- [Google Drive Connector](https://docs.docsgpt.cloud/Guides/Integrations/google-drive-connector): Ingest and sync files from Google Drive.
|
||||
|
||||
## Optional
|
||||
|
||||
- [Customizing Prompts](https://docs.docsgpt.cloud/Guides/Customising-prompts): Template-based prompt customization.
|
||||
- [How to Train on Other Documentation](https://docs.docsgpt.cloud/Guides/How-to-train-on-other-documentation): Add additional documentation sources.
|
||||
- [Context Compression](https://docs.docsgpt.cloud/Guides/compression): Reduce context while preserving key information.
|
||||
- [OCR for Sources and Attachments](https://docs.docsgpt.cloud/Guides/ocr): OCR behavior for ingestion and chat uploads.
|
||||
- [How to Use Different LLMs](https://docs.docsgpt.cloud/Guides/How-to-use-different-LLM): Additional model-selection guidance.
|
||||
- [Avoiding Hallucinations](https://docs.docsgpt.cloud/Guides/My-AI-answers-questions-using-external-knowledge): Improve answer grounding with external knowledge.
|
||||
- [Amazon Lightsail Deployment](https://docs.docsgpt.cloud/Deploying/Amazon-Lightsail): Deploy DocsGPT on AWS Lightsail.
|
||||
- [Railway Deployment](https://docs.docsgpt.cloud/Deploying/Railway): Deploy DocsGPT on Railway.
|
||||
- [Changelog](https://docs.docsgpt.cloud/changelog): Project release history.
|
||||
@@ -1,20 +1,161 @@
|
||||
import Image from 'next/image'
|
||||
import { Analytics } from '@vercel/analytics/react';
|
||||
|
||||
const github = 'https://github.com/arc53/DocsGPT';
|
||||
const isDevelopment = process.env.NODE_ENV === 'development';
|
||||
|
||||
|
||||
|
||||
|
||||
import { useConfig, useTheme } from 'nextra-theme-docs';
|
||||
import CuteLogo from './public/cute-docsgpt.png';
|
||||
const Logo = ({ height, width }) => {
|
||||
const { theme } = useTheme();
|
||||
return (
|
||||
<div style={{ alignItems: 'center', display: 'flex', gap: '8px' }}>
|
||||
<Image src={CuteLogo} alt="DocsGPT logo" width={width} height={height} />
|
||||
|
||||
<span style={{ fontWeight: 'bold', fontSize: 18 }}>DocsGPT Docs</span>
|
||||
|
||||
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
const config = {
|
||||
docsRepositoryBase: `${github}/blob/main/docs`,
|
||||
darkMode: true,
|
||||
search: isDevelopment ? null : undefined,
|
||||
nextThemes: {
|
||||
defaultTheme: 'dark',
|
||||
chat: {
|
||||
link: 'https://discord.com/invite/n5BX8dh8rU',
|
||||
},
|
||||
sidebar: {
|
||||
defaultMenuCollapseLevel: 1,
|
||||
banner: {
|
||||
key: 'docs-launch',
|
||||
text: (
|
||||
<div className="flex justify-center items-center gap-2">
|
||||
Welcome to the new DocsGPT 🦖 docs! 👋
|
||||
</div>
|
||||
),
|
||||
},
|
||||
toc: {
|
||||
float: true,
|
||||
},
|
||||
editLink: 'Edit this page on GitHub',
|
||||
project: {
|
||||
link: github,
|
||||
},
|
||||
darkMode: true,
|
||||
nextThemes: {
|
||||
defaultTheme: 'dark',
|
||||
},
|
||||
primaryHue: {
|
||||
dark: 207,
|
||||
light: 212,
|
||||
},
|
||||
footer: {
|
||||
text: (
|
||||
<div>
|
||||
<span>MIT {new Date().getFullYear()} © </span>
|
||||
<a href="https://www.docsgpt.cloud/" target="_blank">
|
||||
DocsGPT
|
||||
</a>
|
||||
{' | '}
|
||||
<a href="https://github.com/arc53/DocsGPT" target="_blank">
|
||||
GitHub
|
||||
</a>
|
||||
{' | '}
|
||||
<a href="https://blog.docsgpt.cloud/" target="_blank">
|
||||
Blog
|
||||
</a>
|
||||
</div>
|
||||
),
|
||||
},
|
||||
editLink: {
|
||||
content: 'Edit this page on GitHub',
|
||||
},
|
||||
logo() {
|
||||
return (
|
||||
<div className="flex items-center gap-2">
|
||||
<Logo width={28} height={28} />
|
||||
</div>
|
||||
);
|
||||
},
|
||||
useNextSeoProps() {
|
||||
return {
|
||||
titleTemplate: `%s - DocsGPT Documentation`,
|
||||
};
|
||||
},
|
||||
|
||||
head() {
|
||||
const { frontMatter } = useConfig();
|
||||
const { theme } = useTheme();
|
||||
const title = frontMatter?.title || 'Chat with your data with DocsGPT';
|
||||
const description =
|
||||
frontMatter?.description ||
|
||||
'Use DocsGPT to chat with your data. DocsGPT is a GPT powered chatbot that can answer questions about your data.'
|
||||
const image = '/cute-docsgpt.png';
|
||||
|
||||
const composedTitle = `${title} – DocsGPT Documentation`;
|
||||
|
||||
return (
|
||||
<>
|
||||
<link
|
||||
rel="apple-touch-icon"
|
||||
sizes="180x180"
|
||||
href={`/favicons/apple-touch-icon.png`}
|
||||
/>
|
||||
<link
|
||||
rel="icon"
|
||||
type="image/png"
|
||||
sizes="32x32"
|
||||
href={`/favicons/favicon-32x32.png`}
|
||||
/>
|
||||
<link
|
||||
rel="icon"
|
||||
type="image/png"
|
||||
sizes="16x16"
|
||||
href={`/favicons/favicon-16x16.png`}
|
||||
/>
|
||||
<meta name="theme-color" content="#ffffff" />
|
||||
<meta name="msapplication-TileColor" content="#00a300" />
|
||||
<link rel="manifest" href={`/favicons/site.webmanifest`} />
|
||||
<meta httpEquiv="Content-Language" content="en" />
|
||||
<meta name="title" content={composedTitle} />
|
||||
<meta name="description" content={description} />
|
||||
|
||||
<meta name="twitter:card" content="summary_large_image" />
|
||||
<meta name="twitter:site" content="@ATushynski" />
|
||||
<meta name="twitter:image" content={image} />
|
||||
|
||||
<meta property="og:description" content={description} />
|
||||
<meta property="og:title" content={composedTitle} />
|
||||
<meta property="og:image" content={image} />
|
||||
<meta property="og:type" content="website" />
|
||||
<meta
|
||||
name="apple-mobile-web-app-title"
|
||||
content="DocsGPT Documentation"
|
||||
/>
|
||||
|
||||
</>
|
||||
);
|
||||
},
|
||||
sidebar: {
|
||||
defaultMenuCollapseLevel: 1,
|
||||
titleComponent: ({ title, type }) =>
|
||||
type === 'separator' ? (
|
||||
<div className="flex items-center gap-2">
|
||||
<Logo height={10} width={10} />
|
||||
{title}
|
||||
<Analytics />
|
||||
</div>
|
||||
|
||||
) : (
|
||||
<>{title}
|
||||
<Analytics />
|
||||
</>
|
||||
|
||||
),
|
||||
},
|
||||
|
||||
gitTimestamp: ({ timestamp }) => (
|
||||
<>Last updated on {timestamp.toLocaleDateString()}</>
|
||||
),
|
||||
};
|
||||
|
||||
export default config;
|
||||
export default config;
|
||||
@@ -3,5 +3,4 @@ VITE_BASE_URL=http://localhost:5173
|
||||
VITE_API_HOST=http://127.0.0.1:7091
|
||||
VITE_API_STREAMING=true
|
||||
VITE_NOTIFICATION_TEXT="What's new in 0.15.0 — Changelog"
|
||||
VITE_NOTIFICATION_LINK="https://blog.docsgpt.cloud/docsgpt-0-15-masters-long-term-memory-and-tooling/"
|
||||
VITE_GOOGLE_CLIENT_ID=896376503572-u46l78n8ctgtdr4dlei4u06jv6rbpqc5.apps.googleusercontent.com
|
||||
VITE_NOTIFICATION_LINK="https://blog.docsgpt.cloud/docsgpt-0-15-masters-long-term-memory-and-tooling/"
|
||||
1
frontend/package-lock.json
generated
1
frontend/package-lock.json
generated
@@ -8103,6 +8103,7 @@
|
||||
"https://github.com/sponsors/katex"
|
||||
],
|
||||
"license": "MIT",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"commander": "^8.3.0"
|
||||
},
|
||||
|
||||
@@ -320,4 +320,4 @@ export default function AgentCard({
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -603,4 +603,4 @@ function AgentSection({
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -439,24 +439,10 @@ export default function NewAgent({ mode }: { mode: 'new' | 'edit' | 'draft' }) {
|
||||
const data = await response.json();
|
||||
const transformed = modelService.transformModels(data.models || []);
|
||||
setAvailableModels(transformed);
|
||||
|
||||
if (mode === 'new' && transformed.length > 0) {
|
||||
const preferredDefaultModelId =
|
||||
transformed.find((model) => model.id === data.default_model_id)?.id ||
|
||||
transformed[0].id;
|
||||
|
||||
if (preferredDefaultModelId) {
|
||||
setSelectedModelIds((prevSelectedModelIds) =>
|
||||
prevSelectedModelIds.size > 0
|
||||
? prevSelectedModelIds
|
||||
: new Set([preferredDefaultModelId]),
|
||||
);
|
||||
}
|
||||
}
|
||||
};
|
||||
getTools();
|
||||
getModels();
|
||||
}, [token, mode]);
|
||||
}, [token]);
|
||||
|
||||
// Validate folder_id from URL against user's folders
|
||||
useEffect(() => {
|
||||
@@ -1405,4 +1391,4 @@ function AddPromptModal({
|
||||
handleAddPrompt={handleAddPrompt}
|
||||
/>
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -41,4 +41,4 @@ export const agentSectionsConfig = [
|
||||
selectData: selectSharedAgents,
|
||||
updateAction: setSharedAgents,
|
||||
},
|
||||
];
|
||||
];
|
||||
@@ -18,4 +18,4 @@ export default function Agents() {
|
||||
<Route path="/workflow/edit/:agentId" element={<WorkflowBuilder />} />
|
||||
</Routes>
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -1,20 +1,4 @@
|
||||
export type NodeType = 'start' | 'end' | 'agent' | 'note' | 'state' | 'condition';
|
||||
|
||||
export interface ConditionCase {
|
||||
name?: string;
|
||||
expression: string;
|
||||
sourceHandle: string;
|
||||
}
|
||||
|
||||
export interface ConditionNodeConfig {
|
||||
mode: 'simple' | 'advanced';
|
||||
cases: ConditionCase[];
|
||||
}
|
||||
|
||||
export interface StateOperationConfig {
|
||||
expression: string;
|
||||
target_variable: string;
|
||||
}
|
||||
export type NodeType = 'start' | 'end' | 'agent' | 'note' | 'state';
|
||||
|
||||
export interface WorkflowEdge {
|
||||
id: string;
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -4,7 +4,6 @@ import {
|
||||
Circle,
|
||||
Database,
|
||||
Flag,
|
||||
GitBranch,
|
||||
Loader2,
|
||||
MessageSquare,
|
||||
Play,
|
||||
@@ -54,7 +53,6 @@ const NODE_ICONS: Record<string, React.ReactNode> = {
|
||||
end: <Flag className="h-3 w-3" />,
|
||||
note: <StickyNote className="h-3 w-3" />,
|
||||
state: <Database className="h-3 w-3" />,
|
||||
condition: <GitBranch className="h-3 w-3" />,
|
||||
};
|
||||
|
||||
const NODE_COLORS: Record<string, string> = {
|
||||
@@ -63,7 +61,6 @@ const NODE_COLORS: Record<string, string> = {
|
||||
end: 'text-gray-600 dark:text-gray-400',
|
||||
note: 'text-yellow-600 dark:text-yellow-400',
|
||||
state: 'text-blue-600 dark:text-blue-400',
|
||||
condition: 'text-orange-600 dark:text-orange-400',
|
||||
};
|
||||
|
||||
function ExecutionDetails({
|
||||
@@ -87,9 +84,7 @@ function ExecutionDetails({
|
||||
|
||||
const formatValue = (value: unknown): string => {
|
||||
if (typeof value === 'string') return value;
|
||||
if (value === undefined) return '';
|
||||
const formatted = JSON.stringify(value, null, 2);
|
||||
return formatted ?? String(value);
|
||||
return JSON.stringify(value, null, 2);
|
||||
};
|
||||
|
||||
return (
|
||||
@@ -138,11 +133,6 @@ function ExecutionDetails({
|
||||
if (text.length <= maxLength) return text;
|
||||
return text.slice(0, maxLength) + '...';
|
||||
};
|
||||
const hasOutput =
|
||||
step.output !== undefined &&
|
||||
step.output !== null &&
|
||||
formatValue(step.output) !== '';
|
||||
const formattedOutput = hasOutput ? formatValue(step.output) : '';
|
||||
|
||||
return (
|
||||
<div
|
||||
@@ -178,15 +168,15 @@ function ExecutionDetails({
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
{(hasOutput || step.error || stateVars.length > 0) && (
|
||||
{(step.output || step.error || stateVars.length > 0) && (
|
||||
<div className="mt-3 space-y-2 text-sm">
|
||||
{hasOutput && (
|
||||
{step.output && (
|
||||
<div className="rounded-lg bg-white p-2 dark:bg-[#2A2A2A]">
|
||||
<span className="font-medium text-gray-600 dark:text-gray-400">
|
||||
Output:{' '}
|
||||
</span>
|
||||
<span className="wrap-break-word whitespace-pre-wrap text-gray-900 dark:text-gray-100">
|
||||
{truncateText(formattedOutput, 300)}
|
||||
{truncateText(step.output, 300)}
|
||||
</span>
|
||||
</div>
|
||||
)}
|
||||
@@ -261,7 +251,7 @@ function WorkflowMiniMap({
|
||||
return step?.status || 'pending';
|
||||
};
|
||||
|
||||
const getStatusColor = (nodeId: string) => {
|
||||
const getStatusColor = (nodeId: string, nodeType: string) => {
|
||||
const status = getNodeStatus(nodeId);
|
||||
const isActive = nodeId === activeNodeId;
|
||||
|
||||
@@ -277,33 +267,26 @@ function WorkflowMiniMap({
|
||||
case 'failed':
|
||||
return 'bg-red-100 dark:bg-red-900/30 border-red-300 dark:border-red-700';
|
||||
default:
|
||||
if (nodeType === 'start') {
|
||||
return 'bg-green-50 dark:bg-green-900/20 border-green-200 dark:border-green-800';
|
||||
}
|
||||
if (nodeType === 'agent') {
|
||||
return 'bg-purple-50 dark:bg-purple-900/20 border-purple-200 dark:border-purple-800';
|
||||
}
|
||||
if (nodeType === 'end') {
|
||||
return 'bg-gray-50 dark:bg-gray-800 border-gray-200 dark:border-gray-700';
|
||||
}
|
||||
return 'bg-gray-50 dark:bg-gray-800 border-gray-200 dark:border-gray-700';
|
||||
}
|
||||
};
|
||||
|
||||
const executedOrder = new Map(executionSteps.map((s, i) => [s.nodeId, i]));
|
||||
const startNode = nodes.find((node) => node.type === 'start');
|
||||
const visibleNodeIds = new Set(executionSteps.map((step) => step.nodeId));
|
||||
if (activeNodeId) {
|
||||
visibleNodeIds.add(activeNodeId);
|
||||
}
|
||||
if (startNode) {
|
||||
visibleNodeIds.add(startNode.id);
|
||||
}
|
||||
|
||||
const sortedNodes = nodes
|
||||
.filter((node) => visibleNodeIds.has(node.id))
|
||||
.sort((a, b) => {
|
||||
if (a.type === 'start') return -1;
|
||||
if (b.type === 'start') return 1;
|
||||
|
||||
const aIdx = executedOrder.get(a.id);
|
||||
const bIdx = executedOrder.get(b.id);
|
||||
if (aIdx !== undefined && bIdx !== undefined) return aIdx - bIdx;
|
||||
if (aIdx !== undefined) return -1;
|
||||
if (bIdx !== undefined) return 1;
|
||||
return (a.position?.y || 0) - (b.position?.y || 0);
|
||||
});
|
||||
const sortedNodes = [...nodes].sort((a, b) => {
|
||||
if (a.type === 'start') return -1;
|
||||
if (b.type === 'start') return 1;
|
||||
if (a.type === 'end') return 1;
|
||||
if (b.type === 'end') return -1;
|
||||
return (a.position?.y || 0) - (b.position?.y || 0);
|
||||
});
|
||||
|
||||
const hasStepData = (nodeId: string) => {
|
||||
const step = executionSteps.find((s) => s.nodeId === nodeId);
|
||||
@@ -323,7 +306,7 @@ function WorkflowMiniMap({
|
||||
disabled={!hasStepData(node.id)}
|
||||
className={cn(
|
||||
'flex h-12 w-full items-center gap-2 rounded-lg border px-3 text-xs transition-all',
|
||||
getStatusColor(node.id),
|
||||
getStatusColor(node.id, node.type),
|
||||
hasStepData(node.id) && 'cursor-pointer hover:opacity-80',
|
||||
)}
|
||||
>
|
||||
@@ -550,10 +533,6 @@ export default function WorkflowPreview({
|
||||
const querySteps = query.executionSteps || [];
|
||||
const hasResponse = !!(query.response || query.error);
|
||||
const isLastQuery = index === queries.length - 1;
|
||||
const isStreamingLastQuery =
|
||||
status === 'loading' && isLastQuery;
|
||||
const shouldShowThought =
|
||||
!isStreamingLastQuery && Boolean(query.thought);
|
||||
const isOpen =
|
||||
openDetailsIndex === index ||
|
||||
(!hasResponse && isLastQuery && querySteps.length > 0);
|
||||
@@ -588,19 +567,17 @@ export default function WorkflowPreview({
|
||||
|
||||
{/* Response bubble */}
|
||||
{(query.response ||
|
||||
shouldShowThought ||
|
||||
query.thought ||
|
||||
query.tool_calls) && (
|
||||
<ConversationBubble
|
||||
className={isLastQuery ? 'mb-32' : 'mb-7'}
|
||||
message={query.response}
|
||||
type="ANSWER"
|
||||
thought={
|
||||
shouldShowThought ? query.thought : undefined
|
||||
}
|
||||
thought={query.thought}
|
||||
sources={query.sources}
|
||||
toolCalls={query.tool_calls}
|
||||
feedback={query.feedback}
|
||||
isStreaming={isStreamingLastQuery}
|
||||
isStreaming={status === 'loading' && isLastQuery}
|
||||
/>
|
||||
)}
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user