Compare commits

..

1 Commits

Author SHA1 Message Date
dependabot[bot]
4c7a6a78aa chore(deps): bump docker/setup-qemu-action from 3 to 4
Bumps [docker/setup-qemu-action](https://github.com/docker/setup-qemu-action) from 3 to 4.
- [Release notes](https://github.com/docker/setup-qemu-action/releases)
- [Commits](https://github.com/docker/setup-qemu-action/compare/v3...v4)

---
updated-dependencies:
- dependency-name: docker/setup-qemu-action
  dependency-version: '4'
  dependency-type: direct:production
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>
2026-03-04 20:54:17 +00:00
264 changed files with 10612 additions and 40575 deletions

View File

@@ -3,14 +3,6 @@ LLM_NAME=docsgpt
VITE_API_STREAMING=true
INTERNAL_KEY=<internal key for worker-to-backend authentication>
# Provider-specific API keys (optional - use these to enable multiple providers)
# OPENAI_API_KEY=<your-openai-api-key>
# ANTHROPIC_API_KEY=<your-anthropic-api-key>
# GOOGLE_API_KEY=<your-google-api-key>
# GROQ_API_KEY=<your-groq-api-key>
# NOVITA_API_KEY=<your-novita-api-key>
# OPEN_ROUTER_API_KEY=<your-openrouter-api-key>
# Remote Embeddings (Optional - for using a remote embeddings API instead of local SentenceTransformer)
# When set, the app will use the remote API and won't load SentenceTransformer (saves RAM)
EMBEDDINGS_BASE_URL=
@@ -20,17 +12,4 @@ EMBEDDINGS_KEY=
OPENAI_API_BASE=
OPENAI_API_VERSION=
AZURE_DEPLOYMENT_NAME=
AZURE_EMBEDDINGS_DEPLOYMENT_NAME=
#Azure AD Application (client) ID
MICROSOFT_CLIENT_ID=your-azure-ad-client-id
#Azure AD Application client secret
MICROSOFT_CLIENT_SECRET=your-azure-ad-client-secret
#Azure AD Tenant ID (or 'common' for multi-tenant)
MICROSOFT_TENANT_ID=your-azure-ad-tenant-id
#If you are using a Microsoft Entra ID tenant,
#configure the AUTHORITY variable as
#"https://login.microsoftonline.com/TENANT_GUID"
#or "https://login.microsoftonline.com/contoso.onmicrosoft.com".
#Alternatively, use "https://login.microsoftonline.com/common" for multi-tenant app.
MICROSOFT_AUTHORITY=https://{tenantId}.ciamlogin.com/{tenantId}
AZURE_EMBEDDINGS_DEPLOYMENT_NAME=

View File

@@ -25,7 +25,7 @@ jobs:
- name: Set up QEMU # Only needed for emulation, not for native arm64 builds
if: matrix.platform == 'linux/arm64'
uses: docker/setup-qemu-action@v3
uses: docker/setup-qemu-action@v4
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3

View File

@@ -25,7 +25,7 @@ jobs:
- name: Set up QEMU # Only needed for emulation, not for native arm64 builds
if: matrix.platform == 'linux/arm64'
uses: docker/setup-qemu-action@v3
uses: docker/setup-qemu-action@v4
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3

View File

@@ -27,7 +27,7 @@ jobs:
- name: Set up QEMU # Only needed for emulation, not for native arm64 builds
if: matrix.platform == 'linux/arm64'
uses: docker/setup-qemu-action@v3
uses: docker/setup-qemu-action@v4
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3

View File

@@ -1,114 +0,0 @@
name: Publish npm libraries
on:
workflow_dispatch:
inputs:
version:
description: >
Version bump type (patch | minor | major) or explicit semver (e.g. 1.2.3).
Applies to both docsgpt and docsgpt-react.
required: true
default: patch
permissions:
contents: write
pull-requests: write
jobs:
publish:
runs-on: ubuntu-latest
environment: npm-release
defaults:
run:
working-directory: extensions/react-widget
steps:
- uses: actions/checkout@v4
- uses: actions/setup-node@v4
with:
node-version: 20
registry-url: https://registry.npmjs.org
- name: Install dependencies
run: npm ci
# ── docsgpt (HTML embedding bundle) ──────────────────────────────────
# Uses the `build` script (parcel build src/browser.tsx) and keeps
# the `targets` field so Parcel produces browser-optimised bundles.
- name: Set package name → docsgpt
run: jq --arg n "docsgpt" '.name=$n' package.json > _tmp.json && mv _tmp.json package.json
- name: Bump version (docsgpt)
id: version_docsgpt
run: |
VERSION="${{ github.event.inputs.version }}"
NEW_VER=$(npm version "${VERSION:-patch}" --no-git-tag-version)
echo "version=${NEW_VER#v}" >> "$GITHUB_OUTPUT"
- name: Build docsgpt
run: npm run build
- name: Publish docsgpt
run: npm publish --verbose
env:
NODE_AUTH_TOKEN: ${{ secrets.NPM_TOKEN }}
# ── docsgpt-react (React library bundle) ─────────────────────────────
# Uses `build:react` script (parcel build src/index.ts) and strips
# the `targets` field so Parcel treats the output as a plain library
# without browser-specific target resolution, producing a smaller bundle.
- name: Reset package.json from source control
run: git checkout -- package.json
- name: Set package name → docsgpt-react
run: jq --arg n "docsgpt-react" '.name=$n' package.json > _tmp.json && mv _tmp.json package.json
- name: Remove targets field (react library build)
run: jq 'del(.targets)' package.json > _tmp.json && mv _tmp.json package.json
- name: Bump version (docsgpt-react) to match docsgpt
run: npm version "${{ steps.version_docsgpt.outputs.version }}" --no-git-tag-version
- name: Clean dist before react build
run: rm -rf dist
- name: Build docsgpt-react
run: npm run build:react
- name: Publish docsgpt-react
run: npm publish --verbose
env:
NODE_AUTH_TOKEN: ${{ secrets.NPM_TOKEN }}
# ── Commit the bumped version back to the repository ─────────────────
- name: Reset package.json and write final version
run: |
git checkout -- package.json
jq --arg v "${{ steps.version_docsgpt.outputs.version }}" '.version=$v' \
package.json > _tmp.json && mv _tmp.json package.json
npm install --package-lock-only
- name: Commit version bump and create PR
run: |
git config user.name "github-actions[bot]"
git config user.email "github-actions[bot]@users.noreply.github.com"
BRANCH="chore/bump-npm-v${{ steps.version_docsgpt.outputs.version }}"
git checkout -b "$BRANCH"
git add package.json package-lock.json
git commit -m "chore: bump npm libraries to v${{ steps.version_docsgpt.outputs.version }}"
git push origin "$BRANCH"
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
- name: Create PR
run: |
gh pr create \
--title "chore: bump npm libraries to v${{ steps.version_docsgpt.outputs.version }}" \
--body "Automated version bump after npm publish." \
--base main
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}

View File

@@ -1,34 +0,0 @@
name: React Widget Build
on:
push:
paths:
- 'extensions/react-widget/**'
pull_request:
paths:
- 'extensions/react-widget/**'
permissions:
contents: read
jobs:
build:
runs-on: ubuntu-latest
defaults:
run:
working-directory: extensions/react-widget
steps:
- uses: actions/checkout@v4
- uses: actions/setup-node@v4
with:
node-version: 20
cache: npm
cache-dependency-path: extensions/react-widget/package-lock.json
- name: Install dependencies
run: npm ci
- name: Build
run: npm run build

1
.gitignore vendored
View File

@@ -2,7 +2,6 @@
__pycache__/
*.py[cod]
*$py.class
results.txt
experiments/
experiments

134
AGENTS.md
View File

@@ -1,134 +0,0 @@
# AGENTS.md
- Read `CONTRIBUTING.md` before making non-trivial changes.
- For day-to-day development and feature work, follow the development-environment workflow rather than defaulting to `setup.sh` / `setup.ps1`.
- Avoid using the setup scripts during normal feature work unless the user explicitly asks for them. Users configure `.env` usually.
- Try to follow red/green TDD
### Check existing dev prerequisites first
For feature work, do **not** assume the environment needs to be recreated.
- Check whether the user already has a Python virtual environment such as `venv/` or `.venv/`.
- Check whether MongoDB is already running.
- Check whether Redis is already running.
- Reuse what is already working. Do not stop or recreate MongoDB, Redis, or the Python environment unless the task is environment setup or troubleshooting.
## Normal local development commands
Use these commands once the dev prerequisites above are satisfied.
### Backend
```bash
source .venv/bin/activate # macOS/Linux
uv pip install -r application/requirements.txt # or: pip install -r application/requirements.txt
```
Run the Flask API (if needed):
```bash
flask --app application/app.py run --host=0.0.0.0 --port=7091
```
Run the Celery worker in a separate terminal (if needed):
```bash
celery -A application.app.celery worker -l INFO
```
On macOS, prefer the solo pool for Celery:
```bash
python -m celery -A application.app.celery worker -l INFO --pool=solo
```
### Frontend
Install dependencies only when needed, then run the dev server:
```bash
cd frontend
npm install --include=dev
npm run dev
```
### Docs site
```bash
cd docs
npm install
```
### Python / backend changes validation
```bash
ruff check .
python -m pytest
```
### Frontend changes
```bash
cd frontend && npm run lint
cd frontend && npm run build
```
### Documentation changes
```bash
cd docs && npm run build
```
If Vale is installed locally and you edited prose, also run:
```bash
vale .
```
## Repository map
- `application/`: Flask backend, API routes, agent logic, retrieval, parsing, security, storage, Celery worker, and WSGI entrypoints.
- `tests/`: backend unit/integration tests and test-only Python dependencies.
- `frontend/`: Vite + React + TypeScript application.
- `frontend/src/`: main UI code, including `components`, `conversation`, `hooks`, `locale`, `settings`, `upload`, and Redux store wiring in `store.ts`.
- `docs/`: separate documentation site built with Next.js/Nextra.
- `extensions/`: integrations and widgets such as Chatwoot, Chrome, Discord, React widget, Slack bot, and web widget.
- `deployment/`: Docker Compose variants and Kubernetes manifests.
## Coding rules
### Backend
- Follow PEP 8 and keep Python line length at or under 120 characters.
- Use type hints for function arguments and return values.
- Add Google-style docstrings to new or substantially changed functions and classes.
- Add or update tests under `tests/` for backend behavior changes.
- Keep changes narrow in `api`, `auth`, `security`, `parser`, `retriever`, and `storage` areas.
### Backend Abstractions
- LLM providers implement a common interface in `application/llm/` (add new providers by extending the base class).
- Vector stores are abstracted in `application/vectorstore/`.
- Parsers live in `application/parser/` and handle different document formats in the ingestion stage.
- Agents and tools are in `application/agents/` and `application/agents/tools/`.
- Celery setup/config lives in `application/celery_init.py` and `application/celeryconfig.py`.
- Settings and env vars are managed via Pydantic in `application/core/settings.py`.
### Frontend
- Follow the existing ESLint + Prettier setup.
- Prefer small, reusable functional components and hooks.
- If shared state must be added, use Redux rather than introducing a new global state library.
- Avoid broad UI refactors unless the task explicitly asks for them.
- Do not re-create components if we already have some in the app.
## PR readiness
Before opening a PR:
- run the relevant validation commands above
- confirm backend changes still work end-to-end after ingesting sample data when applicable
- clearly summarize user-visible behavior changes
- mention any config, dependency, or deployment implications
- Ask your user to attach a screenshot or a video to it

View File

@@ -22,11 +22,6 @@ Thank you for choosing to contribute to DocsGPT! We are all very grateful!
- We have a frontend built on React (Vite) and a backend in Python.
> **Required for every PR:** Please attach screenshots or a short screen
> recording that shows the working version of your changes. This makes the
> requirement visible to reviewers and helps them quickly verify what you are
> submitting.
Before creating issues, please check out how the latest version of our app looks and works by launching it via [Quickstart](https://github.com/arc53/DocsGPT#quickstart) the version on our live demo is slightly modified with login. Your issues should relate to the version you can launch via [Quickstart](https://github.com/arc53/DocsGPT#quickstart).
@@ -130,7 +125,7 @@ Here's a step-by-step guide on how to contribute to DocsGPT:
```
9. **Submit a Pull Request (PR):**
- Create a Pull Request from your branch to the main repository. Make sure to include a detailed description of your changes, reference any related issues, and attach screenshots or a screen recording showing the working version.
- Create a Pull Request from your branch to the main repository. Make sure to include a detailed description of your changes and reference any related issues.
10. **Collaborate:**
- Be responsive to comments and feedback on your PR.

View File

@@ -7,7 +7,7 @@
</p>
<p align="left">
<strong><a href="https://www.docsgpt.cloud/">DocsGPT</a></strong> is an open-source AI platform for building intelligent agents and assistants. Features Agent Builder, deep research tools, document analysis (PDF, Office, web content, and audio), Multi-model support (choose your provider or run locally), and rich API connectivity for agents with actionable tools and integrations. Deploy anywhere with complete privacy control.
<strong><a href="https://www.docsgpt.cloud/">DocsGPT</a></strong> is an open-source AI platform for building intelligent agents and assistants. Features Agent Builder, deep research tools, document analysis (PDF, Office, web content), Multi-model support (choose your provider or run locally), and rich API connectivity for agents with actionable tools and integrations. Deploy anywhere with complete privacy control.
</p>
<div align="center">
@@ -29,14 +29,13 @@
<div align="center">
<br>
<img src="https://d3dg1063dc54p9.cloudfront.net/videos/demo-26.gif" alt="video-example-of-docs-gpt" width="800" height="480">
<img src="https://d3dg1063dc54p9.cloudfront.net/videos/demov7.gif" alt="video-example-of-docs-gpt" width="800" height="450">
</div>
<h3 align="left">
<strong>Key Features:</strong>
</h3>
<ul align="left">
<li><strong>🗂️ Wide Format Support:</strong> Reads PDF, DOCX, CSV, XLSX, EPUB, MD, RST, HTML, MDX, JSON, PPTX, images, and audio files such as MP3, WAV, M4A, OGG, and WebM.</li>
<li><strong>🎙️ Speech Workflows:</strong> Record voice input into chat, transcribe audio on the backend, and ingest meeting recordings or voice notes as searchable knowledge.</li>
<li><strong>🗂️ Wide Format Support:</strong> Reads PDF, DOCX, CSV, XLSX, EPUB, MD, RST, HTML, MDX, JSON, PPTX, and images.</li>
<li><strong>🌐 Web & Data Integration:</strong> Ingests from URLs, sitemaps, Reddit, GitHub and web crawlers.</li>
<li><strong>✅ Reliable Answers:</strong> Get accurate, hallucination-free responses with source citations viewable in a clean UI.</li>
<li><strong>🔑 Streamlined API Keys:</strong> Generate keys linked to your settings, documents, and models, simplifying chatbot and integration setup.</li>
@@ -159,3 +158,4 @@ The source code license is [MIT](https://opensource.org/license/mit/), as descri
</a>
</p>

View File

@@ -1,8 +1,7 @@
import logging
from application.agents.agentic_agent import AgenticAgent
from application.agents.classic_agent import ClassicAgent
from application.agents.research_agent import ResearchAgent
from application.agents.react_agent import ReActAgent
from application.agents.workflow_agent import WorkflowAgent
logger = logging.getLogger(__name__)
@@ -11,9 +10,7 @@ logger = logging.getLogger(__name__)
class AgentCreator:
agents = {
"classic": ClassicAgent,
"react": ClassicAgent, # backwards compat: react falls back to classic
"agentic": AgenticAgent,
"research": ResearchAgent,
"react": ReActAgent,
"workflow": WorkflowAgent,
}

View File

@@ -1,63 +0,0 @@
import logging
from typing import Dict, Generator, Optional
from application.agents.base import BaseAgent
from application.agents.tools.internal_search import (
INTERNAL_TOOL_ID,
add_internal_search_tool,
)
from application.logging import LogContext
logger = logging.getLogger(__name__)
class AgenticAgent(BaseAgent):
"""Agent where the LLM controls retrieval via tools.
Unlike ClassicAgent which pre-fetches docs into the prompt,
AgenticAgent gives the LLM an internal_search tool so it can
decide when, what, and whether to search.
"""
def __init__(
self,
retriever_config: Optional[Dict] = None,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
self.retriever_config = retriever_config or {}
def _gen_inner(
self, query: str, log_context: LogContext
) -> Generator[Dict, None, None]:
tools_dict = self.tool_executor.get_tools()
add_internal_search_tool(tools_dict, self.retriever_config)
self._prepare_tools(tools_dict)
# 4. Build messages (prompt has NO pre-fetched docs)
messages = self._build_messages(self.prompt, query)
# 5. Call LLM — the handler manages the tool loop
llm_response = self._llm_gen(messages, log_context)
yield from self._handle_response(
llm_response, tools_dict, messages, log_context
)
# 6. Collect sources from internal search tool results
self._collect_internal_sources()
yield {"sources": self.retrieved_docs}
yield {"tool_calls": self._get_truncated_tool_calls()}
log_context.stacks.append(
{"component": "agent", "data": {"tool_calls": self.tool_calls.copy()}}
)
def _collect_internal_sources(self):
"""Collect retrieved docs from the cached InternalSearchTool instance."""
cache_key = f"internal_search:{INTERNAL_TOOL_ID}:{self.user or ''}"
tool = self.tool_executor._loaded_tools.get(cache_key)
if tool and hasattr(tool, "retrieved_docs") and tool.retrieved_docs:
self.retrieved_docs = tool.retrieved_docs

View File

@@ -3,11 +3,15 @@ import uuid
from abc import ABC, abstractmethod
from typing import Dict, Generator, List, Optional
from application.agents.tool_executor import ToolExecutor
from bson.objectid import ObjectId
from application.agents.tools.tool_action_parser import ToolActionParser
from application.agents.tools.tool_manager import ToolManager
from application.core.json_schema_utils import (
JsonSchemaValidationError,
normalize_json_schema_payload,
)
from application.core.mongo_db import MongoDB
from application.core.settings import settings
from application.llm.handlers.handler_creator import LLMHandlerCreator
from application.llm.llm_creator import LLMCreator
@@ -36,10 +40,6 @@ class BaseAgent(ABC):
limited_request_mode: Optional[bool] = False,
request_limit: Optional[int] = settings.DEFAULT_AGENT_LIMITS["request_limit"],
compressed_summary: Optional[str] = None,
llm=None,
llm_handler=None,
tool_executor: Optional[ToolExecutor] = None,
backup_models: Optional[List[str]] = None,
):
self.endpoint = endpoint
self.llm_name = llm_name
@@ -50,42 +50,22 @@ class BaseAgent(ABC):
self.prompt = prompt
self.decoded_token = decoded_token or {}
self.user: str = self.decoded_token.get("sub")
self.tool_config: Dict = {}
self.tools: List[Dict] = []
self.tool_calls: List[Dict] = []
self.chat_history: List[Dict] = chat_history if chat_history is not None else []
# Dependency injection for LLM — fall back to creating if not provided
if llm is not None:
self.llm = llm
else:
self.llm = LLMCreator.create_llm(
llm_name,
api_key=api_key,
user_api_key=user_api_key,
decoded_token=decoded_token,
model_id=model_id,
agent_id=agent_id,
backup_models=backup_models,
)
self.llm = LLMCreator.create_llm(
llm_name,
api_key=api_key,
user_api_key=user_api_key,
decoded_token=decoded_token,
model_id=model_id,
agent_id=agent_id,
)
self.retrieved_docs = retrieved_docs or []
if llm_handler is not None:
self.llm_handler = llm_handler
else:
self.llm_handler = LLMHandlerCreator.create_handler(
llm_name if llm_name else "default"
)
# Tool executor — injected or created
if tool_executor is not None:
self.tool_executor = tool_executor
else:
self.tool_executor = ToolExecutor(
user_api_key=user_api_key,
user=self.user,
decoded_token=decoded_token,
)
self.llm_handler = LLMHandlerCreator.create_handler(
llm_name if llm_name else "default"
)
self.attachments = attachments or []
self.json_schema = None
if json_schema is not None:
@@ -113,72 +93,327 @@ class BaseAgent(ABC):
) -> Generator[Dict, None, None]:
pass
# ---- Tool delegation (thin wrappers around ToolExecutor) ----
@property
def tool_calls(self) -> List[Dict]:
return self.tool_executor.tool_calls
@tool_calls.setter
def tool_calls(self, value: List[Dict]):
self.tool_executor.tool_calls = value
def _get_tools(self, api_key: str = None) -> Dict[str, Dict]:
return self.tool_executor._get_tools_by_api_key(api_key or self.user_api_key)
mongo = MongoDB.get_client()
db = mongo[settings.MONGO_DB_NAME]
agents_collection = db["agents"]
tools_collection = db["user_tools"]
agent_data = agents_collection.find_one({"key": api_key or self.user_api_key})
tool_ids = agent_data.get("tools", []) if agent_data else []
tools = (
tools_collection.find(
{"_id": {"$in": [ObjectId(tool_id) for tool_id in tool_ids]}}
)
if tool_ids
else []
)
tools = list(tools)
tools_by_id = {str(tool["_id"]): tool for tool in tools} if tools else {}
return tools_by_id
def _get_user_tools(self, user="local"):
return self.tool_executor._get_user_tools(user)
mongo = MongoDB.get_client()
db = mongo[settings.MONGO_DB_NAME]
user_tools_collection = db["user_tools"]
user_tools = user_tools_collection.find({"user": user, "status": True})
user_tools = list(user_tools)
return {str(i): tool for i, tool in enumerate(user_tools)}
def _build_tool_parameters(self, action):
return self.tool_executor._build_tool_parameters(action)
params = {"type": "object", "properties": {}, "required": []}
for param_type in ["query_params", "headers", "body", "parameters"]:
if param_type in action and action[param_type].get("properties"):
for k, v in action[param_type]["properties"].items():
if v.get("filled_by_llm", True):
params["properties"][k] = {
key: value
for key, value in v.items()
if key not in ("filled_by_llm", "value", "required")
}
if v.get("required", False):
params["required"].append(k)
return params
def _prepare_tools(self, tools_dict):
self.tools = self.tool_executor.prepare_tools_for_llm(tools_dict)
self.tools = [
{
"type": "function",
"function": {
"name": f"{action['name']}_{tool_id}",
"description": action["description"],
"parameters": self._build_tool_parameters(action),
},
}
for tool_id, tool in tools_dict.items()
if (
(tool["name"] == "api_tool" and "actions" in tool.get("config", {}))
or (tool["name"] != "api_tool" and "actions" in tool)
)
for action in (
tool["config"]["actions"].values()
if tool["name"] == "api_tool"
else tool["actions"]
)
if action.get("active", True)
]
def _execute_tool_action(self, tools_dict, call):
return self.tool_executor.execute(
tools_dict, call, self.llm.__class__.__name__
parser = ToolActionParser(self.llm.__class__.__name__)
tool_id, action_name, call_args = parser.parse_args(call)
call_id = getattr(call, "id", None) or str(uuid.uuid4())
# Check if parsing failed
if tool_id is None or action_name is None:
error_message = f"Error: Failed to parse LLM tool call. Tool name: {getattr(call, 'name', 'unknown')}"
logger.error(error_message)
tool_call_data = {
"tool_name": "unknown",
"call_id": call_id,
"action_name": getattr(call, "name", "unknown"),
"arguments": call_args or {},
"result": f"Failed to parse tool call. Invalid tool name format: {getattr(call, 'name', 'unknown')}",
}
yield {"type": "tool_call", "data": {**tool_call_data, "status": "error"}}
self.tool_calls.append(tool_call_data)
return "Failed to parse tool call.", call_id
# Check if tool_id exists in available tools
if tool_id not in tools_dict:
error_message = f"Error: Tool ID '{tool_id}' extracted from LLM call not found in available tools_dict. Available IDs: {list(tools_dict.keys())}"
logger.error(error_message)
# Return error result
tool_call_data = {
"tool_name": "unknown",
"call_id": call_id,
"action_name": f"{action_name}_{tool_id}",
"arguments": call_args,
"result": f"Tool with ID {tool_id} not found. Available tools: {list(tools_dict.keys())}",
}
yield {"type": "tool_call", "data": {**tool_call_data, "status": "error"}}
self.tool_calls.append(tool_call_data)
return f"Tool with ID {tool_id} not found.", call_id
tool_call_data = {
"tool_name": tools_dict[tool_id]["name"],
"call_id": call_id,
"action_name": f"{action_name}_{tool_id}",
"arguments": call_args,
}
yield {"type": "tool_call", "data": {**tool_call_data, "status": "pending"}}
tool_data = tools_dict[tool_id]
action_data = (
tool_data["config"]["actions"][action_name]
if tool_data["name"] == "api_tool"
else next(
action
for action in tool_data["actions"]
if action["name"] == action_name
)
)
def _get_truncated_tool_calls(self):
return self.tool_executor.get_truncated_tool_calls()
query_params, headers, body, parameters = {}, {}, {}, {}
param_types = {
"query_params": query_params,
"headers": headers,
"body": body,
"parameters": parameters,
}
# ---- Context / token management ----
for param_type, target_dict in param_types.items():
if param_type in action_data and action_data[param_type].get("properties"):
for param, details in action_data[param_type]["properties"].items():
if (
param not in call_args
and "value" in details
and details["value"]
):
target_dict[param] = details["value"]
for param, value in call_args.items():
for param_type, target_dict in param_types.items():
if param_type in action_data and param in action_data[param_type].get(
"properties", {}
):
target_dict[param] = value
tm = ToolManager(config={})
# Prepare tool_config and add tool_id for memory tools
if tool_data["name"] == "api_tool":
action_config = tool_data["config"]["actions"][action_name]
tool_config = {
"url": action_config["url"],
"method": action_config["method"],
"headers": headers,
"query_params": query_params,
}
if "body_content_type" in action_config:
tool_config["body_content_type"] = action_config.get(
"body_content_type", "application/json"
)
tool_config["body_encoding_rules"] = action_config.get(
"body_encoding_rules", {}
)
else:
tool_config = tool_data["config"].copy() if tool_data["config"] else {}
# Add tool_id from MongoDB _id for tools that need instance isolation (like memory tool)
# Use MongoDB _id if available, otherwise fall back to enumerated tool_id
tool_config["tool_id"] = str(tool_data.get("_id", tool_id))
if hasattr(self, "conversation_id") and self.conversation_id:
tool_config["conversation_id"] = self.conversation_id
tool = tm.load_tool(
tool_data["name"],
tool_config=tool_config,
user_id=self.user,
)
resolved_arguments = (
{"query_params": query_params, "headers": headers, "body": body}
if tool_data["name"] == "api_tool"
else parameters
)
if tool_data["name"] == "api_tool":
logger.debug(
f"Executing api: {action_name} with query_params: {query_params}, headers: {headers}, body: {body}"
)
result = tool.execute_action(action_name, **body)
else:
logger.debug(f"Executing tool: {action_name} with args: {call_args}")
result = tool.execute_action(action_name, **parameters)
get_artifact_id = (
getattr(tool, "get_artifact_id", None)
if tool_data["name"] != "api_tool"
else None
)
artifact_id = None
if callable(get_artifact_id):
try:
artifact_id = get_artifact_id(action_name, **parameters)
except Exception:
logger.exception(
"Failed to extract artifact_id from tool %s for action %s",
tool_data["name"],
action_name,
)
artifact_id = str(artifact_id).strip() if artifact_id is not None else ""
if artifact_id:
tool_call_data["artifact_id"] = artifact_id
result_full = str(result)
tool_call_data["resolved_arguments"] = resolved_arguments
tool_call_data["result_full"] = result_full
tool_call_data["result"] = (
f"{result_full[:50]}..." if len(result_full) > 50 else result_full
)
stream_tool_call_data = {
key: value
for key, value in tool_call_data.items()
if key not in {"result_full", "resolved_arguments"}
}
yield {"type": "tool_call", "data": {**stream_tool_call_data, "status": "completed"}}
self.tool_calls.append(tool_call_data)
return result, call_id
def _get_truncated_tool_calls(self):
return [
{
"tool_name": tool_call.get("tool_name"),
"call_id": tool_call.get("call_id"),
"action_name": tool_call.get("action_name"),
"arguments": tool_call.get("arguments"),
"artifact_id": tool_call.get("artifact_id"),
"result": (
f"{str(tool_call['result'])[:50]}..."
if len(str(tool_call["result"])) > 50
else tool_call["result"]
),
"status": "completed",
}
for tool_call in self.tool_calls
]
def _calculate_current_context_tokens(self, messages: List[Dict]) -> int:
"""
Calculate total tokens in current context (messages).
Args:
messages: List of message dicts
Returns:
Total token count
"""
from application.api.answer.services.compression.token_counter import (
TokenCounter,
)
return TokenCounter.count_message_tokens(messages)
def _check_context_limit(self, messages: List[Dict]) -> bool:
"""
Check if we're approaching context limit (80%).
Args:
messages: Current message list
Returns:
True if at or above 80% of context limit
"""
from application.core.model_utils import get_token_limit
from application.core.settings import settings
try:
# Calculate current tokens
current_tokens = self._calculate_current_context_tokens(messages)
self.current_token_count = current_tokens
# Get context limit for model
context_limit = get_token_limit(self.model_id)
# Calculate threshold (80%)
threshold = int(context_limit * settings.COMPRESSION_THRESHOLD_PERCENTAGE)
# Check if we've reached the limit
if current_tokens >= threshold:
logger.warning(
f"Context limit approaching: {current_tokens}/{context_limit} tokens "
f"({(current_tokens/context_limit)*100:.1f}%)"
)
return True
return False
except Exception as e:
logger.error(f"Error checking context limit: {str(e)}", exc_info=True)
return False
def _validate_context_size(self, messages: List[Dict]) -> None:
"""
Pre-flight validation before calling LLM. Logs warnings but never raises errors.
Args:
messages: Messages to be sent to LLM
"""
from application.core.model_utils import get_token_limit
current_tokens = self._calculate_current_context_tokens(messages)
self.current_token_count = current_tokens
context_limit = get_token_limit(self.model_id)
percentage = (current_tokens / context_limit) * 100
# Log based on usage level
if current_tokens >= context_limit:
logger.warning(
f"Context at limit: {current_tokens:,}/{context_limit:,} tokens "
@@ -193,31 +428,43 @@ class BaseAgent(ABC):
)
def _truncate_text_middle(self, text: str, max_tokens: int) -> str:
"""
Truncate text by removing content from the middle, preserving start and end.
Args:
text: Text to truncate
max_tokens: Maximum tokens allowed
Returns:
Truncated text with middle removed if needed
"""
from application.utils import num_tokens_from_string
current_tokens = num_tokens_from_string(text)
if current_tokens <= max_tokens:
return text
# Estimate chars per token (roughly 4 chars per token for English)
chars_per_token = len(text) / current_tokens if current_tokens > 0 else 4
target_chars = int(max_tokens * chars_per_token * 0.95)
target_chars = int(max_tokens * chars_per_token * 0.95) # 5% safety margin
if target_chars <= 0:
return ""
# Split: keep 40% from start, 40% from end, remove middle
start_chars = int(target_chars * 0.4)
end_chars = int(target_chars * 0.4)
truncation_marker = "\n\n[... content truncated to fit context limit ...]\n\n"
truncated = text[:start_chars] + truncation_marker + text[-end_chars:]
logger.info(
f"Truncated text from {current_tokens:,} to ~{max_tokens:,} tokens "
f"(removed middle section)"
)
return truncated
# ---- Message building ----
return truncated
def _build_messages(
self,
@@ -228,6 +475,7 @@ class BaseAgent(ABC):
from application.core.model_utils import get_token_limit
from application.utils import num_tokens_from_string
# Append compression summary to system prompt if present
if self.compressed_summary:
compression_context = (
"\n\n---\n\n"
@@ -241,18 +489,23 @@ class BaseAgent(ABC):
context_limit = get_token_limit(self.model_id)
system_tokens = num_tokens_from_string(system_prompt)
# Reserve 10% for response/tools
safety_buffer = int(context_limit * 0.1)
available_after_system = context_limit - system_tokens - safety_buffer
# Max tokens for query: 80% of available space (leave room for history)
max_query_tokens = int(available_after_system * 0.8)
query_tokens = num_tokens_from_string(query)
# Truncate query from middle if it exceeds 80% of available context
if query_tokens > max_query_tokens:
query = self._truncate_text_middle(query, max_query_tokens)
query_tokens = num_tokens_from_string(query)
# Calculate remaining budget for chat history
available_for_history = max(available_after_system - query_tokens, 0)
# Truncate chat history to fit within available budget
working_history = self._truncate_history_to_fit(
self.chat_history,
available_for_history,
@@ -297,6 +550,16 @@ class BaseAgent(ABC):
history: List[Dict],
max_tokens: int,
) -> List[Dict]:
"""
Truncate chat history to fit within token budget, keeping most recent messages.
Args:
history: Full chat history
max_tokens: Maximum tokens allowed for history
Returns:
Truncated history (most recent messages that fit)
"""
from application.utils import num_tokens_from_string
if not history or max_tokens <= 0:
@@ -305,6 +568,7 @@ class BaseAgent(ABC):
truncated = []
current_tokens = 0
# Iterate from newest to oldest
for message in reversed(history):
message_tokens = 0
@@ -324,7 +588,7 @@ class BaseAgent(ABC):
if current_tokens + message_tokens <= max_tokens:
current_tokens += message_tokens
truncated.insert(0, message)
truncated.insert(0, message) # Maintain chronological order
else:
break
@@ -336,13 +600,13 @@ class BaseAgent(ABC):
return truncated
# ---- LLM generation ----
def _llm_gen(self, messages: List[Dict], log_context: Optional[LogContext] = None):
# Pre-flight context validation - fail fast if over limit
self._validate_context_size(messages)
gen_kwargs = {"model": self.model_id, "messages": messages}
if self.attachments:
# Usage accounting only; stripped before provider invocation.
gen_kwargs["_usage_attachments"] = self.attachments
if (

View File

@@ -15,7 +15,11 @@ class ClassicAgent(BaseAgent):
) -> Generator[Dict, None, None]:
"""Core generator function for ClassicAgent execution flow"""
tools_dict = self.tool_executor.get_tools()
tools_dict = (
self._get_user_tools(self.user)
if not self.user_api_key
else self._get_tools(self.user_api_key)
)
self._prepare_tools(tools_dict)
messages = self._build_messages(self.prompt, query)

View File

@@ -0,0 +1,238 @@
import logging
import os
from typing import Any, Dict, Generator, List
from application.agents.base import BaseAgent
from application.logging import build_stack_data, LogContext
logger = logging.getLogger(__name__)
MAX_ITERATIONS_REASONING = 10
current_dir = os.path.dirname(
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
)
with open(
os.path.join(current_dir, "application/prompts", "react_planning_prompt.txt"), "r"
) as f:
PLANNING_PROMPT_TEMPLATE = f.read()
with open(
os.path.join(current_dir, "application/prompts", "react_final_prompt.txt"), "r"
) as f:
FINAL_PROMPT_TEMPLATE = f.read()
class ReActAgent(BaseAgent):
"""
Research and Action (ReAct) Agent - Advanced reasoning agent with iterative planning.
Implements a think-act-observe loop for complex problem-solving:
1. Creates a strategic plan based on the query
2. Executes tools and gathers observations
3. Iteratively refines approach until satisfied
4. Synthesizes final answer from all observations
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.plan: str = ""
self.observations: List[str] = []
def _gen_inner(
self, query: str, log_context: LogContext
) -> Generator[Dict, None, None]:
"""Execute ReAct reasoning loop with planning, action, and observation cycles"""
self._reset_state()
tools_dict = (
self._get_tools(self.user_api_key)
if self.user_api_key
else self._get_user_tools(self.user)
)
self._prepare_tools(tools_dict)
for iteration in range(1, MAX_ITERATIONS_REASONING + 1):
yield {"thought": f"Reasoning... (iteration {iteration})\n\n"}
yield from self._planning_phase(query, log_context)
if not self.plan:
logger.warning(
f"ReActAgent: No plan generated in iteration {iteration}"
)
break
self.observations.append(f"Plan (iteration {iteration}): {self.plan}")
satisfied = yield from self._execution_phase(query, tools_dict, log_context)
if satisfied:
logger.info("ReActAgent: Goal satisfied, stopping reasoning loop")
break
yield from self._synthesis_phase(query, log_context)
def _reset_state(self):
"""Reset agent state for new query"""
self.plan = ""
self.observations = []
def _planning_phase(
self, query: str, log_context: LogContext
) -> Generator[Dict, None, None]:
"""Generate strategic plan for query"""
logger.info("ReActAgent: Creating plan...")
plan_prompt = self._build_planning_prompt(query)
messages = [{"role": "user", "content": plan_prompt}]
plan_stream = self.llm.gen_stream(
model=self.model_id,
messages=messages,
tools=self.tools if self.tools else None,
)
if log_context:
log_context.stacks.append(
{"component": "planning_llm", "data": build_stack_data(self.llm)}
)
plan_parts = []
for chunk in plan_stream:
content = self._extract_content(chunk)
if content:
plan_parts.append(content)
yield {"thought": content}
self.plan = "".join(plan_parts)
def _execution_phase(
self, query: str, tools_dict: Dict, log_context: LogContext
) -> Generator[bool, None, None]:
"""Execute plan with tool calls and observations"""
execution_prompt = self._build_execution_prompt(query)
messages = self._build_messages(execution_prompt, query)
llm_response = self._llm_gen(messages, log_context)
initial_content = self._extract_content(llm_response)
if initial_content:
self.observations.append(f"Initial response: {initial_content}")
processed_response = self._llm_handler(
llm_response, tools_dict, messages, log_context
)
for tool_call in self.tool_calls:
observation = (
f"Executed: {tool_call.get('tool_name', 'Unknown')} "
f"with args {tool_call.get('arguments', {})}. "
f"Result: {str(tool_call.get('result', ''))[:200]}"
)
self.observations.append(observation)
final_content = self._extract_content(processed_response)
if final_content:
self.observations.append(f"Response after tools: {final_content}")
if log_context:
log_context.stacks.append(
{
"component": "agent_tool_calls",
"data": {"tool_calls": self.tool_calls.copy()},
}
)
yield {"sources": self.retrieved_docs}
yield {"tool_calls": self._get_truncated_tool_calls()}
return "SATISFIED" in (final_content or "")
def _synthesis_phase(
self, query: str, log_context: LogContext
) -> Generator[Dict, None, None]:
"""Synthesize final answer from all observations"""
logger.info("ReActAgent: Generating final answer...")
final_prompt = self._build_final_answer_prompt(query)
messages = [{"role": "user", "content": final_prompt}]
final_stream = self.llm.gen_stream(
model=self.model_id, messages=messages, tools=None
)
if log_context:
log_context.stacks.append(
{"component": "final_answer_llm", "data": build_stack_data(self.llm)}
)
for chunk in final_stream:
content = self._extract_content(chunk)
if content:
yield {"answer": content}
def _build_planning_prompt(self, query: str) -> str:
"""Build planning phase prompt"""
prompt = PLANNING_PROMPT_TEMPLATE.replace("{query}", query)
prompt = prompt.replace("{prompt}", self.prompt or "")
prompt = prompt.replace("{summaries}", "")
prompt = prompt.replace("{observations}", "\n".join(self.observations))
return prompt
def _build_execution_prompt(self, query: str) -> str:
"""Build execution phase prompt with plan and observations"""
observations_str = "\n".join(self.observations)
if len(observations_str) > 20000:
observations_str = observations_str[:20000] + "\n...[truncated]"
return (
f"{self.prompt or ''}\n\n"
f"Follow this plan:\n{self.plan}\n\n"
f"Observations:\n{observations_str}\n\n"
f"If sufficient data exists to answer '{query}', respond with 'SATISFIED'. "
f"Otherwise, continue executing the plan."
)
def _build_final_answer_prompt(self, query: str) -> str:
"""Build final synthesis prompt"""
observations_str = "\n".join(self.observations)
if len(observations_str) > 10000:
observations_str = observations_str[:10000] + "\n...[truncated]"
logger.warning("ReActAgent: Observations truncated for final answer")
return FINAL_PROMPT_TEMPLATE.format(query=query, observations=observations_str)
def _extract_content(self, response: Any) -> str:
"""Extract text content from various LLM response formats"""
if not response:
return ""
collected = []
if isinstance(response, str):
return response
if hasattr(response, "message") and hasattr(response.message, "content"):
if response.message.content:
return response.message.content
if hasattr(response, "choices") and response.choices:
if hasattr(response.choices[0], "message"):
content = response.choices[0].message.content
if content:
return content
if hasattr(response, "content") and isinstance(response.content, list):
if response.content and hasattr(response.content[0], "text"):
return response.content[0].text
try:
for chunk in response:
content_piece = ""
if hasattr(chunk, "choices") and chunk.choices:
if hasattr(chunk.choices[0], "delta"):
delta_content = chunk.choices[0].delta.content
if delta_content:
content_piece = delta_content
elif hasattr(chunk, "type") and chunk.type == "content_block_delta":
if hasattr(chunk, "delta") and hasattr(chunk.delta, "text"):
content_piece = chunk.delta.text
elif isinstance(chunk, str):
content_piece = chunk
if content_piece:
collected.append(content_piece)
except (TypeError, AttributeError):
logger.debug(
f"Response not iterable or unexpected format: {type(response)}"
)
except Exception as e:
logger.error(f"Error extracting content: {e}")
return "".join(collected)

View File

@@ -1,692 +0,0 @@
import json
import logging
import os
import time
from typing import Dict, Generator, List, Optional
from application.agents.base import BaseAgent
from application.agents.tool_executor import ToolExecutor
from application.agents.tools.internal_search import (
INTERNAL_TOOL_ID,
add_internal_search_tool,
)
from application.agents.tools.think import THINK_TOOL_ENTRY, THINK_TOOL_ID
from application.logging import LogContext
logger = logging.getLogger(__name__)
# Defaults (can be overridden via constructor)
DEFAULT_MAX_STEPS = 6
DEFAULT_MAX_SUB_ITERATIONS = 5
DEFAULT_TIMEOUT_SECONDS = 300 # 5 minutes
DEFAULT_TOKEN_BUDGET = 100_000
DEFAULT_PARALLEL_WORKERS = 3
# Adaptive depth caps per complexity level
COMPLEXITY_CAPS = {
"simple": 2,
"moderate": 4,
"complex": 6,
}
_PROMPTS_DIR = os.path.join(
os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
"prompts",
"research",
)
def _load_prompt(name: str) -> str:
with open(os.path.join(_PROMPTS_DIR, name), "r") as f:
return f.read()
CLARIFICATION_PROMPT = _load_prompt("clarification.txt")
PLANNING_PROMPT = _load_prompt("planning.txt")
STEP_PROMPT = _load_prompt("step.txt")
SYNTHESIS_PROMPT = _load_prompt("synthesis.txt")
# ---------------------------------------------------------------------------
# CitationManager
# ---------------------------------------------------------------------------
class CitationManager:
"""Tracks and deduplicates citations across research steps."""
def __init__(self):
self.citations: Dict[int, Dict] = {}
self._counter = 0
def add(self, doc: Dict) -> int:
"""Register a source, return its citation number. Deduplicates by source."""
source = doc.get("source", "")
title = doc.get("title", "")
for num, existing in self.citations.items():
if existing.get("source") == source and existing.get("title") == title:
return num
self._counter += 1
self.citations[self._counter] = doc
return self._counter
def add_docs(self, docs: List[Dict]) -> str:
"""Register multiple docs, return formatted citation mapping text."""
mapping_lines = []
for doc in docs:
num = self.add(doc)
title = doc.get("title", "Untitled")
mapping_lines.append(f"[{num}] {title}")
return "\n".join(mapping_lines)
def format_references(self) -> str:
"""Generate [N] -> source mapping for report footer."""
if not self.citations:
return "No sources found."
lines = []
for num, doc in sorted(self.citations.items()):
title = doc.get("title", "Untitled")
source = doc.get("source", "Unknown")
filename = doc.get("filename", "")
display = filename or title
lines.append(f"[{num}] {display}{source}")
return "\n".join(lines)
def get_all_docs(self) -> List[Dict]:
return list(self.citations.values())
# ---------------------------------------------------------------------------
# ResearchAgent
# ---------------------------------------------------------------------------
class ResearchAgent(BaseAgent):
"""Multi-step research agent with parallel execution and budget controls.
Orchestrates: Plan -> Research (per step, optionally parallel) -> Synthesize.
"""
def __init__(
self,
retriever_config: Optional[Dict] = None,
max_steps: int = DEFAULT_MAX_STEPS,
max_sub_iterations: int = DEFAULT_MAX_SUB_ITERATIONS,
timeout_seconds: int = DEFAULT_TIMEOUT_SECONDS,
token_budget: int = DEFAULT_TOKEN_BUDGET,
parallel_workers: int = DEFAULT_PARALLEL_WORKERS,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
self.retriever_config = retriever_config or {}
self.max_steps = max_steps
self.max_sub_iterations = max_sub_iterations
self.timeout_seconds = timeout_seconds
self.token_budget = token_budget
self.parallel_workers = parallel_workers
self.citations = CitationManager()
self._start_time: float = 0
self._tokens_used: int = 0
self._last_token_snapshot: int = 0
# ------------------------------------------------------------------
# Budget & timeout helpers
# ------------------------------------------------------------------
def _is_timed_out(self) -> bool:
return (time.monotonic() - self._start_time) >= self.timeout_seconds
def _elapsed(self) -> float:
return round(time.monotonic() - self._start_time, 1)
def _track_tokens(self, count: int):
self._tokens_used += count
def _budget_remaining(self) -> int:
return max(self.token_budget - self._tokens_used, 0)
def _is_over_budget(self) -> bool:
return self._tokens_used >= self.token_budget
def _snapshot_llm_tokens(self) -> int:
"""Read current token usage from LLM and return delta since last snapshot."""
current = self.llm.token_usage.get("prompt_tokens", 0) + self.llm.token_usage.get("generated_tokens", 0)
delta = current - self._last_token_snapshot
self._last_token_snapshot = current
return delta
# ------------------------------------------------------------------
# Main orchestration
# ------------------------------------------------------------------
def _gen_inner(
self, query: str, log_context: LogContext
) -> Generator[Dict, None, None]:
self._start_time = time.monotonic()
tools_dict = self._setup_tools()
# Phase 0: Clarification (skip if user is responding to a prior clarification)
if not self._is_follow_up():
clarification = self._clarification_phase(query)
if clarification:
yield {"metadata": {"is_clarification": True}}
yield {"answer": clarification}
yield {"sources": []}
yield {"tool_calls": []}
log_context.stacks.append(
{"component": "agent", "data": {"clarification": True}}
)
return
# Phase 1: Planning (with adaptive depth)
yield {"type": "research_progress", "data": {"status": "planning"}}
plan, complexity = self._planning_phase(query)
if not plan:
logger.warning("ResearchAgent: Planning produced no steps, falling back")
plan = [{"query": query, "rationale": "Direct investigation"}]
complexity = "simple"
yield {
"type": "research_plan",
"data": {"steps": plan, "complexity": complexity},
}
# Phase 2: Research each step (yields progress events in real-time)
intermediate_reports = []
for i, step in enumerate(plan):
step_num = i + 1
step_query = step.get("query", query)
if self._is_timed_out():
logger.warning(
f"ResearchAgent: Timeout at step {step_num}/{len(plan)} "
f"({self._elapsed()}s)"
)
break
if self._is_over_budget():
logger.warning(
f"ResearchAgent: Token budget exhausted at step {step_num}/{len(plan)}"
)
break
yield {
"type": "research_progress",
"data": {
"step": step_num,
"total": len(plan),
"query": step_query,
"status": "researching",
},
}
report = self._research_step(step_query, tools_dict)
intermediate_reports.append({"step": step, "content": report})
yield {
"type": "research_progress",
"data": {
"step": step_num,
"total": len(plan),
"query": step_query,
"status": "complete",
},
}
# Phase 3: Synthesis (streaming)
if self._is_timed_out():
logger.warning(
f"ResearchAgent: Timeout ({self._elapsed()}s) before synthesis, "
f"synthesizing with {len(intermediate_reports)} reports"
)
yield {
"type": "research_progress",
"data": {
"status": "synthesizing",
"elapsed_seconds": self._elapsed(),
"tokens_used": self._tokens_used,
},
}
yield from self._synthesis_phase(
query, plan, intermediate_reports, tools_dict, log_context
)
# Sources and tool calls
self.retrieved_docs = self.citations.get_all_docs()
yield {"sources": self.retrieved_docs}
yield {"tool_calls": self._get_truncated_tool_calls()}
logger.info(
f"ResearchAgent completed: {len(intermediate_reports)}/{len(plan)} steps, "
f"{self._elapsed()}s, ~{self._tokens_used} tokens"
)
log_context.stacks.append(
{"component": "agent", "data": {"tool_calls": self.tool_calls.copy()}}
)
# ------------------------------------------------------------------
# Tool setup
# ------------------------------------------------------------------
def _setup_tools(self) -> Dict:
"""Build tools_dict with user tools + internal search + think."""
tools_dict = self.tool_executor.get_tools()
add_internal_search_tool(tools_dict, self.retriever_config)
think_entry = dict(THINK_TOOL_ENTRY)
think_entry["config"] = {}
tools_dict[THINK_TOOL_ID] = think_entry
self._prepare_tools(tools_dict)
return tools_dict
# ------------------------------------------------------------------
# Phase 0: Clarification
# ------------------------------------------------------------------
def _is_follow_up(self) -> bool:
"""Check if the user is responding to a prior clarification.
Uses the metadata flag stored in the conversation DB — no string matching.
Only skip clarification when the last query was explicitly flagged
as a clarification by this agent.
"""
if not self.chat_history:
return False
last = self.chat_history[-1]
meta = last.get("metadata", {})
return bool(meta.get("is_clarification"))
def _clarification_phase(self, question: str) -> Optional[str]:
"""Ask the LLM whether the question needs clarification.
Returns formatted clarification text if needed, or None to proceed.
Uses response_format to force valid JSON output.
"""
messages = [
{"role": "system", "content": CLARIFICATION_PROMPT},
{"role": "user", "content": question},
]
try:
response = self.llm.gen(
model=self.model_id,
messages=messages,
tools=None,
response_format={"type": "json_object"},
)
text = self._extract_text(response)
self._track_tokens(self._snapshot_llm_tokens())
logger.info(f"ResearchAgent clarification response: {text[:300]}")
data = self._parse_clarification_json(text)
if not data or not data.get("needs_clarification"):
return None
questions = data.get("questions", [])
if not questions:
return None
# Format as a friendly response
lines = [
"Before I begin researching, I'd like to clarify a few things:\n"
]
for i, q in enumerate(questions[:3], 1):
lines.append(f"{i}. {q}")
lines.append(
"\nPlease provide these details and I'll start the research."
)
return "\n".join(lines)
except Exception as e:
logger.error(f"Clarification phase failed: {e}", exc_info=True)
return None # proceed with research on failure
def _parse_clarification_json(self, text: str) -> Optional[Dict]:
"""Parse clarification JSON from LLM response."""
try:
return json.loads(text)
except json.JSONDecodeError:
pass
# Try extracting from code fences
for marker in ["```json", "```"]:
if marker in text:
start = text.index(marker) + len(marker)
end = text.index("```", start) if "```" in text[start:] else len(text)
try:
return json.loads(text[start:end].strip())
except (json.JSONDecodeError, ValueError):
pass
# Try finding JSON object
for i, ch in enumerate(text):
if ch == "{":
for j in range(len(text) - 1, i, -1):
if text[j] == "}":
try:
return json.loads(text[i : j + 1])
except json.JSONDecodeError:
continue
break
return None
# ------------------------------------------------------------------
# Phase 1: Planning (with adaptive depth)
# ------------------------------------------------------------------
def _planning_phase(self, question: str) -> tuple[List[Dict], str]:
"""Decompose the question into research steps via LLM.
Returns (steps, complexity) where complexity is simple/moderate/complex.
"""
messages = [
{"role": "system", "content": PLANNING_PROMPT},
{"role": "user", "content": question},
]
try:
response = self.llm.gen(
model=self.model_id,
messages=messages,
tools=None,
response_format={"type": "json_object"},
)
text = self._extract_text(response)
self._track_tokens(self._snapshot_llm_tokens())
logger.info(f"ResearchAgent planning LLM response: {text[:500]}")
plan_data = self._parse_plan_json(text)
if isinstance(plan_data, dict):
complexity = plan_data.get("complexity", "moderate")
steps = plan_data.get("steps", [])
else:
complexity = "moderate"
steps = plan_data
# Adaptive depth: cap steps based on assessed complexity
cap = COMPLEXITY_CAPS.get(complexity, self.max_steps)
cap = min(cap, self.max_steps)
steps = steps[:cap]
logger.info(
f"ResearchAgent plan: complexity={complexity}, "
f"steps={len(steps)} (cap={cap})"
)
return steps, complexity
except Exception as e:
logger.error(f"Planning phase failed: {e}", exc_info=True)
return (
[{"query": question, "rationale": "Direct investigation (planning failed)"}],
"simple",
)
def _parse_plan_json(self, text: str):
"""Extract JSON plan from LLM response. Returns dict or list."""
# Try direct parse
try:
data = json.loads(text)
if isinstance(data, dict) and "steps" in data:
return data
if isinstance(data, list):
return data
except json.JSONDecodeError:
pass
# Try extracting from markdown code fences
for marker in ["```json", "```"]:
if marker in text:
start = text.index(marker) + len(marker)
end = text.index("```", start) if "```" in text[start:] else len(text)
try:
data = json.loads(text[start:end].strip())
if isinstance(data, dict) and "steps" in data:
return data
if isinstance(data, list):
return data
except (json.JSONDecodeError, ValueError):
pass
# Try finding JSON object in text
for i, ch in enumerate(text):
if ch == "{":
for j in range(len(text) - 1, i, -1):
if text[j] == "}":
try:
data = json.loads(text[i : j + 1])
if isinstance(data, dict) and "steps" in data:
return data
except json.JSONDecodeError:
continue
break
logger.warning(f"Could not parse plan JSON from: {text[:200]}")
return []
# ------------------------------------------------------------------
# Phase 2: Research step (core loop)
# ------------------------------------------------------------------
def _research_step(self, step_query: str, tools_dict: Dict) -> str:
"""Run a focused research loop for one sub-question (sequential path)."""
report = self._research_step_with_executor(
step_query, tools_dict, self.tool_executor
)
self._collect_step_sources()
return report
def _research_step_with_executor(
self, step_query: str, tools_dict: Dict, executor: ToolExecutor
) -> str:
"""Core research loop. Works with any ToolExecutor instance."""
system_prompt = STEP_PROMPT.replace("{step_query}", step_query)
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": step_query},
]
last_search_empty = False
for iteration in range(self.max_sub_iterations):
# Check timeout and budget
if self._is_timed_out():
logger.info(
f"Research step '{step_query[:50]}' timed out at iteration {iteration}"
)
break
if self._is_over_budget():
logger.info(
f"Research step '{step_query[:50]}' hit token budget at iteration {iteration}"
)
break
try:
response = self.llm.gen(
model=self.model_id,
messages=messages,
tools=self.tools if self.tools else None,
)
self._track_tokens(self._snapshot_llm_tokens())
except Exception as e:
logger.error(
f"Research step LLM call failed (iteration {iteration}): {e}",
exc_info=True,
)
break
parsed = self.llm_handler.parse_response(response)
if not parsed.requires_tool_call:
return parsed.content or "No findings for this step."
# Execute tool calls
messages, last_search_empty = self._execute_step_tools_with_refinement(
parsed.tool_calls, tools_dict, messages, executor, last_search_empty
)
# Max iterations / timeout / budget — ask for summary
messages.append(
{
"role": "user",
"content": "Please summarize your findings so far based on the information gathered.",
}
)
try:
response = self.llm.gen(
model=self.model_id, messages=messages, tools=None
)
self._track_tokens(self._snapshot_llm_tokens())
text = self._extract_text(response)
return text or "Research step completed."
except Exception:
return "Research step completed."
def _execute_step_tools_with_refinement(
self,
tool_calls,
tools_dict: Dict,
messages: List[Dict],
executor: ToolExecutor,
last_search_empty: bool,
) -> tuple[List[Dict], bool]:
"""Execute tool calls with query refinement on empty results.
Returns (updated_messages, was_last_search_empty).
"""
search_returned_empty = False
for call in tool_calls:
gen = executor.execute(
tools_dict, call, self.llm.__class__.__name__
)
result = None
call_id = None
while True:
try:
event = next(gen)
# Log tool_call status events instead of discarding them
if isinstance(event, dict) and event.get("type") == "tool_call":
logger.debug(
"Tool %s status: %s",
event.get("data", {}).get("action_name", ""),
event.get("data", {}).get("status", ""),
)
except StopIteration as e:
result, call_id = e.value
break
# Detect empty search results for refinement
is_search = "search" in (call.name or "").lower()
result_str = str(result) if result else ""
if is_search and "No documents found" in result_str:
search_returned_empty = True
if last_search_empty:
# Two consecutive empty searches — inject refinement hint
result_str += (
"\n\nHint: Previous search also returned no results. "
"Try a very different query with different keywords, "
"or broaden your search terms."
)
result = result_str
function_call_content = {
"function_call": {
"name": call.name,
"args": call.arguments,
"call_id": call_id,
}
}
messages.append(
{"role": "assistant", "content": [function_call_content]}
)
tool_message = self.llm_handler.create_tool_message(call, result)
messages.append(tool_message)
return messages, search_returned_empty
def _collect_step_sources(self):
"""Collect sources from InternalSearchTool and register with CitationManager."""
cache_key = f"internal_search:{INTERNAL_TOOL_ID}:{self.user or ''}"
tool = self.tool_executor._loaded_tools.get(cache_key)
if tool and hasattr(tool, "retrieved_docs"):
for doc in tool.retrieved_docs:
self.citations.add(doc)
# ------------------------------------------------------------------
# Phase 3: Synthesis
# ------------------------------------------------------------------
def _synthesis_phase(
self,
question: str,
plan: List[Dict],
intermediate_reports: List[Dict],
tools_dict: Dict,
log_context: LogContext,
) -> Generator[Dict, None, None]:
"""Compile all findings into a final cited report (streaming)."""
plan_lines = []
for i, step in enumerate(plan, 1):
plan_lines.append(
f"{i}. {step.get('query', 'Unknown')}{step.get('rationale', '')}"
)
plan_summary = "\n".join(plan_lines)
findings_parts = []
for i, report in enumerate(intermediate_reports, 1):
step_query = report["step"].get("query", "Unknown")
content = report["content"]
findings_parts.append(
f"--- Step {i}: {step_query} ---\n{content}"
)
findings = "\n\n".join(findings_parts)
references = self.citations.format_references()
synthesis_prompt = SYNTHESIS_PROMPT.replace("{question}", question)
synthesis_prompt = synthesis_prompt.replace("{plan_summary}", plan_summary)
synthesis_prompt = synthesis_prompt.replace("{findings}", findings)
synthesis_prompt = synthesis_prompt.replace("{references}", references)
messages = [
{"role": "system", "content": synthesis_prompt},
{"role": "user", "content": f"Please write the research report for: {question}"},
]
llm_response = self.llm.gen_stream(
model=self.model_id, messages=messages, tools=None
)
if log_context:
from application.logging import build_stack_data
log_context.stacks.append(
{"component": "synthesis_llm", "data": build_stack_data(self.llm)}
)
yield from self._handle_response(
llm_response, tools_dict, messages, log_context
)
# ------------------------------------------------------------------
# Helpers
# ------------------------------------------------------------------
def _extract_text(self, response) -> str:
"""Extract text content from a non-streaming LLM response."""
if isinstance(response, str):
return response
if hasattr(response, "message") and hasattr(response.message, "content"):
return response.message.content or ""
if hasattr(response, "choices") and response.choices:
choice = response.choices[0]
if hasattr(choice, "message") and hasattr(choice.message, "content"):
return choice.message.content or ""
if hasattr(response, "content") and isinstance(response.content, list):
if response.content and hasattr(response.content[0], "text"):
return response.content[0].text or ""
return str(response) if response else ""

View File

@@ -1,313 +0,0 @@
import logging
import uuid
from typing import Dict, List, Optional
from bson.objectid import ObjectId
from application.agents.tools.tool_action_parser import ToolActionParser
from application.agents.tools.tool_manager import ToolManager
from application.core.mongo_db import MongoDB
from application.core.settings import settings
from application.security.encryption import decrypt_credentials
logger = logging.getLogger(__name__)
class ToolExecutor:
"""Handles tool discovery, preparation, and execution.
Extracted from BaseAgent to separate concerns and enable tool caching.
"""
def __init__(
self,
user_api_key: Optional[str] = None,
user: Optional[str] = None,
decoded_token: Optional[Dict] = None,
):
self.user_api_key = user_api_key
self.user = user
self.decoded_token = decoded_token
self.tool_calls: List[Dict] = []
self._loaded_tools: Dict[str, object] = {}
self.conversation_id: Optional[str] = None
def get_tools(self) -> Dict[str, Dict]:
"""Load tool configs from DB based on user context."""
if self.user_api_key:
return self._get_tools_by_api_key(self.user_api_key)
return self._get_user_tools(self.user or "local")
def _get_tools_by_api_key(self, api_key: str) -> Dict[str, Dict]:
mongo = MongoDB.get_client()
db = mongo[settings.MONGO_DB_NAME]
agents_collection = db["agents"]
tools_collection = db["user_tools"]
agent_data = agents_collection.find_one({"key": api_key})
tool_ids = agent_data.get("tools", []) if agent_data else []
tools = (
tools_collection.find(
{"_id": {"$in": [ObjectId(tool_id) for tool_id in tool_ids]}}
)
if tool_ids
else []
)
tools = list(tools)
return {str(tool["_id"]): tool for tool in tools} if tools else {}
def _get_user_tools(self, user: str = "local") -> Dict[str, Dict]:
mongo = MongoDB.get_client()
db = mongo[settings.MONGO_DB_NAME]
user_tools_collection = db["user_tools"]
user_tools = user_tools_collection.find({"user": user, "status": True})
user_tools = list(user_tools)
return {str(i): tool for i, tool in enumerate(user_tools)}
def prepare_tools_for_llm(self, tools_dict: Dict) -> List[Dict]:
"""Convert tool configs to LLM function schemas."""
return [
{
"type": "function",
"function": {
"name": f"{action['name']}_{tool_id}",
"description": action["description"],
"parameters": self._build_tool_parameters(action),
},
}
for tool_id, tool in tools_dict.items()
if (
(tool["name"] == "api_tool" and "actions" in tool.get("config", {}))
or (tool["name"] != "api_tool" and "actions" in tool)
)
for action in (
tool["config"]["actions"].values()
if tool["name"] == "api_tool"
else tool["actions"]
)
if action.get("active", True)
]
def _build_tool_parameters(self, action: Dict) -> Dict:
params = {"type": "object", "properties": {}, "required": []}
for param_type in ["query_params", "headers", "body", "parameters"]:
if param_type in action and action[param_type].get("properties"):
for k, v in action[param_type]["properties"].items():
if v.get("filled_by_llm", True):
params["properties"][k] = {
key: value
for key, value in v.items()
if key not in ("filled_by_llm", "value", "required")
}
if v.get("required", False):
params["required"].append(k)
return params
def execute(self, tools_dict: Dict, call, llm_class_name: str):
"""Execute a tool call. Yields status events, returns (result, call_id)."""
parser = ToolActionParser(llm_class_name)
tool_id, action_name, call_args = parser.parse_args(call)
call_id = getattr(call, "id", None) or str(uuid.uuid4())
if tool_id is None or action_name is None:
error_message = f"Error: Failed to parse LLM tool call. Tool name: {getattr(call, 'name', 'unknown')}"
logger.error(error_message)
tool_call_data = {
"tool_name": "unknown",
"call_id": call_id,
"action_name": getattr(call, "name", "unknown"),
"arguments": call_args or {},
"result": f"Failed to parse tool call. Invalid tool name format: {getattr(call, 'name', 'unknown')}",
}
yield {"type": "tool_call", "data": {**tool_call_data, "status": "error"}}
self.tool_calls.append(tool_call_data)
return "Failed to parse tool call.", call_id
if tool_id not in tools_dict:
error_message = f"Error: Tool ID '{tool_id}' extracted from LLM call not found in available tools_dict. Available IDs: {list(tools_dict.keys())}"
logger.error(error_message)
tool_call_data = {
"tool_name": "unknown",
"call_id": call_id,
"action_name": f"{action_name}_{tool_id}",
"arguments": call_args,
"result": f"Tool with ID {tool_id} not found. Available tools: {list(tools_dict.keys())}",
}
yield {"type": "tool_call", "data": {**tool_call_data, "status": "error"}}
self.tool_calls.append(tool_call_data)
return f"Tool with ID {tool_id} not found.", call_id
tool_call_data = {
"tool_name": tools_dict[tool_id]["name"],
"call_id": call_id,
"action_name": f"{action_name}_{tool_id}",
"arguments": call_args,
}
yield {"type": "tool_call", "data": {**tool_call_data, "status": "pending"}}
tool_data = tools_dict[tool_id]
action_data = (
tool_data["config"]["actions"][action_name]
if tool_data["name"] == "api_tool"
else next(
action
for action in tool_data["actions"]
if action["name"] == action_name
)
)
query_params, headers, body, parameters = {}, {}, {}, {}
param_types = {
"query_params": query_params,
"headers": headers,
"body": body,
"parameters": parameters,
}
for param_type, target_dict in param_types.items():
if param_type in action_data and action_data[param_type].get("properties"):
for param, details in action_data[param_type]["properties"].items():
if (
param not in call_args
and "value" in details
and details["value"]
):
target_dict[param] = details["value"]
for param, value in call_args.items():
for param_type, target_dict in param_types.items():
if param_type in action_data and param in action_data[param_type].get(
"properties", {}
):
target_dict[param] = value
# Load tool (with caching)
tool = self._get_or_load_tool(
tool_data, tool_id, action_name,
headers=headers, query_params=query_params,
)
resolved_arguments = (
{"query_params": query_params, "headers": headers, "body": body}
if tool_data["name"] == "api_tool"
else parameters
)
if tool_data["name"] == "api_tool":
logger.debug(
f"Executing api: {action_name} with query_params: {query_params}, headers: {headers}, body: {body}"
)
result = tool.execute_action(action_name, **body)
else:
logger.debug(f"Executing tool: {action_name} with args: {call_args}")
result = tool.execute_action(action_name, **parameters)
get_artifact_id = (
getattr(tool, "get_artifact_id", None)
if tool_data["name"] != "api_tool"
else None
)
artifact_id = None
if callable(get_artifact_id):
try:
artifact_id = get_artifact_id(action_name, **parameters)
except Exception:
logger.exception(
"Failed to extract artifact_id from tool %s for action %s",
tool_data["name"],
action_name,
)
artifact_id = str(artifact_id).strip() if artifact_id is not None else ""
if artifact_id:
tool_call_data["artifact_id"] = artifact_id
result_full = str(result)
tool_call_data["resolved_arguments"] = resolved_arguments
tool_call_data["result_full"] = result_full
tool_call_data["result"] = (
f"{result_full[:50]}..." if len(result_full) > 50 else result_full
)
stream_tool_call_data = {
key: value
for key, value in tool_call_data.items()
if key not in {"result_full", "resolved_arguments"}
}
yield {"type": "tool_call", "data": {**stream_tool_call_data, "status": "completed"}}
self.tool_calls.append(tool_call_data)
return result, call_id
def _get_or_load_tool(
self, tool_data: Dict, tool_id: str, action_name: str,
headers: Optional[Dict] = None, query_params: Optional[Dict] = None,
):
"""Load a tool, using cache when possible."""
cache_key = f"{tool_data['name']}:{tool_id}:{self.user or ''}"
if cache_key in self._loaded_tools:
return self._loaded_tools[cache_key]
tm = ToolManager(config={})
if tool_data["name"] == "api_tool":
action_config = tool_data["config"]["actions"][action_name]
tool_config = {
"url": action_config["url"],
"method": action_config["method"],
"headers": headers or {},
"query_params": query_params or {},
}
if "body_content_type" in action_config:
tool_config["body_content_type"] = action_config.get(
"body_content_type", "application/json"
)
tool_config["body_encoding_rules"] = action_config.get(
"body_encoding_rules", {}
)
else:
tool_config = tool_data["config"].copy() if tool_data["config"] else {}
if tool_config.get("encrypted_credentials") and self.user:
decrypted = decrypt_credentials(
tool_config["encrypted_credentials"], self.user
)
tool_config.update(decrypted)
tool_config["auth_credentials"] = decrypted
tool_config.pop("encrypted_credentials", None)
tool_config["tool_id"] = str(tool_data.get("_id", tool_id))
if self.conversation_id:
tool_config["conversation_id"] = self.conversation_id
if tool_data["name"] == "mcp_tool":
tool_config["query_mode"] = True
tool = tm.load_tool(
tool_data["name"],
tool_config=tool_config,
user_id=self.user,
)
# Don't cache api_tool since config varies by action
if tool_data["name"] != "api_tool":
self._loaded_tools[cache_key] = tool
return tool
def get_truncated_tool_calls(self) -> List[Dict]:
return [
{
"tool_name": tool_call.get("tool_name"),
"call_id": tool_call.get("call_id"),
"action_name": tool_call.get("action_name"),
"arguments": tool_call.get("arguments"),
"artifact_id": tool_call.get("artifact_id"),
"result": (
f"{str(tool_call['result'])[:50]}..."
if len(str(tool_call["result"])) > 50
else tool_call["result"]
),
"status": "completed",
}
for tool_call in self.tool_calls
]

View File

@@ -1,11 +1,6 @@
import logging
import requests
from application.agents.tools.base import Tool
logger = logging.getLogger(__name__)
class BraveSearchTool(Tool):
"""
@@ -46,7 +41,7 @@ class BraveSearchTool(Tool):
"""
Performs a web search using the Brave Search API.
"""
logger.debug("Performing Brave web search for: %s", query)
print(f"Performing Brave web search for: {query}")
url = f"{self.base_url}/web/search"
@@ -99,7 +94,7 @@ class BraveSearchTool(Tool):
"""
Performs an image search using the Brave Search API.
"""
logger.debug("Performing Brave image search for: %s", query)
print(f"Performing Brave image search for: {query}")
url = f"{self.base_url}/images/search"
@@ -182,10 +177,6 @@ class BraveSearchTool(Tool):
return {
"token": {
"type": "string",
"label": "API Key",
"description": "Brave Search API key for authentication",
"required": True,
"secret": True,
"order": 1,
},
}

View File

@@ -1,14 +1,5 @@
import logging
import time
from typing import Any, Dict, Optional
from application.agents.tools.base import Tool
logger = logging.getLogger(__name__)
MAX_RETRIES = 3
RETRY_DELAY = 2.0
DEFAULT_TIMEOUT = 15
from duckduckgo_search import DDGS
class DuckDuckGoSearchTool(Tool):
@@ -19,123 +10,71 @@ class DuckDuckGoSearchTool(Tool):
def __init__(self, config):
self.config = config
self.timeout = config.get("timeout", DEFAULT_TIMEOUT)
def _get_ddgs_client(self):
from ddgs import DDGS
return DDGS(timeout=self.timeout)
def _execute_with_retry(self, operation, operation_name: str) -> Dict[str, Any]:
last_error = None
for attempt in range(1, MAX_RETRIES + 1):
try:
results = operation()
return {
"status_code": 200,
"results": list(results) if results else [],
"message": f"{operation_name} completed successfully.",
}
except Exception as e:
last_error = e
error_str = str(e).lower()
if "ratelimit" in error_str or "429" in error_str:
if attempt < MAX_RETRIES:
delay = RETRY_DELAY * attempt
logger.warning(
f"{operation_name} rate limited, retrying in {delay}s (attempt {attempt}/{MAX_RETRIES})"
)
time.sleep(delay)
continue
logger.error(f"{operation_name} failed: {e}")
break
return {
"status_code": 500,
"results": [],
"message": f"{operation_name} failed: {str(last_error)}",
}
def execute_action(self, action_name, **kwargs):
actions = {
"ddg_web_search": self._web_search,
"ddg_image_search": self._image_search,
"ddg_news_search": self._news_search,
}
if action_name not in actions:
if action_name in actions:
return actions[action_name](**kwargs)
else:
raise ValueError(f"Unknown action: {action_name}")
return actions[action_name](**kwargs)
def _web_search(
self,
query: str,
max_results: int = 5,
region: str = "wt-wt",
safesearch: str = "moderate",
timelimit: Optional[str] = None,
) -> Dict[str, Any]:
logger.info(f"DuckDuckGo web search: {query}")
query,
max_results=5,
):
print(f"Performing DuckDuckGo web search for: {query}")
def operation():
client = self._get_ddgs_client()
return client.text(
try:
results = DDGS().text(
query,
region=region,
safesearch=safesearch,
timelimit=timelimit,
max_results=min(max_results, 20),
max_results=max_results,
)
return self._execute_with_retry(operation, "Web search")
return {
"status_code": 200,
"results": results,
"message": "Web search completed successfully.",
}
except Exception as e:
return {
"status_code": 500,
"message": f"Web search failed: {str(e)}",
}
def _image_search(
self,
query: str,
max_results: int = 5,
region: str = "wt-wt",
safesearch: str = "moderate",
timelimit: Optional[str] = None,
) -> Dict[str, Any]:
logger.info(f"DuckDuckGo image search: {query}")
query,
max_results=5,
):
print(f"Performing DuckDuckGo image search for: {query}")
def operation():
client = self._get_ddgs_client()
return client.images(
query,
region=region,
safesearch=safesearch,
timelimit=timelimit,
max_results=min(max_results, 50),
try:
results = DDGS().images(
keywords=query,
max_results=max_results,
)
return self._execute_with_retry(operation, "Image search")
def _news_search(
self,
query: str,
max_results: int = 5,
region: str = "wt-wt",
safesearch: str = "moderate",
timelimit: Optional[str] = None,
) -> Dict[str, Any]:
logger.info(f"DuckDuckGo news search: {query}")
def operation():
client = self._get_ddgs_client()
return client.news(
query,
region=region,
safesearch=safesearch,
timelimit=timelimit,
max_results=min(max_results, 20),
)
return self._execute_with_retry(operation, "News search")
return {
"status_code": 200,
"results": results,
"message": "Image search completed successfully.",
}
except Exception as e:
return {
"status_code": 500,
"message": f"Image search failed: {str(e)}",
}
def get_actions_metadata(self):
return [
{
"name": "ddg_web_search",
"description": "Search the web using DuckDuckGo. Returns titles, URLs, and snippets.",
"description": "Perform a web search using DuckDuckGo.",
"parameters": {
"type": "object",
"properties": {
@@ -145,15 +84,7 @@ class DuckDuckGoSearchTool(Tool):
},
"max_results": {
"type": "integer",
"description": "Number of results (default: 5, max: 20)",
},
"region": {
"type": "string",
"description": "Region code (default: wt-wt for worldwide, us-en for US)",
},
"timelimit": {
"type": "string",
"description": "Time filter: d (day), w (week), m (month), y (year)",
"description": "Number of results to return (default: 5)",
},
},
"required": ["query"],
@@ -161,43 +92,17 @@ class DuckDuckGoSearchTool(Tool):
},
{
"name": "ddg_image_search",
"description": "Search for images using DuckDuckGo. Returns image URLs and metadata.",
"description": "Perform an image search using DuckDuckGo.",
"parameters": {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "Image search query",
"description": "Search query",
},
"max_results": {
"type": "integer",
"description": "Number of results (default: 5, max: 50)",
},
"region": {
"type": "string",
"description": "Region code (default: wt-wt for worldwide)",
},
},
"required": ["query"],
},
},
{
"name": "ddg_news_search",
"description": "Search for news articles using DuckDuckGo. Returns recent news.",
"parameters": {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "News search query",
},
"max_results": {
"type": "integer",
"description": "Number of results (default: 5, max: 20)",
},
"timelimit": {
"type": "string",
"description": "Time filter: d (day), w (week), m (month)",
"description": "Number of results to return (default: 5, max: 50)",
},
},
"required": ["query"],

View File

@@ -1,436 +0,0 @@
import json
import logging
from typing import Dict, List, Optional
from application.agents.tools.base import Tool
from application.core.settings import settings
from application.retriever.retriever_creator import RetrieverCreator
logger = logging.getLogger(__name__)
class InternalSearchTool(Tool):
"""Wraps the ClassicRAG retriever as an LLM-callable tool.
Instead of pre-fetching docs into the prompt, the LLM decides
when and what to search. Supports multiple searches per session.
Optional capabilities (enabled when sources have directory_structure):
- path_filter on search: restrict results to a specific file/folder
- list_files action: browse the file/folder structure
"""
def __init__(self, config: Dict):
self.config = config
self.retrieved_docs: List[Dict] = []
self._retriever = None
self._directory_structure: Optional[Dict] = None
self._dir_structure_loaded = False
def _get_retriever(self):
if self._retriever is None:
self._retriever = RetrieverCreator.create_retriever(
self.config.get("retriever_name", "classic"),
source=self.config.get("source", {}),
chat_history=[],
prompt="",
chunks=int(self.config.get("chunks", 2)),
doc_token_limit=int(self.config.get("doc_token_limit", 50000)),
model_id=self.config.get("model_id", "docsgpt-local"),
user_api_key=self.config.get("user_api_key"),
agent_id=self.config.get("agent_id"),
llm_name=self.config.get("llm_name", settings.LLM_PROVIDER),
api_key=self.config.get("api_key", settings.API_KEY),
decoded_token=self.config.get("decoded_token"),
)
return self._retriever
def _get_directory_structure(self) -> Optional[Dict]:
"""Load directory structure from MongoDB for the configured sources."""
if self._dir_structure_loaded:
return self._directory_structure
self._dir_structure_loaded = True
source = self.config.get("source", {})
active_docs = source.get("active_docs", [])
if not active_docs:
return None
try:
from bson.objectid import ObjectId
from application.core.mongo_db import MongoDB
mongo = MongoDB.get_client()
db = mongo[settings.MONGO_DB_NAME]
sources_collection = db["sources"]
if isinstance(active_docs, str):
active_docs = [active_docs]
merged_structure = {}
for doc_id in active_docs:
try:
source_doc = sources_collection.find_one(
{"_id": ObjectId(doc_id)}
)
if not source_doc:
continue
dir_str = source_doc.get("directory_structure")
if dir_str:
if isinstance(dir_str, str):
dir_str = json.loads(dir_str)
source_name = source_doc.get("name", doc_id)
if len(active_docs) > 1:
merged_structure[source_name] = dir_str
else:
merged_structure = dir_str
except Exception as e:
logger.debug(f"Could not load dir structure for {doc_id}: {e}")
self._directory_structure = merged_structure if merged_structure else None
except Exception as e:
logger.debug(f"Failed to load directory structures: {e}")
return self._directory_structure
def execute_action(self, action_name: str, **kwargs):
if action_name == "search":
return self._execute_search(**kwargs)
elif action_name == "list_files":
return self._execute_list_files(**kwargs)
return f"Unknown action: {action_name}"
def _execute_search(self, **kwargs) -> str:
query = kwargs.get("query", "")
path_filter = kwargs.get("path_filter", "")
if not query:
return "Error: 'query' parameter is required."
try:
retriever = self._get_retriever()
docs = retriever.search(query)
except Exception as e:
logger.error(f"Internal search failed: {e}", exc_info=True)
return "Search failed: an internal error occurred."
if not docs:
return "No documents found matching your query."
# Apply path filter if specified
if path_filter:
path_lower = path_filter.lower()
docs = [
d
for d in docs
if path_lower in d.get("source", "").lower()
or path_lower in d.get("filename", "").lower()
or path_lower in d.get("title", "").lower()
]
if not docs:
return f"No documents found matching query '{query}' in path '{path_filter}'."
# Accumulate for source tracking
for doc in docs:
if doc not in self.retrieved_docs:
self.retrieved_docs.append(doc)
# Format results for the LLM
formatted = []
for i, doc in enumerate(docs, 1):
title = doc.get("title", "Untitled")
text = doc.get("text", "")
source = doc.get("source", "Unknown")
filename = doc.get("filename", "")
header = filename or title
formatted.append(f"[{i}] {header} (source: {source})\n{text}")
return "\n\n---\n\n".join(formatted)
def _execute_list_files(self, **kwargs) -> str:
path = kwargs.get("path", "")
dir_structure = self._get_directory_structure()
if not dir_structure:
return "No file structure available for the current sources."
# Navigate to the requested path
current = dir_structure
if path:
for part in path.strip("/").split("/"):
if not part:
continue
if isinstance(current, dict) and part in current:
current = current[part]
else:
return f"Path '{path}' not found in the file structure."
# Format the structure for the LLM
return self._format_structure(current, path or "/")
def _format_structure(self, node: Dict, current_path: str) -> str:
if not isinstance(node, dict):
return f"'{current_path}' is a file, not a directory."
lines = [f"File structure at '{current_path}':\n"]
folders = []
files = []
for name, value in sorted(node.items()):
if isinstance(value, dict):
# Check if it's a file metadata dict or a folder
if "type" in value or "size_bytes" in value or "token_count" in value:
# It's a file with metadata
size = value.get("token_count", "")
ftype = value.get("type", "")
info_parts = []
if ftype:
info_parts.append(ftype)
if size:
info_parts.append(f"{size} tokens")
info = f" ({', '.join(info_parts)})" if info_parts else ""
files.append(f" {name}{info}")
else:
# It's a folder
count = self._count_files(value)
folders.append(f" {name}/ ({count} items)")
else:
files.append(f" {name}")
if folders:
lines.append("Folders:")
lines.extend(folders)
if files:
lines.append("Files:")
lines.extend(files)
if not folders and not files:
lines.append(" (empty)")
return "\n".join(lines)
def _count_files(self, node: Dict) -> int:
count = 0
for value in node.values():
if isinstance(value, dict):
if "type" in value or "size_bytes" in value or "token_count" in value:
count += 1
else:
count += self._count_files(value)
else:
count += 1
return count
def get_actions_metadata(self):
actions = [
{
"name": "search",
"description": (
"Search the user's uploaded documents and knowledge base. "
"Use this to find relevant information before answering questions. "
"You can call this multiple times with different queries."
),
"parameters": {
"properties": {
"query": {
"type": "string",
"description": "The search query. Be specific and focused.",
"filled_by_llm": True,
"required": True,
},
}
},
}
]
# Add path_filter and list_files only if directory structure exists
has_structure = self.config.get("has_directory_structure", False)
if has_structure:
actions[0]["parameters"]["properties"]["path_filter"] = {
"type": "string",
"description": (
"Optional: filter results to a specific file or folder path. "
"Use list_files first to see available paths."
),
"filled_by_llm": True,
"required": False,
}
actions.append(
{
"name": "list_files",
"description": (
"Browse the file and folder structure of the knowledge base. "
"Use this to see what files are available before searching. "
"Optionally provide a path to browse a specific folder."
),
"parameters": {
"properties": {
"path": {
"type": "string",
"description": "Optional: folder path to browse. Leave empty for root.",
"filled_by_llm": True,
"required": False,
}
}
},
}
)
return actions
def get_config_requirements(self):
return {}
# Constants for building synthetic tools_dict entries
INTERNAL_TOOL_ID = "internal"
def build_internal_tool_entry(has_directory_structure: bool = False) -> Dict:
"""Build the tools_dict entry for InternalSearchTool.
Dynamically includes list_files and path_filter based on
whether the sources have directory structure.
"""
search_params = {
"properties": {
"query": {
"type": "string",
"description": "The search query. Be specific and focused.",
"filled_by_llm": True,
"required": True,
}
}
}
actions = [
{
"name": "search",
"description": (
"Search the user's uploaded documents and knowledge base. "
"Use this to find relevant information before answering questions. "
"You can call this multiple times with different queries."
),
"active": True,
"parameters": search_params,
}
]
if has_directory_structure:
search_params["properties"]["path_filter"] = {
"type": "string",
"description": (
"Optional: filter results to a specific file or folder path. "
"Use list_files first to see available paths."
),
"filled_by_llm": True,
"required": False,
}
actions.append(
{
"name": "list_files",
"description": (
"Browse the file and folder structure of the knowledge base. "
"Use this to see what files are available before searching. "
"Optionally provide a path to browse a specific folder."
),
"active": True,
"parameters": {
"properties": {
"path": {
"type": "string",
"description": "Optional: folder path to browse. Leave empty for root.",
"filled_by_llm": True,
"required": False,
}
}
},
}
)
return {"name": "internal_search", "actions": actions}
# Keep backward compat
INTERNAL_TOOL_ENTRY = build_internal_tool_entry(has_directory_structure=False)
def sources_have_directory_structure(source: Dict) -> bool:
"""Check if any of the active sources have directory_structure in MongoDB."""
active_docs = source.get("active_docs", [])
if not active_docs:
return False
try:
from bson.objectid import ObjectId
from application.core.mongo_db import MongoDB
mongo = MongoDB.get_client()
db = mongo[settings.MONGO_DB_NAME]
sources_collection = db["sources"]
if isinstance(active_docs, str):
active_docs = [active_docs]
for doc_id in active_docs:
try:
source_doc = sources_collection.find_one(
{"_id": ObjectId(doc_id)},
{"directory_structure": 1},
)
if source_doc and source_doc.get("directory_structure"):
return True
except Exception:
continue
except Exception as e:
logger.debug(f"Could not check directory structure: {e}")
return False
def add_internal_search_tool(tools_dict: Dict, retriever_config: Dict) -> None:
"""Add the internal search tool to tools_dict if sources are configured.
Shared by AgenticAgent and ResearchAgent to avoid duplicate setup logic.
Mutates tools_dict in place.
"""
source = retriever_config.get("source", {})
has_sources = bool(source.get("active_docs"))
if not retriever_config or not has_sources:
return
has_dir = sources_have_directory_structure(source)
internal_entry = build_internal_tool_entry(has_directory_structure=has_dir)
internal_entry["config"] = build_internal_tool_config(
**retriever_config,
has_directory_structure=has_dir,
)
tools_dict[INTERNAL_TOOL_ID] = internal_entry
def build_internal_tool_config(
source: Dict,
retriever_name: str = "classic",
chunks: int = 2,
doc_token_limit: int = 50000,
model_id: str = "docsgpt-local",
user_api_key: Optional[str] = None,
agent_id: Optional[str] = None,
llm_name: str = None,
api_key: str = None,
decoded_token: Optional[Dict] = None,
has_directory_structure: bool = False,
) -> Dict:
"""Build the config dict for InternalSearchTool."""
return {
"source": source,
"retriever_name": retriever_name,
"chunks": chunks,
"doc_token_limit": doc_token_limit,
"model_id": model_id,
"user_api_key": user_api_key,
"agent_id": agent_id,
"llm_name": llm_name or settings.LLM_PROVIDER,
"api_key": api_key or settings.API_KEY,
"decoded_token": decoded_token,
"has_directory_structure": has_directory_structure,
}

View File

@@ -1,12 +1,20 @@
import asyncio
import base64
import concurrent.futures
import json
import logging
import time
from typing import Any, Dict, List, Optional
from urllib.parse import parse_qs, urlparse
from application.agents.tools.base import Tool
from application.api.user.tasks import mcp_oauth_status_task, mcp_oauth_task
from application.cache import get_redis_instance
from application.core.mongo_db import MongoDB
from application.core.settings import settings
from application.security.encryption import decrypt_credentials
from fastmcp import Client
from fastmcp.client.auth import BearerAuth
from fastmcp.client.transports import (
@@ -16,18 +24,10 @@ from fastmcp.client.transports import (
)
from mcp.client.auth import OAuthClientProvider, TokenStorage
from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken
from pydantic import AnyHttpUrl, ValidationError
from redis import Redis
from application.agents.tools.base import Tool
from application.api.user.tasks import mcp_oauth_status_task, mcp_oauth_task
from application.cache import get_redis_instance
from application.core.mongo_db import MongoDB
from application.core.settings import settings
from application.security.encryption import decrypt_credentials
logger = logging.getLogger(__name__)
mongo = MongoDB.get_client()
db = mongo[settings.MONGO_DB_NAME]
@@ -56,7 +56,6 @@ class MCPTool(Tool):
- args: Arguments for STDIO transport
- oauth_scopes: OAuth scopes for oauth auth type
- oauth_client_name: OAuth client name for oauth auth type
- query_mode: If True, use non-interactive OAuth (fail-fast on 401)
user_id: User ID for decrypting credentials (required if encrypted_credentials exist)
"""
self.config = config
@@ -77,40 +76,23 @@ class MCPTool(Tool):
self.oauth_scopes = config.get("oauth_scopes", [])
self.oauth_task_id = config.get("oauth_task_id", None)
self.oauth_client_name = config.get("oauth_client_name", "DocsGPT-MCP")
self.redirect_uri = self._resolve_redirect_uri(config.get("redirect_uri"))
self.redirect_uri = f"{settings.API_URL}/api/mcp_server/callback"
self.available_tools = []
self._cache_key = self._generate_cache_key()
self._client = None
self.query_mode = config.get("query_mode", False)
# Only validate and setup if server_url is provided and not OAuth
if self.server_url and self.auth_type != "oauth":
self._setup_client()
def _resolve_redirect_uri(self, configured_redirect_uri: Optional[str]) -> str:
if configured_redirect_uri:
return configured_redirect_uri.rstrip("/")
explicit = getattr(settings, "MCP_OAUTH_REDIRECT_URI", None)
if explicit:
return explicit.rstrip("/")
connector_base = getattr(settings, "CONNECTOR_REDIRECT_BASE_URI", None)
if connector_base:
parsed = urlparse(connector_base)
if parsed.scheme and parsed.netloc:
return f"{parsed.scheme}://{parsed.netloc}/api/mcp_server/callback"
return f"{settings.API_URL.rstrip('/')}/api/mcp_server/callback"
def _generate_cache_key(self) -> str:
"""Generate a unique cache key for this MCP server configuration."""
auth_key = ""
if self.auth_type == "oauth":
scopes_str = ",".join(self.oauth_scopes) if self.oauth_scopes else "none"
auth_key = (
f"oauth:{self.oauth_client_name}:{scopes_str}:{self.redirect_uri}"
)
auth_key = f"oauth:{self.oauth_client_name}:{scopes_str}"
elif self.auth_type in ["bearer"]:
token = self.auth_credentials.get(
"bearer_token", ""
@@ -127,10 +109,11 @@ class MCPTool(Tool):
return f"{self.server_url}#{self.transport_type}#{auth_key}"
def _setup_client(self):
"""Setup FastMCP client with proper transport and authentication."""
global _mcp_clients_cache
if self._cache_key in _mcp_clients_cache:
cached_data = _mcp_clients_cache[self._cache_key]
if time.time() - cached_data["created_at"] < 300:
if time.time() - cached_data["created_at"] < 1800:
self._client = cached_data["client"]
return
else:
@@ -140,25 +123,15 @@ class MCPTool(Tool):
if self.auth_type == "oauth":
redis_client = get_redis_instance()
if self.query_mode:
auth = NonInteractiveOAuth(
mcp_url=self.server_url,
scopes=self.oauth_scopes,
redis_client=redis_client,
redirect_uri=self.redirect_uri,
db=db,
user_id=self.user_id,
)
else:
auth = DocsGPTOAuth(
mcp_url=self.server_url,
scopes=self.oauth_scopes,
redis_client=redis_client,
redirect_uri=self.redirect_uri,
task_id=self.oauth_task_id,
db=db,
user_id=self.user_id,
)
auth = DocsGPTOAuth(
mcp_url=self.server_url,
scopes=self.oauth_scopes,
redis_client=redis_client,
redirect_uri=self.redirect_uri,
task_id=self.oauth_task_id,
db=db,
user_id=self.user_id,
)
elif self.auth_type == "bearer":
token = self.auth_credentials.get(
"bearer_token", ""
@@ -260,53 +233,38 @@ class MCPTool(Tool):
else:
raise Exception(f"Unknown operation: {operation}")
_ERROR_MAP = [
(concurrent.futures.TimeoutError, lambda op, t, _: f"Timed out after {t}s"),
(ConnectionRefusedError, lambda *_: "Connection refused"),
]
_ERROR_PATTERNS = {
("403", "Forbidden"): "Access denied (403 Forbidden)",
("401", "Unauthorized"): "Authentication failed (401 Unauthorized)",
("ECONNREFUSED",): "Connection refused",
("SSL", "certificate"): "SSL/TLS error",
}
def _run_async_operation(self, operation: str, *args, **kwargs):
"""Run async operation in sync context."""
try:
try:
asyncio.get_running_loop()
loop = asyncio.get_running_loop()
import concurrent.futures
def run_in_thread():
new_loop = asyncio.new_event_loop()
asyncio.set_event_loop(new_loop)
try:
return new_loop.run_until_complete(
self._execute_with_client(operation, *args, **kwargs)
)
finally:
new_loop.close()
with concurrent.futures.ThreadPoolExecutor() as executor:
future = executor.submit(
self._run_in_new_loop, operation, *args, **kwargs
)
future = executor.submit(run_in_thread)
return future.result(timeout=self.timeout)
except RuntimeError:
return self._run_in_new_loop(operation, *args, **kwargs)
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
return loop.run_until_complete(
self._execute_with_client(operation, *args, **kwargs)
)
finally:
loop.close()
except Exception as e:
raise self._map_error(operation, e) from e
raise self._map_error(operation, e) from e
def _run_in_new_loop(self, operation, *args, **kwargs):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
return loop.run_until_complete(
self._execute_with_client(operation, *args, **kwargs)
)
finally:
loop.close()
def _map_error(self, operation: str, exc: Exception) -> Exception:
for exc_type, msg_fn in self._ERROR_MAP:
if isinstance(exc, exc_type):
return Exception(msg_fn(operation, self.timeout, exc))
error_msg = str(exc)
for patterns, friendly in self._ERROR_PATTERNS.items():
if any(p.lower() in error_msg.lower() for p in patterns):
return Exception(friendly)
logger.error("MCP %s failed: %s", operation, exc)
return exc
print(f"Error occurred while running async operation: {e}")
raise
def discover_tools(self) -> List[Dict]:
"""
@@ -327,6 +285,16 @@ class MCPTool(Tool):
raise Exception(f"Failed to discover tools from MCP server: {str(e)}")
def execute_action(self, action_name: str, **kwargs) -> Any:
"""
Execute an action on the remote MCP server using FastMCP.
Args:
action_name: Name of the action to execute
**kwargs: Parameters for the action
Returns:
Result from the MCP server
"""
if not self.server_url:
raise Exception("No MCP server configured")
if not self._client:
@@ -342,37 +310,7 @@ class MCPTool(Tool):
)
return self._format_result(result)
except Exception as e:
error_msg = str(e)
lower_msg = error_msg.lower()
is_auth_error = (
"401" in error_msg
or "unauthorized" in lower_msg
or "session expired" in lower_msg
or "re-authorize" in lower_msg
)
if is_auth_error:
if self.auth_type == "oauth":
raise Exception(
f"Action '{action_name}' failed: OAuth session expired. "
"Please re-authorize this MCP server in tool settings."
) from e
global _mcp_clients_cache
_mcp_clients_cache.pop(self._cache_key, None)
self._client = None
self._setup_client()
try:
result = self._run_async_operation(
"call_tool", action_name, **cleaned_kwargs
)
return self._format_result(result)
except Exception as retry_e:
raise Exception(
f"Action '{action_name}' failed after re-auth attempt: {retry_e}. "
"Your credentials may have expired — please re-authorize in tool settings."
) from retry_e
raise Exception(
f"Failed to execute action '{action_name}': {error_msg}"
) from e
raise Exception(f"Failed to execute action '{action_name}': {str(e)}")
def _format_result(self, result) -> Dict:
"""Format FastMCP result to match expected format."""
@@ -395,35 +333,23 @@ class MCPTool(Tool):
return result
def test_connection(self) -> Dict:
"""
Test the connection to the MCP server and validate functionality.
Returns:
Dictionary with connection test results including tool count
"""
if not self.server_url:
return {
"success": False,
"message": "No server URL configured",
"tools_count": 0,
}
try:
parsed = urlparse(self.server_url)
if parsed.scheme not in ("http", "https"):
return {
"success": False,
"message": f"Invalid URL scheme '{parsed.scheme}' — use http:// or https://",
"tools_count": 0,
}
except Exception:
return {
"success": False,
"message": "Invalid URL format",
"message": "No MCP server URL configured",
"tools_count": 0,
"transport_type": self.transport_type,
"auth_type": self.auth_type,
"error_type": "ConfigurationError",
}
if not self._client:
try:
self._setup_client()
except Exception as e:
return {
"success": False,
"message": f"Client init failed: {str(e)}",
"tools_count": 0,
}
self._setup_client()
try:
if self.auth_type == "oauth":
return self._test_oauth_connection()
@@ -434,93 +360,55 @@ class MCPTool(Tool):
"success": False,
"message": f"Connection failed: {str(e)}",
"tools_count": 0,
"transport_type": self.transport_type,
"auth_type": self.auth_type,
"error_type": type(e).__name__,
}
def _test_regular_connection(self) -> Dict:
ping_ok = False
ping_error = None
"""Test connection for non-OAuth auth types."""
try:
self._run_async_operation("ping")
ping_ok = True
except Exception as e:
ping_error = str(e)
try:
tools = self.discover_tools()
except Exception as e:
return {
"success": False,
"message": f"Connection failed: {ping_error or str(e)}",
"tools_count": 0,
}
if not tools and not ping_ok:
return {
"success": False,
"message": f"Connection failed: {ping_error or 'No tools found'}",
"tools_count": 0,
}
ping_success = True
except Exception:
ping_success = False
tools = self.discover_tools()
message = f"Successfully connected to MCP server. Found {len(tools)} tools."
if not ping_success:
message += " (Ping not supported, but tool discovery worked)"
return {
"success": True,
"message": f"Connected — found {len(tools)} tool{'s' if len(tools) != 1 else ''}.",
"message": message,
"tools_count": len(tools),
"tools": [
{
"name": tool.get("name", "unknown"),
"description": tool.get("description", ""),
}
for tool in tools
],
"transport_type": self.transport_type,
"auth_type": self.auth_type,
"ping_supported": ping_success,
"tools": [tool.get("name", "unknown") for tool in tools],
}
def _test_oauth_connection(self) -> Dict:
storage = DBTokenStorage(
server_url=self.server_url, user_id=self.user_id, db_client=db
)
loop = asyncio.new_event_loop()
"""Test connection for OAuth auth type with proper async handling."""
try:
tokens = loop.run_until_complete(storage.get_tokens())
finally:
loop.close()
if tokens and tokens.access_token:
self.query_mode = True
_mcp_clients_cache.pop(self._cache_key, None)
self._client = None
self._setup_client()
try:
tools = self.discover_tools()
return {
"success": True,
"message": f"Connected — found {len(tools)} tool{'s' if len(tools) != 1 else ''}.",
"tools_count": len(tools),
"tools": [
{
"name": t.get("name", "unknown"),
"description": t.get("description", ""),
}
for t in tools
],
}
except Exception as e:
logger.warning("OAuth token validation failed: %s", e)
_mcp_clients_cache.pop(self._cache_key, None)
self._client = None
return self._start_oauth_task()
def _start_oauth_task(self) -> Dict:
task_config = self.config.copy()
task_config.pop("query_mode", None)
result = mcp_oauth_task.delay(task_config, self.user_id)
return {
"success": False,
"requires_oauth": True,
"task_id": result.id,
"message": "OAuth authorization required.",
"tools_count": 0,
}
task = mcp_oauth_task.delay(config=self.config, user=self.user_id)
if not task:
raise Exception("Failed to start OAuth authentication")
return {
"success": True,
"requires_oauth": True,
"task_id": task.id,
"status": "pending",
"message": "OAuth flow started",
}
except Exception as e:
return {
"success": False,
"message": f"OAuth connection failed: {str(e)}",
"tools_count": 0,
"transport_type": self.transport_type,
"auth_type": self.auth_type,
"error_type": type(e).__name__,
}
def get_actions_metadata(self) -> List[Dict]:
"""
@@ -567,88 +455,110 @@ class MCPTool(Tool):
return actions
def get_config_requirements(self) -> Dict:
"""Get configuration requirements for the MCP tool."""
transport_enum = ["auto", "sse", "http"]
transport_help = {
"auto": "Automatically detect best transport",
"sse": "Server-Sent Events (for real-time streaming)",
"http": "HTTP streaming (recommended for production)",
}
return {
"server_url": {
"type": "string",
"label": "Server URL",
"description": "URL of the remote MCP server",
"description": "URL of the remote MCP server (e.g., https://api.example.com/mcp or https://docs.mcp.cloudflare.com/sse)",
"required": True,
"secret": False,
"order": 1,
},
"transport_type": {
"type": "string",
"description": "Transport type for connection",
"enum": transport_enum,
"default": "auto",
"required": False,
"help": {
**transport_help,
},
},
"auth_type": {
"type": "string",
"label": "Authentication Type",
"description": "Authentication method for the MCP server",
"description": "Authentication type",
"enum": ["none", "bearer", "oauth", "api_key", "basic"],
"default": "none",
"required": True,
"secret": False,
"order": 2,
"help": {
"none": "No authentication",
"bearer": "Bearer token authentication",
"oauth": "OAuth 2.1 authentication (with frontend integration)",
"api_key": "API key authentication",
"basic": "Basic authentication",
},
},
"api_key": {
"type": "string",
"label": "API Key",
"description": "API key for authentication",
"auth_credentials": {
"type": "object",
"description": "Authentication credentials (varies by auth_type)",
"required": False,
"secret": True,
"order": 3,
"depends_on": {"auth_type": "api_key"},
},
"api_key_header": {
"type": "string",
"label": "API Key Header",
"description": "Header name for API key (default: X-API-Key)",
"default": "X-API-Key",
"required": False,
"secret": False,
"order": 4,
"depends_on": {"auth_type": "api_key"},
},
"bearer_token": {
"type": "string",
"label": "Bearer Token",
"description": "Bearer token for authentication",
"required": False,
"secret": True,
"order": 3,
"depends_on": {"auth_type": "bearer"},
},
"username": {
"type": "string",
"label": "Username",
"description": "Username for basic authentication",
"required": False,
"secret": False,
"order": 3,
"depends_on": {"auth_type": "basic"},
},
"password": {
"type": "string",
"label": "Password",
"description": "Password for basic authentication",
"required": False,
"secret": True,
"order": 4,
"depends_on": {"auth_type": "basic"},
"properties": {
"bearer_token": {
"type": "string",
"description": "Bearer token for bearer auth",
},
"access_token": {
"type": "string",
"description": "Access token for OAuth (if pre-obtained)",
},
"api_key": {
"type": "string",
"description": "API key for api_key auth",
},
"api_key_header": {
"type": "string",
"description": "Header name for API key (default: X-API-Key)",
},
"username": {
"type": "string",
"description": "Username for basic auth",
},
"password": {
"type": "string",
"description": "Password for basic auth",
},
},
},
"oauth_scopes": {
"type": "string",
"label": "OAuth Scopes",
"description": "Comma-separated OAuth scopes to request",
"type": "array",
"description": "OAuth scopes to request (for oauth auth_type)",
"items": {"type": "string"},
"required": False,
"default": [],
},
"oauth_client_name": {
"type": "string",
"description": "Client name for OAuth registration (for oauth auth_type)",
"default": "DocsGPT-MCP",
"required": False,
},
"headers": {
"type": "object",
"description": "Custom headers to send with requests",
"required": False,
"secret": False,
"order": 3,
"depends_on": {"auth_type": "oauth"},
},
"timeout": {
"type": "number",
"label": "Timeout (seconds)",
"description": "Request timeout in seconds (1-300)",
"type": "integer",
"description": "Request timeout in seconds",
"default": 30,
"minimum": 1,
"maximum": 300,
"required": False,
},
"command": {
"type": "string",
"description": "Command to run for STDIO transport (e.g., 'python')",
"required": False,
},
"args": {
"type": "array",
"description": "Arguments for STDIO command",
"items": {"type": "string"},
"required": False,
"secret": False,
"order": 10,
},
}
@@ -670,8 +580,23 @@ class DocsGPTOAuth(OAuthClientProvider):
user_id=None,
db=None,
additional_client_metadata: dict[str, Any] | None = None,
skip_redirect_validation: bool = False,
):
"""
Initialize custom OAuth client provider for DocsGPT.
Args:
mcp_url: Full URL to the MCP endpoint
redirect_uri: Custom redirect URI for DocsGPT frontend
redis_client: Redis client for storing auth state
redis_prefix: Prefix for Redis keys
task_id: Task ID for tracking auth status
scopes: OAuth scopes to request
client_name: Name for this client during registration
user_id: User ID for token storage
db: Database instance for token storage
additional_client_metadata: Extra fields for OAuthClientMetadata
"""
self.redirect_uri = redirect_uri
self.redis_client = redis_client
self.redis_prefix = redis_prefix
@@ -694,10 +619,7 @@ class DocsGPTOAuth(OAuthClientProvider):
)
storage = DBTokenStorage(
server_url=self.server_base_url,
user_id=self.user_id,
db_client=self.db,
expected_redirect_uri=None if skip_redirect_validation else redirect_uri,
server_url=self.server_base_url, user_id=self.user_id, db_client=self.db
)
super().__init__(
@@ -729,20 +651,22 @@ class DocsGPTOAuth(OAuthClientProvider):
async def redirect_handler(self, authorization_url: str) -> None:
"""Store auth URL and state in Redis for frontend to use."""
auth_url, state = self._process_auth_url(authorization_url)
logger.info("Processed auth_url: %s, state: %s", auth_url, state)
logging.info(
"[DocsGPTOAuth] Processed auth_url: %s, state: %s", auth_url, state
)
self.auth_url = auth_url
self.extracted_state = state
if self.redis_client and self.extracted_state:
key = f"{self.redis_prefix}auth_url:{self.extracted_state}"
self.redis_client.setex(key, 600, auth_url)
logger.info("Stored auth_url in Redis: %s", key)
logging.info("[DocsGPTOAuth] Stored auth_url in Redis: %s", key)
if self.task_id:
status_key = f"mcp_oauth_status:{self.task_id}"
status_data = {
"status": "requires_redirect",
"message": "Authorization required",
"message": "OAuth authorization required",
"authorization_url": self.auth_url,
"state": self.extracted_state,
"requires_oauth": True,
@@ -762,7 +686,7 @@ class DocsGPTOAuth(OAuthClientProvider):
status_key = f"mcp_oauth_status:{self.task_id}"
status_data = {
"status": "awaiting_callback",
"message": "Waiting for authorization...",
"message": "Waiting for OAuth callback...",
"authorization_url": self.auth_url,
"state": self.extracted_state,
"requires_oauth": True,
@@ -787,7 +711,7 @@ class DocsGPTOAuth(OAuthClientProvider):
if self.task_id:
status_data = {
"status": "callback_received",
"message": "Completing authentication...",
"message": "OAuth callback received, completing authentication...",
"task_id": self.task_id,
}
self.redis_client.setex(status_key, 600, json.dumps(status_data))
@@ -807,44 +731,14 @@ class DocsGPTOAuth(OAuthClientProvider):
await asyncio.sleep(poll_interval)
self.redis_client.delete(f"{self.redis_prefix}auth_url:{self.extracted_state}")
self.redis_client.delete(f"{self.redis_prefix}state:{self.extracted_state}")
raise Exception("OAuth timeout: no code received within 5 minutes")
class NonInteractiveOAuth(DocsGPTOAuth):
"""OAuth provider that fails fast on 401 instead of starting interactive auth.
Used during query execution to prevent the streaming response from blocking
while waiting for user authorization that will never come.
"""
def __init__(self, **kwargs):
kwargs.setdefault("task_id", None)
kwargs["skip_redirect_validation"] = True
super().__init__(**kwargs)
async def redirect_handler(self, authorization_url: str) -> None:
raise Exception(
"OAuth session expired — please re-authorize this MCP server in tool settings."
)
async def callback_handler(self) -> tuple[str, str | None]:
raise Exception(
"OAuth session expired — please re-authorize this MCP server in tool settings."
)
raise Exception("OAuth callback timeout: no code received within 5 minutes")
class DBTokenStorage(TokenStorage):
def __init__(
self,
server_url: str,
user_id: str,
db_client,
expected_redirect_uri: Optional[str] = None,
):
def __init__(self, server_url: str, user_id: str, db_client):
self.server_url = server_url
self.user_id = user_id
self.db_client = db_client
self.expected_redirect_uri = expected_redirect_uri
self.collection = db_client["connector_sessions"]
@staticmethod
@@ -863,9 +757,10 @@ class DBTokenStorage(TokenStorage):
if not doc or "tokens" not in doc:
return None
try:
return OAuthToken.model_validate(doc["tokens"])
tokens = OAuthToken.model_validate(doc["tokens"])
return tokens
except ValidationError as e:
logger.error("Could not load tokens: %s", e)
logging.error(f"Could not load tokens: {e}")
return None
async def set_tokens(self, tokens: OAuthToken) -> None:
@@ -875,38 +770,28 @@ class DBTokenStorage(TokenStorage):
{"$set": {"tokens": tokens.model_dump()}},
True,
)
logger.info("Saved tokens for %s", self.get_base_url(self.server_url))
logging.info(f"Saved tokens for {self.get_base_url(self.server_url)}")
async def get_client_info(self) -> OAuthClientInformationFull | None:
doc = await asyncio.to_thread(self.collection.find_one, self.get_db_key())
if not doc or "client_info" not in doc:
logger.debug(
"No client_info in DB for %s", self.get_base_url(self.server_url)
)
return None
try:
client_info = OAuthClientInformationFull.model_validate(doc["client_info"])
if self.expected_redirect_uri:
stored_uris = [
str(uri).rstrip("/") for uri in client_info.redirect_uris
]
expected_uri = self.expected_redirect_uri.rstrip("/")
if expected_uri not in stored_uris:
logger.warning(
"Redirect URI mismatch for %s: expected=%s stored=%s — clearing.",
self.get_base_url(self.server_url),
expected_uri,
stored_uris,
)
await asyncio.to_thread(
self.collection.update_one,
self.get_db_key(),
{"$unset": {"client_info": "", "tokens": ""}},
)
return None
tokens = await self.get_tokens()
if tokens is None:
logging.debug(
"No tokens found, clearing client info to force fresh registration."
)
await asyncio.to_thread(
self.collection.update_one,
self.get_db_key(),
{"$unset": {"client_info": ""}},
)
return None
return client_info
except ValidationError as e:
logger.error("Could not load client info: %s", e)
logging.error(f"Could not load client info: {e}")
return None
def _serialize_client_info(self, info: dict) -> dict:
@@ -922,17 +807,17 @@ class DBTokenStorage(TokenStorage):
{"$set": {"client_info": serialized_info}},
True,
)
logger.info("Saved client info for %s", self.get_base_url(self.server_url))
logging.info(f"Saved client info for {self.get_base_url(self.server_url)}")
async def clear(self) -> None:
await asyncio.to_thread(self.collection.delete_one, self.get_db_key())
logger.info("Cleared OAuth cache for %s", self.get_base_url(self.server_url))
logging.info(f"Cleared OAuth cache for {self.get_base_url(self.server_url)}")
@classmethod
async def clear_all(cls, db_client) -> None:
collection = db_client["connector_sessions"]
await asyncio.to_thread(collection.delete_many, {})
logger.info("Cleared all OAuth client cache data.")
logging.info("Cleared all OAuth client cache data.")
class MCPOAuthManager:
@@ -971,7 +856,7 @@ class MCPOAuthManager:
return True
except Exception as e:
logger.error("Error handling OAuth callback: %s", e)
logging.error(f"Error handling OAuth callback: {e}")
return False
def get_oauth_status(self, task_id: str) -> Dict[str, Any]:

View File

@@ -116,13 +116,12 @@ class NtfyTool(Tool):
]
def get_config_requirements(self):
"""
Specify the configuration requirements.
Returns:
dict: Dictionary describing required config parameters.
"""
return {
"token": {
"type": "string",
"label": "Access Token",
"description": "Ntfy access token for authentication",
"required": True,
"secret": True,
"order": 1,
},
"token": {"type": "string", "description": "Access token for authentication"},
}

View File

@@ -1,12 +1,6 @@
import logging
import psycopg2
from application.agents.tools.base import Tool
logger = logging.getLogger(__name__)
class PostgresTool(Tool):
"""
PostgreSQL Database Tool
@@ -23,15 +17,17 @@ class PostgresTool(Tool):
"postgres_execute_sql": self._execute_sql,
"postgres_get_schema": self._get_schema,
}
if action_name not in actions:
if action_name in actions:
return actions[action_name](**kwargs)
else:
raise ValueError(f"Unknown action: {action_name}")
return actions[action_name](**kwargs)
def _execute_sql(self, sql_query):
"""
Executes an SQL query against the PostgreSQL database using a connection string.
"""
conn = None
conn = None # Initialize conn to None for error handling
try:
conn = psycopg2.connect(self.connection_string)
cur = conn.cursor()
@@ -39,9 +35,7 @@ class PostgresTool(Tool):
conn.commit()
if sql_query.strip().lower().startswith("select"):
column_names = (
[desc[0] for desc in cur.description] if cur.description else []
)
column_names = [desc[0] for desc in cur.description] if cur.description else []
results = []
rows = cur.fetchall()
for row in rows:
@@ -49,9 +43,7 @@ class PostgresTool(Tool):
response_data = {"data": results, "column_names": column_names}
else:
row_count = cur.rowcount
response_data = {
"message": f"Query executed successfully, {row_count} rows affected."
}
response_data = {"message": f"Query executed successfully, {row_count} rows affected."}
cur.close()
return {
@@ -62,27 +54,26 @@ class PostgresTool(Tool):
except psycopg2.Error as e:
error_message = f"Database error: {e}"
logger.error("PostgreSQL execute_sql error: %s", e)
print(f"Database error: {e}")
return {
"status_code": 500,
"message": "Failed to execute SQL query.",
"error": error_message,
}
finally:
if conn:
if conn: # Ensure connection is closed even if errors occur
conn.close()
def _get_schema(self, db_name):
"""
Retrieves the schema of the PostgreSQL database using a connection string.
"""
conn = None
conn = None # Initialize conn to None for error handling
try:
conn = psycopg2.connect(self.connection_string)
cur = conn.cursor()
cur.execute(
"""
cur.execute("""
SELECT
table_name,
column_name,
@@ -96,22 +87,19 @@ class PostgresTool(Tool):
ORDER BY
table_name,
ordinal_position;
"""
)
""")
schema_data = {}
for row in cur.fetchall():
table_name, column_name, data_type, column_default, is_nullable = row
if table_name not in schema_data:
schema_data[table_name] = []
schema_data[table_name].append(
{
"column_name": column_name,
"data_type": data_type,
"column_default": column_default,
"is_nullable": is_nullable,
}
)
schema_data[table_name].append({
"column_name": column_name,
"data_type": data_type,
"column_default": column_default,
"is_nullable": is_nullable
})
cur.close()
return {
@@ -122,14 +110,14 @@ class PostgresTool(Tool):
except psycopg2.Error as e:
error_message = f"Database error: {e}"
logger.error("PostgreSQL get_schema error: %s", e)
print(f"Database error: {e}")
return {
"status_code": 500,
"message": "Failed to retrieve database schema.",
"error": error_message,
}
finally:
if conn:
if conn: # Ensure connection is closed even if errors occur
conn.close()
def get_actions_metadata(self):
@@ -170,10 +158,6 @@ class PostgresTool(Tool):
return {
"token": {
"type": "string",
"label": "Connection String",
"description": "PostgreSQL database connection string",
"required": True,
"secret": True,
"order": 1,
"description": "PostgreSQL database connection string (e.g., 'postgresql://user:password@host:port/dbname')",
},
}
}

View File

@@ -1,11 +1,6 @@
import logging
import requests
from application.agents.tools.base import Tool
logger = logging.getLogger(__name__)
class TelegramTool(Tool):
"""
@@ -23,19 +18,21 @@ class TelegramTool(Tool):
"telegram_send_message": self._send_message,
"telegram_send_image": self._send_image,
}
if action_name not in actions:
if action_name in actions:
return actions[action_name](**kwargs)
else:
raise ValueError(f"Unknown action: {action_name}")
return actions[action_name](**kwargs)
def _send_message(self, text, chat_id):
logger.debug("Sending Telegram message to chat_id=%s", chat_id)
print(f"Sending message: {text}")
url = f"https://api.telegram.org/bot{self.token}/sendMessage"
payload = {"chat_id": chat_id, "text": text}
response = requests.post(url, data=payload)
return {"status_code": response.status_code, "message": "Message sent"}
def _send_image(self, image_url, chat_id):
logger.debug("Sending Telegram image to chat_id=%s", chat_id)
print(f"Sending image: {image_url}")
url = f"https://api.telegram.org/bot{self.token}/sendPhoto"
payload = {"chat_id": chat_id, "photo": image_url}
response = requests.post(url, data=payload)
@@ -85,12 +82,5 @@ class TelegramTool(Tool):
def get_config_requirements(self):
return {
"token": {
"type": "string",
"label": "Bot Token",
"description": "Telegram bot token for authentication",
"required": True,
"secret": True,
"order": 1,
},
"token": {"type": "string", "description": "Bot token for authentication"},
}

View File

@@ -1,68 +0,0 @@
from application.agents.tools.base import Tool
THINK_TOOL_ID = "think"
THINK_TOOL_ENTRY = {
"name": "think",
"actions": [
{
"name": "reason",
"description": (
"Use this tool to think through your reasoning step by step "
"before deciding on your next action. Always reason before "
"searching or answering."
),
"active": True,
"parameters": {
"properties": {
"reasoning": {
"type": "string",
"description": "Your step-by-step reasoning and analysis",
"filled_by_llm": True,
"required": True,
}
}
},
}
],
}
class ThinkTool(Tool):
"""Pseudo-tool that captures chain-of-thought reasoning.
Returns a short acknowledgment so the LLM can continue.
The reasoning content is captured in tool_call data for transparency.
"""
def __init__(self, config=None):
pass
def execute_action(self, action_name: str, **kwargs):
return "Continue."
def get_actions_metadata(self):
return [
{
"name": "reason",
"description": (
"Use this tool to think through your reasoning step by step "
"before deciding on your next action. Always reason before "
"searching or answering."
),
"parameters": {
"properties": {
"reasoning": {
"type": "string",
"description": "Your step-by-step reasoning and analysis",
"filled_by_llm": True,
"required": True,
}
}
},
}
]
def get_config_requirements(self):
return {}

View File

@@ -36,7 +36,7 @@ class ToolManager:
def execute_action(self, tool_name, action_name, user_id=None, **kwargs):
if tool_name not in self.tools:
raise ValueError(f"Tool '{tool_name}' not loaded")
if tool_name in {"mcp_tool", "memory", "todo_list", "notes"} and user_id:
if tool_name in {"mcp_tool", "memory", "todo_list"} and user_id:
tool_config = self.config.get(tool_name, {})
tool = self.load_tool(tool_name, tool_config, user_id)
return tool.execute_action(action_name, **kwargs)

View File

@@ -2,10 +2,9 @@
from typing import Any, Dict, List, Optional, Type
from application.agents.agentic_agent import AgenticAgent
from application.agents.base import BaseAgent
from application.agents.classic_agent import ClassicAgent
from application.agents.research_agent import ResearchAgent
from application.agents.react_agent import ReActAgent
from application.agents.workflows.schemas import AgentType
@@ -37,8 +36,7 @@ class ToolFilterMixin:
return filtered_tools
class _WorkflowNodeMixin:
"""Common __init__ for all workflow node agents."""
class WorkflowNodeClassicAgent(ToolFilterMixin, ClassicAgent):
def __init__(
self,
@@ -59,25 +57,32 @@ class _WorkflowNodeMixin:
self._allowed_tool_ids = tool_ids or []
class WorkflowNodeClassicAgent(ToolFilterMixin, _WorkflowNodeMixin, ClassicAgent):
pass
class WorkflowNodeReActAgent(ToolFilterMixin, ReActAgent):
class WorkflowNodeAgenticAgent(ToolFilterMixin, _WorkflowNodeMixin, AgenticAgent):
pass
class WorkflowNodeResearchAgent(ToolFilterMixin, _WorkflowNodeMixin, ResearchAgent):
pass
def __init__(
self,
endpoint: str,
llm_name: str,
model_id: str,
api_key: str,
tool_ids: Optional[List[str]] = None,
**kwargs,
):
super().__init__(
endpoint=endpoint,
llm_name=llm_name,
model_id=model_id,
api_key=api_key,
**kwargs,
)
self._allowed_tool_ids = tool_ids or []
class WorkflowNodeAgentFactory:
_agents: Dict[AgentType, Type[BaseAgent]] = {
AgentType.CLASSIC: WorkflowNodeClassicAgent,
AgentType.REACT: WorkflowNodeClassicAgent, # backwards compat
AgentType.AGENTIC: WorkflowNodeAgenticAgent,
AgentType.RESEARCH: WorkflowNodeResearchAgent,
AgentType.REACT: WorkflowNodeReActAgent,
}
@classmethod

View File

@@ -18,8 +18,6 @@ class NodeType(str, Enum):
class AgentType(str, Enum):
CLASSIC = "classic"
REACT = "react"
AGENTIC = "agentic"
RESEARCH = "research"
class ExecutionStatus(str, Enum):

View File

@@ -7,7 +7,6 @@ from application.agents.workflows.cel_evaluator import CelEvaluationError, evalu
from application.agents.workflows.node_agent import WorkflowNodeAgentFactory
from application.agents.workflows.schemas import (
AgentNodeConfig,
AgentType,
ConditionNodeConfig,
ExecutionStatus,
NodeExecutionLog,
@@ -19,7 +18,6 @@ from application.core.json_schema_utils import (
JsonSchemaValidationError,
normalize_json_schema_payload,
)
from application.error import sanitize_api_error
from application.templates.namespaces import NamespaceManager
from application.templates.template_engine import TemplateEngine, TemplateRenderError
@@ -101,7 +99,6 @@ class WorkflowEngine:
log_entry["state_snapshot"] = dict(self.state)
self.execution_log.append(log_entry)
user_friendly_error = sanitize_api_error(e)
yield {
"type": "workflow_step",
"node_id": node.id,
@@ -109,9 +106,9 @@ class WorkflowEngine:
"node_title": node.title,
"status": "failed",
"state_snapshot": dict(self.state),
"error": user_friendly_error,
"error": str(e),
}
yield {"type": "error", "error": user_friendly_error}
yield {"type": "error", "error": str(e)}
break
log_entry["state_snapshot"] = dict(self.state)
self.execution_log.append(log_entry)
@@ -224,32 +221,18 @@ class WorkflowEngine:
f'Model "{node_model_id}" does not support structured output for node "{node.title}"'
)
factory_kwargs = {
"agent_type": node_config.agent_type,
"endpoint": self.agent.endpoint,
"llm_name": node_llm_name,
"model_id": node_model_id,
"api_key": node_api_key,
"tool_ids": node_config.tools,
"prompt": node_config.system_prompt,
"chat_history": self.agent.chat_history,
"decoded_token": self.agent.decoded_token,
"json_schema": node_json_schema,
}
# Agentic/research agents need retriever_config for on-demand search
if node_config.agent_type in (AgentType.AGENTIC, AgentType.RESEARCH):
factory_kwargs["retriever_config"] = {
"source": {"active_docs": node_config.sources} if node_config.sources else {},
"retriever_name": node_config.retriever or "classic",
"chunks": int(node_config.chunks) if node_config.chunks else 2,
"model_id": node_model_id,
"llm_name": node_llm_name,
"api_key": node_api_key,
"decoded_token": self.agent.decoded_token,
}
node_agent = WorkflowNodeAgentFactory.create(**factory_kwargs)
node_agent = WorkflowNodeAgentFactory.create(
agent_type=node_config.agent_type,
endpoint=self.agent.endpoint,
llm_name=node_llm_name,
model_id=node_model_id,
api_key=node_api_key,
tool_ids=node_config.tools,
prompt=node_config.system_prompt,
chat_history=self.agent.chat_history,
decoded_token=self.agent.decoded_token,
json_schema=node_json_schema,
)
full_response_parts: List[str] = []
structured_response_parts: List[str] = []

View File

@@ -74,10 +74,21 @@ class AnswerResource(Resource, BaseAnswerResource):
decoded_token = getattr(request, "decoded_token", None)
processor = StreamProcessor(data, decoded_token)
try:
agent = processor.build_agent(data.get("question", ""))
processor.initialize()
if not processor.decoded_token:
return make_response({"error": "Unauthorized"}, 401)
docs_together, docs_list = processor.pre_fetch_docs(
data.get("question", "")
)
tools_data = processor.pre_fetch_tools()
agent = processor.create_agent(
docs_together=docs_together,
docs=docs_list,
tools_data=tools_data,
)
if error := self.check_usage(processor.agent_config):
return error

View File

@@ -15,7 +15,6 @@ from application.core.model_utils import (
from application.core.mongo_db import MongoDB
from application.core.settings import settings
from application.error import sanitize_api_error
from application.llm.llm_creator import LLMCreator
from application.utils import check_required_fields
@@ -206,12 +205,9 @@ class BaseAnswerResource:
is_structured = False
schema_info = None
structured_chunks = []
query_metadata = {}
for line in agent.gen(query=question):
if "metadata" in line:
query_metadata.update(line["metadata"])
elif "answer" in line:
if "answer" in line:
response_full += str(line["answer"])
if line.get("structured"):
is_structured = True
@@ -244,14 +240,7 @@ class BaseAnswerResource:
data = json.dumps({"type": "thought", "thought": line["thought"]})
yield f"data: {data}\n\n"
elif "type" in line:
if line.get("type") == "error":
sanitized_error = {
"type": "error",
"error": sanitize_api_error(line.get("error", "An error occurred"))
}
data = json.dumps(sanitized_error)
else:
data = json.dumps(line)
data = json.dumps(line)
yield f"data: {data}\n\n"
if is_structured and structured_chunks:
structured_data = {
@@ -298,7 +287,6 @@ class BaseAnswerResource:
is_shared_usage=is_shared_usage,
shared_token=shared_token,
attachment_ids=attachment_ids,
metadata=query_metadata if query_metadata else None,
)
# Persist compression metadata/summary if it exists and wasn't saved mid-execution
compression_meta = getattr(agent, "compression_metadata", None)
@@ -388,7 +376,6 @@ class BaseAnswerResource:
is_shared_usage=is_shared_usage,
shared_token=shared_token,
attachment_ids=attachment_ids,
metadata=query_metadata if query_metadata else None,
)
compression_meta = getattr(agent, "compression_metadata", None)
compression_saved = getattr(agent, "compression_saved", False)

View File

@@ -80,7 +80,7 @@ class StreamResource(Resource, BaseAnswerResource):
decoded_token = getattr(request, "decoded_token", None)
processor = StreamProcessor(data, decoded_token)
try:
agent = processor.build_agent(data["question"])
processor.initialize()
if not processor.decoded_token:
return Response(
self.error_stream_generate("Unauthorized"),
@@ -88,6 +88,13 @@ class StreamResource(Resource, BaseAnswerResource):
mimetype="text/event-stream",
)
docs_together, docs_list = processor.pre_fetch_docs(data["question"])
tools_data = processor.pre_fetch_tools()
agent = processor.create_agent(
docs_together=docs_together, docs=docs_list, tools_data=tools_data
)
if error := self.check_usage(processor.agent_config):
return error
return Response(

View File

@@ -60,7 +60,6 @@ class ConversationService:
is_shared_usage: bool = False,
shared_token: Optional[str] = None,
attachment_ids: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
) -> str:
"""Save or update a conversation in the database"""
if decoded_token is None:
@@ -94,11 +93,6 @@ class ConversationService:
f"queries.{index}.timestamp": current_time,
f"queries.{index}.attachments": attachment_ids,
f"queries.{index}.model_id": model_id,
**(
{f"queries.{index}.metadata": metadata}
if metadata
else {}
),
}
},
)
@@ -130,7 +124,6 @@ class ConversationService:
"timestamp": current_time,
"attachments": attachment_ids,
"model_id": model_id,
**({"metadata": metadata} if metadata else {}),
}
}
},
@@ -163,24 +156,22 @@ class ConversationService:
if not completion or not completion.strip():
completion = question[:50] if question else "New Conversation"
query_doc = {
"prompt": question,
"response": response,
"thought": thought,
"sources": sources,
"tool_calls": tool_calls,
"timestamp": current_time,
"attachments": attachment_ids,
"model_id": model_id,
}
if metadata:
query_doc["metadata"] = metadata
conversation_data = {
"user": user_id,
"date": current_time,
"name": completion,
"queries": [query_doc],
"queries": [
{
"prompt": question,
"response": response,
"thought": thought,
"sources": sources,
"tool_calls": tool_calls,
"timestamp": current_time,
"attachments": attachment_ids,
"model_id": model_id,
}
],
}
if api_key:

View File

@@ -38,23 +38,13 @@ def get_prompt(prompt_id: str, prompts_collection=None) -> str:
current_dir = Path(__file__).resolve().parents[3]
prompts_dir = current_dir / "prompts"
# Maps for classic agent types
CLASSIC_PRESETS = {
preset_mapping = {
"default": "chat_combine_default.txt",
"creative": "chat_combine_creative.txt",
"strict": "chat_combine_strict.txt",
"reduce": "chat_reduce_prompt.txt",
}
# Agentic counterparts — same styles, but with search tool instructions
AGENTIC_PRESETS = {
"default": "agentic/default.txt",
"creative": "agentic/creative.txt",
"strict": "agentic/strict.txt",
}
preset_mapping = {**CLASSIC_PRESETS, **{f"agentic_{k}": v for k, v in AGENTIC_PRESETS.items()}}
if prompt_id in preset_mapping:
file_path = os.path.join(prompts_dir, preset_mapping[prompt_id])
try:
@@ -101,7 +91,6 @@ class StreamProcessor:
self.is_shared_usage = False
self.shared_token = None
self.agent_id = self.data.get("agent_id")
self.agent_key = None
self.model_id: Optional[str] = None
self.conversation_service = ConversationService()
self.compression_orchestrator = CompressionOrchestrator(
@@ -122,29 +111,6 @@ class StreamProcessor:
self._load_conversation_history()
self._process_attachments()
def build_agent(self, question: str):
"""One call to go from request data to a ready-to-run agent.
Combines initialize(), pre_fetch_docs(), pre_fetch_tools(), and
create_agent() into a single convenience method.
"""
self.initialize()
agent_type = self.agent_config.get("agent_type", "classic")
# Agentic/research agents skip pre-fetch — the LLM searches on-demand via tools
if agent_type in ("agentic", "research"):
tools_data = self.pre_fetch_tools()
return self.create_agent(tools_data=tools_data)
docs_together, docs_list = self.pre_fetch_docs(question)
tools_data = self.pre_fetch_tools()
return self.create_agent(
docs_together=docs_together,
docs=docs_list,
tools_data=tools_data,
)
def _load_conversation_history(self):
"""Load conversation history either from DB or request"""
if self.conversation_id and self.initial_user_id:
@@ -158,17 +124,9 @@ class StreamProcessor:
if settings.ENABLE_CONVERSATION_COMPRESSION:
self._handle_compression(conversation)
else:
# Original behavior - load all history (include metadata if present)
# Original behavior - load all history
self.history = [
{
"prompt": query["prompt"],
"response": query["response"],
**(
{"metadata": query["metadata"]}
if "metadata" in query
else {}
),
}
{"prompt": query["prompt"], "response": query["response"]}
for query in conversation.get("queries", [])
]
else:
@@ -177,8 +135,14 @@ class StreamProcessor:
)
def _handle_compression(self, conversation: Dict[str, Any]):
"""Handle conversation compression logic using orchestrator."""
"""
Handle conversation compression logic using orchestrator.
Args:
conversation: Full conversation document
"""
try:
# Use orchestrator to handle all compression logic
result = self.compression_orchestrator.compress_if_needed(
conversation_id=self.conversation_id,
user_id=self.initial_user_id,
@@ -189,15 +153,12 @@ class StreamProcessor:
if not result.success:
logger.error(f"Compression failed: {result.error}, using full history")
self.history = [
{
"prompt": query["prompt"],
"response": query["response"],
**({"metadata": query["metadata"]} if "metadata" in query else {}),
}
{"prompt": query["prompt"], "response": query["response"]}
for query in conversation.get("queries", [])
]
return
# Set compressed summary if compression was performed
if result.compression_performed and result.compressed_summary:
self.compressed_summary = result.compressed_summary
self.compressed_summary_tokens = TokenCounter.count_message_tokens(
@@ -208,27 +169,17 @@ class StreamProcessor:
f"+ {len(result.recent_queries)} recent messages"
)
# Build history from recent queries
self.history = result.as_history()
# Preserve metadata from recent queries (as_history only has prompt/response)
recent = result.recent_queries if result.recent_queries else conversation.get("queries", [])
for i, entry in enumerate(self.history):
# Match by index from the end of recent queries
offset = len(recent) - len(self.history)
qi = offset + i
if 0 <= qi < len(recent) and "metadata" in recent[qi]:
entry["metadata"] = recent[qi]["metadata"]
except Exception as e:
logger.error(
f"Error handling compression, falling back to standard history: {str(e)}",
exc_info=True,
)
# Fallback to original behavior
self.history = [
{
"prompt": query["prompt"],
"response": query["response"],
**({"metadata": query["metadata"]} if "metadata" in query else {}),
}
{"prompt": query["prompt"], "response": query["response"]}
for query in conversation.get("queries", [])
]
@@ -240,6 +191,9 @@ class StreamProcessor:
)
def _get_attachments_content(self, attachment_ids, user_id):
"""
Retrieve content from attachment documents based on their IDs.
"""
if not attachment_ids:
return []
attachments = []
@@ -248,6 +202,7 @@ class StreamProcessor:
attachment_doc = self.attachments_collection.find_one(
{"_id": ObjectId(attachment_id), "user": user_id}
)
if attachment_doc:
attachments.append(attachment_doc)
except Exception as e:
@@ -277,6 +232,7 @@ class StreamProcessor:
)
self.model_id = requested_model
else:
# Check if agent has a default model configured
agent_default_model = self.agent_config.get("default_model_id", "")
if agent_default_model and validate_model_id(agent_default_model):
self.model_id = agent_default_model
@@ -329,6 +285,7 @@ class StreamProcessor:
data["source"] = "default"
else:
data["source"] = None
# Handle multiple sources
sources = data.get("sources", [])
if sources and isinstance(sources, list):
@@ -354,6 +311,7 @@ class StreamProcessor:
else:
data["sources"] = []
# Preserve model configuration from agent
data["default_model_id"] = data.get("default_model_id", "")
return data
@@ -392,74 +350,63 @@ class StreamProcessor:
self.source = {}
self.all_sources = []
def _resolve_agent_id(self) -> Optional[str]:
"""Resolve agent_id from request, then fall back to conversation context."""
request_agent_id = self.data.get("agent_id")
if request_agent_id:
return str(request_agent_id)
if not self.conversation_id or not self.initial_user_id:
return None
try:
conversation = self.conversation_service.get_conversation(
self.conversation_id, self.initial_user_id
)
except Exception:
return None
if not conversation:
return None
conversation_agent_id = conversation.get("agent_id")
if conversation_agent_id:
return str(conversation_agent_id)
return None
def _configure_agent(self):
"""Configure the agent based on request data.
Unified flow: resolve the effective API key, then extract config once.
"""
agent_id = self._resolve_agent_id()
"""Configure the agent based on request data"""
agent_id = self.data.get("agent_id")
self.agent_key, self.is_shared_usage, self.shared_token = self._get_agent_key(
agent_id, self.initial_user_id
)
self.agent_id = str(agent_id) if agent_id else None
# Determine the effective API key (explicit > agent-derived)
effective_key = self.data.get("api_key") or self.agent_key
if effective_key:
data_key = self._get_data_from_api_key(effective_key)
api_key = self.data.get("api_key")
if api_key:
data_key = self._get_data_from_api_key(api_key)
if data_key.get("_id"):
self.agent_id = str(data_key.get("_id"))
self.agent_config.update(
{
"prompt_id": data_key.get("prompt_id", "default"),
"agent_type": data_key.get("agent_type", settings.AGENT_NAME),
"user_api_key": effective_key,
"user_api_key": api_key,
"json_schema": data_key.get("json_schema"),
"default_model_id": data_key.get("default_model_id", ""),
"models": data_key.get("models", []),
}
)
# Set identity context
if self.data.get("api_key"):
# External API key: use the key owner's identity
self.initial_user_id = data_key.get("user")
self.decoded_token = {"sub": data_key.get("user")}
elif self.is_shared_usage:
# Shared agent: keep the caller's identity
pass
else:
# Owner using their own agent
self.decoded_token = {"sub": data_key.get("user")}
self.initial_user_id = data_key.get("user")
self.decoded_token = {"sub": data_key.get("user")}
if data_key.get("source"):
self.source = {"active_docs": data_key["source"]}
if data_key.get("workflow"):
self.agent_config["workflow"] = data_key["workflow"]
self.agent_config["workflow_owner"] = data_key.get("user")
if data_key.get("retriever"):
self.retriever_config["retriever_name"] = data_key["retriever"]
if data_key.get("chunks") is not None:
try:
self.retriever_config["chunks"] = int(data_key["chunks"])
except (ValueError, TypeError):
logger.warning(
f"Invalid chunks value: {data_key['chunks']}, using default value 2"
)
self.retriever_config["chunks"] = 2
elif self.agent_key:
data_key = self._get_data_from_api_key(self.agent_key)
if data_key.get("_id"):
self.agent_id = str(data_key.get("_id"))
self.agent_config.update(
{
"prompt_id": data_key.get("prompt_id", "default"),
"agent_type": data_key.get("agent_type", settings.AGENT_NAME),
"user_api_key": self.agent_key,
"json_schema": data_key.get("json_schema"),
"default_model_id": data_key.get("default_model_id", ""),
}
)
self.decoded_token = (
self.decoded_token
if self.is_shared_usage
else {"sub": data_key.get("user")}
)
if data_key.get("source"):
self.source = {"active_docs": data_key["source"]}
if data_key.get("workflow"):
@@ -476,7 +423,6 @@ class StreamProcessor:
)
self.retriever_config["chunks"] = 2
else:
# No API key — default/workflow configuration
agent_type = settings.AGENT_NAME
if self.data.get("workflow") and isinstance(
self.data.get("workflow"), dict
@@ -525,7 +471,7 @@ class StreamProcessor:
def pre_fetch_docs(self, question: str) -> tuple[Optional[str], Optional[list]]:
"""Pre-fetch documents for template rendering before agent creation"""
if self.data.get("isNoneDoc", False) and not self.agent_id:
if self.data.get("isNoneDoc", False):
logger.info("Pre-fetch skipped: isNoneDoc=True")
return None, None
try:
@@ -559,7 +505,12 @@ class StreamProcessor:
return None, None
def pre_fetch_tools(self) -> Optional[Dict[str, Any]]:
"""Pre-fetch tool data for template rendering before agent creation"""
"""Pre-fetch tool data for template rendering before agent creation
Can be controlled via:
1. Global setting: ENABLE_TOOL_PREFETCH in .env
2. Per-request: disable_tool_prefetch in request data
"""
if not settings.ENABLE_TOOL_PREFETCH:
logger.info(
"Tool pre-fetching disabled globally via ENABLE_TOOL_PREFETCH setting"
@@ -778,21 +729,11 @@ class StreamProcessor:
tools_data: Optional[Dict[str, Any]] = None,
):
"""Create and return the configured agent with rendered prompt"""
agent_type = self.agent_config["agent_type"]
# For agentic agents, swap standard presets for their agentic
# counterparts (which include search tool instructions instead of
# {summaries}). Custom / user-provided prompts pass through as-is.
raw_prompt = self._get_prompt_content()
if raw_prompt is None:
prompt_id = self.agent_config.get("prompt_id", "default")
agentic_presets = {"default", "creative", "strict"}
if agent_type in ("agentic", "research") and prompt_id in agentic_presets:
raw_prompt = get_prompt(
f"agentic_{prompt_id}", self.prompts_collection
)
else:
raw_prompt = get_prompt(prompt_id, self.prompts_collection)
raw_prompt = get_prompt(
self.agent_config["prompt_id"], self.prompts_collection
)
self._prompt_content = raw_prompt
rendered_prompt = self.prompt_renderer.render_prompt(
@@ -812,35 +753,7 @@ class StreamProcessor:
)
system_api_key = get_api_key_for_provider(provider or settings.LLM_PROVIDER)
# Create LLM and handler (dependency injection)
from application.llm.llm_creator import LLMCreator
from application.llm.handlers.handler_creator import LLMHandlerCreator
from application.agents.tool_executor import ToolExecutor
# Compute backup models: agent's configured models minus the active one
agent_models = self.agent_config.get("models", [])
backup_models = [m for m in agent_models if m != self.model_id]
llm = LLMCreator.create_llm(
provider or settings.LLM_PROVIDER,
api_key=system_api_key,
user_api_key=self.agent_config["user_api_key"],
decoded_token=self.decoded_token,
model_id=self.model_id,
agent_id=self.agent_id,
backup_models=backup_models,
)
llm_handler = LLMHandlerCreator.create_handler(
provider if provider else "default"
)
user = self.decoded_token.get("sub") if self.decoded_token else None
tool_executor = ToolExecutor(
user_api_key=self.agent_config["user_api_key"],
user=user,
decoded_token=self.decoded_token,
)
tool_executor.conversation_id = self.conversation_id
agent_type = self.agent_config["agent_type"]
# Base agent kwargs
agent_kwargs = {
@@ -857,31 +770,10 @@ class StreamProcessor:
"attachments": self.attachments,
"json_schema": self.agent_config.get("json_schema"),
"compressed_summary": self.compressed_summary,
"llm": llm,
"llm_handler": llm_handler,
"tool_executor": tool_executor,
}
# Type-specific kwargs
if agent_type in ("agentic", "research"):
agent_kwargs["retriever_config"] = {
"source": self.source,
"retriever_name": self.retriever_config.get(
"retriever_name", "classic"
),
"chunks": self.retriever_config.get("chunks", 2),
"doc_token_limit": self.retriever_config.get(
"doc_token_limit", 50000
),
"model_id": self.model_id,
"user_api_key": self.agent_config["user_api_key"],
"agent_id": self.agent_id,
"llm_name": provider or settings.LLM_PROVIDER,
"api_key": system_api_key,
"decoded_token": self.decoded_token,
}
elif agent_type == "workflow":
# Workflow-specific kwargs for workflow agents
if agent_type == "workflow":
workflow_config = self.agent_config.get("workflow")
if isinstance(workflow_config, str):
agent_kwargs["workflow_id"] = workflow_config

View File

@@ -146,19 +146,20 @@ class ConnectorsCallback(Resource):
session_token = str(uuid.uuid4())
try:
if provider == "google_drive":
credentials = auth.create_credentials_from_token_info(token_info)
service = auth.build_drive_service(credentials)
user_info = service.about().get(fields="user").execute()
user_email = user_info.get('user', {}).get('emailAddress', 'Connected User')
else:
user_email = token_info.get('user_info', {}).get('email', 'Connected User')
credentials = auth.create_credentials_from_token_info(token_info)
service = auth.build_drive_service(credentials)
user_info = service.about().get(fields="user").execute()
user_email = user_info.get('user', {}).get('emailAddress', 'Connected User')
except Exception as e:
current_app.logger.warning(f"Could not get user info: {e}")
user_email = 'Connected User'
sanitized_token_info = auth.sanitize_token_info(token_info)
sanitized_token_info = {
"access_token": token_info.get("access_token"),
"refresh_token": token_info.get("refresh_token"),
"token_uri": token_info.get("token_uri"),
"expiry": token_info.get("expiry")
}
sessions_collection.find_one_and_update(
{"_id": ObjectId(state_object_id), "provider": provider},
@@ -200,12 +201,12 @@ class ConnectorsCallback(Resource):
@connectors_ns.route("/api/connectors/files")
class ConnectorFiles(Resource):
@api.expect(api.model("ConnectorFilesModel", {
"provider": fields.String(required=True),
"session_token": fields.String(required=True),
"folder_id": fields.String(required=False),
"limit": fields.Integer(required=False),
"provider": fields.String(required=True),
"session_token": fields.String(required=True),
"folder_id": fields.String(required=False),
"limit": fields.Integer(required=False),
"page_token": fields.String(required=False),
"search_query": fields.String(required=False),
"search_query": fields.String(required=False)
}))
@api.doc(description="List files from a connector provider (supports pagination and search)")
def post(self):
@@ -213,8 +214,11 @@ class ConnectorFiles(Resource):
data = request.get_json()
provider = data.get('provider')
session_token = data.get('session_token')
folder_id = data.get('folder_id')
limit = data.get('limit', 10)
page_token = data.get('page_token')
search_query = data.get('search_query')
if not provider or not session_token:
return make_response(jsonify({"success": False, "error": "provider and session_token are required"}), 400)
@@ -227,12 +231,15 @@ class ConnectorFiles(Resource):
return make_response(jsonify({"success": False, "error": "Invalid or unauthorized session"}), 401)
loader = ConnectorCreator.create_connector(provider, session_token)
generic_keys = {'provider', 'session_token'}
input_config = {
k: v for k, v in data.items() if k not in generic_keys
'limit': limit,
'list_only': True,
'session_token': session_token,
'folder_id': folder_id,
'page_token': page_token
}
input_config['list_only'] = True
if search_query:
input_config['search_query'] = search_query
documents = loader.load_data(input_config)
@@ -299,7 +306,12 @@ class ConnectorValidateSession(Resource):
if is_expired and token_info.get('refresh_token'):
try:
refreshed_token_info = auth.refresh_access_token(token_info.get('refresh_token'))
sanitized_token_info = auth.sanitize_token_info(refreshed_token_info)
sanitized_token_info = {
"access_token": refreshed_token_info.get("access_token"),
"refresh_token": refreshed_token_info.get("refresh_token"),
"token_uri": refreshed_token_info.get("token_uri"),
"expiry": refreshed_token_info.get("expiry")
}
sessions_collection.update_one(
{"session_token": session_token},
{"$set": {"token_info": sanitized_token_info}}
@@ -316,18 +328,12 @@ class ConnectorValidateSession(Resource):
"error": "Session token has expired. Please reconnect."
}), 401)
_base_fields = {"access_token", "refresh_token", "token_uri", "expiry"}
provider_extras = {k: v for k, v in token_info.items() if k not in _base_fields}
response_data = {
return make_response(jsonify({
"success": True,
"expired": False,
"user_email": session.get('user_email', 'Connected User'),
"access_token": token_info.get('access_token'),
**provider_extras,
}
return make_response(jsonify(response_data), 200)
"access_token": token_info.get('access_token')
}), 200)
except Exception as e:
current_app.logger.error(f"Error validating connector session: {e}", exc_info=True)
return make_response(jsonify({"success": False, "error": "Failed to validate session"}), 500)

View File

@@ -5,7 +5,7 @@ Provides virtual folder organization for agents (Google Drive-like structure).
import datetime
from bson.objectid import ObjectId
from flask import current_app, jsonify, make_response, request
from flask import jsonify, make_response, request
from flask_restx import Namespace, Resource, fields
from application.api import api
@@ -19,11 +19,6 @@ agents_folders_ns = Namespace(
)
def _folder_error_response(message: str, err: Exception):
current_app.logger.error(f"{message}: {err}", exc_info=True)
return make_response(jsonify({"success": False, "message": message}), 400)
@agents_folders_ns.route("/")
class AgentFolders(Resource):
@api.doc(description="Get all folders for the user")
@@ -45,8 +40,8 @@ class AgentFolders(Resource):
for f in folders
]
return make_response(jsonify({"folders": result}), 200)
except Exception as err:
return _folder_error_response("Failed to fetch folders", err)
except Exception as e:
return make_response(jsonify({"success": False, "message": str(e)}), 400)
@api.doc(description="Create a new folder")
@api.expect(
@@ -87,8 +82,8 @@ class AgentFolders(Resource):
jsonify({"id": str(result.inserted_id), "name": data["name"], "parent_id": parent_id}),
201,
)
except Exception as err:
return _folder_error_response("Failed to create folder", err)
except Exception as e:
return make_response(jsonify({"success": False, "message": str(e)}), 400)
@agents_folders_ns.route("/<string:folder_id>")
@@ -122,8 +117,8 @@ class AgentFolder(Resource):
}),
200,
)
except Exception as err:
return _folder_error_response("Failed to fetch folder", err)
except Exception as e:
return make_response(jsonify({"success": False, "message": str(e)}), 400)
@api.doc(description="Update a folder")
def put(self, folder_id):
@@ -150,8 +145,8 @@ class AgentFolder(Resource):
if result.matched_count == 0:
return make_response(jsonify({"success": False, "message": "Folder not found"}), 404)
return make_response(jsonify({"success": True}), 200)
except Exception as err:
return _folder_error_response("Failed to update folder", err)
except Exception as e:
return make_response(jsonify({"success": False, "message": str(e)}), 400)
@api.doc(description="Delete a folder")
def delete(self, folder_id):
@@ -170,8 +165,8 @@ class AgentFolder(Resource):
if result.deleted_count == 0:
return make_response(jsonify({"success": False, "message": "Folder not found"}), 404)
return make_response(jsonify({"success": True}), 200)
except Exception as err:
return _folder_error_response("Failed to delete folder", err)
except Exception as e:
return make_response(jsonify({"success": False, "message": str(e)}), 400)
@agents_folders_ns.route("/move_agent")
@@ -216,8 +211,8 @@ class MoveAgentToFolder(Resource):
)
return make_response(jsonify({"success": True}), 200)
except Exception as err:
return _folder_error_response("Failed to move agent", err)
except Exception as e:
return make_response(jsonify({"success": False, "message": str(e)}), 400)
@agents_folders_ns.route("/bulk_move")
@@ -262,5 +257,5 @@ class BulkMoveAgents(Resource):
{"$unset": {"folder_id": ""}},
)
return make_response(jsonify({"success": True}), 200)
except Exception as err:
return _folder_error_response("Failed to move agents", err)
except Exception as e:
return make_response(jsonify({"success": False, "message": str(e)}), 400)

View File

@@ -104,8 +104,6 @@ AGENT_TYPE_SCHEMAS = {
}
AGENT_TYPE_SCHEMAS["react"] = AGENT_TYPE_SCHEMAS["classic"]
AGENT_TYPE_SCHEMAS["agentic"] = AGENT_TYPE_SCHEMAS["classic"]
AGENT_TYPE_SCHEMAS["research"] = AGENT_TYPE_SCHEMAS["classic"]
AGENT_TYPE_SCHEMAS["openai"] = AGENT_TYPE_SCHEMAS["classic"]
@@ -742,7 +740,13 @@ class UpdateAgent(Resource):
request, existing_agent.get("image", ""), user, storage
)
if error:
return error
current_app.logger.error(
f"Image upload error for agent {agent_id}: {error}"
)
return make_response(
jsonify({"success": False, "message": f"Image upload failed: {error}"}),
400,
)
update_fields = {}
allowed_fields = [
"name",

View File

@@ -1,36 +1,15 @@
"""File attachments and media routes."""
import os
import tempfile
from pathlib import Path
from bson.objectid import ObjectId
from flask import current_app, jsonify, make_response, request
from flask_restx import fields, Namespace, Resource
from application.api import api
from application.cache import get_redis_instance
from application.api.user.base import agents_collection, storage
from application.api.user.tasks import store_attachment
from application.core.settings import settings
from application.stt.constants import (
SUPPORTED_AUDIO_EXTENSIONS,
SUPPORTED_AUDIO_MIME_TYPES,
)
from application.stt.upload_limits import (
AudioFileTooLargeError,
build_stt_file_size_limit_message,
enforce_audio_file_size_limit,
is_audio_filename,
)
from application.stt.live_session import (
apply_live_stt_hypothesis,
create_live_stt_session,
delete_live_stt_session,
finalize_live_stt_session,
get_live_stt_transcript_text,
load_live_stt_session,
save_live_stt_session,
)
from application.stt.stt_creator import STTCreator
from application.tts.tts_creator import TTSCreator
from application.utils import safe_filename
@@ -40,74 +19,6 @@ attachments_ns = Namespace(
)
def _resolve_authenticated_user():
decoded_token = getattr(request, "decoded_token", None)
api_key = request.form.get("api_key") or request.args.get("api_key")
if decoded_token:
return safe_filename(decoded_token.get("sub"))
if api_key:
from application.api.user.base import agents_collection
agent = agents_collection.find_one({"key": api_key})
if not agent:
return make_response(
jsonify({"success": False, "message": "Invalid API key"}), 401
)
return safe_filename(agent.get("user"))
return None
def _get_uploaded_file_size(file) -> int:
try:
current_position = file.stream.tell()
file.stream.seek(0, os.SEEK_END)
size_bytes = file.stream.tell()
file.stream.seek(current_position)
return size_bytes
except Exception:
return 0
def _is_supported_audio_mimetype(mimetype: str) -> bool:
if not mimetype:
return True
normalized = mimetype.split(";")[0].strip().lower()
return normalized.startswith("audio/") or normalized in SUPPORTED_AUDIO_MIME_TYPES
def _enforce_uploaded_audio_size_limit(file, filename: str) -> None:
if not is_audio_filename(filename):
return
size_bytes = _get_uploaded_file_size(file)
if size_bytes:
enforce_audio_file_size_limit(size_bytes)
def _get_store_attachment_user_error(exc: Exception) -> str:
if isinstance(exc, AudioFileTooLargeError):
return build_stt_file_size_limit_message()
return "Failed to process file"
def _require_live_stt_redis():
redis_client = get_redis_instance()
if redis_client:
return redis_client
return make_response(
jsonify({"success": False, "message": "Live transcription is unavailable"}),
503,
)
def _parse_bool_form_value(value: str | None) -> bool:
if value is None:
return False
return value.strip().lower() in {"1", "true", "yes", "on"}
@attachments_ns.route("/store_attachment")
class StoreAttachment(Resource):
@api.expect(
@@ -125,9 +36,8 @@ class StoreAttachment(Resource):
description="Stores one or multiple attachments without vectorization or training. Supports user or API key authentication."
)
def post(self):
auth_user = _resolve_authenticated_user()
if hasattr(auth_user, "status_code"):
return auth_user
decoded_token = getattr(request, "decoded_token", None)
api_key = request.form.get("api_key") or request.args.get("api_key")
files = request.files.getlist("file")
if not files:
@@ -141,16 +51,22 @@ class StoreAttachment(Resource):
400,
)
user = auth_user
if not user:
user = None
if decoded_token:
user = safe_filename(decoded_token.get("sub"))
elif api_key:
agent = agents_collection.find_one({"key": api_key})
if not agent:
return make_response(
jsonify({"success": False, "message": "Invalid API key"}), 401
)
user = safe_filename(agent.get("user"))
else:
return make_response(
jsonify({"success": False, "message": "Authentication required"}), 401
)
try:
from application.api.user.tasks import store_attachment
from application.api.user.base import storage
tasks = []
errors = []
original_file_count = len(files)
@@ -159,7 +75,6 @@ class StoreAttachment(Resource):
try:
attachment_id = ObjectId()
original_filename = safe_filename(os.path.basename(file.filename))
_enforce_uploaded_audio_size_limit(file, original_filename)
relative_path = f"{settings.UPLOAD_FOLDER}/{user}/attachments/{str(attachment_id)}/{original_filename}"
metadata = storage.save_file(file, relative_path)
@@ -175,31 +90,15 @@ class StoreAttachment(Resource):
"task_id": task.id,
"filename": original_filename,
"attachment_id": str(attachment_id),
"upload_index": idx,
})
except Exception as file_err:
current_app.logger.error(f"Error processing file {idx} ({file.filename}): {file_err}", exc_info=True)
errors.append({
"upload_index": idx,
"filename": file.filename,
"error": _get_store_attachment_user_error(file_err),
"error": str(file_err)
})
if not tasks:
if errors and all(
error.get("error") == build_stt_file_size_limit_message()
for error in errors
):
return make_response(
jsonify(
{
"success": False,
"message": build_stt_file_size_limit_message(),
"errors": errors,
}
),
413,
)
return make_response(
jsonify({"status": "error", "message": "No valid files to upload"}),
400,
@@ -236,385 +135,11 @@ class StoreAttachment(Resource):
return make_response(jsonify({"success": False, "error": "Failed to store attachment"}), 400)
@attachments_ns.route("/stt")
class SpeechToText(Resource):
@api.expect(
api.model(
"SpeechToTextModel",
{
"file": fields.Raw(required=True, description="Audio file"),
"language": fields.String(
required=False, description="Optional transcription language hint"
),
},
)
)
@api.doc(description="Transcribe an uploaded audio file")
def post(self):
auth_user = _resolve_authenticated_user()
if hasattr(auth_user, "status_code"):
return auth_user
if not auth_user:
return make_response(
jsonify({"success": False, "message": "Authentication required"}),
401,
)
file = request.files.get("file")
if not file or file.filename == "":
return make_response(
jsonify({"success": False, "message": "Missing file"}),
400,
)
filename = safe_filename(os.path.basename(file.filename))
suffix = Path(filename).suffix.lower()
if suffix not in SUPPORTED_AUDIO_EXTENSIONS:
return make_response(
jsonify({"success": False, "message": "Unsupported audio format"}),
400,
)
if not _is_supported_audio_mimetype(file.mimetype or ""):
return make_response(
jsonify({"success": False, "message": "Unsupported audio MIME type"}),
400,
)
try:
_enforce_uploaded_audio_size_limit(file, filename)
except AudioFileTooLargeError:
return make_response(
jsonify(
{
"success": False,
"message": build_stt_file_size_limit_message(),
}
),
413,
)
temp_path = None
try:
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as temp_file:
file.save(temp_file.name)
temp_path = Path(temp_file.name)
stt_instance = STTCreator.create_stt(settings.STT_PROVIDER)
transcript = stt_instance.transcribe(
temp_path,
language=request.form.get("language") or settings.STT_LANGUAGE,
timestamps=settings.STT_ENABLE_TIMESTAMPS,
diarize=settings.STT_ENABLE_DIARIZATION,
)
return make_response(jsonify({"success": True, **transcript}), 200)
except Exception as err:
current_app.logger.error(f"Error transcribing audio: {err}", exc_info=True)
return make_response(
jsonify({"success": False, "message": "Failed to transcribe audio"}),
400,
)
finally:
if temp_path and temp_path.exists():
temp_path.unlink()
@attachments_ns.route("/stt/live/start")
class LiveSpeechToTextStart(Resource):
@api.doc(description="Start a live speech-to-text session")
def post(self):
auth_user = _resolve_authenticated_user()
if hasattr(auth_user, "status_code"):
return auth_user
if not auth_user:
return make_response(
jsonify({"success": False, "message": "Authentication required"}),
401,
)
redis_client = _require_live_stt_redis()
if hasattr(redis_client, "status_code"):
return redis_client
payload = request.get_json(silent=True) or {}
session_state = create_live_stt_session(
user=auth_user,
language=payload.get("language") or settings.STT_LANGUAGE,
)
save_live_stt_session(redis_client, session_state)
return make_response(
jsonify(
{
"success": True,
"session_id": session_state["session_id"],
"language": session_state.get("language"),
"committed_text": "",
"mutable_text": "",
"previous_hypothesis": "",
"latest_hypothesis": "",
"finalized_text": "",
"pending_text": "",
"transcript_text": "",
}
),
200,
)
@attachments_ns.route("/stt/live/chunk")
class LiveSpeechToTextChunk(Resource):
@api.expect(
api.model(
"LiveSpeechToTextChunkModel",
{
"session_id": fields.String(
required=True, description="Live transcription session ID"
),
"chunk_index": fields.Integer(
required=True, description="Sequential chunk index"
),
"is_silence": fields.Boolean(
required=False,
description="Whether the latest capture window was mostly silence",
),
"file": fields.Raw(required=True, description="Audio chunk"),
},
)
)
@api.doc(description="Transcribe a chunk for a live speech-to-text session")
def post(self):
auth_user = _resolve_authenticated_user()
if hasattr(auth_user, "status_code"):
return auth_user
if not auth_user:
return make_response(
jsonify({"success": False, "message": "Authentication required"}),
401,
)
redis_client = _require_live_stt_redis()
if hasattr(redis_client, "status_code"):
return redis_client
session_id = request.form.get("session_id", "").strip()
if not session_id:
return make_response(
jsonify({"success": False, "message": "Missing session_id"}),
400,
)
session_state = load_live_stt_session(redis_client, session_id)
if not session_state:
return make_response(
jsonify(
{
"success": False,
"message": "Live transcription session not found",
}
),
404,
)
if safe_filename(str(session_state.get("user", ""))) != auth_user:
return make_response(
jsonify({"success": False, "message": "Forbidden"}),
403,
)
chunk_index_raw = request.form.get("chunk_index", "").strip()
if chunk_index_raw == "":
return make_response(
jsonify({"success": False, "message": "Missing chunk_index"}),
400,
)
try:
chunk_index = int(chunk_index_raw)
except ValueError:
return make_response(
jsonify({"success": False, "message": "Invalid chunk_index"}),
400,
)
is_silence = _parse_bool_form_value(request.form.get("is_silence"))
file = request.files.get("file")
if not file or file.filename == "":
return make_response(
jsonify({"success": False, "message": "Missing file"}),
400,
)
filename = safe_filename(os.path.basename(file.filename))
suffix = Path(filename).suffix.lower()
if suffix not in SUPPORTED_AUDIO_EXTENSIONS:
return make_response(
jsonify({"success": False, "message": "Unsupported audio format"}),
400,
)
if not _is_supported_audio_mimetype(file.mimetype or ""):
return make_response(
jsonify({"success": False, "message": "Unsupported audio MIME type"}),
400,
)
try:
_enforce_uploaded_audio_size_limit(file, filename)
except AudioFileTooLargeError:
return make_response(
jsonify(
{
"success": False,
"message": build_stt_file_size_limit_message(),
}
),
413,
)
temp_path = None
try:
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as temp_file:
file.save(temp_file.name)
temp_path = Path(temp_file.name)
session_language = session_state.get("language") or settings.STT_LANGUAGE
stt_instance = STTCreator.create_stt(settings.STT_PROVIDER)
transcript = stt_instance.transcribe(
temp_path,
language=session_language,
timestamps=False,
diarize=False,
)
if not session_state.get("language") and transcript.get("language"):
session_state["language"] = transcript["language"]
try:
apply_live_stt_hypothesis(
session_state,
str(transcript.get("text", "")),
chunk_index,
is_silence=is_silence,
)
except ValueError:
current_app.logger.warning(
"Invalid live transcription chunk",
exc_info=True,
)
return make_response(
jsonify(
{
"success": False,
"message": "Invalid live transcription chunk",
}
),
409,
)
save_live_stt_session(redis_client, session_state)
return make_response(
jsonify(
{
"success": True,
"session_id": session_id,
"chunk_index": chunk_index,
"chunk_text": transcript.get("text", ""),
"is_silence": is_silence,
"language": session_state.get("language"),
"committed_text": session_state.get("committed_text", ""),
"mutable_text": session_state.get("mutable_text", ""),
"previous_hypothesis": session_state.get(
"previous_hypothesis", ""
),
"latest_hypothesis": session_state.get(
"latest_hypothesis", ""
),
"finalized_text": session_state.get("committed_text", ""),
"pending_text": session_state.get("mutable_text", ""),
"transcript_text": get_live_stt_transcript_text(session_state),
}
),
200,
)
except Exception as err:
current_app.logger.error(
f"Error transcribing live audio chunk: {err}", exc_info=True
)
return make_response(
jsonify({"success": False, "message": "Failed to transcribe audio"}),
400,
)
finally:
if temp_path and temp_path.exists():
temp_path.unlink()
@attachments_ns.route("/stt/live/finish")
class LiveSpeechToTextFinish(Resource):
@api.doc(description="Finish a live speech-to-text session")
def post(self):
auth_user = _resolve_authenticated_user()
if hasattr(auth_user, "status_code"):
return auth_user
if not auth_user:
return make_response(
jsonify({"success": False, "message": "Authentication required"}),
401,
)
redis_client = _require_live_stt_redis()
if hasattr(redis_client, "status_code"):
return redis_client
payload = request.get_json(silent=True) or {}
session_id = str(payload.get("session_id", "")).strip()
if not session_id:
return make_response(
jsonify({"success": False, "message": "Missing session_id"}),
400,
)
session_state = load_live_stt_session(redis_client, session_id)
if not session_state:
return make_response(
jsonify(
{
"success": False,
"message": "Live transcription session not found",
}
),
404,
)
if safe_filename(str(session_state.get("user", ""))) != auth_user:
return make_response(
jsonify({"success": False, "message": "Forbidden"}),
403,
)
final_text = finalize_live_stt_session(session_state)
delete_live_stt_session(redis_client, session_id)
return make_response(
jsonify(
{
"success": True,
"session_id": session_id,
"language": session_state.get("language"),
"text": final_text,
}
),
200,
)
@attachments_ns.route("/images/<path:image_path>")
class ServeImage(Resource):
@api.doc(description="Serve an image from storage")
def get(self, image_path):
try:
from application.api.user.base import storage
file_obj = storage.get_file(image_path)
extension = image_path.split(".")[-1].lower()
content_type = f"image/{extension}"

View File

@@ -145,22 +145,14 @@ def resolve_tool_details(tool_ids):
Returns:
List of tool details with id, name, and display_name
"""
valid_ids = []
for tid in tool_ids:
try:
valid_ids.append(ObjectId(tid))
except Exception:
continue
tools = user_tools_collection.find(
{"_id": {"$in": valid_ids}}
) if valid_ids else []
{"_id": {"$in": [ObjectId(tid) for tid in tool_ids]}}
)
return [
{
"id": str(tool["_id"]),
"name": tool.get("name", ""),
"display_name": tool.get("customName")
or tool.get("displayName")
or tool.get("name", ""),
"display_name": tool.get("displayName", tool.get("name", "")),
}
for tool in tools
]

View File

@@ -14,14 +14,7 @@ from application.api.user.base import sources_collection
from application.api.user.tasks import ingest, ingest_connector_task, ingest_remote
from application.core.settings import settings
from application.parser.connectors.connector_creator import ConnectorCreator
from application.parser.file.constants import SUPPORTED_SOURCE_EXTENSIONS
from application.storage.storage_creator import StorageCreator
from application.stt.upload_limits import (
AudioFileTooLargeError,
build_stt_file_size_limit_message,
enforce_audio_file_size_limit,
is_audio_filename,
)
from application.utils import check_required_fields, safe_filename
@@ -30,12 +23,6 @@ sources_upload_ns = Namespace(
)
def _enforce_audio_path_size_limit(file_path: str, filename: str) -> None:
if not is_audio_filename(filename):
return
enforce_audio_file_size_limit(os.path.getsize(file_path))
@sources_upload_ns.route("/upload")
class UploadFile(Resource):
@api.expect(
@@ -91,7 +78,6 @@ class UploadFile(Resource):
with tempfile.TemporaryDirectory() as temp_dir:
temp_file_path = os.path.join(temp_dir, safe_file)
file.save(temp_file_path)
_enforce_audio_path_size_limit(temp_file_path, safe_file)
# Only extract actual .zip files, not Office formats (.docx, .xlsx, .pptx)
# which are technically zip archives but should be processed as-is
@@ -116,10 +102,6 @@ class UploadFile(Resource):
os.path.join(root, extracted_file), temp_dir
)
storage_path = f"{base_path}/{rel_path}"
_enforce_audio_path_size_limit(
os.path.join(root, extracted_file),
extracted_file,
)
with open(
os.path.join(root, extracted_file), "rb"
@@ -142,23 +124,29 @@ class UploadFile(Resource):
storage.save_file(f, file_path)
task = ingest.delay(
settings.UPLOAD_FOLDER,
list(SUPPORTED_SOURCE_EXTENSIONS),
[
".rst",
".md",
".pdf",
".txt",
".docx",
".csv",
".epub",
".html",
".mdx",
".json",
".xlsx",
".pptx",
".png",
".jpg",
".jpeg",
],
job_name,
user,
file_path=base_path,
filename=dir_name,
file_name_map=file_name_map,
)
except AudioFileTooLargeError:
return make_response(
jsonify(
{
"success": False,
"message": build_stt_file_size_limit_message(),
}
),
413,
)
except Exception as err:
current_app.logger.error(f"Error uploading file: {err}", exc_info=True)
return make_response(jsonify({"success": False}), 400)

View File

@@ -1,67 +1,21 @@
"""Tool management MCP server integration."""
import json
from urllib.parse import urlencode, urlparse
from urllib.parse import unquote, urlencode
from bson.objectid import ObjectId
from flask import current_app, jsonify, make_response, redirect, request
from flask_restx import Namespace, Resource, fields
from flask_restx import fields, Namespace, Resource
from application.agents.tools.mcp_tool import MCPOAuthManager, MCPTool
from application.api import api
from application.api.user.base import user_tools_collection
from application.api.user.tools.routes import transform_actions
from application.cache import get_redis_instance
from application.core.mongo_db import MongoDB
from application.core.settings import settings
from application.security.encryption import decrypt_credentials, encrypt_credentials
from application.security.encryption import encrypt_credentials
from application.utils import check_required_fields
tools_mcp_ns = Namespace("tools", description="Tool management operations", path="/api")
_mongo = MongoDB.get_client()
_db = _mongo[settings.MONGO_DB_NAME]
_connector_sessions = _db["connector_sessions"]
_ALLOWED_TRANSPORTS = {"auto", "sse", "http"}
def _sanitize_mcp_transport(config):
"""Normalise and validate the transport_type field.
Strips ``command`` / ``args`` keys that are only valid for local STDIO
transports and returns the cleaned transport type string.
"""
transport_type = (config.get("transport_type") or "auto").lower()
if transport_type not in _ALLOWED_TRANSPORTS:
raise ValueError(f"Unsupported transport_type: {transport_type}")
config.pop("command", None)
config.pop("args", None)
config["transport_type"] = transport_type
return transport_type
def _extract_auth_credentials(config):
"""Build an ``auth_credentials`` dict from the raw MCP config."""
auth_credentials = {}
auth_type = config.get("auth_type", "none")
if auth_type == "api_key":
if config.get("api_key"):
auth_credentials["api_key"] = config["api_key"]
if config.get("api_key_header"):
auth_credentials["api_key_header"] = config["api_key_header"]
elif auth_type == "bearer":
if config.get("bearer_token"):
auth_credentials["bearer_token"] = config["bearer_token"]
elif auth_type == "basic":
if config.get("username"):
auth_credentials["username"] = config["username"]
if config.get("password"):
auth_credentials["password"] = config["password"]
return auth_credentials
@tools_mcp_ns.route("/mcp_server/test")
class TestMCPServerConfig(Resource):
@@ -89,35 +43,49 @@ class TestMCPServerConfig(Resource):
return missing_fields
try:
config = data["config"]
try:
_sanitize_mcp_transport(config)
except ValueError:
transport_type = (config.get("transport_type") or "auto").lower()
allowed_transports = {"auto", "sse", "http"}
if transport_type not in allowed_transports:
return make_response(
jsonify({"success": False, "error": "Unsupported transport_type"}),
400,
)
config.pop("command", None)
config.pop("args", None)
config["transport_type"] = transport_type
auth_credentials = _extract_auth_credentials(config)
auth_credentials = {}
auth_type = config.get("auth_type", "none")
if auth_type == "api_key" and "api_key" in config:
auth_credentials["api_key"] = config["api_key"]
if "api_key_header" in config:
auth_credentials["api_key_header"] = config["api_key_header"]
elif auth_type == "bearer" and "bearer_token" in config:
auth_credentials["bearer_token"] = config["bearer_token"]
elif auth_type == "basic":
if "username" in config:
auth_credentials["username"] = config["username"]
if "password" in config:
auth_credentials["password"] = config["password"]
test_config = config.copy()
test_config["auth_credentials"] = auth_credentials
mcp_tool = MCPTool(config=test_config, user_id=user)
result = mcp_tool.test_connection()
if result.get("requires_oauth"):
return make_response(jsonify(result), 200)
# Sanitize the response to avoid exposing internal error details
if not result.get("success") and "message" in result:
current_app.logger.error(
f"MCP connection test failed: {result.get('message')}"
)
current_app.logger.error(f"MCP connection test failed: {result.get('message')}")
result["message"] = "Connection test failed"
return make_response(jsonify(result), 200)
except Exception as e:
current_app.logger.error(f"Error testing MCP server: {e}", exc_info=True)
return make_response(
jsonify({"success": False, "error": "Connection test failed"}),
jsonify(
{"success": False, "error": "Connection test failed"}
),
500,
)
@@ -157,16 +125,32 @@ class MCPServerSave(Resource):
return missing_fields
try:
config = data["config"]
try:
_sanitize_mcp_transport(config)
except ValueError:
transport_type = (config.get("transport_type") or "auto").lower()
allowed_transports = {"auto", "sse", "http"}
if transport_type not in allowed_transports:
return make_response(
jsonify({"success": False, "error": "Unsupported transport_type"}),
400,
)
config.pop("command", None)
config.pop("args", None)
config["transport_type"] = transport_type
auth_credentials = _extract_auth_credentials(config)
auth_credentials = {}
auth_type = config.get("auth_type", "none")
if auth_type == "api_key":
if "api_key" in config and config["api_key"]:
auth_credentials["api_key"] = config["api_key"]
if "api_key_header" in config:
auth_credentials["api_key_header"] = config["api_key_header"]
elif auth_type == "bearer":
if "bearer_token" in config and config["bearer_token"]:
auth_credentials["bearer_token"] = config["bearer_token"]
elif auth_type == "basic":
if "username" in config and config["username"]:
auth_credentials["username"] = config["username"]
if "password" in config and config["password"]:
auth_credentials["password"] = config["password"]
mcp_config = config.copy()
mcp_config["auth_credentials"] = auth_credentials
@@ -204,39 +188,30 @@ class MCPServerSave(Resource):
"No valid credentials provided for the selected authentication type"
)
storage_config = config.copy()
tool_id = data.get("id")
existing_encrypted = None
if tool_id:
existing_doc = user_tools_collection.find_one(
{"_id": ObjectId(tool_id), "user": user, "name": "mcp_tool"}
)
if existing_doc:
existing_encrypted = existing_doc.get("config", {}).get(
"encrypted_credentials"
)
if auth_credentials:
if existing_encrypted:
existing_secrets = decrypt_credentials(existing_encrypted, user)
existing_secrets.update(auth_credentials)
auth_credentials = existing_secrets
storage_config["encrypted_credentials"] = encrypt_credentials(
encrypted_credentials_string = encrypt_credentials(
auth_credentials, user
)
elif existing_encrypted:
storage_config["encrypted_credentials"] = existing_encrypted
storage_config["encrypted_credentials"] = encrypted_credentials_string
for field in [
"api_key",
"bearer_token",
"username",
"password",
"api_key_header",
"redirect_uri",
]:
storage_config.pop(field, None)
transformed_actions = transform_actions(actions_metadata)
transformed_actions = []
for action in actions_metadata:
action["active"] = True
if "parameters" in action:
if "properties" in action["parameters"]:
for param_name, param_details in action["parameters"][
"properties"
].items():
param_details["filled_by_llm"] = True
param_details["value"] = ""
transformed_actions.append(action)
tool_data = {
"name": "mcp_tool",
"displayName": data["displayName"],
@@ -248,6 +223,7 @@ class MCPServerSave(Resource):
"user": user,
}
tool_id = data.get("id")
if tool_id:
result = user_tools_collection.update_one(
{"_id": ObjectId(tool_id), "user": user, "name": "mcp_tool"},
@@ -282,7 +258,9 @@ class MCPServerSave(Resource):
except Exception as e:
current_app.logger.error(f"Error saving MCP server: {e}", exc_info=True)
return make_response(
jsonify({"success": False, "error": "Failed to save MCP server"}),
jsonify(
{"success": False, "error": "Failed to save MCP server"}
),
500,
)
@@ -313,7 +291,7 @@ class MCPOAuthCallback(Resource):
params = {
"status": "error",
"message": f"OAuth error: {error}. Please try again and make sure to grant all requested permissions, including offline access.",
"provider": "mcp_tool",
"provider": "mcp_tool"
}
return redirect(f"/api/connectors/callback-status?{urlencode(params)}")
if not code or not state:
@@ -326,6 +304,7 @@ class MCPOAuthCallback(Resource):
return redirect(
"/api/connectors/callback-status?status=error&message=Internal+server+error:+Redis+not+available.&provider=mcp_tool"
)
code = unquote(code)
manager = MCPOAuthManager(redis_client)
success = manager.handle_oauth_callback(state, code, error)
if success:
@@ -348,6 +327,10 @@ class MCPOAuthCallback(Resource):
@tools_mcp_ns.route("/mcp_server/oauth_status/<string:task_id>")
class MCPOAuthStatus(Resource):
def get(self, task_id):
"""
Get current status of OAuth flow.
Frontend should poll this endpoint periodically.
"""
try:
redis_client = get_redis_instance()
status_key = f"mcp_oauth_status:{task_id}"
@@ -355,14 +338,6 @@ class MCPOAuthStatus(Resource):
if status_data:
status = json.loads(status_data)
if "tools" in status and isinstance(status["tools"], list):
status["tools"] = [
{
"name": t.get("name", "unknown"),
"description": t.get("description", ""),
}
for t in status["tools"]
]
return make_response(
jsonify({"success": True, "task_id": task_id, **status})
)
@@ -370,93 +345,17 @@ class MCPOAuthStatus(Resource):
return make_response(
jsonify(
{
"success": True,
"success": False,
"error": "Task not found or expired",
"task_id": task_id,
"status": "pending",
"message": "Waiting for OAuth to start...",
}
),
200,
404,
)
except Exception as e:
current_app.logger.error(
f"Error getting OAuth status for task {task_id}: {str(e)}",
exc_info=True,
f"Error getting OAuth status for task {task_id}: {str(e)}", exc_info=True
)
return make_response(
jsonify(
{
"success": False,
"error": "Failed to get OAuth status",
"task_id": task_id,
}
),
500,
)
@tools_mcp_ns.route("/mcp_server/auth_status")
class MCPAuthStatus(Resource):
@api.doc(
description="Batch check auth status for all MCP tools. "
"Lightweight DB-only check — no network calls to MCP servers."
)
def get(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
try:
mcp_tools = list(
user_tools_collection.find(
{"user": user, "name": "mcp_tool"},
{"_id": 1, "config": 1},
)
)
if not mcp_tools:
return make_response(jsonify({"success": True, "statuses": {}}), 200)
oauth_server_urls = {}
statuses = {}
for tool in mcp_tools:
tool_id = str(tool["_id"])
config = tool.get("config", {})
auth_type = config.get("auth_type", "none")
if auth_type == "oauth":
server_url = config.get("server_url", "")
if server_url:
parsed = urlparse(server_url)
base_url = f"{parsed.scheme}://{parsed.netloc}"
oauth_server_urls[tool_id] = base_url
else:
statuses[tool_id] = "needs_auth"
else:
statuses[tool_id] = "configured"
if oauth_server_urls:
unique_urls = list(set(oauth_server_urls.values()))
sessions = list(
_connector_sessions.find(
{"user_id": user, "server_url": {"$in": unique_urls}},
{"server_url": 1, "tokens": 1},
)
)
url_has_tokens = {
doc["server_url"]: bool(doc.get("tokens", {}).get("access_token"))
for doc in sessions
}
for tool_id, base_url in oauth_server_urls.items():
if url_has_tokens.get(base_url):
statuses[tool_id] = "connected"
else:
statuses[tool_id] = "needs_auth"
return make_response(jsonify({"success": True, "statuses": statuses}), 200)
except Exception as e:
current_app.logger.error(
"Error checking MCP auth status: %s", e, exc_info=True
)
return make_response(
jsonify({"success": False, "error": "Failed to check auth status"}),
500,
jsonify({"success": False, "error": "Failed to get OAuth status", "task_id": task_id}), 500
)

View File

@@ -15,114 +15,6 @@ tool_config = {}
tool_manager = ToolManager(config=tool_config)
def _encrypt_secret_fields(config, config_requirements, user_id):
secret_keys = [
key for key, spec in config_requirements.items()
if spec.get("secret") and key in config and config[key]
]
if not secret_keys:
return config
storage_config = config.copy()
secret_values = {k: config[k] for k in secret_keys}
storage_config["encrypted_credentials"] = encrypt_credentials(secret_values, user_id)
for key in secret_keys:
storage_config.pop(key, None)
return storage_config
def _validate_config(config, config_requirements, has_existing_secrets=False):
errors = {}
for key, spec in config_requirements.items():
depends_on = spec.get("depends_on")
if depends_on:
if not all(config.get(dk) == dv for dk, dv in depends_on.items()):
continue
if spec.get("required") and not config.get(key):
if has_existing_secrets and spec.get("secret"):
continue
errors[key] = f"{spec.get('label', key)} is required"
value = config.get(key)
if value is not None and value != "":
if spec.get("type") == "number":
try:
num = float(value)
if key == "timeout" and (num < 1 or num > 300):
errors[key] = "Timeout must be between 1 and 300"
except (ValueError, TypeError):
errors[key] = f"{spec.get('label', key)} must be a number"
if spec.get("enum") and value not in spec["enum"]:
errors[key] = f"Invalid value for {spec.get('label', key)}"
return errors
def _merge_secrets_on_update(new_config, existing_config, config_requirements, user_id):
"""Merge incoming config with existing encrypted secrets and re-encrypt.
For updates, the client may omit unchanged secret values. This helper
decrypts any previously stored secrets, overlays whatever the client *did*
send, strips plain-text secrets from the stored config, and re-encrypts
the merged result.
Returns the final ``config`` dict ready for persistence.
"""
secret_keys = [
key for key, spec in config_requirements.items()
if spec.get("secret")
]
if not secret_keys:
return new_config
existing_secrets = {}
if "encrypted_credentials" in existing_config:
existing_secrets = decrypt_credentials(
existing_config["encrypted_credentials"], user_id
)
merged_secrets = existing_secrets.copy()
for key in secret_keys:
if key in new_config and new_config[key]:
merged_secrets[key] = new_config[key]
# Start from existing non-secret values, then overlay incoming non-secrets
storage_config = {
k: v for k, v in existing_config.items()
if k not in secret_keys and k != "encrypted_credentials"
}
storage_config.update(
{k: v for k, v in new_config.items() if k not in secret_keys}
)
if merged_secrets:
storage_config["encrypted_credentials"] = encrypt_credentials(
merged_secrets, user_id
)
else:
storage_config.pop("encrypted_credentials", None)
storage_config.pop("has_encrypted_credentials", None)
return storage_config
def transform_actions(actions_metadata):
"""Set default flags on action metadata for storage.
Marks each action as active, sets ``filled_by_llm`` and ``value`` on every
parameter property. Used by both the generic create_tool and MCP save routes.
"""
transformed = []
for action in actions_metadata:
action["active"] = True
if "parameters" in action:
props = action["parameters"].get("properties", {})
for param_details in props.values():
param_details["filled_by_llm"] = True
param_details["value"] = ""
transformed.append(action)
return transformed
tools_ns = Namespace("tools", description="Tool management operations", path="/api")
@@ -137,15 +29,12 @@ class AvailableTools(Resource):
lines = doc.split("\n", 1)
name = lines[0].strip()
description = lines[1].strip() if len(lines) > 1 else ""
config_req = tool_instance.get_config_requirements()
actions = tool_instance.get_actions_metadata()
tools_metadata.append(
{
"name": tool_name,
"displayName": name,
"description": description,
"configRequirements": config_req,
"actions": actions,
"configRequirements": tool_instance.get_config_requirements(),
}
)
except Exception as err:
@@ -171,21 +60,6 @@ class GetTools(Resource):
tool_copy = {**tool}
tool_copy["id"] = str(tool["_id"])
tool_copy.pop("_id", None)
config_req = tool_copy.get("configRequirements", {})
if not config_req:
tool_instance = tool_manager.tools.get(tool_copy.get("name"))
if tool_instance:
config_req = tool_instance.get_config_requirements()
tool_copy["configRequirements"] = config_req
has_secrets = any(
spec.get("secret") for spec in config_req.values()
) if config_req else False
if has_secrets and "encrypted_credentials" in tool_copy.get("config", {}):
tool_copy["config"]["has_encrypted_credentials"] = True
tool_copy["config"].pop("encrypted_credentials", None)
user_tools.append(tool_copy)
except Exception as err:
current_app.logger.error(f"Error getting user tools: {err}", exc_info=True)
@@ -242,32 +116,23 @@ class CreateTool(Resource):
jsonify({"success": False, "message": "Tool not found"}), 404
)
actions_metadata = tool_instance.get_actions_metadata()
transformed_actions = transform_actions(actions_metadata)
transformed_actions = []
for action in actions_metadata:
action["active"] = True
if "parameters" in action:
if "properties" in action["parameters"]:
for param_name, param_details in action["parameters"][
"properties"
].items():
param_details["filled_by_llm"] = True
param_details["value"] = ""
transformed_actions.append(action)
except Exception as err:
current_app.logger.error(
f"Error getting tool actions: {err}", exc_info=True
)
return make_response(jsonify({"success": False}), 400)
try:
config_requirements = tool_instance.get_config_requirements()
if config_requirements:
validation_errors = _validate_config(
data["config"], config_requirements
)
if validation_errors:
return make_response(
jsonify(
{
"success": False,
"message": "Validation failed",
"errors": validation_errors,
}
),
400,
)
storage_config = _encrypt_secret_fields(
data["config"], config_requirements, user
)
new_tool = {
"user": user,
"name": data["name"],
@@ -275,8 +140,7 @@ class CreateTool(Resource):
"description": data["description"],
"customName": data.get("customName", ""),
"actions": transformed_actions,
"config": storage_config,
"configRequirements": config_requirements,
"config": data["config"],
"status": data["status"],
}
resp = user_tools_collection.insert_one(new_tool)
@@ -346,37 +210,57 @@ class UpdateTool(Resource):
tool_doc = user_tools_collection.find_one(
{"_id": ObjectId(data["id"]), "user": user}
)
if not tool_doc:
return make_response(
jsonify({"success": False, "message": "Tool not found"}),
404,
)
tool_name = tool_doc.get("name", data.get("name"))
tool_instance = tool_manager.tools.get(tool_name)
config_requirements = (
tool_instance.get_config_requirements() if tool_instance else {}
)
existing_config = tool_doc.get("config", {})
has_existing_secrets = "encrypted_credentials" in existing_config
if tool_doc and tool_doc.get("name") == "mcp_tool":
config = data["config"]
existing_config = tool_doc.get("config", {})
storage_config = existing_config.copy()
if config_requirements:
validation_errors = _validate_config(
data["config"], config_requirements,
has_existing_secrets=has_existing_secrets,
)
if validation_errors:
return make_response(
jsonify({
"success": False,
"message": "Validation failed",
"errors": validation_errors,
}),
400,
storage_config.update(config)
existing_credentials = {}
if "encrypted_credentials" in existing_config:
existing_credentials = decrypt_credentials(
existing_config["encrypted_credentials"], user
)
update_data["config"] = _merge_secrets_on_update(
data["config"], existing_config, config_requirements, user
)
auth_credentials = existing_credentials.copy()
auth_type = storage_config.get("auth_type", "none")
if auth_type == "api_key":
if "api_key" in config and config["api_key"]:
auth_credentials["api_key"] = config["api_key"]
if "api_key_header" in config:
auth_credentials["api_key_header"] = config[
"api_key_header"
]
elif auth_type == "bearer":
if "bearer_token" in config and config["bearer_token"]:
auth_credentials["bearer_token"] = config["bearer_token"]
elif "encrypted_token" in config and config["encrypted_token"]:
auth_credentials["bearer_token"] = config["encrypted_token"]
elif auth_type == "basic":
if "username" in config and config["username"]:
auth_credentials["username"] = config["username"]
if "password" in config and config["password"]:
auth_credentials["password"] = config["password"]
if auth_type != "none" and auth_credentials:
encrypted_credentials_string = encrypt_credentials(
auth_credentials, user
)
storage_config["encrypted_credentials"] = (
encrypted_credentials_string
)
elif auth_type == "none":
storage_config.pop("encrypted_credentials", None)
for field in [
"api_key",
"bearer_token",
"encrypted_token",
"username",
"password",
"api_key_header",
]:
storage_config.pop(field, None)
update_data["config"] = storage_config
else:
update_data["config"] = data["config"]
if "status" in data:
update_data["status"] = data["status"]
user_tools_collection.update_one(
@@ -414,42 +298,9 @@ class UpdateToolConfig(Resource):
if missing_fields:
return missing_fields
try:
tool_doc = user_tools_collection.find_one(
{"_id": ObjectId(data["id"]), "user": user}
)
if not tool_doc:
return make_response(jsonify({"success": False}), 404)
tool_name = tool_doc.get("name")
tool_instance = tool_manager.tools.get(tool_name)
config_requirements = (
tool_instance.get_config_requirements() if tool_instance else {}
)
existing_config = tool_doc.get("config", {})
has_existing_secrets = "encrypted_credentials" in existing_config
if config_requirements:
validation_errors = _validate_config(
data["config"], config_requirements,
has_existing_secrets=has_existing_secrets,
)
if validation_errors:
return make_response(
jsonify({
"success": False,
"message": "Validation failed",
"errors": validation_errors,
}),
400,
)
final_config = _merge_secrets_on_update(
data["config"], existing_config, config_requirements, user
)
user_tools_collection.update_one(
{"_id": ObjectId(data["id"]), "user": user},
{"$set": {"config": final_config}},
{"$set": {"config": data["config"]}},
)
except Exception as err:
current_app.logger.error(
@@ -559,13 +410,11 @@ class DeleteTool(Resource):
{"_id": ObjectId(data["id"]), "user": user}
)
if result.deleted_count == 0:
return make_response(
jsonify({"success": False, "message": "Tool not found"}), 404
)
return {"success": False, "message": "Tool not found"}, 404
except Exception as err:
current_app.logger.error(f"Error deleting tool: {err}", exc_info=True)
return make_response(jsonify({"success": False}), 400)
return make_response(jsonify({"success": True}), 200)
return {"success": False}, 400
return {"success": True}, 200
@tools_ns.route("/parse_spec")
@@ -662,6 +511,7 @@ class GetArtifact(Resource):
todo_doc = db["todos"].find_one({"_id": obj_id, "user_id": user_id})
if todo_doc:
tool_id = todo_doc.get("tool_id")
# Return all todos for the tool
query = {"user_id": user_id, "tool_id": tool_id}
all_todos = list(db["todos"].find(query))
items = []

View File

@@ -5,14 +5,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple
from bson.errors import InvalidId
from bson.objectid import ObjectId
from flask import (
Response,
current_app,
has_app_context,
jsonify,
make_response,
request,
)
from flask import jsonify, make_response, request, Response
from pymongo.collection import Collection
@@ -326,10 +319,8 @@ def safe_db_operation(
try:
result = operation()
return result, None
except Exception as err:
if has_app_context():
current_app.logger.error(f"{error_message}: {err}", exc_info=True)
return None, error_response(error_message)
except Exception as e:
return None, error_response(f"{error_message}: {str(e)}")
def validate_enum(

View File

@@ -30,11 +30,6 @@ from application.api.user.utils import (
workflows_ns = Namespace("workflows", path="/api")
def _workflow_error_response(message: str, err: Exception):
current_app.logger.error(f"{message}: {err}", exc_info=True)
return error_response(message)
def serialize_workflow(w: Dict) -> Dict:
"""Serialize workflow document to API response format."""
return {
@@ -401,11 +396,11 @@ class WorkflowList(Resource):
try:
create_workflow_nodes(workflow_id, nodes_data, 1)
create_workflow_edges(workflow_id, edges_data, 1)
except Exception as err:
except Exception as e:
workflow_nodes_collection.delete_many({"workflow_id": workflow_id})
workflow_edges_collection.delete_many({"workflow_id": workflow_id})
workflows_collection.delete_one({"_id": result.inserted_id})
return _workflow_error_response("Failed to create workflow structure", err)
return error_response(f"Failed to create workflow structure: {str(e)}")
return success_response({"id": workflow_id}, 201)
@@ -475,14 +470,14 @@ class WorkflowDetail(Resource):
try:
create_workflow_nodes(workflow_id, nodes_data, next_graph_version)
create_workflow_edges(workflow_id, edges_data, next_graph_version)
except Exception as err:
except Exception as e:
workflow_nodes_collection.delete_many(
{"workflow_id": workflow_id, "graph_version": next_graph_version}
)
workflow_edges_collection.delete_many(
{"workflow_id": workflow_id, "graph_version": next_graph_version}
)
return _workflow_error_response("Failed to update workflow structure", err)
return error_response(f"Failed to update workflow structure: {str(e)}")
now = datetime.now(timezone.utc)
_, error = safe_db_operation(
@@ -540,7 +535,7 @@ class WorkflowDetail(Resource):
workflow_nodes_collection.delete_many({"workflow_id": workflow_id})
workflow_edges_collection.delete_many({"workflow_id": workflow_id})
workflows_collection.delete_one({"_id": workflow["_id"], "user": user_id})
except Exception as err:
return _workflow_error_response("Failed to delete workflow", err)
except Exception as e:
return error_response(f"Failed to delete workflow: {str(e)}")
return success_response()

View File

@@ -19,10 +19,6 @@ from application.api.user.routes import user # noqa: E402
from application.api.connector.routes import connector # noqa: E402
from application.celery_init import celery # noqa: E402
from application.core.settings import settings # noqa: E402
from application.stt.upload_limits import ( # noqa: E402
build_stt_file_size_limit_message,
should_reject_stt_request,
)
if platform.system() == "Windows":
@@ -72,11 +68,6 @@ def home():
return "Welcome to DocsGPT Backend!"
@app.route("/api/health")
def health():
return jsonify({"status": "ok"})
@app.route("/api/config")
def get_config():
response = {
@@ -97,23 +88,6 @@ def generate_token():
return jsonify({"error": "Token generation not allowed in current auth mode"}), 400
@app.before_request
def enforce_stt_request_size_limits():
if request.method == "OPTIONS":
return None
if should_reject_stt_request(request.path, request.content_length):
return (
jsonify(
{
"success": False,
"message": build_stt_file_size_limit_message(),
}
),
413,
)
return None
@app.before_request
def authenticate_request():
if request.method == "OPTIONS":

View File

@@ -27,8 +27,6 @@ ANTHROPIC_ATTACHMENTS = IMAGE_ATTACHMENTS
OPENROUTER_ATTACHMENTS = IMAGE_ATTACHMENTS
NOVITA_ATTACHMENTS = IMAGE_ATTACHMENTS
OPENAI_MODELS = [
AvailableModel(
@@ -195,46 +193,6 @@ OPENROUTER_MODELS = [
),
]
NOVITA_MODELS = [
AvailableModel(
id="moonshotai/kimi-k2.5",
provider=ModelProvider.NOVITA,
display_name="Kimi K2.5",
description="MoE model with function calling, structured output, reasoning, and vision",
capabilities=ModelCapabilities(
supports_tools=True,
supports_structured_output=True,
supported_attachment_types=NOVITA_ATTACHMENTS,
context_window=262144,
),
),
AvailableModel(
id="zai-org/glm-5",
provider=ModelProvider.NOVITA,
display_name="GLM-5",
description="MoE model with function calling, structured output, and reasoning",
capabilities=ModelCapabilities(
supports_tools=True,
supports_structured_output=True,
supported_attachment_types=[],
context_window=202800,
),
),
AvailableModel(
id="minimax/minimax-m2.5",
provider=ModelProvider.NOVITA,
display_name="MiniMax M2.5",
description="MoE model with function calling, structured output, and reasoning",
capabilities=ModelCapabilities(
supports_tools=True,
supports_structured_output=True,
supported_attachment_types=[],
context_window=204800,
),
),
]
AZURE_OPENAI_MODELS = [
AvailableModel(
id="azure-gpt-4",

View File

@@ -114,10 +114,6 @@ class ModelRegistry:
settings.LLM_PROVIDER == "openrouter" and settings.API_KEY
):
self._add_openrouter_models(settings)
if settings.NOVITA_API_KEY or (
settings.LLM_PROVIDER == "novita" and settings.API_KEY
):
self._add_novita_models(settings)
if settings.HUGGINGFACE_API_KEY or (
settings.LLM_PROVIDER == "huggingface" and settings.API_KEY
):
@@ -249,21 +245,6 @@ class ModelRegistry:
for model in OPENROUTER_MODELS:
self.models[model.id] = model
def _add_novita_models(self, settings):
from application.core.model_configs import NOVITA_MODELS
if settings.NOVITA_API_KEY:
for model in NOVITA_MODELS:
self.models[model.id] = model
return
if settings.LLM_PROVIDER == "novita" and settings.LLM_NAME:
for model in NOVITA_MODELS:
if model.id == settings.LLM_NAME:
self.models[model.id] = model
return
for model in NOVITA_MODELS:
self.models[model.id] = model
def _add_docsgpt_models(self, settings):
model_id = "docsgpt-local"
model = AvailableModel(

View File

@@ -10,7 +10,6 @@ def get_api_key_for_provider(provider: str) -> Optional[str]:
provider_key_map = {
"openai": settings.OPENAI_API_KEY,
"openrouter": settings.OPEN_ROUTER_API_KEY,
"novita": settings.NOVITA_API_KEY,
"anthropic": settings.ANTHROPIC_API_KEY,
"google": settings.GOOGLE_API_KEY,
"groq": settings.GROQ_API_KEY,

View File

@@ -5,7 +5,9 @@ from typing import Optional
from pydantic import field_validator
from pydantic_settings import BaseSettings, SettingsConfigDict
current_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
current_dir = os.path.dirname(
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
)
class Settings(BaseSettings):
@@ -13,11 +15,15 @@ class Settings(BaseSettings):
AUTH_TYPE: Optional[str] = None # simple_jwt, session_jwt, or None
LLM_PROVIDER: str = "docsgpt"
LLM_NAME: Optional[str] = None # if LLM_PROVIDER is openai, LLM_NAME can be gpt-4 or gpt-3.5-turbo
LLM_NAME: Optional[str] = (
None # if LLM_PROVIDER is openai, LLM_NAME can be gpt-4 or gpt-3.5-turbo
)
EMBEDDINGS_NAME: str = "huggingface_sentence-transformers/all-mpnet-base-v2"
EMBEDDINGS_BASE_URL: Optional[str] = None # Remote embeddings API URL (OpenAI-compatible)
EMBEDDINGS_KEY: Optional[str] = None # api key for embeddings (if using openai, just copy API_KEY)
EMBEDDINGS_KEY: Optional[str] = (
None # api key for embeddings (if using openai, just copy API_KEY)
)
CELERY_BROKER_URL: str = "redis://localhost:6379/0"
CELERY_RESULT_BACKEND: str = "redis://localhost:6379/1"
MONGO_URI: str = "mongodb://localhost:27017/docsgpt"
@@ -39,7 +45,9 @@ class Settings(BaseSettings):
PARSE_IMAGE_REMOTE: bool = False
DOCLING_OCR_ENABLED: bool = False # Enable OCR for docling parsers (PDF, images)
DOCLING_OCR_ATTACHMENTS_ENABLED: bool = False # Enable OCR for docling when parsing attachments
VECTOR_STORE: str = "faiss" # "faiss" or "elasticsearch" or "qdrant" or "milvus" or "lancedb" or "pgvector"
VECTOR_STORE: str = (
"faiss" # "faiss" or "elasticsearch" or "qdrant" or "milvus" or "lancedb" or "pgvector"
)
RETRIEVERS_ENABLED: list = ["classic_rag"]
AGENT_NAME: str = "classic"
FALLBACK_LLM_PROVIDER: Optional[str] = None # provider for fallback llm
@@ -47,18 +55,16 @@ class Settings(BaseSettings):
FALLBACK_LLM_API_KEY: Optional[str] = None # api key for fallback llm
# Google Drive integration
GOOGLE_CLIENT_ID: Optional[str] = None # Replace with your actual Google OAuth client ID
GOOGLE_CLIENT_SECRET: Optional[str] = None # Replace with your actual Google OAuth client secret
GOOGLE_CLIENT_ID: Optional[str] = (
None # Replace with your actual Google OAuth client ID
)
GOOGLE_CLIENT_SECRET: Optional[str] = (
None # Replace with your actual Google OAuth client secret
)
CONNECTOR_REDIRECT_BASE_URI: Optional[str] = (
"http://127.0.0.1:7091/api/connectors/callback" ##add redirect url as it is to your provider's console(gcp)
)
# Microsoft Entra ID (Azure AD) integration
MICROSOFT_CLIENT_ID: Optional[str] = None # Azure AD Application (client) ID
MICROSOFT_CLIENT_SECRET: Optional[str] = None # Azure AD Application client secret
MICROSOFT_TENANT_ID: Optional[str] = "common" # Azure AD Tenant ID (or 'common' for multi-tenant)
MICROSOFT_AUTHORITY: Optional[str] = None # e.g., "https://login.microsoftonline.com/{tenant_id}"
# GitHub source
GITHUB_ACCESS_TOKEN: Optional[str] = None # PAT token with read repo access
@@ -66,7 +72,6 @@ class Settings(BaseSettings):
CACHE_REDIS_URL: str = "redis://localhost:6379/2"
API_URL: str = "http://localhost:7091" # backend url for celery worker
MCP_OAUTH_REDIRECT_URI: Optional[str] = None # public callback URL for MCP OAuth
INTERNAL_KEY: Optional[str] = None # internal api key for worker-to-backend auth
API_KEY: Optional[str] = None # LLM api key (used by LLM_PROVIDER)
@@ -78,13 +83,16 @@ class Settings(BaseSettings):
GROQ_API_KEY: Optional[str] = None
HUGGINGFACE_API_KEY: Optional[str] = None
OPEN_ROUTER_API_KEY: Optional[str] = None
NOVITA_API_KEY: Optional[str] = None
OPENAI_API_BASE: Optional[str] = None # azure openai api base url
OPENAI_API_VERSION: Optional[str] = None # azure openai api version
AZURE_DEPLOYMENT_NAME: Optional[str] = None # azure deployment name for answering
AZURE_EMBEDDINGS_DEPLOYMENT_NAME: Optional[str] = None # azure deployment name for embeddings
OPENAI_BASE_URL: Optional[str] = None # openai base url for open ai compatable models
AZURE_EMBEDDINGS_DEPLOYMENT_NAME: Optional[str] = (
None # azure deployment name for embeddings
)
OPENAI_BASE_URL: Optional[str] = (
None # openai base url for open ai compatable models
)
# elasticsearch
ELASTIC_CLOUD_ID: Optional[str] = None # cloud id for elasticsearch
@@ -126,7 +134,9 @@ class Settings(BaseSettings):
# LanceDB vectorstore config
LANCEDB_PATH: str = "./data/lancedb" # Path where LanceDB stores its local data
LANCEDB_TABLE_NAME: Optional[str] = "docsgpts" # Name of the table to use for storing vectors
LANCEDB_TABLE_NAME: Optional[str] = (
"docsgpts" # Name of the table to use for storing vectors
)
FLASK_DEBUG_MODE: bool = False
STORAGE_TYPE: str = "local" # local or s3
@@ -139,12 +149,6 @@ class Settings(BaseSettings):
TTS_PROVIDER: str = "google_tts" # google_tts or elevenlabs
ELEVENLABS_API_KEY: Optional[str] = None
STT_PROVIDER: str = "openai" # openai or faster_whisper
OPENAI_STT_MODEL: str = "gpt-4o-mini-transcribe"
STT_LANGUAGE: Optional[str] = None
STT_MAX_FILE_SIZE_MB: int = 50
STT_ENABLE_TIMESTAMPS: bool = False
STT_ENABLE_DIARIZATION: bool = False
# Tool pre-fetch settings
ENABLE_TOOL_PREFETCH: bool = True
@@ -163,7 +167,6 @@ class Settings(BaseSettings):
"GOOGLE_API_KEY",
"GROQ_API_KEY",
"HUGGINGFACE_API_KEY",
"NOVITA_API_KEY",
"EMBEDDINGS_KEY",
"FALLBACK_LLM_API_KEY",
"QDRANT_API_KEY",

View File

@@ -13,25 +13,3 @@ def response_error(code_status, message=None):
def bad_request(status_code=400, message=''):
return response_error(code_status=status_code, message=message)
def sanitize_api_error(error) -> str:
"""
Convert technical API errors to user-friendly messages.
Works with both Exception objects and error message strings.
"""
error_str = str(error).lower()
if "503" in error_str or "unavailable" in error_str or "high demand" in error_str:
return "The AI service is temporarily unavailable due to high demand. Please try again in a moment."
if "429" in error_str or "rate limit" in error_str or "quota" in error_str:
return "Rate limit exceeded. Please wait a moment before trying again."
if "401" in error_str or "unauthorized" in error_str or "invalid api key" in error_str:
return "Authentication error. Please check your API configuration."
if "timeout" in error_str or "timed out" in error_str:
return "The request timed out. Please try again."
if "connection" in error_str or "network" in error_str:
return "Network error. Please check your connection and try again."
original = str(error)
if len(original) > 200 or "{" in original or "traceback" in error_str:
return "An error occurred while processing your request. Please try again later."
return original

View File

@@ -16,61 +16,22 @@ class BaseLLM(ABC):
agent_id=None,
model_id=None,
base_url=None,
backup_models=None,
):
self.decoded_token = decoded_token
self.agent_id = str(agent_id) if agent_id else None
self.model_id = model_id
self.base_url = base_url
self.token_usage = {"prompt_tokens": 0, "generated_tokens": 0}
self._backup_models = backup_models or []
self._fallback_llm = None
self._fallback_sequence_index = 0
@property
def fallback_llm(self):
"""Lazy-loaded fallback LLM: tries per-agent backup models first,
then the global FALLBACK_* settings."""
if self._fallback_llm is not None:
return self._fallback_llm
from application.llm.llm_creator import LLMCreator
from application.core.model_utils import (
get_provider_from_model_id,
get_api_key_for_provider,
)
# Try per-agent backup models first
for backup_model_id in self._backup_models:
"""Lazy-loaded fallback LLM from FALLBACK_* settings."""
if self._fallback_llm is None and settings.FALLBACK_LLM_PROVIDER:
try:
provider = get_provider_from_model_id(backup_model_id)
if not provider:
logger.warning(
f"Could not resolve provider for backup model: {backup_model_id}"
)
continue
api_key = get_api_key_for_provider(provider)
self._fallback_llm = LLMCreator.create_llm(
provider,
api_key=api_key,
user_api_key=getattr(self, "user_api_key", None),
decoded_token=self.decoded_token,
model_id=backup_model_id,
agent_id=self.agent_id,
)
logger.info(
f"Fallback LLM initialized from agent backup model: "
f"{provider}/{backup_model_id}"
)
return self._fallback_llm
except Exception as e:
logger.warning(
f"Failed to initialize backup model {backup_model_id}: {str(e)}"
)
continue
from application.llm.llm_creator import LLMCreator
# Fall back to global FALLBACK_* settings
if settings.FALLBACK_LLM_PROVIDER:
try:
self._fallback_llm = LLMCreator.create_llm(
settings.FALLBACK_LLM_PROVIDER,
api_key=settings.FALLBACK_LLM_API_KEY or settings.API_KEY,
@@ -80,14 +41,12 @@ class BaseLLM(ABC):
agent_id=self.agent_id,
)
logger.info(
f"Fallback LLM initialized from global settings: "
f"{settings.FALLBACK_LLM_PROVIDER}/{settings.FALLBACK_LLM_NAME}"
f"Fallback LLM initialized: {settings.FALLBACK_LLM_PROVIDER}/{settings.FALLBACK_LLM_NAME}"
)
except Exception as e:
logger.error(
f"Failed to initialize fallback LLM: {str(e)}", exc_info=True
)
return self._fallback_llm
@staticmethod
@@ -115,60 +74,20 @@ class BaseLLM(ABC):
method = decorator(method)
return method(self, *args, **kwargs)
is_stream = "stream" in method_name
if is_stream:
return self._stream_with_fallback(
decorated_method, method_name, *args, **kwargs
)
try:
return decorated_method()
except Exception as e:
if not self.fallback_llm:
logger.error(f"Primary LLM failed and no fallback configured: {str(e)}")
raise
fallback = self.fallback_llm
logger.warning(
f"Primary LLM failed. Falling back to "
f"{fallback.model_id}. Error: {str(e)}"
f"Primary LLM failed. Falling back to {settings.FALLBACK_LLM_PROVIDER}/{settings.FALLBACK_LLM_NAME}. Error: {str(e)}"
)
fallback_method = getattr(
fallback, method_name.replace("_raw_", "")
self.fallback_llm, method_name.replace("_raw_", "")
)
fallback_kwargs = {**kwargs, "model": fallback.model_id}
return fallback_method(*args, **fallback_kwargs)
def _stream_with_fallback(
self, decorated_method, method_name, *args, **kwargs
):
"""
Wrapper generator that catches mid-stream errors and falls back.
Unlike non-streaming calls where exceptions are raised immediately,
streaming generators raise exceptions during iteration. This wrapper
ensures that if the primary LLM fails at any point during streaming
(creation or mid-stream), we fall back to the backup model.
"""
try:
yield from decorated_method()
except Exception as e:
if not self.fallback_llm:
logger.error(
f"Primary LLM failed and no fallback configured: {str(e)}"
)
raise
fallback = self.fallback_llm
logger.warning(
f"Primary LLM failed mid-stream. Falling back to "
f"{fallback.model_id}. Error: {str(e)}"
)
fallback_method = getattr(
fallback, method_name.replace("_raw_", "")
)
fallback_kwargs = {**kwargs, "model": fallback.model_id}
yield from fallback_method(*args, **fallback_kwargs)
return fallback_method(*args, **kwargs)
def gen(self, model, messages, stream=False, tools=None, *args, **kwargs):
decorators = [gen_token_usage, gen_cache]

View File

@@ -158,11 +158,7 @@ class GoogleLLM(BaseLLM):
if isinstance(content, list):
parts = []
for item in content:
if (
isinstance(item, dict)
and "text" in item
and item["text"] is not None
):
if isinstance(item, dict) and "text" in item and item["text"] is not None:
parts.append(item["text"])
return "\n".join(parts)
return ""
@@ -240,9 +236,7 @@ class GoogleLLM(BaseLLM):
raise ValueError(f"Unexpected content type: {type(content)}")
if parts:
cleaned_messages.append(types.Content(role=role, parts=parts))
system_instruction = (
"\n\n".join(system_instructions) if system_instructions else None
)
system_instruction = "\n\n".join(system_instructions) if system_instructions else None
return cleaned_messages, system_instruction
def _clean_schema(self, schema_obj):
@@ -342,10 +336,7 @@ class GoogleLLM(BaseLLM):
return f"function_call:{name}"
function_response = getattr(part, "function_response", None)
if function_response:
name = (
getattr(function_response, "name", "")
or "function_response"
)
name = getattr(function_response, "name", "") or "function_response"
return f"function_response:{name}"
if isinstance(message, dict):
content = message.get("content")
@@ -516,9 +507,6 @@ class GoogleLLM(BaseLLM):
yield {"type": "thought", "thought": chunk_text}
else:
yield chunk_text
except Exception as e:
logging.error(f"GoogleLLM: Stream error: {e}", exc_info=True)
raise
finally:
if hasattr(response, "close"):
response.close()

View File

@@ -7,7 +7,6 @@ class LLMHandlerCreator:
handlers = {
"openai": OpenAILLMHandler,
"google": GoogleLLMHandler,
"novita": OpenAILLMHandler, # Novita uses OpenAI-compatible API
"default": OpenAILLMHandler,
}

View File

@@ -38,7 +38,6 @@ class LLMCreator:
decoded_token,
model_id=None,
agent_id=None,
backup_models=None,
*args,
**kwargs,
):
@@ -60,7 +59,6 @@ class LLMCreator:
model_id=model_id,
agent_id=agent_id,
base_url=base_url,
backup_models=backup_models,
*args,
**kwargs,
)

View File

@@ -1,13 +1,13 @@
from application.core.settings import settings
from application.llm.openai import OpenAILLM
NOVITA_BASE_URL = "https://api.novita.ai/openai"
NOVITA_BASE_URL = "https://api.novita.ai/v3/openai"
class NovitaLLM(OpenAILLM):
def __init__(self, api_key=None, user_api_key=None, base_url=None, *args, **kwargs):
super().__init__(
api_key=api_key or settings.NOVITA_API_KEY or settings.API_KEY,
api_key=api_key or settings.API_KEY,
user_api_key=user_api_key,
base_url=base_url or NOVITA_BASE_URL,
*args,

View File

@@ -101,13 +101,9 @@ class OpenAILLM(BaseLLM):
for item in content:
if "function_call" in item:
# Function calls need their own message
args = item["function_call"]["args"]
if isinstance(args, str):
try:
args = json.loads(args)
except (json.JSONDecodeError, TypeError):
pass
cleaned_args = self._remove_null_values(args)
cleaned_args = self._remove_null_values(
item["function_call"]["args"]
)
tool_call = {
"id": item["function_call"]["call_id"],
"type": "function",

View File

@@ -62,26 +62,15 @@ class BaseConnectorAuth(ABC):
def is_token_expired(self, token_info: Dict[str, Any]) -> bool:
"""
Check if a token is expired.
Args:
token_info: Token information dictionary
Returns:
True if token is expired, False otherwise
"""
pass
def sanitize_token_info(self, token_info: Dict[str, Any], **extra_fields) -> Dict[str, Any]:
"""Extract the fields safe to persist in the session store.
"""
return {
"access_token": token_info.get("access_token"),
"refresh_token": token_info.get("refresh_token"),
"token_uri": token_info.get("token_uri"),
"expiry": token_info.get("expiry"),
**extra_fields,
}
class BaseConnectorLoader(ABC):
"""

View File

@@ -1,7 +1,5 @@
from application.parser.connectors.google_drive.loader import GoogleDriveLoader
from application.parser.connectors.google_drive.auth import GoogleDriveAuth
from application.parser.connectors.share_point.auth import SharePointAuth
from application.parser.connectors.share_point.loader import SharePointLoader
class ConnectorCreator:
@@ -14,12 +12,10 @@ class ConnectorCreator:
connectors = {
"google_drive": GoogleDriveLoader,
"share_point": SharePointLoader,
}
auth_providers = {
"google_drive": GoogleDriveAuth,
"share_point": SharePointAuth,
}
@classmethod

View File

@@ -232,6 +232,10 @@ class GoogleDriveAuth(BaseConnectorAuth):
if missing_fields:
raise ValueError(f"Missing required token fields: {missing_fields}")
if 'client_id' not in token_info:
token_info['client_id'] = settings.GOOGLE_CLIENT_ID
if 'client_secret' not in token_info:
token_info['client_secret'] = settings.GOOGLE_CLIENT_SECRET
if 'token_uri' not in token_info:
token_info['token_uri'] = 'https://oauth2.googleapis.com/token'

View File

@@ -327,10 +327,15 @@ class GoogleDriveLoader(BaseConnectorLoader):
content_bytes = file_io.getvalue()
try:
return content_bytes.decode('utf-8')
content = content_bytes.decode('utf-8')
except UnicodeDecodeError:
logging.error(f"Could not decode file {file_id} as text")
return None
try:
content = content_bytes.decode('latin-1')
except UnicodeDecodeError:
logging.error(f"Could not decode file {file_id} as text")
return None
return content
except HttpError as e:
logging.error(f"HTTP error downloading file {file_id}: {e.resp.status} - {e.content}")

View File

@@ -1,10 +0,0 @@
"""
Share Point connector package for DocsGPT.
This module provides authentication and document loading capabilities for Share Point.
"""
from .auth import SharePointAuth
from .loader import SharePointLoader
__all__ = ['SharePointAuth', 'SharePointLoader']

View File

@@ -1,152 +0,0 @@
import datetime
import logging
from typing import Optional, Dict, Any
from msal import ConfidentialClientApplication
from application.core.settings import settings
from application.parser.connectors.base import BaseConnectorAuth
logger = logging.getLogger(__name__)
class SharePointAuth(BaseConnectorAuth):
"""
Handles Microsoft OAuth 2.0 authentication for SharePoint/OneDrive.
Note: Files.Read scope allows access to files the user has granted access to,
similar to Google Drive's drive.file scope.
"""
SCOPES = [
"Files.Read",
"Sites.Read.All",
"User.Read",
]
def __init__(self):
self.client_id = settings.MICROSOFT_CLIENT_ID
self.client_secret = settings.MICROSOFT_CLIENT_SECRET
if not self.client_id:
raise ValueError(
"Microsoft OAuth credentials not configured. Please set MICROSOFT_CLIENT_ID in settings."
)
if not self.client_secret:
raise ValueError(
"Microsoft OAuth credentials not configured. Please set MICROSOFT_CLIENT_SECRET in settings."
)
self.redirect_uri = settings.CONNECTOR_REDIRECT_BASE_URI
self.tenant_id = settings.MICROSOFT_TENANT_ID
self.authority = getattr(settings, "MICROSOFT_AUTHORITY", f"https://login.microsoftonline.com/{self.tenant_id}")
self.auth_app = ConfidentialClientApplication(
client_id=self.client_id,
client_credential=self.client_secret,
authority=self.authority
)
def get_authorization_url(self, state: Optional[str] = None) -> str:
return self.auth_app.get_authorization_request_url(
scopes=self.SCOPES, state=state, redirect_uri=self.redirect_uri
)
def exchange_code_for_tokens(self, authorization_code: str) -> Dict[str, Any]:
result = self.auth_app.acquire_token_by_authorization_code(
code=authorization_code,
scopes=self.SCOPES,
redirect_uri=self.redirect_uri
)
if "error" in result:
logger.error("Token exchange failed: %s", result.get("error_description"))
raise ValueError(f"Error acquiring token: {result.get('error_description')}")
return self.map_token_response(result)
def refresh_access_token(self, refresh_token: str) -> Dict[str, Any]:
result = self.auth_app.acquire_token_by_refresh_token(refresh_token=refresh_token, scopes=self.SCOPES)
if "error" in result:
logger.error("Token refresh failed: %s", result.get("error_description"))
raise ValueError(f"Error refreshing token: {result.get('error_description')}")
return self.map_token_response(result)
def get_token_info_from_session(self, session_token: str) -> Dict[str, Any]:
try:
from application.core.mongo_db import MongoDB
from application.core.settings import settings
mongo = MongoDB.get_client()
db = mongo[settings.MONGO_DB_NAME]
sessions_collection = db["connector_sessions"]
session = sessions_collection.find_one({"session_token": session_token})
if not session:
raise ValueError(f"Invalid session token: {session_token}")
if "token_info" not in session:
raise ValueError("Session missing token information")
token_info = session["token_info"]
if not token_info:
raise ValueError("Invalid token information")
required_fields = ["access_token", "refresh_token"]
missing_fields = [field for field in required_fields if field not in token_info or not token_info.get(field)]
if missing_fields:
raise ValueError(f"Missing required token fields: {missing_fields}")
if 'token_uri' not in token_info:
token_info['token_uri'] = f"https://login.microsoftonline.com/{settings.MICROSOFT_TENANT_ID}/oauth2/v2.0/token"
return token_info
except Exception as e:
logger.error("Failed to retrieve token from session: %s", e)
raise ValueError(f"Failed to retrieve SharePoint token information: {str(e)}")
def is_token_expired(self, token_info: Dict[str, Any]) -> bool:
if not token_info:
return True
expiry_timestamp = token_info.get("expiry")
if expiry_timestamp is None:
return True
current_timestamp = int(datetime.datetime.now().timestamp())
return (expiry_timestamp - current_timestamp) < 60
def sanitize_token_info(self, token_info: Dict[str, Any], **extra_fields) -> Dict[str, Any]:
return super().sanitize_token_info(
token_info,
allows_shared_content=token_info.get("allows_shared_content", False),
**extra_fields,
)
PERSONAL_ACCOUNT_TENANT_ID = "9188040d-6c67-4c5b-b112-36a304b66dad"
def _allows_shared_content(self, id_token_claims: Dict[str, Any]) -> bool:
"""Return True when the account is a work/school tenant that can access SharePoint shared content."""
tid = id_token_claims.get("tid", "")
return bool(tid) and tid != self.PERSONAL_ACCOUNT_TENANT_ID
def map_token_response(self, result) -> Dict[str, Any]:
claims = result.get("id_token_claims", {})
return {
"access_token": result.get("access_token"),
"refresh_token": result.get("refresh_token"),
"token_uri": claims.get("iss"),
"scopes": result.get("scope"),
"expiry": claims.get("exp"),
"allows_shared_content": self._allows_shared_content(claims),
"user_info": {
"name": claims.get("name"),
"email": claims.get("preferred_username"),
},
}

View File

@@ -1,649 +0,0 @@
"""
SharePoint/OneDrive loader for DocsGPT.
Loads documents from SharePoint/OneDrive using Microsoft Graph API.
"""
import functools
import logging
import os
from typing import List, Dict, Any, Optional, Tuple
from urllib.parse import quote
import requests
from application.parser.connectors.base import BaseConnectorLoader
from application.parser.connectors.share_point.auth import SharePointAuth
from application.parser.schema.base import Document
def _retry_on_auth_failure(func):
"""Retry once after refreshing the access token on 401/403 responses."""
@functools.wraps(func)
def wrapper(self, *args, **kwargs):
try:
return func(self, *args, **kwargs)
except requests.exceptions.HTTPError as e:
if e.response is not None and e.response.status_code in (401, 403):
logging.info(f"Auth failure in {func.__name__}, refreshing token and retrying")
try:
new_token_info = self.auth.refresh_access_token(self.refresh_token)
self.access_token = new_token_info.get('access_token')
except Exception as refresh_error:
raise ValueError(
f"Authentication failed and could not be refreshed: {refresh_error}"
) from e
return func(self, *args, **kwargs)
raise
return wrapper
class SharePointLoader(BaseConnectorLoader):
SUPPORTED_MIME_TYPES = {
'application/pdf': '.pdf',
'application/vnd.openxmlformats-officedocument.wordprocessingml.document': '.docx',
'application/vnd.openxmlformats-officedocument.presentationml.presentation': '.pptx',
'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet': '.xlsx',
'application/msword': '.doc',
'application/vnd.ms-powerpoint': '.ppt',
'application/vnd.ms-excel': '.xls',
'text/plain': '.txt',
'text/csv': '.csv',
'text/html': '.html',
'text/markdown': '.md',
'text/x-rst': '.rst',
'application/json': '.json',
'application/epub+zip': '.epub',
'application/rtf': '.rtf',
'image/jpeg': '.jpg',
'image/png': '.png',
}
EXTENSION_TO_MIME = {v: k for k, v in SUPPORTED_MIME_TYPES.items()}
GRAPH_API_BASE = "https://graph.microsoft.com/v1.0"
def __init__(self, session_token: str):
self.auth = SharePointAuth()
self.session_token = session_token
token_info = self.auth.get_token_info_from_session(session_token)
self.access_token = token_info.get('access_token')
self.refresh_token = token_info.get('refresh_token')
self.allows_shared_content = token_info.get('allows_shared_content', False)
if not self.access_token:
raise ValueError("No access token found in session")
self.next_page_token = None
def _get_headers(self) -> Dict[str, str]:
return {
'Authorization': f'Bearer {self.access_token}',
'Accept': 'application/json'
}
def _ensure_valid_token(self):
if not self.access_token:
raise ValueError("No access token available")
token_info = {'access_token': self.access_token, 'expiry': None}
if self.auth.is_token_expired(token_info):
logging.info("Token expired, attempting refresh")
try:
new_token_info = self.auth.refresh_access_token(self.refresh_token)
self.access_token = new_token_info.get('access_token')
except Exception:
raise ValueError("Failed to refresh access token")
def _get_item_url(self, item_ref: str) -> str:
if ':' in item_ref:
drive_id, item_id = item_ref.split(':', 1)
return f"{self.GRAPH_API_BASE}/drives/{drive_id}/items/{item_id}"
return f"{self.GRAPH_API_BASE}/me/drive/items/{item_ref}"
def _process_file(self, file_metadata: Dict[str, Any], load_content: bool = True) -> Optional[Document]:
try:
drive_item_id = file_metadata.get('id')
file_name = file_metadata.get('name', 'Unknown')
file_data = file_metadata.get('file', {})
mime_type = file_data.get('mimeType', 'application/octet-stream')
if mime_type not in self.SUPPORTED_MIME_TYPES:
logging.info(f"Skipping unsupported file type: {mime_type} for file {file_name}")
return None
doc_metadata = {
'file_name': file_name,
'mime_type': mime_type,
'size': file_metadata.get('size'),
'created_time': file_metadata.get('createdDateTime'),
'modified_time': file_metadata.get('lastModifiedDateTime'),
'source': 'share_point'
}
if not load_content:
return Document(
text="",
doc_id=drive_item_id,
extra_info=doc_metadata
)
content = self._download_file_content(drive_item_id)
if content is None:
logging.warning(f"Could not load content for file {file_name} ({drive_item_id})")
return None
return Document(
text=content,
doc_id=drive_item_id,
extra_info=doc_metadata
)
except Exception as e:
logging.error(f"Error processing file: {e}")
return None
def load_data(self, inputs: Dict[str, Any]) -> List[Document]:
try:
documents: List[Document] = []
folder_id = inputs.get('folder_id')
file_ids = inputs.get('file_ids', [])
limit = inputs.get('limit', 100)
list_only = inputs.get('list_only', False)
load_content = not list_only
page_token = inputs.get('page_token')
search_query = inputs.get('search_query')
self.next_page_token = None
shared = inputs.get('shared', False)
if file_ids:
for file_id in file_ids:
try:
doc = self._load_file_by_id(file_id, load_content=load_content)
if doc:
if not search_query or (
search_query.lower() in doc.extra_info.get('file_name', '').lower()
):
documents.append(doc)
except Exception as e:
logging.error(f"Error loading file {file_id}: {e}")
continue
elif shared:
if not self.allows_shared_content:
logging.warning("Shared content is only available for work/school Microsoft accounts")
return []
documents = self._list_shared_items(
limit=limit,
load_content=load_content,
page_token=page_token,
search_query=search_query
)
else:
parent_id = folder_id if folder_id else 'root'
documents = self._list_items_in_parent(
parent_id,
limit=limit,
load_content=load_content,
page_token=page_token,
search_query=search_query
)
logging.info(f"Loaded {len(documents)} documents from SharePoint/OneDrive")
return documents
except Exception as e:
logging.error(f"Error loading data from SharePoint/OneDrive: {e}", exc_info=True)
raise
@_retry_on_auth_failure
def _load_file_by_id(self, file_id: str, load_content: bool = True) -> Optional[Document]:
self._ensure_valid_token()
try:
url = self._get_item_url(file_id)
params = {'$select': 'id,name,file,createdDateTime,lastModifiedDateTime,size'}
response = requests.get(url, headers=self._get_headers(), params=params)
response.raise_for_status()
file_metadata = response.json()
return self._process_file(file_metadata, load_content=load_content)
except requests.exceptions.HTTPError:
raise
except Exception as e:
logging.error(f"Error loading file {file_id}: {e}")
return None
@_retry_on_auth_failure
def _list_items_in_parent(self, parent_id: str, limit: int = 100, load_content: bool = False, page_token: Optional[str] = None, search_query: Optional[str] = None) -> List[Document]:
self._ensure_valid_token()
documents: List[Document] = []
try:
url = f"{self._get_item_url(parent_id)}/children"
params = {'$top': min(100, limit) if limit else 100, '$select': 'id,name,file,folder,createdDateTime,lastModifiedDateTime,size'}
if page_token:
params['$skipToken'] = page_token
if search_query:
encoded_query = quote(search_query, safe='')
if ':' in parent_id:
drive_id = parent_id.split(':', 1)[0]
search_url = f"{self.GRAPH_API_BASE}/drives/{drive_id}/root/search(q='{encoded_query}')"
else:
search_url = f"{self.GRAPH_API_BASE}/me/drive/search(q='{encoded_query}')"
response = requests.get(search_url, headers=self._get_headers(), params=params)
else:
response = requests.get(url, headers=self._get_headers(), params=params)
response.raise_for_status()
results = response.json()
items = results.get('value', [])
for item in items:
if 'folder' in item:
doc_metadata = {
'file_name': item.get('name', 'Unknown'),
'mime_type': 'folder',
'size': item.get('size'),
'created_time': item.get('createdDateTime'),
'modified_time': item.get('lastModifiedDateTime'),
'source': 'share_point',
'is_folder': True
}
documents.append(Document(text="", doc_id=item.get('id'), extra_info=doc_metadata))
else:
doc = self._process_file(item, load_content=load_content)
if doc:
documents.append(doc)
if limit and len(documents) >= limit:
break
next_link = results.get('@odata.nextLink')
if next_link:
from urllib.parse import urlparse, parse_qs
parsed = urlparse(next_link)
query_params = parse_qs(parsed.query)
skiptoken_list = query_params.get('$skiptoken')
if skiptoken_list:
self.next_page_token = skiptoken_list[0]
else:
self.next_page_token = None
else:
self.next_page_token = None
return documents
except Exception as e:
logging.error(f"Error listing items under parent {parent_id}: {e}")
return documents
def _resolve_mime_type(self, resource: Dict[str, Any]) -> Tuple[str, bool]:
"""Resolve mime type from resource, falling back to file extension."""
file_data = resource.get('file', {})
mime_type = file_data.get('mimeType') if file_data else None
if mime_type and mime_type in self.SUPPORTED_MIME_TYPES:
return mime_type, True
name = resource.get('name', '')
ext = os.path.splitext(name)[1].lower()
if ext in self.EXTENSION_TO_MIME:
return self.EXTENSION_TO_MIME[ext], True
return mime_type or 'application/octet-stream', False
def _get_user_drive_web_url(self) -> Optional[str]:
"""Fetch the current user's OneDrive web URL for KQL path exclusion."""
try:
response = requests.get(
f"{self.GRAPH_API_BASE}/me/drive",
headers=self._get_headers(),
params={'$select': 'webUrl'}
)
response.raise_for_status()
return response.json().get('webUrl')
except Exception as e:
logging.warning(f"Could not fetch user drive web URL: {e}")
return None
def _build_shared_kql_query(self, search_query: Optional[str], user_drive_url: Optional[str]) -> str:
"""Build KQL query string that excludes the user's own drive items."""
base_query = search_query if search_query else "*"
if user_drive_url:
return f'{base_query} AND -path:"{user_drive_url}"'
return base_query
def _list_shared_items(self, limit: int = 100, load_content: bool = False, page_token: Optional[str] = None, search_query: Optional[str] = None) -> List[Document]:
"""Fetch shared drive items using Microsoft Graph Search API with local offset paging.
We always fetch up to a fixed maximum number of hits from Graph (single request),
then page through that array locally using `page_token` as a simple integer offset.
This avoids relying on buggy or inconsistent remote `from`/`size` semantics.
"""
self._ensure_valid_token()
documents: List[Document] = []
try:
user_drive_url = self._get_user_drive_web_url()
query_text = self._build_shared_kql_query(search_query, user_drive_url)
url = f"{self.GRAPH_API_BASE}/search/query"
page_size = 500 # maximum number of hits we care about for selection
body = {
"requests": [
{
"entityTypes": ["driveItem"],
"query": {"queryString": query_text},
"from": 0,
"size": page_size,
}
]
}
headers = self._get_headers()
headers["Content-Type"] = "application/json"
response = requests.post(url, headers=headers, json=body)
response.raise_for_status()
results = response.json()
search_response = results.get("value", [])
if not search_response:
logging.warning("Search API returned empty value array")
self.next_page_token = None
return documents
hits_containers = search_response[0].get("hitsContainers", [])
if not hits_containers:
logging.warning("Search API returned no hitsContainers")
self.next_page_token = None
return documents
container = hits_containers[0]
total = container.get("total", 0)
raw_hits = container.get("hits", [])
# Deduplicate by effective item ID (driveId:itemId) to avoid the same
# resource appearing multiple times across the result set.
deduped_hits = []
seen_ids = set()
for hit in raw_hits:
resource = hit.get("resource", {})
item_id = resource.get("id")
drive_id = resource.get("parentReference", {}).get("driveId")
effective_id = f"{drive_id}:{item_id}" if drive_id and item_id else item_id
if not effective_id or effective_id in seen_ids:
continue
seen_ids.add(effective_id)
deduped_hits.append(hit)
hits = deduped_hits
logging.info(
f"Search API returned {total} total results, {len(raw_hits)} raw hits, {len(hits)} unique hits in this batch"
)
try:
offset = int(page_token) if page_token is not None else 0
except (TypeError, ValueError):
logging.warning(
f"Invalid page_token '{page_token}' for shared items search, defaulting to 0"
)
offset = 0
if offset < 0:
offset = 0
if offset >= len(hits):
self.next_page_token = None
return documents
end_index = offset + limit if limit else len(hits)
end_index = min(end_index, len(hits))
for hit in hits[offset:end_index]:
resource = hit.get("resource", {})
item_name = resource.get("name", "Unknown")
item_id = resource.get("id")
drive_id = resource.get("parentReference", {}).get("driveId")
effective_id = f"{drive_id}:{item_id}" if drive_id and item_id else item_id
is_folder = "folder" in resource
if is_folder:
doc_metadata = {
"file_name": item_name,
"mime_type": "folder",
"size": resource.get("size"),
"created_time": resource.get("createdDateTime"),
"modified_time": resource.get("lastModifiedDateTime"),
"source": "share_point",
"is_folder": True,
}
documents.append(
Document(text="", doc_id=effective_id, extra_info=doc_metadata)
)
else:
mime_type, supported = self._resolve_mime_type(resource)
if not supported:
logging.info(
f"Skipping unsupported shared file: {item_name} (mime: {mime_type})"
)
continue
doc_metadata = {
"file_name": item_name,
"mime_type": mime_type,
"size": resource.get("size"),
"created_time": resource.get("createdDateTime"),
"modified_time": resource.get("lastModifiedDateTime"),
"source": "share_point",
}
content = ""
if load_content:
content = self._download_file_content(effective_id) or ""
documents.append(
Document(text=content, doc_id=effective_id, extra_info=doc_metadata)
)
if limit and end_index < len(hits):
self.next_page_token = str(end_index)
else:
self.next_page_token = None
return documents
except Exception as e:
logging.error(f"Error listing shared items via search API: {e}", exc_info=True)
return documents
@_retry_on_auth_failure
def _download_file_content(self, file_id: str) -> Optional[str]:
self._ensure_valid_token()
try:
url = f"{self._get_item_url(file_id)}/content"
response = requests.get(url, headers=self._get_headers())
response.raise_for_status()
try:
return response.content.decode('utf-8')
except UnicodeDecodeError:
logging.error(f"Could not decode file {file_id} as text")
return None
except requests.exceptions.HTTPError:
raise
except Exception as e:
logging.error(f"Error downloading file {file_id}: {e}")
return None
def _download_single_file(self, file_id: str, local_dir: str) -> bool:
try:
url = self._get_item_url(file_id)
params = {'$select': 'id,name,file'}
response = requests.get(url, headers=self._get_headers(), params=params)
response.raise_for_status()
metadata = response.json()
file_name = metadata.get('name', 'unknown')
file_data = metadata.get('file', {})
mime_type = file_data.get('mimeType', 'application/octet-stream')
if mime_type not in self.SUPPORTED_MIME_TYPES:
logging.info(f"Skipping unsupported file type: {mime_type}")
return False
os.makedirs(local_dir, exist_ok=True)
full_path = os.path.join(local_dir, file_name)
download_url = f"{self._get_item_url(file_id)}/content"
download_response = requests.get(download_url, headers=self._get_headers())
download_response.raise_for_status()
with open(full_path, 'wb') as f:
f.write(download_response.content)
return True
except Exception as e:
logging.error(f"Error in _download_single_file: {e}")
return False
def _download_folder_recursive(self, folder_id: str, local_dir: str, recursive: bool = True) -> int:
files_downloaded = 0
try:
os.makedirs(local_dir, exist_ok=True)
url = f"{self._get_item_url(folder_id)}/children"
params = {'$top': 1000}
while url:
response = requests.get(url, headers=self._get_headers(), params=params)
response.raise_for_status()
results = response.json()
items = results.get('value', [])
logging.info(f"Found {len(items)} items in folder {folder_id}")
for item in items:
item_name = item.get('name', 'unknown')
item_id = item.get('id')
if 'folder' in item:
if recursive:
subfolder_path = os.path.join(local_dir, item_name)
os.makedirs(subfolder_path, exist_ok=True)
subfolder_files = self._download_folder_recursive(
item_id,
subfolder_path,
recursive
)
files_downloaded += subfolder_files
logging.info(f"Downloaded {subfolder_files} files from subfolder {item_name}")
else:
success = self._download_single_file(item_id, local_dir)
if success:
files_downloaded += 1
logging.info(f"Downloaded file: {item_name}")
else:
logging.warning(f"Failed to download file: {item_name}")
url = results.get('@odata.nextLink')
return files_downloaded
except Exception as e:
logging.error(f"Error in _download_folder_recursive for folder {folder_id}: {e}", exc_info=True)
return files_downloaded
def _download_folder_contents(self, folder_id: str, local_dir: str, recursive: bool = True) -> int:
try:
self._ensure_valid_token()
return self._download_folder_recursive(folder_id, local_dir, recursive)
except Exception as e:
logging.error(f"Error downloading folder {folder_id}: {e}", exc_info=True)
return 0
def _download_file_to_directory(self, file_id: str, local_dir: str) -> bool:
try:
self._ensure_valid_token()
return self._download_single_file(file_id, local_dir)
except Exception as e:
logging.error(f"Error downloading file {file_id}: {e}", exc_info=True)
return False
def download_to_directory(self, local_dir: str, source_config: Dict[str, Any] = None) -> Dict[str, Any]:
if source_config is None:
source_config = {}
config = source_config if source_config else getattr(self, 'config', {})
files_downloaded = 0
try:
folder_ids = config.get('folder_ids', [])
file_ids = config.get('file_ids', [])
recursive = config.get('recursive', True)
if file_ids:
if isinstance(file_ids, str):
file_ids = [file_ids]
for file_id in file_ids:
if self._download_file_to_directory(file_id, local_dir):
files_downloaded += 1
if folder_ids:
if isinstance(folder_ids, str):
folder_ids = [folder_ids]
for folder_id in folder_ids:
try:
url = self._get_item_url(folder_id)
params = {'$select': 'id,name'}
response = requests.get(url, headers=self._get_headers(), params=params)
response.raise_for_status()
folder_metadata = response.json()
folder_name = folder_metadata.get('name', '')
folder_path = os.path.join(local_dir, folder_name)
os.makedirs(folder_path, exist_ok=True)
folder_files = self._download_folder_recursive(
folder_id,
folder_path,
recursive
)
files_downloaded += folder_files
logging.info(f"Downloaded {folder_files} files from folder {folder_name}")
except Exception as e:
logging.error(f"Error downloading folder {folder_id}: {e}", exc_info=True)
if not file_ids and not folder_ids:
raise ValueError("No folder_ids or file_ids provided for download")
return {
"files_downloaded": files_downloaded,
"directory_path": local_dir,
"empty_result": files_downloaded == 0,
"source_type": "share_point",
"config_used": config
}
except Exception as e:
return {
"files_downloaded": files_downloaded,
"directory_path": local_dir,
"empty_result": True,
"source_type": "share_point",
"config_used": config,
"error": str(e)
}

View File

@@ -1,48 +0,0 @@
from pathlib import Path
from typing import Dict, Union
from application.core.settings import settings
from application.parser.file.base_parser import BaseParser
from application.stt.stt_creator import STTCreator
from application.stt.upload_limits import enforce_audio_file_size_limit
class AudioParser(BaseParser):
def __init__(self, parser_config=None):
super().__init__(parser_config=parser_config)
self._transcript_metadata: Dict[str, Dict] = {}
def _init_parser(self) -> Dict:
return {}
def parse_file(self, file: Path, errors: str = "ignore") -> Union[str, list[str]]:
_ = errors
try:
enforce_audio_file_size_limit(file.stat().st_size)
except OSError:
pass
stt = STTCreator.create_stt(settings.STT_PROVIDER)
result = stt.transcribe(
file,
language=settings.STT_LANGUAGE,
timestamps=settings.STT_ENABLE_TIMESTAMPS,
diarize=settings.STT_ENABLE_DIARIZATION,
)
transcript_metadata = {
"transcript_duration_s": result.get("duration_s"),
"transcript_language": result.get("language"),
"transcript_provider": result.get("provider"),
}
if result.get("segments"):
transcript_metadata["transcript_segments"] = result["segments"]
self._transcript_metadata[str(file)] = {
key: value
for key, value in transcript_metadata.items()
if value not in (None, [], {})
}
return result.get("text", "")
def get_file_metadata(self, file: Path) -> Dict:
return self._transcript_metadata.get(str(file), {})

View File

@@ -36,8 +36,3 @@ class BaseParser:
@abstractmethod
def parse_file(self, file: Path, errors: str = "ignore") -> Union[str, List[str]]:
"""Parse file."""
def get_file_metadata(self, file: Path) -> Dict:
"""Return parser-specific metadata for the most recently parsed file."""
_ = file
return {}

View File

@@ -14,17 +14,11 @@ from application.parser.file.tabular_parser import PandasCSVParser, ExcelParser
from application.parser.file.json_parser import JSONParser
from application.parser.file.pptx_parser import PPTXParser
from application.parser.file.image_parser import ImageParser
from application.parser.file.audio_parser import AudioParser
from application.parser.schema.base import Document
from application.stt.constants import SUPPORTED_AUDIO_EXTENSIONS
from application.utils import num_tokens_from_string
from application.core.settings import settings
def _build_audio_parser_mapping() -> Dict[str, BaseParser]:
return {extension: AudioParser() for extension in SUPPORTED_AUDIO_EXTENSIONS}
def get_default_file_extractor(
ocr_enabled: Optional[bool] = None,
) -> Dict[str, BaseParser]:
@@ -76,7 +70,6 @@ def get_default_file_extractor(
".webp": DoclingImageParser(ocr_enabled=ocr_enabled) if ocr_enabled else ImageParser(),
# Media/subtitles
".vtt": DoclingVTTParser(),
**_build_audio_parser_mapping(),
# Specialized XML formats
".xml": DoclingXMLParser(),
# Formats docling doesn't support - use standard parsers
@@ -103,7 +96,6 @@ def get_default_file_extractor(
".png": ImageParser(),
".jpg": ImageParser(),
".jpeg": ImageParser(),
**_build_audio_parser_mapping(),
}
@@ -229,13 +221,11 @@ class SimpleDirectoryReader(BaseReader):
for input_file in self.input_files:
suffix_lower = input_file.suffix.lower()
parser_metadata = {}
if suffix_lower in self.file_extractor:
parser = self.file_extractor[suffix_lower]
if not parser.parser_config_set:
parser.init_parser()
data = parser.parse_file(input_file, errors=self.errors)
parser_metadata = parser.get_file_metadata(input_file)
else:
# do standard read
with open(input_file, "r", errors=self.errors) as f:
@@ -254,8 +244,6 @@ class SimpleDirectoryReader(BaseReader):
'title': input_file.name,
'token_count': file_tokens,
}
if parser_metadata:
base_metadata.update(parser_metadata)
if hasattr(self, 'input_dir'):
try:

View File

@@ -1,27 +0,0 @@
"""Shared file-extension constants for parsing and ingestion flows."""
from application.stt.constants import SUPPORTED_AUDIO_EXTENSIONS
SUPPORTED_SOURCE_DOCUMENT_EXTENSIONS = (
".rst",
".md",
".pdf",
".txt",
".docx",
".csv",
".epub",
".html",
".mdx",
".json",
".xlsx",
".pptx",
)
SUPPORTED_SOURCE_IMAGE_EXTENSIONS = (".png", ".jpg", ".jpeg")
SUPPORTED_SOURCE_EXTENSIONS = (
*SUPPORTED_SOURCE_DOCUMENT_EXTENSIONS,
*SUPPORTED_SOURCE_IMAGE_EXTENSIONS,
*SUPPORTED_AUDIO_EXTENSIONS,
)

View File

@@ -1,16 +0,0 @@
You are a helpful AI assistant, DocsGPT. You are proactive and helpful. Try to use tools, if they are available to you,
be proactive and fill in missing information.
Users can Upload documents for your context as attachments or sources via UI using the Conversation input box.
If appropriate, your answers can include code examples, formatted as follows:
```(language)
(code)
```
Users are also able to see charts and diagrams if you use them with valid mermaid syntax in your responses.
Try to respond with mermaid charts if visualization helps with users queries.
You effectively utilize chat history, ensuring relevant and tailored responses.
You have access to a search tool that searches the user's uploaded documents and knowledge base.
Use the search_internal tool to find relevant information before answering questions.
You may search multiple times with different queries if needed.
Do not guess when documents are available — search first, then answer based on what you find.
If no relevant documents are found, use your general knowledge and tool capabilities.
Allow yourself to be very creative and use your imagination.

View File

@@ -1,15 +0,0 @@
You are a helpful AI assistant, DocsGPT. You are proactive and helpful. Try to use tools, if they are available to you,
be proactive and fill in missing information.
Users can Upload documents for your context as attachments or sources via UI using the Conversation input box.
If appropriate, your answers can include code examples, formatted as follows:
```(language)
(code)
```
Users are also able to see charts and diagrams if you use them with valid mermaid syntax in your responses.
Try to respond with mermaid charts if visualization helps with users queries.
You effectively utilize chat history, ensuring relevant and tailored responses.
You have access to a search tool that searches the user's uploaded documents and knowledge base.
Use the search_internal tool to find relevant information before answering questions.
You may search multiple times with different queries if needed.
Do not guess when documents are available — search first, then answer based on what you find.
If no relevant documents are found, use your general knowledge and tool capabilities.

View File

@@ -1,16 +0,0 @@
You are a helpful AI assistant, DocsGPT. You are proactive and helpful. Try to use tools, if they are available to you,
be proactive and fill in missing information.
Users can Upload documents for your context as attachments or sources via UI using the Conversation input box.
If appropriate, your answers can include code examples, formatted as follows:
```(language)
(code)
```
Users are also able to see charts and diagrams if you use them with valid mermaid syntax in your responses.
Try to respond with mermaid charts if visualization helps with users queries.
You effectively utilize chat history, ensuring relevant and tailored responses.
You have access to a search tool that searches the user's uploaded documents and knowledge base.
Use the search_internal tool to find relevant information before answering questions.
You may search multiple times with different queries if needed.
You MUST search before answering any factual question. Do not guess or use general knowledge when documents are available.
If you dont have enough information from the search results or tools, answer "I don't know" or "I don't have enough information".
Never make up information or provide false information!

View File

@@ -0,0 +1,3 @@
Query: {query}
Observations: {observations}
Now, using the insights from the observations, formulate a well-structured and precise final answer.

View File

@@ -0,0 +1,13 @@
You are an AI assistant and talk like you're thinking out loud. Given the following query, outline a concise thought process that includes key steps and considerations necessary for effective analysis and response. Avoid pointwise formatting. The goal is to break down the query into manageable components without excessive detail, focusing on clarity and logical progression.
Include the following elements in your thought and execution process:
1. Identify the main objective of the query.
2. Determine any relevant context or background information needed to understand the query.
3. List potential approaches or methods to address the query.
4. Highlight any critical factors or constraints that may influence the outcome.
5. Plan with available tools to help you with the analysis but dont execute them. Tools will be executed by another AI.
Query: {query}
Summaries: {summaries}
Prompt: {prompt}
Observations(potentially previous tool calls): {observations}

View File

@@ -1,23 +0,0 @@
You are a research assistant evaluating whether a user's question is clear enough to begin in-depth research.
Decide whether the question can be researched as-is, or whether you need to ask clarifying questions first.
Proceed WITHOUT clarification if:
- The question has a clear topic and intent
- There is enough context to form a research plan
- The question is broad but can be broken into sub-topics
- Documents/sources are available to search
Ask for clarification ONLY if:
- The question is critically ambiguous (e.g., "review the thing" with no context about what "thing" means)
- Key information is missing that would make research impossible (not just imperfect)
- The user seems to reference something specific but hasn't provided it
Err on the side of proceeding. Most questions are clear enough to start researching.
You MUST respond with ONLY a valid JSON object (no markdown, no code fences):
{
"needs_clarification": true or false,
"questions": ["question 1", "question 2"] (only if needs_clarification is true, max 3 questions),
"reason": "brief explanation of why clarification is or isn't needed"
}

View File

@@ -1,23 +0,0 @@
You are a research planner. Your job is to analyze the user's question and create a focused research plan.
IMPORTANT: Every step must be a concrete research action — something you can search for and find information about. Never generate steps that ask the user for more information or request documents. Work with what you have.
Assess the question's complexity:
- "simple": Can be answered with 1-2 focused searches (e.g., "What is our refund policy?")
- "moderate": Needs 3-4 searches across different topics (e.g., "Compare our pricing across product lines")
- "complex": Requires 5-6 searches with synthesis (e.g., "Audit our compliance docs against regulation X")
Break the question into sub-questions appropriate to its complexity. Don't over-decompose simple questions.
Consider:
- What distinct aspects of the question need separate investigation?
- What order should they be investigated in?
- Are there any dependencies between sub-questions?
You MUST respond with ONLY a valid JSON object in this exact format (no markdown, no code fences):
{
"complexity": "simple|moderate|complex",
"steps": [
{"query": "specific sub-question to investigate", "rationale": "why this step is needed"}
]
}

View File

@@ -1,13 +0,0 @@
You are a research agent investigating a specific topic. Your goal is to find comprehensive, accurate information.
Your current research task: {step_query}
Instructions:
1. Use the available tools to search for and gather relevant information.
2. If you have a search_internal tool, use it to search the knowledge base. You may search multiple times with different queries.
3. If you have a reason_think tool, use it to analyze what you've found before drawing conclusions.
4. After gathering sufficient information, provide a detailed summary of your findings.
5. Cite specific documents and passages you found. Reference sources by their titles or filenames.
6. If you cannot find relevant information through tools, use your general knowledge but clearly indicate this.
Be thorough — prefer completeness over brevity. Include all relevant details you find.

View File

@@ -1,22 +0,0 @@
You are compiling a comprehensive research report based on multiple research steps.
Original question: {question}
Research plan:
{plan_summary}
Intermediate findings:
{findings}
Write a well-structured, thorough report that:
1. Directly addresses the user's original question
2. Synthesizes findings from all research steps into a coherent narrative
3. Uses inline citations [N] for every factual claim, where N maps to the source list below
4. Highlights key insights and connections across different research steps
5. Notes any gaps or areas where information was insufficient
6. Includes a "References" section at the end listing all cited sources
Available sources for citation:
{references}
Format the report with clear headings and sections. Be comprehensive but well-organized.

View File

@@ -10,7 +10,7 @@ docling>=2.16.0
rapidocr>=1.4.0
onnxruntime>=1.19.0
docx2txt==0.9
ddgs>=8.0.0
duckduckgo-search==8.1.1
ebooklib==0.20
escodegen==1.0.11
esprima==4.0.1
@@ -47,7 +47,6 @@ markupsafe==3.0.3
marshmallow>=3.18.0,<5.0.0
mpmath==1.3.0
multidict==6.7.0
msal==1.34.0
mypy-extensions==1.1.0
networkx==3.6.1
numpy==2.4.0
@@ -96,4 +95,4 @@ werkzeug>=3.1.0
yarl==1.22.0
markdownify==1.2.2
tldextract==5.3.0
websockets==15.0.1
websockets==15.0.1

View File

@@ -18,7 +18,7 @@ agents:
- name: "Researcher"
description: "A specialized research agent that performs deep dives into subjects."
image: "https://d3dg1063dc54p9.cloudfront.net/imgs/agents/agent-researcher.svg"
agent_type: "classic"
agent_type: "react"
prompt:
name: "Researcher-Agent"
content: |

View File

@@ -1 +0,0 @@
"""Speech-to-text providers."""

View File

@@ -1,15 +0,0 @@
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any, Dict, Optional
class BaseSTT(ABC):
@abstractmethod
def transcribe(
self,
file_path: Path,
language: Optional[str] = None,
timestamps: bool = False,
diarize: bool = False,
) -> Dict[str, Any]:
pass

View File

@@ -1,15 +0,0 @@
SUPPORTED_AUDIO_EXTENSIONS = (".wav", ".mp3", ".m4a", ".ogg", ".webm")
SUPPORTED_AUDIO_MIME_TYPES = {
"application/ogg",
"audio/aac",
"audio/mp3",
"audio/mp4",
"audio/mpeg",
"audio/ogg",
"audio/wav",
"audio/webm",
"audio/x-m4a",
"audio/x-wav",
"video/webm",
}

View File

@@ -1,70 +0,0 @@
from pathlib import Path
from typing import Dict, Optional
from application.stt.base import BaseSTT
class FasterWhisperSTT(BaseSTT):
def __init__(
self,
model_size: str = "base",
device: str = "auto",
compute_type: str = "int8",
):
self.model_size = model_size
self.device = device
self.compute_type = compute_type
self._model = None
def _get_model(self):
if self._model is None:
try:
from faster_whisper import WhisperModel
except ImportError as exc:
raise ImportError(
"faster-whisper is required to use the faster_whisper STT provider."
) from exc
self._model = WhisperModel(
self.model_size,
device=self.device,
compute_type=self.compute_type,
)
return self._model
def transcribe(
self,
file_path: Path,
language: Optional[str] = None,
timestamps: bool = False,
diarize: bool = False,
) -> Dict[str, object]:
_ = diarize
model = self._get_model()
segments_iter, info = model.transcribe(
str(file_path),
language=language,
word_timestamps=timestamps,
)
segments = []
text_parts = []
for segment in segments_iter:
segment_text = getattr(segment, "text", "").strip()
if segment_text:
text_parts.append(segment_text)
segments.append(
{
"start": getattr(segment, "start", None),
"end": getattr(segment, "end", None),
"text": segment_text,
}
)
return {
"text": " ".join(text_parts).strip(),
"language": getattr(info, "language", language),
"duration_s": getattr(info, "duration", None),
"segments": segments if timestamps else [],
"provider": "faster_whisper",
}

View File

@@ -1,201 +0,0 @@
import json
import re
import uuid
from typing import Dict, Optional
LIVE_STT_SESSION_PREFIX = "stt_live_session:"
LIVE_STT_SESSION_TTL_SECONDS = 15 * 60
LIVE_STT_MUTABLE_TAIL_WORDS = 8
LIVE_STT_SILENCE_MUTABLE_TAIL_WORDS = 2
LIVE_STT_MIN_COMMITTED_OVERLAP_WORDS = 2
def normalize_transcript_text(text: str) -> str:
return " ".join((text or "").split()).strip()
def join_transcript_parts(*parts: str) -> str:
return " ".join(part for part in map(normalize_transcript_text, parts) if part)
def _normalize_word(word: str) -> str:
normalized = re.sub(r"[^\w]+", "", word.casefold(), flags=re.UNICODE)
return normalized or word.casefold()
def _split_words(text: str) -> list[str]:
normalized = normalize_transcript_text(text)
return normalized.split() if normalized else []
def _common_prefix_length(left_words: list[str], right_words: list[str]) -> int:
max_index = min(len(left_words), len(right_words))
prefix_length = 0
for index in range(max_index):
if _normalize_word(left_words[index]) != _normalize_word(right_words[index]):
break
prefix_length += 1
return prefix_length
def _find_suffix_prefix_overlap(
left_words: list[str], right_words: list[str], min_overlap: int
) -> int:
max_overlap = min(len(left_words), len(right_words))
if max_overlap < min_overlap:
return 0
left_keys = [_normalize_word(word) for word in left_words]
right_keys = [_normalize_word(word) for word in right_words]
for overlap_size in range(max_overlap, min_overlap - 1, -1):
if left_keys[-overlap_size:] == right_keys[:overlap_size]:
return overlap_size
return 0
def strip_committed_prefix(committed_text: str, hypothesis_text: str) -> str:
committed_words = _split_words(committed_text)
hypothesis_words = _split_words(hypothesis_text)
if not committed_words or not hypothesis_words:
return normalize_transcript_text(hypothesis_text)
full_prefix_length = _common_prefix_length(committed_words, hypothesis_words)
if full_prefix_length == len(committed_words):
return " ".join(hypothesis_words[full_prefix_length:])
overlap_size = _find_suffix_prefix_overlap(
committed_words,
hypothesis_words,
LIVE_STT_MIN_COMMITTED_OVERLAP_WORDS,
)
if overlap_size:
return " ".join(hypothesis_words[overlap_size:])
return " ".join(hypothesis_words)
def _calculate_commit_count(
previous_hypothesis: str, current_hypothesis: str, is_silence: bool
) -> int:
previous_words = _split_words(previous_hypothesis)
current_words = _split_words(current_hypothesis)
if not current_words:
return 0
if not previous_words:
if is_silence:
return max(0, len(current_words) - LIVE_STT_SILENCE_MUTABLE_TAIL_WORDS)
return 0
stable_prefix_length = _common_prefix_length(previous_words, current_words)
if not stable_prefix_length:
return 0
mutable_tail_words = (
LIVE_STT_SILENCE_MUTABLE_TAIL_WORDS
if is_silence
else LIVE_STT_MUTABLE_TAIL_WORDS
)
max_committable_by_tail = max(0, len(current_words) - mutable_tail_words)
return min(stable_prefix_length, max_committable_by_tail)
def create_live_stt_session(
user: str, language: Optional[str] = None
) -> Dict[str, object]:
return {
"session_id": str(uuid.uuid4()),
"user": user,
"language": language,
"committed_text": "",
"mutable_text": "",
"previous_hypothesis": "",
"latest_hypothesis": "",
"last_chunk_index": -1,
}
def get_live_stt_session_key(session_id: str) -> str:
return f"{LIVE_STT_SESSION_PREFIX}{session_id}"
def save_live_stt_session(redis_client, session_state: Dict[str, object]) -> None:
redis_client.setex(
get_live_stt_session_key(str(session_state["session_id"])),
LIVE_STT_SESSION_TTL_SECONDS,
json.dumps(session_state),
)
def load_live_stt_session(redis_client, session_id: str) -> Optional[Dict[str, object]]:
raw_session = redis_client.get(get_live_stt_session_key(session_id))
if not raw_session:
return None
if isinstance(raw_session, bytes):
raw_session = raw_session.decode("utf-8")
return json.loads(raw_session)
def delete_live_stt_session(redis_client, session_id: str) -> None:
redis_client.delete(get_live_stt_session_key(session_id))
def apply_live_stt_hypothesis(
session_state: Dict[str, object],
hypothesis_text: str,
chunk_index: int,
is_silence: bool = False,
) -> Dict[str, object]:
last_chunk_index = int(session_state.get("last_chunk_index", -1))
if chunk_index < 0:
raise ValueError("chunk_index must be non-negative")
if chunk_index < last_chunk_index:
raise ValueError("chunk_index is older than the last processed chunk")
if chunk_index == last_chunk_index:
return session_state
committed_text = normalize_transcript_text(str(session_state.get("committed_text", "")))
previous_hypothesis = normalize_transcript_text(
str(session_state.get("latest_hypothesis", ""))
)
current_hypothesis = strip_committed_prefix(committed_text, hypothesis_text)
if not current_hypothesis and is_silence and previous_hypothesis:
committed_text = join_transcript_parts(committed_text, previous_hypothesis)
previous_hypothesis = ""
commit_count = _calculate_commit_count(
previous_hypothesis,
current_hypothesis,
is_silence=is_silence,
)
current_words = _split_words(current_hypothesis)
if commit_count:
committed_text = join_transcript_parts(
committed_text,
" ".join(current_words[:commit_count]),
)
current_hypothesis = " ".join(current_words[commit_count:])
session_state["committed_text"] = committed_text
session_state["mutable_text"] = normalize_transcript_text(current_hypothesis)
session_state["previous_hypothesis"] = previous_hypothesis
session_state["latest_hypothesis"] = normalize_transcript_text(current_hypothesis)
session_state["last_chunk_index"] = chunk_index
return session_state
def get_live_stt_transcript_text(session_state: Dict[str, object]) -> str:
return join_transcript_parts(
str(session_state.get("committed_text", "")),
str(session_state.get("mutable_text", "")),
)
def finalize_live_stt_session(session_state: Dict[str, object]) -> str:
return join_transcript_parts(
str(session_state.get("committed_text", "")),
str(session_state.get("latest_hypothesis", "")),
)

View File

@@ -1,60 +0,0 @@
from pathlib import Path
from typing import Any, Dict, Optional
from openai import OpenAI
from application.core.settings import settings
from application.stt.base import BaseSTT
class OpenAISTT(BaseSTT):
def __init__(
self,
api_key: Optional[str] = None,
base_url: Optional[str] = None,
model: Optional[str] = None,
):
self.api_key = api_key or settings.OPENAI_API_KEY or settings.API_KEY
self.base_url = base_url or settings.OPENAI_BASE_URL or "https://api.openai.com/v1"
self.model = model or settings.OPENAI_STT_MODEL
self.client = OpenAI(api_key=self.api_key, base_url=self.base_url)
def transcribe(
self,
file_path: Path,
language: Optional[str] = None,
timestamps: bool = False,
diarize: bool = False,
) -> Dict[str, Any]:
_ = diarize
request: Dict[str, Any] = {
"file": file_path,
"model": self.model,
"response_format": "verbose_json",
}
if language:
request["language"] = language
if timestamps:
request["timestamp_granularities"] = ["segment"]
with open(file_path, "rb") as audio_file:
request["file"] = audio_file
response = self.client.audio.transcriptions.create(**request)
response_dict = self._to_dict(response)
segments = response_dict.get("segments") or []
return {
"text": response_dict.get("text", ""),
"language": response_dict.get("language") or language,
"duration_s": response_dict.get("duration"),
"segments": [self._to_dict(segment) for segment in segments],
"provider": "openai",
}
@staticmethod
def _to_dict(value: Any) -> Dict[str, Any]:
if hasattr(value, "model_dump"):
return value.model_dump()
if isinstance(value, dict):
return value
return {}

View File

@@ -1,17 +0,0 @@
from application.stt.base import BaseSTT
from application.stt.faster_whisper_stt import FasterWhisperSTT
from application.stt.openai_stt import OpenAISTT
class STTCreator:
stt_providers = {
"openai": OpenAISTT,
"faster_whisper": FasterWhisperSTT,
}
@classmethod
def create_stt(cls, stt_type, *args, **kwargs) -> BaseSTT:
stt_class = cls.stt_providers.get(stt_type.lower())
if not stt_class:
raise ValueError(f"No stt class found for type {stt_type}")
return stt_class(*args, **kwargs)

View File

@@ -1,43 +0,0 @@
from pathlib import Path
from application.core.settings import settings
from application.stt.constants import SUPPORTED_AUDIO_EXTENSIONS
from application.utils import safe_filename
STT_REQUEST_SIZE_OVERHEAD_BYTES = 1024 * 1024
STT_SIZE_LIMITED_PATHS = frozenset(("/api/stt", "/api/stt/live/chunk"))
class AudioFileTooLargeError(ValueError):
pass
def get_stt_max_file_size_bytes() -> int:
return max(0, settings.STT_MAX_FILE_SIZE_MB) * 1024 * 1024
def build_stt_file_size_limit_message() -> str:
return f"Audio file exceeds {settings.STT_MAX_FILE_SIZE_MB}MB limit"
def is_audio_filename(filename: str | Path | None) -> bool:
if not filename:
return False
safe_name = safe_filename(Path(str(filename)).name)
return Path(safe_name).suffix.lower() in SUPPORTED_AUDIO_EXTENSIONS
def enforce_audio_file_size_limit(size_bytes: int) -> None:
max_size_bytes = get_stt_max_file_size_bytes()
if max_size_bytes and size_bytes > max_size_bytes:
raise AudioFileTooLargeError(build_stt_file_size_limit_message())
def should_reject_stt_request(path: str, content_length: int | None) -> bool:
if path not in STT_SIZE_LIMITED_PATHS or content_length is None:
return False
max_request_size_bytes = (
get_stt_max_file_size_bytes() + STT_REQUEST_SIZE_OVERHEAD_BYTES
)
return content_length > max_request_size_bytes

View File

@@ -3,12 +3,12 @@ from typing import Any, Dict, List, Optional, Set
from jinja2 import (
ChainableUndefined,
Environment,
nodes,
select_autoescape,
TemplateSyntaxError,
)
from jinja2.exceptions import SecurityError, UndefinedError
from jinja2.sandbox import SandboxedEnvironment
from jinja2.exceptions import UndefinedError
logger = logging.getLogger(__name__)
@@ -23,7 +23,7 @@ class TemplateEngine:
"""Jinja2-based template engine for dynamic prompt rendering"""
def __init__(self):
self._env = SandboxedEnvironment(
self._env = Environment(
undefined=ChainableUndefined,
trim_blocks=True,
lstrip_blocks=True,
@@ -57,10 +57,6 @@ class TemplateEngine:
error_msg = f"Undefined variable in template: {e.message}"
logger.error(error_msg)
raise TemplateRenderError(error_msg) from e
except SecurityError as e:
error_msg = f"Template security violation: {str(e)}"
logger.error(error_msg)
raise TemplateRenderError(error_msg) from e
except Exception as e:
error_msg = f"Template rendering failed: {str(e)}"
logger.error(error_msg)

View File

@@ -26,7 +26,6 @@ from application.parser.chunking import Chunker
from application.parser.connectors.connector_creator import ConnectorCreator
from application.parser.embedding_pipeline import embed_and_store_documents
from application.parser.file.bulk import SimpleDirectoryReader, get_default_file_extractor
from application.parser.file.constants import SUPPORTED_SOURCE_EXTENSIONS
from application.parser.remote.remote_creator import RemoteCreator
from application.parser.schema.base import Document
from application.retriever.retriever_creator import RetrieverCreator
@@ -647,7 +646,23 @@ def reingest_source_worker(self, source_id, user):
reader = SimpleDirectoryReader(
input_dir=temp_dir,
recursive=True,
required_exts=list(SUPPORTED_SOURCE_EXTENSIONS),
required_exts=[
".rst",
".md",
".pdf",
".txt",
".docx",
".csv",
".epub",
".html",
".mdx",
".json",
".xlsx",
".pptx",
".png",
".jpg",
".jpeg",
],
exclude_hidden=True,
file_metadata=metadata_from_filename,
)
@@ -1125,25 +1140,17 @@ def attachment_worker(self, file_info, user):
file_extractor = get_default_file_extractor(
ocr_enabled=settings.DOCLING_OCR_ATTACHMENTS_ENABLED
)
attachment_document = storage.process_file(
content = storage.process_file(
relative_path,
lambda local_path, **kwargs: SimpleDirectoryReader(
input_files=[local_path],
exclude_hidden=True,
errors="ignore",
file_extractor=file_extractor,
file_metadata=metadata_from_filename,
)
.load_data()[0],
.load_data()[0]
.text,
)
content = attachment_document.text
parser_metadata = {
key: value
for key, value in (attachment_document.extra_info or {}).items()
if key.startswith("transcript_")
}
if parser_metadata:
metadata = {**metadata, **parser_metadata}
token_count = num_tokens_from_string(content)
if token_count > 100000:
@@ -1326,7 +1333,23 @@ def ingest_connector(
reader = SimpleDirectoryReader(
input_dir=temp_dir,
recursive=True,
required_exts=list(SUPPORTED_SOURCE_EXTENSIONS),
required_exts=[
".rst",
".md",
".pdf",
".txt",
".docx",
".csv",
".epub",
".html",
".mdx",
".json",
".xlsx",
".pptx",
".png",
".jpg",
".jpeg",
],
exclude_hidden=True,
file_metadata=metadata_from_filename,
)
@@ -1426,28 +1449,34 @@ def ingest_connector(
def mcp_oauth(self, config: Dict[str, Any], user_id: str = None) -> Dict[str, Any]:
"""Worker to handle MCP OAuth flow asynchronously."""
logging.info(
"[MCP OAuth] Worker started for user_id=%s, config=%s", user_id, config
)
try:
import asyncio
from application.agents.tools.mcp_tool import MCPTool
task_id = self.request.id
logging.info("[MCP OAuth] Task ID: %s", task_id)
redis_client = get_redis_instance()
def update_status(status_data: Dict[str, Any]):
logging.info("[MCP OAuth] Updating status: %s", status_data)
status_key = f"mcp_oauth_status:{task_id}"
redis_client.setex(status_key, 600, json.dumps(status_data))
update_status(
{
"status": "in_progress",
"message": "Starting OAuth...",
"message": "Starting OAuth flow...",
"task_id": task_id,
}
)
tool_config = config.copy()
tool_config["oauth_task_id"] = task_id
logging.info("[MCP OAuth] Initializing MCPTool with config: %s", tool_config)
mcp_tool = MCPTool(tool_config, user_id)
async def run_oauth_discovery():
@@ -1458,7 +1487,7 @@ def mcp_oauth(self, config: Dict[str, Any], user_id: str = None) -> Dict[str, An
update_status(
{
"status": "awaiting_redirect",
"message": "Awaiting OAuth redirect...",
"message": "Waiting for OAuth redirect...",
"task_id": task_id,
}
)
@@ -1467,40 +1496,66 @@ def mcp_oauth(self, config: Dict[str, Any], user_id: str = None) -> Dict[str, An
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(run_oauth_discovery())
logging.info("[MCP OAuth] Starting event loop for OAuth discovery...")
tools_response = loop.run_until_complete(run_oauth_discovery())
logging.info(
"[MCP OAuth] Tools response after async call: %s", tools_response
)
status_key = f"mcp_oauth_status:{task_id}"
redis_status = redis_client.get(status_key)
if redis_status:
logging.info(
"[MCP OAuth] Redis status after async call: %s", redis_status
)
else:
logging.warning(
"[MCP OAuth] No Redis status found after async call for key: %s",
status_key,
)
tools = mcp_tool.get_actions_metadata()
update_status(
{
"status": "completed",
"message": f"Connected \u2014 found {len(tools)} tool{'s' if len(tools) != 1 else ''}.",
"message": f"OAuth completed successfully. Found {len(tools)} tools.",
"tools": tools,
"tools_count": len(tools),
"task_id": task_id,
}
)
logging.info(
"[MCP OAuth] OAuth flow completed successfully for task_id=%s", task_id
)
return {"success": True, "tools": tools, "tools_count": len(tools)}
except Exception as e:
error_msg = f"OAuth failed: {str(e)}"
logging.error("MCP OAuth discovery failed: %s", error_msg, exc_info=True)
error_msg = f"OAuth flow failed: {str(e)}"
logging.error(
"[MCP OAuth] Exception in OAuth discovery: %s", error_msg, exc_info=True
)
update_status(
{
"status": "error",
"message": error_msg,
"error": str(e),
"task_id": task_id,
}
)
return {"success": False, "error": error_msg}
finally:
logging.info("[MCP OAuth] Closing event loop for task_id=%s", task_id)
loop.close()
except Exception as e:
error_msg = f"OAuth init failed: {str(e)}"
logging.error("MCP OAuth init failed: %s", error_msg, exc_info=True)
error_msg = f"Failed to initialize OAuth flow: {str(e)}"
logging.error(
"[MCP OAuth] Exception during initialization: %s", error_msg, exc_info=True
)
update_status(
{
"status": "error",
"message": error_msg,
"error": str(e),
"task_id": task_id,
}
)

View File

@@ -44,40 +44,36 @@ The main set of instructions or system [prompt](/Guides/Customising-prompts) tha
## Understanding Agent Types
DocsGPT supports several agent types, each with a distinct way of processing information. The code for these can be found in the `application/agents/` directory.
DocsGPT allows for different "types" of agents, each with a distinct way of processing information and generating responses. The code for these agent types can be found in the `application/agents/` directory.
### 1. Classic Agent
### 1. Classic Agent (`classic_agent.py`)
The Classic Agent follows a traditional Retrieval Augmented Generation (RAG) approach: it retrieves relevant document chunks, augments the prompt context with them, and generates a response. It can also use configured tools if the LLM decides they are necessary.
**How it works:** The Classic Agent follows a traditional Retrieval Augmented Generation (RAG) approach.
1. **Retrieve:** When a query is made, it first searches the selected Source documents for relevant information.
2. **Augment:** This retrieved data is then added to the context, along with the main Prompt and the user's query.
3. **Generate:** The LLM generates a response based on this augmented context. It can also utilize any configured tools if the LLM decides they are necessary.
**Best for:** Direct question-answering over a specific set of documents and straightforward tool use.
**Best for:**
* Direct question-answering over a specific set of documents.
* Tasks where the primary goal is to extract and synthesize information from the provided sources.
* Simpler tool integrations where the decision to use a tool is straightforward.
### 2. Agentic Agent
### 2. ReAct Agent (`react_agent.py`)
Unlike Classic which pre-fetches documents into the prompt, the Agentic Agent gives the LLM an `internal_search` tool so it can decide **when, what, and whether** to search. This means the LLM controls its own retrieval — it can search multiple times, refine queries, or skip retrieval entirely if the question doesn't need it.
**How it works:** The ReAct Agent employs a more sophisticated "Reason and Act" framework. This involves a multi-step process:
1. **Plan (Thought):** Based on the query, its prompt, and available tools/sources, the LLM first generates a plan or a sequence of thoughts on how to approach the problem. You might see this output as a "thought" process during generation.
2. **Act:** The agent then executes actions based on this plan. This might involve querying its sources, using a tool, or performing internal reasoning.
3. **Observe:** It gathers observations from the results of its actions (e.g., data from a tool, snippets from documents).
4. **Repeat (if necessary):** Steps 2 and 3 can be repeated as the agent refines its approach or gathers more information.
5. **Conclude:** Finally, it generates the final answer based on the initial query and all accumulated observations.
**Best for:** Tasks where the agent needs to dynamically decide how to gather information, use multiple tools in sequence, or combine retrieval with external tool calls.
### 3. Research Agent
A multi-phase agent designed for in-depth research tasks:
1. **Clarification** — Determines if the question needs clarification before proceeding.
2. **Planning** — Decomposes the question into research steps with adaptive depth based on complexity.
3. **Research** — Executes each step, calling tools and refining queries as needed.
4. **Synthesis** — Compiles findings into a final cited report.
Includes budget controls for max steps, timeout, and token limits to keep research bounded.
**Best for:** Complex questions that require multi-step investigation, gathering information from multiple sources, and producing structured reports with citations.
### 4. Workflow Agent
Executes predefined workflows composed of connected nodes (AI Agent, Set State, Condition). See the [Workflow Nodes](/Agents/nodes) page for details on building workflows.
**Best for:** Structured, multi-step processes with branching logic and shared state between steps.
**Best for:**
* More complex tasks that require multi-step reasoning or problem-solving.
* Scenarios where the agent needs to dynamically decide which tools to use and in what order, based on intermediate results.
* Interactive tasks where the agent needs to "think" through a problem.
<Callout type="info">
The legacy "ReAct" agent type is still accepted for backwards compatibility but maps to the Classic Agent internally. New agents should use Classic, Agentic, or Research instead.
Developers looking to introduce new agent architectures can explore the `application/agents/` directory. `classic_agent.py` and `react_agent.py` serve as excellent starting points, demonstrating how to inherit from `BaseAgent` and structure agent logic.
</Callout>
## Navigating and Managing Agents in DocsGPT

View File

@@ -70,9 +70,9 @@ Inside the DocsGPT folder create a `.env` file and copy the contents of `.env_sa
Make sure your `.env` file looks like this:
```
API_KEY=<Your LLM API key>
LLM_NAME=docsgpt
OPENAI_API_KEY=(Your OpenAI API key)
VITE_API_STREAMING=true
SELF_HOSTED_MODEL=false
```
To save the file, press CTRL+X, then Y, and then ENTER.

View File

@@ -65,8 +65,6 @@ Here are some of the most fundamental settings you'll likely want to configure:
- **`OPENAI_BASE_URL`**: Specifically used when `LLM_PROVIDER` is set to `openai` but you are connecting to a local inference engine (like Ollama, Llama.cpp, etc.) that exposes an OpenAI-compatible API. This setting tells DocsGPT where to find your local LLM server.
- **`STT_PROVIDER`**: Selects the speech-to-text provider used for microphone transcription in chat and for audio file ingestion through the parser pipeline.
## Configuration Examples
Let's look at some concrete examples of how to configure these settings in your `.env` file.
@@ -97,49 +95,6 @@ EMBEDDINGS_NAME=huggingface_sentence-transformers/all-mpnet-base-v2 # You can al
In this case, even though you are using Ollama locally, `LLM_PROVIDER` is set to `openai` because Ollama (and many other local inference engines) are designed to be API-compatible with OpenAI. `OPENAI_BASE_URL` points DocsGPT to the local Ollama server.
## Speech-to-Text Settings
DocsGPT can transcribe audio in two places:
- Voice input in the chat.
- Audio file ingestion. Uploaded `.wav`, `.mp3`, `.m4a`, `.ogg`, and `.webm` files are transcribed first and then passed through the normal parser, chunking, embedding, and indexing pipeline.
The settings below control speech-to-text behaviour for both voice input and audio file ingestion.
| Setting | Purpose | Typical values |
| --- | --- | --- |
| `STT_PROVIDER` | Speech-to-text backend provider. | `openai`, `faster_whisper` |
| `OPENAI_STT_MODEL` | OpenAI transcription model used when `STT_PROVIDER=openai`. | `gpt-4o-mini-transcribe` |
| `STT_LANGUAGE` | Optional language hint passed to the provider. Leave unset for auto-detection when supported. | `en`, `es`, unset |
| `STT_MAX_FILE_SIZE_MB` | Maximum file size accepted by the synchronous `/api/stt` endpoint. | `50` |
| `STT_ENABLE_TIMESTAMPS` | Include timestamp segments in the normalized transcript response and stored parser metadata. | `true`, `false` |
| `STT_ENABLE_DIARIZATION` | Reserved provider option for speaker diarization. Some providers may ignore it. | `true`, `false` |
### Example: OpenAI Speech-to-Text
```env
STT_PROVIDER=openai
OPENAI_API_KEY=YOUR_OPENAI_API_KEY
OPENAI_STT_MODEL=gpt-4o-mini-transcribe
STT_LANGUAGE=
STT_MAX_FILE_SIZE_MB=50
STT_ENABLE_TIMESTAMPS=false
STT_ENABLE_DIARIZATION=false
```
If you already use `API_KEY` for OpenAI, DocsGPT can reuse that key for transcription. Set `OPENAI_API_KEY` only when you want a dedicated key.
### Example: Local `faster_whisper`
```env
STT_PROVIDER=faster_whisper
STT_LANGUAGE=en
STT_ENABLE_TIMESTAMPS=true
STT_ENABLE_DIARIZATION=false
```
`faster_whisper` is an optional backend dependency. Install it in the Python environment used by the DocsGPT API and worker before selecting this provider.
## Authentication Settings
DocsGPT includes a JWT (JSON Web Token) based authentication feature for managing sessions or securing local deployments while allowing access.
@@ -214,31 +169,6 @@ If you have configured `AUTH_TYPE=simple_jwt`, the DocsGPT frontend will prompt
}}
/>
## S3 Storage Backend
By default DocsGPT stores files locally. Set `STORAGE_TYPE=s3` to use Amazon S3 instead.
| Setting | Description | Default |
| --- | --- | --- |
| `STORAGE_TYPE` | `local` or `s3` | `local` |
| `S3_BUCKET_NAME` | S3 bucket name | `docsgpt-test-bucket` |
| `SAGEMAKER_ACCESS_KEY` | AWS access key ID | — |
| `SAGEMAKER_SECRET_KEY` | AWS secret access key | — |
| `SAGEMAKER_REGION` | AWS region | — |
| `URL_STRATEGY` | `backend` (proxy through API) or `s3` (direct S3 URLs) | `backend` |
The S3 credentials use `SAGEMAKER_*` variable names because they are shared with the SageMaker integration.
```env
STORAGE_TYPE=s3
S3_BUCKET_NAME=your-bucket-name
SAGEMAKER_ACCESS_KEY=your-aws-access-key-id
SAGEMAKER_SECRET_KEY=your-aws-secret-access-key
SAGEMAKER_REGION=us-east-1
```
Your IAM user needs these permissions on the bucket: `s3:PutObject`, `s3:GetObject`, `s3:DeleteObject`, `s3:ListBucket`, `s3:HeadObject`.
## Exploring More Settings
These are just the basic settings to get you started. The `settings.py` file contains many more advanced options that you can explore to further customize DocsGPT, such as:
@@ -249,3 +179,5 @@ These are just the basic settings to get you started. The `settings.py` file con
- And many more!
For a complete list of available settings and their descriptions, refer to the `settings.py` file in `application/core`. Remember to restart your Docker containers after making changes to your `.env` file or `settings.py` for the changes to take effect.

View File

@@ -86,9 +86,13 @@ Make sure your `.env` file looks like this:
```
API_KEY=<Your LLM API key>
LLM_NAME=docsgpt
OPENAI_API_KEY=(Your OpenAI API key)
VITE_API_STREAMING=true
SELF_HOSTED_MODEL=false
```

View File

@@ -11,18 +11,18 @@ DocsGPT API keys are essential for developers and users who wish to integrate th
After uploading your document, you can obtain an API key either through the graphical user interface or via an API call:
- **Graphical User Interface:** Navigate to the Settings section of the DocsGPT web app, find the Agents option, and press 'Create New' to generate a new agent (which includes an API key).
- **API Call:** Alternatively, you can use the `/api/create_agent` endpoint to create a new agent. An API key is automatically generated for each agent. For detailed instructions, visit [DocsGPT API Documentation](https://gptcloud.arc53.com/).
- **Graphical User Interface:** Navigate to the Settings section of the DocsGPT web app, find the API Keys option, and press 'Create New' to generate your key.
- **API Call:** Alternatively, you can use the `/api/create_api_key` endpoint to create a new API key. For detailed instructions, visit [DocsGPT API Documentation](https://gptcloud.arc53.com/).
## Understanding Key Variables
Upon creating your agent, you will encounter several key variables. Each serves a specific purpose:
Upon creating your API key, you will encounter several key variables. Each serves a specific purpose:
- **Name:** Assign a name to your agent for easy identification.
- **Source:** Indicates the source document(s) linked to your agent, which DocsGPT will use to generate responses.
- **ID:** A unique identifier for your agent. You can view this by making a call to `/api/get_agents`.
- **Key:** The API key for the agent, which will be used in your application to authenticate API requests.
- **Name:** Assign a name to your API key for easy identification.
- **Source:** Indicates the source document(s) linked to your API key, which DocsGPT will use to generate responses.
- **ID:** A unique identifier for your API key. You can view this by making a call to `/api/get_api_keys`.
- **Key:** The API key itself, which will be used in your application to authenticate API requests.
With your API key ready, you can now integrate DocsGPT into your application, such as the DocsGPT Widget or any other software, via `/api/answer` or `/stream` endpoints. The source document is preset with the agent, allowing you to bypass fields like `selectDocs` and `active_docs` during implementation.
With your API key ready, you can now integrate DocsGPT into your application, such as the DocsGPT Widget or any other software, via `/api/answer` or `/stream` endpoints. The source document is preset with the API key, allowing you to bypass fields like `selectDocs` and `active_docs` during implementation.
Congratulations on taking the first step towards enhancing your applications with DocsGPT!

View File

@@ -64,7 +64,7 @@ flowchart LR
* **Technology:** Supports multiple vector databases.
* **Responsibility:** Vector Stores are used to store and retrieve vector embeddings of document chunks. This enables semantic search and retrieval of relevant document snippets in response to user queries.
* **Key Features:**
* Supports vector databases including FAISS, Elasticsearch, Qdrant, Milvus, MongoDB Atlas Vector Search, and pgvector.
* Supports vector databases including FAISS, Elasticsearch, Qdrant, Milvus, and LanceDB.
* Provides storage and indexing of high-dimensional vector embeddings.
* Enables editing and updating of vector indexes including specific chunks.

View File

@@ -16,7 +16,7 @@ Training on other documentation sources can greatly enhance the versatility and
Make sure you have the document on which you want to train on ready with you on the device which you are using .You can also use links to the documentation to train on.
<Callout type="warning" emoji="⚠️">
Note: Supported file formats include .pdf, .txt, .rst, .docx, .md, .mdx, .csv, .epub, .html, .json, .xlsx, .pptx, .png, .jpg, .jpeg, and audio files (.wav, .mp3, .m4a, .ogg, .webm). You can also train using the link of the documentation.
Note: The document should be either of the given file formats .pdf, .txt, .rst, .docx, .md, .zip and limited to 25mb.You can also train using the link of the documentation.
</Callout>

View File

@@ -35,34 +35,8 @@ Choose the LLM of your choice.
For open source version please edit `LLM_PROVIDER`, `LLM_NAME` and others in the .env file. Refer to [⚙️ App Configuration](/Deploying/DocsGPT-Settings) for more information.
### Step 2
Visit [☁️ Cloud Providers](/Models/cloud-providers) for the updated list of online models. Make sure you have the right API_KEY and correct LLM_PROVIDER.
For self-hosted please visit [🖥️ Local Inference](/Models/local-inference).
For self-hosted please visit [🖥️ Local Inference](/Models/local-inference).
</Steps>
## Fallback LLM
DocsGPT can automatically switch to a fallback LLM when the primary model fails, including mid-stream. This works with both streaming and non-streaming requests.
**Fallback order:**
1. Per-agent backup models (other models configured on the same agent)
2. Global fallback (`FALLBACK_LLM_*` env vars below)
3. Error returned if all fail
| Setting | Description | Default |
| --- | --- | --- |
| `FALLBACK_LLM_PROVIDER` | Provider name (e.g., `openai`, `anthropic`, `google`) | — |
| `FALLBACK_LLM_NAME` | Model name (e.g., `gpt-4o`, `claude-sonnet-4-20250514`) | — |
| `FALLBACK_LLM_API_KEY` | API key for the fallback provider | Falls back to `API_KEY` |
All three (`FALLBACK_LLM_PROVIDER`, `FALLBACK_LLM_NAME`, and an API key) must resolve for the global fallback to activate.
```env
FALLBACK_LLM_PROVIDER=anthropic
FALLBACK_LLM_NAME=claude-sonnet-4-20250514
FALLBACK_LLM_API_KEY=sk-ant-your-anthropic-key
```
<Callout type="info">
For maximum resilience, use a fallback provider from a different cloud than your primary. Each agent can also have multiple models configured — the other models are tried first before the global fallback.
</Callout>

View File

@@ -2,13 +2,5 @@ export default {
"google-drive-connector": {
"title": "🔗 Google Drive",
"href": "/Guides/Integrations/google-drive-connector"
},
"sharepoint-connector": {
"title": "🔗 SharePoint / OneDrive",
"href": "/Guides/Integrations/sharepoint-connector"
},
"mcp-tool-integration": {
"title": "🔗 MCP Tools",
"href": "/Guides/Integrations/mcp-tool-integration"
}
}

View File

@@ -1,66 +0,0 @@
---
title: MCP Tool Integration
description: Connect external tools to DocsGPT agents using the Model Context Protocol (MCP) standard.
---
import { Callout } from 'nextra/components'
import { Steps } from 'nextra/components'
# MCP Tool Integration
The [Model Context Protocol (MCP)](https://modelcontextprotocol.io/) integration lets you connect external tool servers to DocsGPT. Your agents can then discover and call tools provided by those servers during conversations — for example, querying a CRM, running code, or accessing a database.
## Setup
<Steps>
### Step 1: Configure Environment Variables (Optional)
Only needed if your MCP servers use OAuth authentication:
```env
MCP_OAUTH_REDIRECT_URI=https://yourdomain.com/api/mcp_server/callback
```
If not set, falls back to `API_URL/api/mcp_server/callback`.
### Step 2: Add an MCP Server
Go to **Settings** > **Tools** > **Add Tool** > **MCP Server**. Enter the server URL, select an auth type, and click **Test Connection** to verify, then **Save**.
### Step 3: Enable for Your Agent
In your agent configuration, enable the MCP tools you want the agent to use.
</Steps>
## Authentication Types
| Auth Type | Config Fields |
|-----------|---------------|
| **None** | — |
| **Bearer** | `bearer_token` |
| **API Key** | `api_key`, `api_key_header` (default: `X-API-Key`) |
| **Basic** | `username`, `password` |
| **OAuth** | `oauth_scopes` (optional) |
<Callout type="warning">
For OAuth in production, `MCP_OAUTH_REDIRECT_URI` must be a publicly accessible URL pointing to your DocsGPT backend.
</Callout>
## API Endpoints
| Endpoint | Method | Description |
|----------|--------|-------------|
| `/api/mcp_server/test` | POST | Test a connection without saving |
| `/api/mcp_server/save` | POST | Save or update a server configuration |
| `/api/mcp_server/callback` | GET | OAuth callback handler |
| `/api/mcp_server/oauth_status/<task_id>` | GET | Poll OAuth flow status |
| `/api/mcp_server/auth_status` | GET | Batch check auth status for all MCP tools |
## Troubleshooting
- **Connection refused** — Verify the URL and that the server is reachable from your backend.
- **403 Forbidden** — Check credentials and permissions.
- **Timed out** — Default is 30s; increase timeout in tool config (max 300s).
- **OAuth "needs_auth" persists** — Verify `MCP_OAUTH_REDIRECT_URI` is correct and Redis is running.

View File

@@ -1,63 +0,0 @@
---
title: SharePoint / OneDrive Connector
description: Connect your Microsoft SharePoint or OneDrive as an external knowledge base to upload and process files directly.
---
import { Callout } from 'nextra/components'
import { Steps } from 'nextra/components'
# SharePoint / OneDrive Connector
Connect your SharePoint or OneDrive account to upload and process files directly as an external knowledge base. Supports Office files, PDFs, text files, CSVs, images, and more. Authentication is handled via Microsoft Entra ID (Azure AD) with automatic token refresh.
## Setup
<Steps>
### Step 1: Create an App Registration in Azure
1. Go to the [Azure Portal](https://portal.azure.com/) > **Microsoft Entra ID** > **App registrations** > **New registration**
2. Set **Redirect URI** (Web) to:
- Local: `http://localhost:7091/api/connectors/callback?provider=share_point`
- Production: `https://yourdomain.com/api/connectors/callback?provider=share_point`
### Step 2: Configure API Permissions
In your App Registration, go to **API permissions** > **Add a permission** > **Microsoft Graph** > **Delegated permissions** and add: `Files.Read`, `Files.Read.All`, `Sites.Read.All`. Grant admin consent if possible.
### Step 3: Create a Client Secret
Go to **Certificates & secrets** > **New client secret**. Copy the secret value immediately (it won't be shown again).
### Step 4: Configure Environment Variables
Add to your `.env` file:
```env
MICROSOFT_CLIENT_ID=your-azure-ad-client-id
MICROSOFT_CLIENT_SECRET=your-azure-ad-client-secret
MICROSOFT_TENANT_ID=your-azure-ad-tenant-id
```
| Variable | Description | Required | Default |
|----------|-------------|----------|---------|
| `MICROSOFT_CLIENT_ID` | Application (client) ID from App Registration overview | Yes | — |
| `MICROSOFT_CLIENT_SECRET` | Client secret value | Yes | — |
| `MICROSOFT_TENANT_ID` | Directory (tenant) ID | No | `common` |
| `MICROSOFT_AUTHORITY` | Login endpoint override | No | Auto-constructed |
<Callout type="warning">
`MICROSOFT_TENANT_ID=common` (the default) allows any Microsoft account to authenticate. Set this to your specific tenant ID in production.
</Callout>
### Step 5: Restart and Use
Restart your application, then go to the upload section in DocsGPT and select **SharePoint / OneDrive** as the source. You'll be redirected to Microsoft to sign in, then can browse and select files to process.
</Steps>
## Troubleshooting
- **Option not appearing** — Verify `MICROSOFT_CLIENT_ID` and `MICROSOFT_CLIENT_SECRET` are set, then restart.
- **Authentication failed** — Check that the redirect URI matches exactly, including `?provider=share_point`.
- **Permission denied** — Ensure admin consent is granted and the user has access to the target files.

View File

@@ -7,10 +7,20 @@ description:
If your AI uses external knowledge and is not explicit enough, it is ok, because we try to make DocsGPT friendly.
But if you want to adjust it, prompts are now managed through the UI and API using a template-based system. See the [Customising Prompts](/Guides/Customising-prompts) guide for details.
But if you want to adjust it, here is a simple way:-
- Got to `application/prompts/chat_combine_prompt.txt`
- And change it to
To make the AI stricter about staying on-topic, edit your active prompt template (via **Sidebar → Settings → Active Prompt**) to include instructions like:
```
You are a DocsGPT, friendly and helpful AI assistant by Arc53 that provides help with documents. You give thorough answers with code examples, if possible.
Write an answer for the question below based on the provided context.
If the context provides insufficient information, reply "I cannot answer".
You have access to chat history and can use it to help answer the question.
----------------
{summaries}
```

Some files were not shown because too many files have changed in this diff Show More