Compare commits

..

1 Commits

Author SHA1 Message Date
dependabot[bot]
4c7a6a78aa chore(deps): bump docker/setup-qemu-action from 3 to 4
Bumps [docker/setup-qemu-action](https://github.com/docker/setup-qemu-action) from 3 to 4.
- [Release notes](https://github.com/docker/setup-qemu-action/releases)
- [Commits](https://github.com/docker/setup-qemu-action/compare/v3...v4)

---
updated-dependencies:
- dependency-name: docker/setup-qemu-action
  dependency-version: '4'
  dependency-type: direct:production
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>
2026-03-04 20:54:17 +00:00
426 changed files with 12883 additions and 95703 deletions

View File

@@ -3,14 +3,6 @@ LLM_NAME=docsgpt
VITE_API_STREAMING=true
INTERNAL_KEY=<internal key for worker-to-backend authentication>
# Provider-specific API keys (optional - use these to enable multiple providers)
# OPENAI_API_KEY=<your-openai-api-key>
# ANTHROPIC_API_KEY=<your-anthropic-api-key>
# GOOGLE_API_KEY=<your-google-api-key>
# GROQ_API_KEY=<your-groq-api-key>
# NOVITA_API_KEY=<your-novita-api-key>
# OPEN_ROUTER_API_KEY=<your-openrouter-api-key>
# Remote Embeddings (Optional - for using a remote embeddings API instead of local SentenceTransformer)
# When set, the app will use the remote API and won't load SentenceTransformer (saves RAM)
EMBEDDINGS_BASE_URL=
@@ -20,17 +12,4 @@ EMBEDDINGS_KEY=
OPENAI_API_BASE=
OPENAI_API_VERSION=
AZURE_DEPLOYMENT_NAME=
AZURE_EMBEDDINGS_DEPLOYMENT_NAME=
#Azure AD Application (client) ID
MICROSOFT_CLIENT_ID=your-azure-ad-client-id
#Azure AD Application client secret
MICROSOFT_CLIENT_SECRET=your-azure-ad-client-secret
#Azure AD Tenant ID (or 'common' for multi-tenant)
MICROSOFT_TENANT_ID=your-azure-ad-tenant-id
#If you are using a Microsoft Entra ID tenant,
#configure the AUTHORITY variable as
#"https://login.microsoftonline.com/TENANT_GUID"
#or "https://login.microsoftonline.com/contoso.onmicrosoft.com".
#Alternatively, use "https://login.microsoftonline.com/common" for multi-tenant app.
MICROSOFT_AUTHORITY=https://{tenantId}.ciamlogin.com/{tenantId}
AZURE_EMBEDDINGS_DEPLOYMENT_NAME=

View File

@@ -25,7 +25,7 @@ jobs:
- name: Set up QEMU # Only needed for emulation, not for native arm64 builds
if: matrix.platform == 'linux/arm64'
uses: docker/setup-qemu-action@v3
uses: docker/setup-qemu-action@v4
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3

View File

@@ -25,7 +25,7 @@ jobs:
- name: Set up QEMU # Only needed for emulation, not for native arm64 builds
if: matrix.platform == 'linux/arm64'
uses: docker/setup-qemu-action@v3
uses: docker/setup-qemu-action@v4
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3

View File

@@ -27,7 +27,7 @@ jobs:
- name: Set up QEMU # Only needed for emulation, not for native arm64 builds
if: matrix.platform == 'linux/arm64'
uses: docker/setup-qemu-action@v3
uses: docker/setup-qemu-action@v4
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3

View File

@@ -1,114 +0,0 @@
name: Publish npm libraries
on:
workflow_dispatch:
inputs:
version:
description: >
Version bump type (patch | minor | major) or explicit semver (e.g. 1.2.3).
Applies to both docsgpt and docsgpt-react.
required: true
default: patch
permissions:
contents: write
pull-requests: write
jobs:
publish:
runs-on: ubuntu-latest
environment: npm-release
defaults:
run:
working-directory: extensions/react-widget
steps:
- uses: actions/checkout@v4
- uses: actions/setup-node@v4
with:
node-version: 20
registry-url: https://registry.npmjs.org
- name: Install dependencies
run: npm ci
# ── docsgpt (HTML embedding bundle) ──────────────────────────────────
# Uses the `build` script (parcel build src/browser.tsx) and keeps
# the `targets` field so Parcel produces browser-optimised bundles.
- name: Set package name → docsgpt
run: jq --arg n "docsgpt" '.name=$n' package.json > _tmp.json && mv _tmp.json package.json
- name: Bump version (docsgpt)
id: version_docsgpt
run: |
VERSION="${{ github.event.inputs.version }}"
NEW_VER=$(npm version "${VERSION:-patch}" --no-git-tag-version)
echo "version=${NEW_VER#v}" >> "$GITHUB_OUTPUT"
- name: Build docsgpt
run: npm run build
- name: Publish docsgpt
run: npm publish --verbose
env:
NODE_AUTH_TOKEN: ${{ secrets.NPM_TOKEN }}
# ── docsgpt-react (React library bundle) ─────────────────────────────
# Uses `build:react` script (parcel build src/index.ts) and strips
# the `targets` field so Parcel treats the output as a plain library
# without browser-specific target resolution, producing a smaller bundle.
- name: Reset package.json from source control
run: git checkout -- package.json
- name: Set package name → docsgpt-react
run: jq --arg n "docsgpt-react" '.name=$n' package.json > _tmp.json && mv _tmp.json package.json
- name: Remove targets field (react library build)
run: jq 'del(.targets)' package.json > _tmp.json && mv _tmp.json package.json
- name: Bump version (docsgpt-react) to match docsgpt
run: npm version "${{ steps.version_docsgpt.outputs.version }}" --no-git-tag-version
- name: Clean dist before react build
run: rm -rf dist
- name: Build docsgpt-react
run: npm run build:react
- name: Publish docsgpt-react
run: npm publish --verbose
env:
NODE_AUTH_TOKEN: ${{ secrets.NPM_TOKEN }}
# ── Commit the bumped version back to the repository ─────────────────
- name: Reset package.json and write final version
run: |
git checkout -- package.json
jq --arg v "${{ steps.version_docsgpt.outputs.version }}" '.version=$v' \
package.json > _tmp.json && mv _tmp.json package.json
npm install --package-lock-only
- name: Commit version bump and create PR
run: |
git config user.name "github-actions[bot]"
git config user.email "github-actions[bot]@users.noreply.github.com"
BRANCH="chore/bump-npm-v${{ steps.version_docsgpt.outputs.version }}"
git checkout -b "$BRANCH"
git add package.json package-lock.json
git commit -m "chore: bump npm libraries to v${{ steps.version_docsgpt.outputs.version }}"
git push origin "$BRANCH"
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
- name: Create PR
run: |
gh pr create \
--title "chore: bump npm libraries to v${{ steps.version_docsgpt.outputs.version }}" \
--body "Automated version bump after npm publish." \
--base main
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}

View File

@@ -1,34 +0,0 @@
name: React Widget Build
on:
push:
paths:
- 'extensions/react-widget/**'
pull_request:
paths:
- 'extensions/react-widget/**'
permissions:
contents: read
jobs:
build:
runs-on: ubuntu-latest
defaults:
run:
working-directory: extensions/react-widget
steps:
- uses: actions/checkout@v4
- uses: actions/setup-node@v4
with:
node-version: 20
cache: npm
cache-dependency-path: extensions/react-widget/package-lock.json
- name: Install dependencies
run: npm ci
- name: Build
run: npm run build

1
.gitignore vendored
View File

@@ -2,7 +2,6 @@
__pycache__/
*.py[cod]
*$py.class
results.txt
experiments/
experiments

134
AGENTS.md
View File

@@ -1,134 +0,0 @@
# AGENTS.md
- Read `CONTRIBUTING.md` before making non-trivial changes.
- For day-to-day development and feature work, follow the development-environment workflow rather than defaulting to `setup.sh` / `setup.ps1`.
- Avoid using the setup scripts during normal feature work unless the user explicitly asks for them. Users configure `.env` usually.
- Try to follow red/green TDD
### Check existing dev prerequisites first
For feature work, do **not** assume the environment needs to be recreated.
- Check whether the user already has a Python virtual environment such as `venv/` or `.venv/`.
- Check whether MongoDB is already running.
- Check whether Redis is already running.
- Reuse what is already working. Do not stop or recreate MongoDB, Redis, or the Python environment unless the task is environment setup or troubleshooting.
## Normal local development commands
Use these commands once the dev prerequisites above are satisfied.
### Backend
```bash
source .venv/bin/activate # macOS/Linux
uv pip install -r application/requirements.txt # or: pip install -r application/requirements.txt
```
Run the Flask API (if needed):
```bash
flask --app application/app.py run --host=0.0.0.0 --port=7091
```
Run the Celery worker in a separate terminal (if needed):
```bash
celery -A application.app.celery worker -l INFO
```
On macOS, prefer the solo pool for Celery:
```bash
python -m celery -A application.app.celery worker -l INFO --pool=solo
```
### Frontend
Install dependencies only when needed, then run the dev server:
```bash
cd frontend
npm install --include=dev
npm run dev
```
### Docs site
```bash
cd docs
npm install
```
### Python / backend changes validation
```bash
ruff check .
python -m pytest
```
### Frontend changes
```bash
cd frontend && npm run lint
cd frontend && npm run build
```
### Documentation changes
```bash
cd docs && npm run build
```
If Vale is installed locally and you edited prose, also run:
```bash
vale .
```
## Repository map
- `application/`: Flask backend, API routes, agent logic, retrieval, parsing, security, storage, Celery worker, and WSGI entrypoints.
- `tests/`: backend unit/integration tests and test-only Python dependencies.
- `frontend/`: Vite + React + TypeScript application.
- `frontend/src/`: main UI code, including `components`, `conversation`, `hooks`, `locale`, `settings`, `upload`, and Redux store wiring in `store.ts`.
- `docs/`: separate documentation site built with Next.js/Nextra.
- `extensions/`: integrations and widgets such as Chatwoot, Chrome, Discord, React widget, Slack bot, and web widget.
- `deployment/`: Docker Compose variants and Kubernetes manifests.
## Coding rules
### Backend
- Follow PEP 8 and keep Python line length at or under 120 characters.
- Use type hints for function arguments and return values.
- Add Google-style docstrings to new or substantially changed functions and classes.
- Add or update tests under `tests/` for backend behavior changes.
- Keep changes narrow in `api`, `auth`, `security`, `parser`, `retriever`, and `storage` areas.
### Backend Abstractions
- LLM providers implement a common interface in `application/llm/` (add new providers by extending the base class).
- Vector stores are abstracted in `application/vectorstore/`.
- Parsers live in `application/parser/` and handle different document formats in the ingestion stage.
- Agents and tools are in `application/agents/` and `application/agents/tools/`.
- Celery setup/config lives in `application/celery_init.py` and `application/celeryconfig.py`.
- Settings and env vars are managed via Pydantic in `application/core/settings.py`.
### Frontend
- Follow the existing ESLint + Prettier setup.
- Prefer small, reusable functional components and hooks.
- If shared state must be added, use Redux rather than introducing a new global state library.
- Avoid broad UI refactors unless the task explicitly asks for them.
- Do not re-create components if we already have some in the app.
## PR readiness
Before opening a PR:
- run the relevant validation commands above
- confirm backend changes still work end-to-end after ingesting sample data when applicable
- clearly summarize user-visible behavior changes
- mention any config, dependency, or deployment implications
- Ask your user to attach a screenshot or a video to it

View File

@@ -22,11 +22,6 @@ Thank you for choosing to contribute to DocsGPT! We are all very grateful!
- We have a frontend built on React (Vite) and a backend in Python.
> **Required for every PR:** Please attach screenshots or a short screen
> recording that shows the working version of your changes. This makes the
> requirement visible to reviewers and helps them quickly verify what you are
> submitting.
Before creating issues, please check out how the latest version of our app looks and works by launching it via [Quickstart](https://github.com/arc53/DocsGPT#quickstart) the version on our live demo is slightly modified with login. Your issues should relate to the version you can launch via [Quickstart](https://github.com/arc53/DocsGPT#quickstart).
@@ -130,7 +125,7 @@ Here's a step-by-step guide on how to contribute to DocsGPT:
```
9. **Submit a Pull Request (PR):**
- Create a Pull Request from your branch to the main repository. Make sure to include a detailed description of your changes, reference any related issues, and attach screenshots or a screen recording showing the working version.
- Create a Pull Request from your branch to the main repository. Make sure to include a detailed description of your changes and reference any related issues.
10. **Collaborate:**
- Be responsive to comments and feedback on your PR.

View File

@@ -7,7 +7,7 @@
</p>
<p align="left">
<strong><a href="https://www.docsgpt.cloud/">DocsGPT</a></strong> is an open-source AI platform for building intelligent agents and assistants. Features Agent Builder, deep research tools, document analysis (PDF, Office, web content, and audio), Multi-model support (choose your provider or run locally), and rich API connectivity for agents with actionable tools and integrations. Deploy anywhere with complete privacy control.
<strong><a href="https://www.docsgpt.cloud/">DocsGPT</a></strong> is an open-source AI platform for building intelligent agents and assistants. Features Agent Builder, deep research tools, document analysis (PDF, Office, web content), Multi-model support (choose your provider or run locally), and rich API connectivity for agents with actionable tools and integrations. Deploy anywhere with complete privacy control.
</p>
<div align="center">
@@ -29,14 +29,13 @@
<div align="center">
<br>
<img src="https://d3dg1063dc54p9.cloudfront.net/videos/demo-26.gif" alt="video-example-of-docs-gpt" width="800" height="480">
<img src="https://d3dg1063dc54p9.cloudfront.net/videos/demov7.gif" alt="video-example-of-docs-gpt" width="800" height="450">
</div>
<h3 align="left">
<strong>Key Features:</strong>
</h3>
<ul align="left">
<li><strong>🗂️ Wide Format Support:</strong> Reads PDF, DOCX, CSV, XLSX, EPUB, MD, RST, HTML, MDX, JSON, PPTX, images, and audio files such as MP3, WAV, M4A, OGG, and WebM.</li>
<li><strong>🎙️ Speech Workflows:</strong> Record voice input into chat, transcribe audio on the backend, and ingest meeting recordings or voice notes as searchable knowledge.</li>
<li><strong>🗂️ Wide Format Support:</strong> Reads PDF, DOCX, CSV, XLSX, EPUB, MD, RST, HTML, MDX, JSON, PPTX, and images.</li>
<li><strong>🌐 Web & Data Integration:</strong> Ingests from URLs, sitemaps, Reddit, GitHub and web crawlers.</li>
<li><strong>✅ Reliable Answers:</strong> Get accurate, hallucination-free responses with source citations viewable in a clean UI.</li>
<li><strong>🔑 Streamlined API Keys:</strong> Generate keys linked to your settings, documents, and models, simplifying chatbot and integration setup.</li>
@@ -159,3 +158,4 @@ The source code license is [MIT](https://opensource.org/license/mit/), as descri
</a>
</p>

View File

@@ -8,13 +8,7 @@ Currently, we support security patches by committing changes and bumping the ver
## Reporting a Vulnerability
Preferred method: use GitHub's private vulnerability reporting flow:
https://github.com/arc53/DocsGPT/security
Then click **Report a vulnerability**.
Alternatively:
Found a vulnerability? Please email us:
security@arc53.com

View File

@@ -1,8 +1,7 @@
import logging
from application.agents.agentic_agent import AgenticAgent
from application.agents.classic_agent import ClassicAgent
from application.agents.research_agent import ResearchAgent
from application.agents.react_agent import ReActAgent
from application.agents.workflow_agent import WorkflowAgent
logger = logging.getLogger(__name__)
@@ -11,9 +10,7 @@ logger = logging.getLogger(__name__)
class AgentCreator:
agents = {
"classic": ClassicAgent,
"react": ClassicAgent, # backwards compat: react falls back to classic
"agentic": AgenticAgent,
"research": ResearchAgent,
"react": ReActAgent,
"workflow": WorkflowAgent,
}

View File

@@ -1,63 +0,0 @@
import logging
from typing import Dict, Generator, Optional
from application.agents.base import BaseAgent
from application.agents.tools.internal_search import (
INTERNAL_TOOL_ID,
add_internal_search_tool,
)
from application.logging import LogContext
logger = logging.getLogger(__name__)
class AgenticAgent(BaseAgent):
"""Agent where the LLM controls retrieval via tools.
Unlike ClassicAgent which pre-fetches docs into the prompt,
AgenticAgent gives the LLM an internal_search tool so it can
decide when, what, and whether to search.
"""
def __init__(
self,
retriever_config: Optional[Dict] = None,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
self.retriever_config = retriever_config or {}
def _gen_inner(
self, query: str, log_context: LogContext
) -> Generator[Dict, None, None]:
tools_dict = self.tool_executor.get_tools()
add_internal_search_tool(tools_dict, self.retriever_config)
self._prepare_tools(tools_dict)
# 4. Build messages (prompt has NO pre-fetched docs)
messages = self._build_messages(self.prompt, query)
# 5. Call LLM — the handler manages the tool loop
llm_response = self._llm_gen(messages, log_context)
yield from self._handle_response(
llm_response, tools_dict, messages, log_context
)
# 6. Collect sources from internal search tool results
self._collect_internal_sources()
yield {"sources": self.retrieved_docs}
yield {"tool_calls": self._get_truncated_tool_calls()}
log_context.stacks.append(
{"component": "agent", "data": {"tool_calls": self.tool_calls.copy()}}
)
def _collect_internal_sources(self):
"""Collect retrieved docs from the cached InternalSearchTool instance."""
cache_key = f"internal_search:{INTERNAL_TOOL_ID}:{self.user or ''}"
tool = self.tool_executor._loaded_tools.get(cache_key)
if tool and hasattr(tool, "retrieved_docs") and tool.retrieved_docs:
self.retrieved_docs = tool.retrieved_docs

View File

@@ -1,16 +1,18 @@
import json
import logging
import uuid
from abc import ABC, abstractmethod
from typing import Any, Dict, Generator, List, Optional
from typing import Dict, Generator, List, Optional
from application.agents.tool_executor import ToolExecutor
from bson.objectid import ObjectId
from application.agents.tools.tool_action_parser import ToolActionParser
from application.agents.tools.tool_manager import ToolManager
from application.core.json_schema_utils import (
JsonSchemaValidationError,
normalize_json_schema_payload,
)
from application.core.mongo_db import MongoDB
from application.core.settings import settings
from application.llm.handlers.base import ToolCall
from application.llm.handlers.handler_creator import LLMHandlerCreator
from application.llm.llm_creator import LLMCreator
from application.logging import build_stack_data, log_activity, LogContext
@@ -38,10 +40,6 @@ class BaseAgent(ABC):
limited_request_mode: Optional[bool] = False,
request_limit: Optional[int] = settings.DEFAULT_AGENT_LIMITS["request_limit"],
compressed_summary: Optional[str] = None,
llm=None,
llm_handler=None,
tool_executor: Optional[ToolExecutor] = None,
backup_models: Optional[List[str]] = None,
):
self.endpoint = endpoint
self.llm_name = llm_name
@@ -52,42 +50,22 @@ class BaseAgent(ABC):
self.prompt = prompt
self.decoded_token = decoded_token or {}
self.user: str = self.decoded_token.get("sub")
self.tool_config: Dict = {}
self.tools: List[Dict] = []
self.tool_calls: List[Dict] = []
self.chat_history: List[Dict] = chat_history if chat_history is not None else []
# Dependency injection for LLM — fall back to creating if not provided
if llm is not None:
self.llm = llm
else:
self.llm = LLMCreator.create_llm(
llm_name,
api_key=api_key,
user_api_key=user_api_key,
decoded_token=decoded_token,
model_id=model_id,
agent_id=agent_id,
backup_models=backup_models,
)
self.llm = LLMCreator.create_llm(
llm_name,
api_key=api_key,
user_api_key=user_api_key,
decoded_token=decoded_token,
model_id=model_id,
agent_id=agent_id,
)
self.retrieved_docs = retrieved_docs or []
if llm_handler is not None:
self.llm_handler = llm_handler
else:
self.llm_handler = LLMHandlerCreator.create_handler(
llm_name if llm_name else "default"
)
# Tool executor — injected or created
if tool_executor is not None:
self.tool_executor = tool_executor
else:
self.tool_executor = ToolExecutor(
user_api_key=user_api_key,
user=self.user,
decoded_token=decoded_token,
)
self.llm_handler = LLMHandlerCreator.create_handler(
llm_name if llm_name else "default"
)
self.attachments = attachments or []
self.json_schema = None
if json_schema is not None:
@@ -115,219 +93,327 @@ class BaseAgent(ABC):
) -> Generator[Dict, None, None]:
pass
def gen_continuation(
self,
messages: List[Dict],
tools_dict: Dict,
pending_tool_calls: List[Dict],
tool_actions: List[Dict],
) -> Generator[Dict, None, None]:
"""Resume generation after tool actions are resolved.
Processes the client-provided *tool_actions* (approvals, denials,
or client-side results), appends the resulting messages, then
hands back to the LLM to continue the conversation.
Args:
messages: The saved messages array from the pause point.
tools_dict: The saved tools dictionary.
pending_tool_calls: The pending tool call descriptors from the pause.
tool_actions: Client-provided actions resolving the pending calls.
"""
self._prepare_tools(tools_dict)
actions_by_id = {a["call_id"]: a for a in tool_actions}
# Build a single assistant message containing all tool calls so
# the message history matches the format LLM providers expect
# (one assistant message with N tool_calls, followed by N tool results).
tc_objects: List[Dict[str, Any]] = []
for pending in pending_tool_calls:
call_id = pending["call_id"]
args = pending["arguments"]
args_str = (
json.dumps(args) if isinstance(args, dict) else (args or "{}")
)
tc_obj: Dict[str, Any] = {
"id": call_id,
"type": "function",
"function": {
"name": pending["name"],
"arguments": args_str,
},
}
if pending.get("thought_signature"):
tc_obj["thought_signature"] = pending["thought_signature"]
tc_objects.append(tc_obj)
messages.append({
"role": "assistant",
"content": None,
"tool_calls": tc_objects,
})
# Now process each pending call and append tool result messages
for pending in pending_tool_calls:
call_id = pending["call_id"]
args = pending["arguments"]
action = actions_by_id.get(call_id)
if not action:
action = {
"call_id": call_id,
"decision": "denied",
"comment": "No response provided",
}
if action.get("decision") == "approved":
# Execute the tool server-side
tc = ToolCall(
id=call_id,
name=pending["name"],
arguments=(
json.dumps(args) if isinstance(args, dict) else args
),
)
tool_gen = self._execute_tool_action(tools_dict, tc)
tool_response = None
while True:
try:
event = next(tool_gen)
yield event
except StopIteration as e:
tool_response, _ = e.value
break
messages.append(
self.llm_handler.create_tool_message(tc, tool_response)
)
elif action.get("decision") == "denied":
comment = action.get("comment", "")
denial = (
f"Tool execution denied by user. Reason: {comment}"
if comment
else "Tool execution denied by user."
)
tc = ToolCall(
id=call_id, name=pending["name"], arguments=args
)
messages.append(
self.llm_handler.create_tool_message(tc, denial)
)
yield {
"type": "tool_call",
"data": {
"tool_name": pending.get("tool_name", "unknown"),
"call_id": call_id,
"action_name": pending.get("llm_name", pending["name"]),
"arguments": args,
"status": "denied",
},
}
elif "result" in action:
result = action["result"]
result_str = (
json.dumps(result)
if not isinstance(result, str)
else result
)
tc = ToolCall(
id=call_id, name=pending["name"], arguments=args
)
messages.append(
self.llm_handler.create_tool_message(tc, result_str)
)
yield {
"type": "tool_call",
"data": {
"tool_name": pending.get("tool_name", "unknown"),
"call_id": call_id,
"action_name": pending.get("llm_name", pending["name"]),
"arguments": args,
"result": (
result_str[:50] + "..."
if len(result_str) > 50
else result_str
),
"status": "completed",
},
}
# Resume the LLM loop with the updated messages
llm_response = self._llm_gen(messages)
yield from self._handle_response(
llm_response, tools_dict, messages, None
)
yield {"sources": self.retrieved_docs}
yield {"tool_calls": self._get_truncated_tool_calls()}
# ---- Tool delegation (thin wrappers around ToolExecutor) ----
@property
def tool_calls(self) -> List[Dict]:
return self.tool_executor.tool_calls
@tool_calls.setter
def tool_calls(self, value: List[Dict]):
self.tool_executor.tool_calls = value
def _get_tools(self, api_key: str = None) -> Dict[str, Dict]:
return self.tool_executor._get_tools_by_api_key(api_key or self.user_api_key)
mongo = MongoDB.get_client()
db = mongo[settings.MONGO_DB_NAME]
agents_collection = db["agents"]
tools_collection = db["user_tools"]
agent_data = agents_collection.find_one({"key": api_key or self.user_api_key})
tool_ids = agent_data.get("tools", []) if agent_data else []
tools = (
tools_collection.find(
{"_id": {"$in": [ObjectId(tool_id) for tool_id in tool_ids]}}
)
if tool_ids
else []
)
tools = list(tools)
tools_by_id = {str(tool["_id"]): tool for tool in tools} if tools else {}
return tools_by_id
def _get_user_tools(self, user="local"):
return self.tool_executor._get_user_tools(user)
mongo = MongoDB.get_client()
db = mongo[settings.MONGO_DB_NAME]
user_tools_collection = db["user_tools"]
user_tools = user_tools_collection.find({"user": user, "status": True})
user_tools = list(user_tools)
return {str(i): tool for i, tool in enumerate(user_tools)}
def _build_tool_parameters(self, action):
return self.tool_executor._build_tool_parameters(action)
params = {"type": "object", "properties": {}, "required": []}
for param_type in ["query_params", "headers", "body", "parameters"]:
if param_type in action and action[param_type].get("properties"):
for k, v in action[param_type]["properties"].items():
if v.get("filled_by_llm", True):
params["properties"][k] = {
key: value
for key, value in v.items()
if key not in ("filled_by_llm", "value", "required")
}
if v.get("required", False):
params["required"].append(k)
return params
def _prepare_tools(self, tools_dict):
self.tools = self.tool_executor.prepare_tools_for_llm(tools_dict)
self.tools = [
{
"type": "function",
"function": {
"name": f"{action['name']}_{tool_id}",
"description": action["description"],
"parameters": self._build_tool_parameters(action),
},
}
for tool_id, tool in tools_dict.items()
if (
(tool["name"] == "api_tool" and "actions" in tool.get("config", {}))
or (tool["name"] != "api_tool" and "actions" in tool)
)
for action in (
tool["config"]["actions"].values()
if tool["name"] == "api_tool"
else tool["actions"]
)
if action.get("active", True)
]
def _execute_tool_action(self, tools_dict, call):
return self.tool_executor.execute(
tools_dict, call, self.llm.__class__.__name__
parser = ToolActionParser(self.llm.__class__.__name__)
tool_id, action_name, call_args = parser.parse_args(call)
call_id = getattr(call, "id", None) or str(uuid.uuid4())
# Check if parsing failed
if tool_id is None or action_name is None:
error_message = f"Error: Failed to parse LLM tool call. Tool name: {getattr(call, 'name', 'unknown')}"
logger.error(error_message)
tool_call_data = {
"tool_name": "unknown",
"call_id": call_id,
"action_name": getattr(call, "name", "unknown"),
"arguments": call_args or {},
"result": f"Failed to parse tool call. Invalid tool name format: {getattr(call, 'name', 'unknown')}",
}
yield {"type": "tool_call", "data": {**tool_call_data, "status": "error"}}
self.tool_calls.append(tool_call_data)
return "Failed to parse tool call.", call_id
# Check if tool_id exists in available tools
if tool_id not in tools_dict:
error_message = f"Error: Tool ID '{tool_id}' extracted from LLM call not found in available tools_dict. Available IDs: {list(tools_dict.keys())}"
logger.error(error_message)
# Return error result
tool_call_data = {
"tool_name": "unknown",
"call_id": call_id,
"action_name": f"{action_name}_{tool_id}",
"arguments": call_args,
"result": f"Tool with ID {tool_id} not found. Available tools: {list(tools_dict.keys())}",
}
yield {"type": "tool_call", "data": {**tool_call_data, "status": "error"}}
self.tool_calls.append(tool_call_data)
return f"Tool with ID {tool_id} not found.", call_id
tool_call_data = {
"tool_name": tools_dict[tool_id]["name"],
"call_id": call_id,
"action_name": f"{action_name}_{tool_id}",
"arguments": call_args,
}
yield {"type": "tool_call", "data": {**tool_call_data, "status": "pending"}}
tool_data = tools_dict[tool_id]
action_data = (
tool_data["config"]["actions"][action_name]
if tool_data["name"] == "api_tool"
else next(
action
for action in tool_data["actions"]
if action["name"] == action_name
)
)
def _get_truncated_tool_calls(self):
return self.tool_executor.get_truncated_tool_calls()
query_params, headers, body, parameters = {}, {}, {}, {}
param_types = {
"query_params": query_params,
"headers": headers,
"body": body,
"parameters": parameters,
}
# ---- Context / token management ----
for param_type, target_dict in param_types.items():
if param_type in action_data and action_data[param_type].get("properties"):
for param, details in action_data[param_type]["properties"].items():
if (
param not in call_args
and "value" in details
and details["value"]
):
target_dict[param] = details["value"]
for param, value in call_args.items():
for param_type, target_dict in param_types.items():
if param_type in action_data and param in action_data[param_type].get(
"properties", {}
):
target_dict[param] = value
tm = ToolManager(config={})
# Prepare tool_config and add tool_id for memory tools
if tool_data["name"] == "api_tool":
action_config = tool_data["config"]["actions"][action_name]
tool_config = {
"url": action_config["url"],
"method": action_config["method"],
"headers": headers,
"query_params": query_params,
}
if "body_content_type" in action_config:
tool_config["body_content_type"] = action_config.get(
"body_content_type", "application/json"
)
tool_config["body_encoding_rules"] = action_config.get(
"body_encoding_rules", {}
)
else:
tool_config = tool_data["config"].copy() if tool_data["config"] else {}
# Add tool_id from MongoDB _id for tools that need instance isolation (like memory tool)
# Use MongoDB _id if available, otherwise fall back to enumerated tool_id
tool_config["tool_id"] = str(tool_data.get("_id", tool_id))
if hasattr(self, "conversation_id") and self.conversation_id:
tool_config["conversation_id"] = self.conversation_id
tool = tm.load_tool(
tool_data["name"],
tool_config=tool_config,
user_id=self.user,
)
resolved_arguments = (
{"query_params": query_params, "headers": headers, "body": body}
if tool_data["name"] == "api_tool"
else parameters
)
if tool_data["name"] == "api_tool":
logger.debug(
f"Executing api: {action_name} with query_params: {query_params}, headers: {headers}, body: {body}"
)
result = tool.execute_action(action_name, **body)
else:
logger.debug(f"Executing tool: {action_name} with args: {call_args}")
result = tool.execute_action(action_name, **parameters)
get_artifact_id = (
getattr(tool, "get_artifact_id", None)
if tool_data["name"] != "api_tool"
else None
)
artifact_id = None
if callable(get_artifact_id):
try:
artifact_id = get_artifact_id(action_name, **parameters)
except Exception:
logger.exception(
"Failed to extract artifact_id from tool %s for action %s",
tool_data["name"],
action_name,
)
artifact_id = str(artifact_id).strip() if artifact_id is not None else ""
if artifact_id:
tool_call_data["artifact_id"] = artifact_id
result_full = str(result)
tool_call_data["resolved_arguments"] = resolved_arguments
tool_call_data["result_full"] = result_full
tool_call_data["result"] = (
f"{result_full[:50]}..." if len(result_full) > 50 else result_full
)
stream_tool_call_data = {
key: value
for key, value in tool_call_data.items()
if key not in {"result_full", "resolved_arguments"}
}
yield {"type": "tool_call", "data": {**stream_tool_call_data, "status": "completed"}}
self.tool_calls.append(tool_call_data)
return result, call_id
def _get_truncated_tool_calls(self):
return [
{
"tool_name": tool_call.get("tool_name"),
"call_id": tool_call.get("call_id"),
"action_name": tool_call.get("action_name"),
"arguments": tool_call.get("arguments"),
"artifact_id": tool_call.get("artifact_id"),
"result": (
f"{str(tool_call['result'])[:50]}..."
if len(str(tool_call["result"])) > 50
else tool_call["result"]
),
"status": "completed",
}
for tool_call in self.tool_calls
]
def _calculate_current_context_tokens(self, messages: List[Dict]) -> int:
"""
Calculate total tokens in current context (messages).
Args:
messages: List of message dicts
Returns:
Total token count
"""
from application.api.answer.services.compression.token_counter import (
TokenCounter,
)
return TokenCounter.count_message_tokens(messages)
def _check_context_limit(self, messages: List[Dict]) -> bool:
"""
Check if we're approaching context limit (80%).
Args:
messages: Current message list
Returns:
True if at or above 80% of context limit
"""
from application.core.model_utils import get_token_limit
from application.core.settings import settings
try:
# Calculate current tokens
current_tokens = self._calculate_current_context_tokens(messages)
self.current_token_count = current_tokens
# Get context limit for model
context_limit = get_token_limit(self.model_id)
# Calculate threshold (80%)
threshold = int(context_limit * settings.COMPRESSION_THRESHOLD_PERCENTAGE)
# Check if we've reached the limit
if current_tokens >= threshold:
logger.warning(
f"Context limit approaching: {current_tokens}/{context_limit} tokens "
f"({(current_tokens/context_limit)*100:.1f}%)"
)
return True
return False
except Exception as e:
logger.error(f"Error checking context limit: {str(e)}", exc_info=True)
return False
def _validate_context_size(self, messages: List[Dict]) -> None:
"""
Pre-flight validation before calling LLM. Logs warnings but never raises errors.
Args:
messages: Messages to be sent to LLM
"""
from application.core.model_utils import get_token_limit
current_tokens = self._calculate_current_context_tokens(messages)
self.current_token_count = current_tokens
context_limit = get_token_limit(self.model_id)
percentage = (current_tokens / context_limit) * 100
# Log based on usage level
if current_tokens >= context_limit:
logger.warning(
f"Context at limit: {current_tokens:,}/{context_limit:,} tokens "
@@ -342,31 +428,43 @@ class BaseAgent(ABC):
)
def _truncate_text_middle(self, text: str, max_tokens: int) -> str:
"""
Truncate text by removing content from the middle, preserving start and end.
Args:
text: Text to truncate
max_tokens: Maximum tokens allowed
Returns:
Truncated text with middle removed if needed
"""
from application.utils import num_tokens_from_string
current_tokens = num_tokens_from_string(text)
if current_tokens <= max_tokens:
return text
# Estimate chars per token (roughly 4 chars per token for English)
chars_per_token = len(text) / current_tokens if current_tokens > 0 else 4
target_chars = int(max_tokens * chars_per_token * 0.95)
target_chars = int(max_tokens * chars_per_token * 0.95) # 5% safety margin
if target_chars <= 0:
return ""
# Split: keep 40% from start, 40% from end, remove middle
start_chars = int(target_chars * 0.4)
end_chars = int(target_chars * 0.4)
truncation_marker = "\n\n[... content truncated to fit context limit ...]\n\n"
truncated = text[:start_chars] + truncation_marker + text[-end_chars:]
logger.info(
f"Truncated text from {current_tokens:,} to ~{max_tokens:,} tokens "
f"(removed middle section)"
)
return truncated
# ---- Message building ----
return truncated
def _build_messages(
self,
@@ -377,6 +475,7 @@ class BaseAgent(ABC):
from application.core.model_utils import get_token_limit
from application.utils import num_tokens_from_string
# Append compression summary to system prompt if present
if self.compressed_summary:
compression_context = (
"\n\n---\n\n"
@@ -390,18 +489,23 @@ class BaseAgent(ABC):
context_limit = get_token_limit(self.model_id)
system_tokens = num_tokens_from_string(system_prompt)
# Reserve 10% for response/tools
safety_buffer = int(context_limit * 0.1)
available_after_system = context_limit - system_tokens - safety_buffer
# Max tokens for query: 80% of available space (leave room for history)
max_query_tokens = int(available_after_system * 0.8)
query_tokens = num_tokens_from_string(query)
# Truncate query from middle if it exceeds 80% of available context
if query_tokens > max_query_tokens:
query = self._truncate_text_middle(query, max_query_tokens)
query_tokens = num_tokens_from_string(query)
# Calculate remaining budget for chat history
available_for_history = max(available_after_system - query_tokens, 0)
# Truncate chat history to fit within available budget
working_history = self._truncate_history_to_fit(
self.chat_history,
available_for_history,
@@ -416,35 +520,28 @@ class BaseAgent(ABC):
if "tool_calls" in i:
for tool_call in i["tool_calls"]:
call_id = tool_call.get("call_id") or str(uuid.uuid4())
args = tool_call.get("arguments")
args_str = (
json.dumps(args)
if isinstance(args, dict)
else (args or "{}")
function_call_dict = {
"function_call": {
"name": tool_call.get("action_name"),
"args": tool_call.get("arguments"),
"call_id": call_id,
}
}
function_response_dict = {
"function_response": {
"name": tool_call.get("action_name"),
"response": {"result": tool_call.get("result")},
"call_id": call_id,
}
}
messages.append(
{"role": "assistant", "content": [function_call_dict]}
)
messages.append({
"role": "assistant",
"content": None,
"tool_calls": [{
"id": call_id,
"type": "function",
"function": {
"name": tool_call.get("action_name", ""),
"arguments": args_str,
},
}],
})
result = tool_call.get("result")
result_str = (
json.dumps(result)
if not isinstance(result, str)
else (result or "")
messages.append(
{"role": "tool", "content": [function_response_dict]}
)
messages.append({
"role": "tool",
"tool_call_id": call_id,
"content": result_str,
})
messages.append({"role": "user", "content": query})
return messages
@@ -453,6 +550,16 @@ class BaseAgent(ABC):
history: List[Dict],
max_tokens: int,
) -> List[Dict]:
"""
Truncate chat history to fit within token budget, keeping most recent messages.
Args:
history: Full chat history
max_tokens: Maximum tokens allowed for history
Returns:
Truncated history (most recent messages that fit)
"""
from application.utils import num_tokens_from_string
if not history or max_tokens <= 0:
@@ -461,6 +568,7 @@ class BaseAgent(ABC):
truncated = []
current_tokens = 0
# Iterate from newest to oldest
for message in reversed(history):
message_tokens = 0
@@ -480,7 +588,7 @@ class BaseAgent(ABC):
if current_tokens + message_tokens <= max_tokens:
current_tokens += message_tokens
truncated.insert(0, message)
truncated.insert(0, message) # Maintain chronological order
else:
break
@@ -492,13 +600,13 @@ class BaseAgent(ABC):
return truncated
# ---- LLM generation ----
def _llm_gen(self, messages: List[Dict], log_context: Optional[LogContext] = None):
# Pre-flight context validation - fail fast if over limit
self._validate_context_size(messages)
gen_kwargs = {"model": self.model_id, "messages": messages}
if self.attachments:
# Usage accounting only; stripped before provider invocation.
gen_kwargs["_usage_attachments"] = self.attachments
if (

View File

@@ -15,7 +15,11 @@ class ClassicAgent(BaseAgent):
) -> Generator[Dict, None, None]:
"""Core generator function for ClassicAgent execution flow"""
tools_dict = self.tool_executor.get_tools()
tools_dict = (
self._get_user_tools(self.user)
if not self.user_api_key
else self._get_tools(self.user_api_key)
)
self._prepare_tools(tools_dict)
messages = self._build_messages(self.prompt, query)

View File

@@ -0,0 +1,238 @@
import logging
import os
from typing import Any, Dict, Generator, List
from application.agents.base import BaseAgent
from application.logging import build_stack_data, LogContext
logger = logging.getLogger(__name__)
MAX_ITERATIONS_REASONING = 10
current_dir = os.path.dirname(
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
)
with open(
os.path.join(current_dir, "application/prompts", "react_planning_prompt.txt"), "r"
) as f:
PLANNING_PROMPT_TEMPLATE = f.read()
with open(
os.path.join(current_dir, "application/prompts", "react_final_prompt.txt"), "r"
) as f:
FINAL_PROMPT_TEMPLATE = f.read()
class ReActAgent(BaseAgent):
"""
Research and Action (ReAct) Agent - Advanced reasoning agent with iterative planning.
Implements a think-act-observe loop for complex problem-solving:
1. Creates a strategic plan based on the query
2. Executes tools and gathers observations
3. Iteratively refines approach until satisfied
4. Synthesizes final answer from all observations
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.plan: str = ""
self.observations: List[str] = []
def _gen_inner(
self, query: str, log_context: LogContext
) -> Generator[Dict, None, None]:
"""Execute ReAct reasoning loop with planning, action, and observation cycles"""
self._reset_state()
tools_dict = (
self._get_tools(self.user_api_key)
if self.user_api_key
else self._get_user_tools(self.user)
)
self._prepare_tools(tools_dict)
for iteration in range(1, MAX_ITERATIONS_REASONING + 1):
yield {"thought": f"Reasoning... (iteration {iteration})\n\n"}
yield from self._planning_phase(query, log_context)
if not self.plan:
logger.warning(
f"ReActAgent: No plan generated in iteration {iteration}"
)
break
self.observations.append(f"Plan (iteration {iteration}): {self.plan}")
satisfied = yield from self._execution_phase(query, tools_dict, log_context)
if satisfied:
logger.info("ReActAgent: Goal satisfied, stopping reasoning loop")
break
yield from self._synthesis_phase(query, log_context)
def _reset_state(self):
"""Reset agent state for new query"""
self.plan = ""
self.observations = []
def _planning_phase(
self, query: str, log_context: LogContext
) -> Generator[Dict, None, None]:
"""Generate strategic plan for query"""
logger.info("ReActAgent: Creating plan...")
plan_prompt = self._build_planning_prompt(query)
messages = [{"role": "user", "content": plan_prompt}]
plan_stream = self.llm.gen_stream(
model=self.model_id,
messages=messages,
tools=self.tools if self.tools else None,
)
if log_context:
log_context.stacks.append(
{"component": "planning_llm", "data": build_stack_data(self.llm)}
)
plan_parts = []
for chunk in plan_stream:
content = self._extract_content(chunk)
if content:
plan_parts.append(content)
yield {"thought": content}
self.plan = "".join(plan_parts)
def _execution_phase(
self, query: str, tools_dict: Dict, log_context: LogContext
) -> Generator[bool, None, None]:
"""Execute plan with tool calls and observations"""
execution_prompt = self._build_execution_prompt(query)
messages = self._build_messages(execution_prompt, query)
llm_response = self._llm_gen(messages, log_context)
initial_content = self._extract_content(llm_response)
if initial_content:
self.observations.append(f"Initial response: {initial_content}")
processed_response = self._llm_handler(
llm_response, tools_dict, messages, log_context
)
for tool_call in self.tool_calls:
observation = (
f"Executed: {tool_call.get('tool_name', 'Unknown')} "
f"with args {tool_call.get('arguments', {})}. "
f"Result: {str(tool_call.get('result', ''))[:200]}"
)
self.observations.append(observation)
final_content = self._extract_content(processed_response)
if final_content:
self.observations.append(f"Response after tools: {final_content}")
if log_context:
log_context.stacks.append(
{
"component": "agent_tool_calls",
"data": {"tool_calls": self.tool_calls.copy()},
}
)
yield {"sources": self.retrieved_docs}
yield {"tool_calls": self._get_truncated_tool_calls()}
return "SATISFIED" in (final_content or "")
def _synthesis_phase(
self, query: str, log_context: LogContext
) -> Generator[Dict, None, None]:
"""Synthesize final answer from all observations"""
logger.info("ReActAgent: Generating final answer...")
final_prompt = self._build_final_answer_prompt(query)
messages = [{"role": "user", "content": final_prompt}]
final_stream = self.llm.gen_stream(
model=self.model_id, messages=messages, tools=None
)
if log_context:
log_context.stacks.append(
{"component": "final_answer_llm", "data": build_stack_data(self.llm)}
)
for chunk in final_stream:
content = self._extract_content(chunk)
if content:
yield {"answer": content}
def _build_planning_prompt(self, query: str) -> str:
"""Build planning phase prompt"""
prompt = PLANNING_PROMPT_TEMPLATE.replace("{query}", query)
prompt = prompt.replace("{prompt}", self.prompt or "")
prompt = prompt.replace("{summaries}", "")
prompt = prompt.replace("{observations}", "\n".join(self.observations))
return prompt
def _build_execution_prompt(self, query: str) -> str:
"""Build execution phase prompt with plan and observations"""
observations_str = "\n".join(self.observations)
if len(observations_str) > 20000:
observations_str = observations_str[:20000] + "\n...[truncated]"
return (
f"{self.prompt or ''}\n\n"
f"Follow this plan:\n{self.plan}\n\n"
f"Observations:\n{observations_str}\n\n"
f"If sufficient data exists to answer '{query}', respond with 'SATISFIED'. "
f"Otherwise, continue executing the plan."
)
def _build_final_answer_prompt(self, query: str) -> str:
"""Build final synthesis prompt"""
observations_str = "\n".join(self.observations)
if len(observations_str) > 10000:
observations_str = observations_str[:10000] + "\n...[truncated]"
logger.warning("ReActAgent: Observations truncated for final answer")
return FINAL_PROMPT_TEMPLATE.format(query=query, observations=observations_str)
def _extract_content(self, response: Any) -> str:
"""Extract text content from various LLM response formats"""
if not response:
return ""
collected = []
if isinstance(response, str):
return response
if hasattr(response, "message") and hasattr(response.message, "content"):
if response.message.content:
return response.message.content
if hasattr(response, "choices") and response.choices:
if hasattr(response.choices[0], "message"):
content = response.choices[0].message.content
if content:
return content
if hasattr(response, "content") and isinstance(response.content, list):
if response.content and hasattr(response.content[0], "text"):
return response.content[0].text
try:
for chunk in response:
content_piece = ""
if hasattr(chunk, "choices") and chunk.choices:
if hasattr(chunk.choices[0], "delta"):
delta_content = chunk.choices[0].delta.content
if delta_content:
content_piece = delta_content
elif hasattr(chunk, "type") and chunk.type == "content_block_delta":
if hasattr(chunk, "delta") and hasattr(chunk.delta, "text"):
content_piece = chunk.delta.text
elif isinstance(chunk, str):
content_piece = chunk
if content_piece:
collected.append(content_piece)
except (TypeError, AttributeError):
logger.debug(
f"Response not iterable or unexpected format: {type(response)}"
)
except Exception as e:
logger.error(f"Error extracting content: {e}")
return "".join(collected)

View File

@@ -1,698 +0,0 @@
import json
import logging
import os
import time
from typing import Dict, Generator, List, Optional
from application.agents.base import BaseAgent
from application.agents.tool_executor import ToolExecutor
from application.agents.tools.internal_search import (
INTERNAL_TOOL_ID,
add_internal_search_tool,
)
from application.agents.tools.think import THINK_TOOL_ENTRY, THINK_TOOL_ID
from application.logging import LogContext
logger = logging.getLogger(__name__)
# Defaults (can be overridden via constructor)
DEFAULT_MAX_STEPS = 6
DEFAULT_MAX_SUB_ITERATIONS = 5
DEFAULT_TIMEOUT_SECONDS = 300 # 5 minutes
DEFAULT_TOKEN_BUDGET = 100_000
DEFAULT_PARALLEL_WORKERS = 3
# Adaptive depth caps per complexity level
COMPLEXITY_CAPS = {
"simple": 2,
"moderate": 4,
"complex": 6,
}
_PROMPTS_DIR = os.path.join(
os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
"prompts",
"research",
)
def _load_prompt(name: str) -> str:
with open(os.path.join(_PROMPTS_DIR, name), "r") as f:
return f.read()
CLARIFICATION_PROMPT = _load_prompt("clarification.txt")
PLANNING_PROMPT = _load_prompt("planning.txt")
STEP_PROMPT = _load_prompt("step.txt")
SYNTHESIS_PROMPT = _load_prompt("synthesis.txt")
# ---------------------------------------------------------------------------
# CitationManager
# ---------------------------------------------------------------------------
class CitationManager:
"""Tracks and deduplicates citations across research steps."""
def __init__(self):
self.citations: Dict[int, Dict] = {}
self._counter = 0
def add(self, doc: Dict) -> int:
"""Register a source, return its citation number. Deduplicates by source."""
source = doc.get("source", "")
title = doc.get("title", "")
for num, existing in self.citations.items():
if existing.get("source") == source and existing.get("title") == title:
return num
self._counter += 1
self.citations[self._counter] = doc
return self._counter
def add_docs(self, docs: List[Dict]) -> str:
"""Register multiple docs, return formatted citation mapping text."""
mapping_lines = []
for doc in docs:
num = self.add(doc)
title = doc.get("title", "Untitled")
mapping_lines.append(f"[{num}] {title}")
return "\n".join(mapping_lines)
def format_references(self) -> str:
"""Generate [N] -> source mapping for report footer."""
if not self.citations:
return "No sources found."
lines = []
for num, doc in sorted(self.citations.items()):
title = doc.get("title", "Untitled")
source = doc.get("source", "Unknown")
filename = doc.get("filename", "")
display = filename or title
lines.append(f"[{num}] {display}{source}")
return "\n".join(lines)
def get_all_docs(self) -> List[Dict]:
return list(self.citations.values())
# ---------------------------------------------------------------------------
# ResearchAgent
# ---------------------------------------------------------------------------
class ResearchAgent(BaseAgent):
"""Multi-step research agent with parallel execution and budget controls.
Orchestrates: Plan -> Research (per step, optionally parallel) -> Synthesize.
"""
def __init__(
self,
retriever_config: Optional[Dict] = None,
max_steps: int = DEFAULT_MAX_STEPS,
max_sub_iterations: int = DEFAULT_MAX_SUB_ITERATIONS,
timeout_seconds: int = DEFAULT_TIMEOUT_SECONDS,
token_budget: int = DEFAULT_TOKEN_BUDGET,
parallel_workers: int = DEFAULT_PARALLEL_WORKERS,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
self.retriever_config = retriever_config or {}
self.max_steps = max_steps
self.max_sub_iterations = max_sub_iterations
self.timeout_seconds = timeout_seconds
self.token_budget = token_budget
self.parallel_workers = parallel_workers
self.citations = CitationManager()
self._start_time: float = 0
self._tokens_used: int = 0
self._last_token_snapshot: int = 0
# ------------------------------------------------------------------
# Budget & timeout helpers
# ------------------------------------------------------------------
def _is_timed_out(self) -> bool:
return (time.monotonic() - self._start_time) >= self.timeout_seconds
def _elapsed(self) -> float:
return round(time.monotonic() - self._start_time, 1)
def _track_tokens(self, count: int):
self._tokens_used += count
def _budget_remaining(self) -> int:
return max(self.token_budget - self._tokens_used, 0)
def _is_over_budget(self) -> bool:
return self._tokens_used >= self.token_budget
def _snapshot_llm_tokens(self) -> int:
"""Read current token usage from LLM and return delta since last snapshot."""
current = self.llm.token_usage.get("prompt_tokens", 0) + self.llm.token_usage.get("generated_tokens", 0)
delta = current - self._last_token_snapshot
self._last_token_snapshot = current
return delta
# ------------------------------------------------------------------
# Main orchestration
# ------------------------------------------------------------------
def _gen_inner(
self, query: str, log_context: LogContext
) -> Generator[Dict, None, None]:
self._start_time = time.monotonic()
tools_dict = self._setup_tools()
# Phase 0: Clarification (skip if user is responding to a prior clarification)
if not self._is_follow_up():
clarification = self._clarification_phase(query)
if clarification:
yield {"metadata": {"is_clarification": True}}
yield {"answer": clarification}
yield {"sources": []}
yield {"tool_calls": []}
log_context.stacks.append(
{"component": "agent", "data": {"clarification": True}}
)
return
# Phase 1: Planning (with adaptive depth)
yield {"type": "research_progress", "data": {"status": "planning"}}
plan, complexity = self._planning_phase(query)
if not plan:
logger.warning("ResearchAgent: Planning produced no steps, falling back")
plan = [{"query": query, "rationale": "Direct investigation"}]
complexity = "simple"
yield {
"type": "research_plan",
"data": {"steps": plan, "complexity": complexity},
}
# Phase 2: Research each step (yields progress events in real-time)
intermediate_reports = []
for i, step in enumerate(plan):
step_num = i + 1
step_query = step.get("query", query)
if self._is_timed_out():
logger.warning(
f"ResearchAgent: Timeout at step {step_num}/{len(plan)} "
f"({self._elapsed()}s)"
)
break
if self._is_over_budget():
logger.warning(
f"ResearchAgent: Token budget exhausted at step {step_num}/{len(plan)}"
)
break
yield {
"type": "research_progress",
"data": {
"step": step_num,
"total": len(plan),
"query": step_query,
"status": "researching",
},
}
report = self._research_step(step_query, tools_dict)
intermediate_reports.append({"step": step, "content": report})
yield {
"type": "research_progress",
"data": {
"step": step_num,
"total": len(plan),
"query": step_query,
"status": "complete",
},
}
# Phase 3: Synthesis (streaming)
if self._is_timed_out():
logger.warning(
f"ResearchAgent: Timeout ({self._elapsed()}s) before synthesis, "
f"synthesizing with {len(intermediate_reports)} reports"
)
yield {
"type": "research_progress",
"data": {
"status": "synthesizing",
"elapsed_seconds": self._elapsed(),
"tokens_used": self._tokens_used,
},
}
yield from self._synthesis_phase(
query, plan, intermediate_reports, tools_dict, log_context
)
# Sources and tool calls
self.retrieved_docs = self.citations.get_all_docs()
yield {"sources": self.retrieved_docs}
yield {"tool_calls": self._get_truncated_tool_calls()}
logger.info(
f"ResearchAgent completed: {len(intermediate_reports)}/{len(plan)} steps, "
f"{self._elapsed()}s, ~{self._tokens_used} tokens"
)
log_context.stacks.append(
{"component": "agent", "data": {"tool_calls": self.tool_calls.copy()}}
)
# ------------------------------------------------------------------
# Tool setup
# ------------------------------------------------------------------
def _setup_tools(self) -> Dict:
"""Build tools_dict with user tools + internal search + think."""
tools_dict = self.tool_executor.get_tools()
add_internal_search_tool(tools_dict, self.retriever_config)
think_entry = dict(THINK_TOOL_ENTRY)
think_entry["config"] = {}
tools_dict[THINK_TOOL_ID] = think_entry
self._prepare_tools(tools_dict)
return tools_dict
# ------------------------------------------------------------------
# Phase 0: Clarification
# ------------------------------------------------------------------
def _is_follow_up(self) -> bool:
"""Check if the user is responding to a prior clarification.
Uses the metadata flag stored in the conversation DB — no string matching.
Only skip clarification when the last query was explicitly flagged
as a clarification by this agent.
"""
if not self.chat_history:
return False
last = self.chat_history[-1]
meta = last.get("metadata", {})
return bool(meta.get("is_clarification"))
def _clarification_phase(self, question: str) -> Optional[str]:
"""Ask the LLM whether the question needs clarification.
Returns formatted clarification text if needed, or None to proceed.
Uses response_format to force valid JSON output.
"""
messages = [
{"role": "system", "content": CLARIFICATION_PROMPT},
{"role": "user", "content": question},
]
try:
response = self.llm.gen(
model=self.model_id,
messages=messages,
tools=None,
response_format={"type": "json_object"},
)
text = self._extract_text(response)
self._track_tokens(self._snapshot_llm_tokens())
logger.info(f"ResearchAgent clarification response: {text[:300]}")
data = self._parse_clarification_json(text)
if not data or not data.get("needs_clarification"):
return None
questions = data.get("questions", [])
if not questions:
return None
# Format as a friendly response
lines = [
"Before I begin researching, I'd like to clarify a few things:\n"
]
for i, q in enumerate(questions[:3], 1):
lines.append(f"{i}. {q}")
lines.append(
"\nPlease provide these details and I'll start the research."
)
return "\n".join(lines)
except Exception as e:
logger.error(f"Clarification phase failed: {e}", exc_info=True)
return None # proceed with research on failure
def _parse_clarification_json(self, text: str) -> Optional[Dict]:
"""Parse clarification JSON from LLM response."""
try:
return json.loads(text)
except json.JSONDecodeError:
pass
# Try extracting from code fences
for marker in ["```json", "```"]:
if marker in text:
start = text.index(marker) + len(marker)
end = text.index("```", start) if "```" in text[start:] else len(text)
try:
return json.loads(text[start:end].strip())
except (json.JSONDecodeError, ValueError):
pass
# Try finding JSON object
for i, ch in enumerate(text):
if ch == "{":
for j in range(len(text) - 1, i, -1):
if text[j] == "}":
try:
return json.loads(text[i : j + 1])
except json.JSONDecodeError:
continue
break
return None
# ------------------------------------------------------------------
# Phase 1: Planning (with adaptive depth)
# ------------------------------------------------------------------
def _planning_phase(self, question: str) -> tuple[List[Dict], str]:
"""Decompose the question into research steps via LLM.
Returns (steps, complexity) where complexity is simple/moderate/complex.
"""
messages = [
{"role": "system", "content": PLANNING_PROMPT},
{"role": "user", "content": question},
]
try:
response = self.llm.gen(
model=self.model_id,
messages=messages,
tools=None,
response_format={"type": "json_object"},
)
text = self._extract_text(response)
self._track_tokens(self._snapshot_llm_tokens())
logger.info(f"ResearchAgent planning LLM response: {text[:500]}")
plan_data = self._parse_plan_json(text)
if isinstance(plan_data, dict):
complexity = plan_data.get("complexity", "moderate")
steps = plan_data.get("steps", [])
else:
complexity = "moderate"
steps = plan_data
# Adaptive depth: cap steps based on assessed complexity
cap = COMPLEXITY_CAPS.get(complexity, self.max_steps)
cap = min(cap, self.max_steps)
steps = steps[:cap]
logger.info(
f"ResearchAgent plan: complexity={complexity}, "
f"steps={len(steps)} (cap={cap})"
)
return steps, complexity
except Exception as e:
logger.error(f"Planning phase failed: {e}", exc_info=True)
return (
[{"query": question, "rationale": "Direct investigation (planning failed)"}],
"simple",
)
def _parse_plan_json(self, text: str):
"""Extract JSON plan from LLM response. Returns dict or list."""
# Try direct parse
try:
data = json.loads(text)
if isinstance(data, dict) and "steps" in data:
return data
if isinstance(data, list):
return data
except json.JSONDecodeError:
pass
# Try extracting from markdown code fences
for marker in ["```json", "```"]:
if marker in text:
start = text.index(marker) + len(marker)
end = text.index("```", start) if "```" in text[start:] else len(text)
try:
data = json.loads(text[start:end].strip())
if isinstance(data, dict) and "steps" in data:
return data
if isinstance(data, list):
return data
except (json.JSONDecodeError, ValueError):
pass
# Try finding JSON object in text
for i, ch in enumerate(text):
if ch == "{":
for j in range(len(text) - 1, i, -1):
if text[j] == "}":
try:
data = json.loads(text[i : j + 1])
if isinstance(data, dict) and "steps" in data:
return data
except json.JSONDecodeError:
continue
break
logger.warning(f"Could not parse plan JSON from: {text[:200]}")
return []
# ------------------------------------------------------------------
# Phase 2: Research step (core loop)
# ------------------------------------------------------------------
def _research_step(self, step_query: str, tools_dict: Dict) -> str:
"""Run a focused research loop for one sub-question (sequential path)."""
report = self._research_step_with_executor(
step_query, tools_dict, self.tool_executor
)
self._collect_step_sources()
return report
def _research_step_with_executor(
self, step_query: str, tools_dict: Dict, executor: ToolExecutor
) -> str:
"""Core research loop. Works with any ToolExecutor instance."""
system_prompt = STEP_PROMPT.replace("{step_query}", step_query)
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": step_query},
]
last_search_empty = False
for iteration in range(self.max_sub_iterations):
# Check timeout and budget
if self._is_timed_out():
logger.info(
f"Research step '{step_query[:50]}' timed out at iteration {iteration}"
)
break
if self._is_over_budget():
logger.info(
f"Research step '{step_query[:50]}' hit token budget at iteration {iteration}"
)
break
try:
response = self.llm.gen(
model=self.model_id,
messages=messages,
tools=self.tools if self.tools else None,
)
self._track_tokens(self._snapshot_llm_tokens())
except Exception as e:
logger.error(
f"Research step LLM call failed (iteration {iteration}): {e}",
exc_info=True,
)
break
parsed = self.llm_handler.parse_response(response)
if not parsed.requires_tool_call:
return parsed.content or "No findings for this step."
# Execute tool calls
messages, last_search_empty = self._execute_step_tools_with_refinement(
parsed.tool_calls, tools_dict, messages, executor, last_search_empty
)
# Max iterations / timeout / budget — ask for summary
messages.append(
{
"role": "user",
"content": "Please summarize your findings so far based on the information gathered.",
}
)
try:
response = self.llm.gen(
model=self.model_id, messages=messages, tools=None
)
self._track_tokens(self._snapshot_llm_tokens())
text = self._extract_text(response)
return text or "Research step completed."
except Exception:
return "Research step completed."
def _execute_step_tools_with_refinement(
self,
tool_calls,
tools_dict: Dict,
messages: List[Dict],
executor: ToolExecutor,
last_search_empty: bool,
) -> tuple[List[Dict], bool]:
"""Execute tool calls with query refinement on empty results.
Returns (updated_messages, was_last_search_empty).
"""
search_returned_empty = False
for call in tool_calls:
gen = executor.execute(
tools_dict, call, self.llm.__class__.__name__
)
result = None
call_id = None
while True:
try:
event = next(gen)
# Log tool_call status events instead of discarding them
if isinstance(event, dict) and event.get("type") == "tool_call":
logger.debug(
"Tool %s status: %s",
event.get("data", {}).get("action_name", ""),
event.get("data", {}).get("status", ""),
)
except StopIteration as e:
result, call_id = e.value
break
# Detect empty search results for refinement
is_search = "search" in (call.name or "").lower()
result_str = str(result) if result else ""
if is_search and "No documents found" in result_str:
search_returned_empty = True
if last_search_empty:
# Two consecutive empty searches — inject refinement hint
result_str += (
"\n\nHint: Previous search also returned no results. "
"Try a very different query with different keywords, "
"or broaden your search terms."
)
result = result_str
import json as _json
args_str = (
_json.dumps(call.arguments)
if isinstance(call.arguments, dict)
else call.arguments
)
messages.append({
"role": "assistant",
"content": None,
"tool_calls": [{
"id": call_id,
"type": "function",
"function": {"name": call.name, "arguments": args_str},
}],
})
tool_message = self.llm_handler.create_tool_message(call, result)
messages.append(tool_message)
return messages, search_returned_empty
def _collect_step_sources(self):
"""Collect sources from InternalSearchTool and register with CitationManager."""
cache_key = f"internal_search:{INTERNAL_TOOL_ID}:{self.user or ''}"
tool = self.tool_executor._loaded_tools.get(cache_key)
if tool and hasattr(tool, "retrieved_docs"):
for doc in tool.retrieved_docs:
self.citations.add(doc)
# ------------------------------------------------------------------
# Phase 3: Synthesis
# ------------------------------------------------------------------
def _synthesis_phase(
self,
question: str,
plan: List[Dict],
intermediate_reports: List[Dict],
tools_dict: Dict,
log_context: LogContext,
) -> Generator[Dict, None, None]:
"""Compile all findings into a final cited report (streaming)."""
plan_lines = []
for i, step in enumerate(plan, 1):
plan_lines.append(
f"{i}. {step.get('query', 'Unknown')}{step.get('rationale', '')}"
)
plan_summary = "\n".join(plan_lines)
findings_parts = []
for i, report in enumerate(intermediate_reports, 1):
step_query = report["step"].get("query", "Unknown")
content = report["content"]
findings_parts.append(
f"--- Step {i}: {step_query} ---\n{content}"
)
findings = "\n\n".join(findings_parts)
references = self.citations.format_references()
synthesis_prompt = SYNTHESIS_PROMPT.replace("{question}", question)
synthesis_prompt = synthesis_prompt.replace("{plan_summary}", plan_summary)
synthesis_prompt = synthesis_prompt.replace("{findings}", findings)
synthesis_prompt = synthesis_prompt.replace("{references}", references)
messages = [
{"role": "system", "content": synthesis_prompt},
{"role": "user", "content": f"Please write the research report for: {question}"},
]
llm_response = self.llm.gen_stream(
model=self.model_id, messages=messages, tools=None
)
if log_context:
from application.logging import build_stack_data
log_context.stacks.append(
{"component": "synthesis_llm", "data": build_stack_data(self.llm)}
)
yield from self._handle_response(
llm_response, tools_dict, messages, log_context
)
# ------------------------------------------------------------------
# Helpers
# ------------------------------------------------------------------
def _extract_text(self, response) -> str:
"""Extract text content from a non-streaming LLM response."""
if isinstance(response, str):
return response
if hasattr(response, "message") and hasattr(response.message, "content"):
return response.message.content or ""
if hasattr(response, "choices") and response.choices:
choice = response.choices[0]
if hasattr(choice, "message") and hasattr(choice.message, "content"):
return choice.message.content or ""
if hasattr(response, "content") and isinstance(response.content, list):
if response.content and hasattr(response.content[0], "text"):
return response.content[0].text or ""
return str(response) if response else ""

View File

@@ -1,477 +0,0 @@
import logging
import uuid
from collections import Counter
from typing import Dict, List, Optional, Tuple
from bson.objectid import ObjectId
from application.agents.tools.tool_action_parser import ToolActionParser
from application.agents.tools.tool_manager import ToolManager
from application.core.mongo_db import MongoDB
from application.core.settings import settings
from application.security.encryption import decrypt_credentials
logger = logging.getLogger(__name__)
class ToolExecutor:
"""Handles tool discovery, preparation, and execution.
Extracted from BaseAgent to separate concerns and enable tool caching.
"""
def __init__(
self,
user_api_key: Optional[str] = None,
user: Optional[str] = None,
decoded_token: Optional[Dict] = None,
):
self.user_api_key = user_api_key
self.user = user
self.decoded_token = decoded_token
self.tool_calls: List[Dict] = []
self._loaded_tools: Dict[str, object] = {}
self.conversation_id: Optional[str] = None
self.client_tools: Optional[List[Dict]] = None
self._name_to_tool: Dict[str, Tuple[str, str]] = {}
self._tool_to_name: Dict[Tuple[str, str], str] = {}
def get_tools(self) -> Dict[str, Dict]:
"""Load tool configs from DB based on user context.
If *client_tools* have been set on this executor, they are
automatically merged into the returned dict.
"""
if self.user_api_key:
tools = self._get_tools_by_api_key(self.user_api_key)
else:
tools = self._get_user_tools(self.user or "local")
if self.client_tools:
self.merge_client_tools(tools, self.client_tools)
return tools
def _get_tools_by_api_key(self, api_key: str) -> Dict[str, Dict]:
mongo = MongoDB.get_client()
db = mongo[settings.MONGO_DB_NAME]
agents_collection = db["agents"]
tools_collection = db["user_tools"]
agent_data = agents_collection.find_one({"key": api_key})
tool_ids = agent_data.get("tools", []) if agent_data else []
tools = (
tools_collection.find(
{"_id": {"$in": [ObjectId(tool_id) for tool_id in tool_ids]}}
)
if tool_ids
else []
)
tools = list(tools)
return {str(tool["_id"]): tool for tool in tools} if tools else {}
def _get_user_tools(self, user: str = "local") -> Dict[str, Dict]:
mongo = MongoDB.get_client()
db = mongo[settings.MONGO_DB_NAME]
user_tools_collection = db["user_tools"]
user_tools = user_tools_collection.find({"user": user, "status": True})
user_tools = list(user_tools)
return {str(i): tool for i, tool in enumerate(user_tools)}
def merge_client_tools(
self, tools_dict: Dict, client_tools: List[Dict]
) -> Dict:
"""Merge client-provided tool definitions into tools_dict.
Client tools use the standard function-calling format::
[{"type": "function", "function": {"name": "get_weather",
"description": "...", "parameters": {...}}}]
They are stored in *tools_dict* with ``client_side: True`` so that
:meth:`check_pause` returns a pause signal instead of trying to
execute them server-side.
Args:
tools_dict: The mutable server tools dict (will be modified in place).
client_tools: List of tool definitions in function-calling format.
Returns:
The updated *tools_dict* (same reference, for convenience).
"""
for i, ct in enumerate(client_tools):
func = ct.get("function", ct) # tolerate bare {"name":..} too
name = func.get("name", f"clienttool{i}")
tool_id = f"ct{i}"
tools_dict[tool_id] = {
"name": name,
"client_side": True,
"actions": [
{
"name": name,
"description": func.get("description", ""),
"active": True,
"parameters": func.get("parameters", {}),
}
],
}
return tools_dict
def prepare_tools_for_llm(self, tools_dict: Dict) -> List[Dict]:
"""Convert tool configs to LLM function schemas.
Action names are kept clean for the LLM:
- Unique action names appear as-is (e.g. ``get_weather``).
- Duplicate action names get numbered suffixes (e.g. ``search_1``,
``search_2``).
A reverse mapping is stored in ``_name_to_tool`` so that tool calls
can be routed back to the correct ``(tool_id, action_name)`` without
brittle string splitting.
"""
# Pass 1: collect entries and count action name occurrences
entries: List[Tuple[str, str, Dict, bool]] = [] # (tool_id, action_name, action, is_client)
name_counts: Counter = Counter()
for tool_id, tool in tools_dict.items():
is_api = tool["name"] == "api_tool"
is_client = tool.get("client_side", False)
if is_api and "actions" not in tool.get("config", {}):
continue
if not is_api and "actions" not in tool:
continue
actions = (
tool["config"]["actions"].values()
if is_api
else tool["actions"]
)
for action in actions:
if not action.get("active", True):
continue
entries.append((tool_id, action["name"], action, is_client))
name_counts[action["name"]] += 1
# Pass 2: assign LLM-visible names and build mappings
self._name_to_tool = {}
self._tool_to_name = {}
collision_counters: Dict[str, int] = {}
all_llm_names: set = set()
result = []
for tool_id, action_name, action, is_client in entries:
if name_counts[action_name] == 1:
llm_name = action_name
else:
counter = collision_counters.get(action_name, 1)
candidate = f"{action_name}_{counter}"
# Skip if candidate collides with a unique action name
while candidate in all_llm_names or (
candidate in name_counts and name_counts[candidate] == 1
):
counter += 1
candidate = f"{action_name}_{counter}"
collision_counters[action_name] = counter + 1
llm_name = candidate
all_llm_names.add(llm_name)
self._name_to_tool[llm_name] = (tool_id, action_name)
self._tool_to_name[(tool_id, action_name)] = llm_name
if is_client:
params = action.get("parameters", {})
else:
params = self._build_tool_parameters(action)
result.append({
"type": "function",
"function": {
"name": llm_name,
"description": action.get("description", ""),
"parameters": params,
},
})
return result
def _build_tool_parameters(self, action: Dict) -> Dict:
params = {"type": "object", "properties": {}, "required": []}
for param_type in ["query_params", "headers", "body", "parameters"]:
if param_type in action and action[param_type].get("properties"):
for k, v in action[param_type]["properties"].items():
if v.get("filled_by_llm", True):
params["properties"][k] = {
key: value
for key, value in v.items()
if key not in ("filled_by_llm", "value", "required")
}
if v.get("required", False):
params["required"].append(k)
return params
def check_pause(
self, tools_dict: Dict, call, llm_class_name: str
) -> Optional[Dict]:
"""Check if a tool call requires pausing for approval or client execution.
Returns a dict describing the pending action if pause is needed, None otherwise.
"""
parser = ToolActionParser(llm_class_name, name_mapping=self._name_to_tool)
tool_id, action_name, call_args = parser.parse_args(call)
call_id = getattr(call, "id", None) or str(uuid.uuid4())
llm_name = getattr(call, "name", "")
if tool_id is None or action_name is None or tool_id not in tools_dict:
return None # Will be handled as error by execute()
tool_data = tools_dict[tool_id]
# Client-side tools
if tool_data.get("client_side"):
return {
"call_id": call_id,
"name": llm_name,
"tool_name": tool_data.get("name", "unknown"),
"tool_id": tool_id,
"action_name": action_name,
"llm_name": llm_name,
"arguments": call_args if isinstance(call_args, dict) else {},
"pause_type": "requires_client_execution",
"thought_signature": getattr(call, "thought_signature", None),
}
# Approval required
if tool_data["name"] == "api_tool":
action_data = tool_data.get("config", {}).get("actions", {}).get(
action_name, {}
)
else:
action_data = next(
(a for a in tool_data.get("actions", []) if a["name"] == action_name),
{},
)
if action_data.get("require_approval"):
return {
"call_id": call_id,
"name": llm_name,
"tool_name": tool_data.get("name", "unknown"),
"tool_id": tool_id,
"action_name": action_name,
"llm_name": llm_name,
"arguments": call_args if isinstance(call_args, dict) else {},
"pause_type": "awaiting_approval",
"thought_signature": getattr(call, "thought_signature", None),
}
return None
def execute(self, tools_dict: Dict, call, llm_class_name: str):
"""Execute a tool call. Yields status events, returns (result, call_id)."""
parser = ToolActionParser(llm_class_name, name_mapping=self._name_to_tool)
tool_id, action_name, call_args = parser.parse_args(call)
llm_name = getattr(call, "name", "unknown")
call_id = getattr(call, "id", None) or str(uuid.uuid4())
if tool_id is None or action_name is None:
error_message = f"Error: Failed to parse LLM tool call. Tool name: {llm_name}"
logger.error(error_message)
tool_call_data = {
"tool_name": "unknown",
"call_id": call_id,
"action_name": llm_name,
"arguments": call_args or {},
"result": f"Failed to parse tool call. Invalid tool name format: {llm_name}",
}
yield {"type": "tool_call", "data": {**tool_call_data, "status": "error"}}
self.tool_calls.append(tool_call_data)
return "Failed to parse tool call.", call_id
if tool_id not in tools_dict:
error_message = f"Error: Tool ID '{tool_id}' extracted from LLM call not found in available tools_dict. Available IDs: {list(tools_dict.keys())}"
logger.error(error_message)
tool_call_data = {
"tool_name": "unknown",
"call_id": call_id,
"action_name": llm_name,
"arguments": call_args,
"result": f"Tool with ID {tool_id} not found. Available tools: {list(tools_dict.keys())}",
}
yield {"type": "tool_call", "data": {**tool_call_data, "status": "error"}}
self.tool_calls.append(tool_call_data)
return f"Tool with ID {tool_id} not found.", call_id
tool_call_data = {
"tool_name": tools_dict[tool_id]["name"],
"call_id": call_id,
"action_name": llm_name,
"arguments": call_args,
}
yield {"type": "tool_call", "data": {**tool_call_data, "status": "pending"}}
tool_data = tools_dict[tool_id]
action_data = (
tool_data["config"]["actions"][action_name]
if tool_data["name"] == "api_tool"
else next(
action
for action in tool_data["actions"]
if action["name"] == action_name
)
)
query_params, headers, body, parameters = {}, {}, {}, {}
param_types = {
"query_params": query_params,
"headers": headers,
"body": body,
"parameters": parameters,
}
for param_type, target_dict in param_types.items():
if param_type in action_data and action_data[param_type].get("properties"):
for param, details in action_data[param_type]["properties"].items():
if (
param not in call_args
and "value" in details
and details["value"]
):
target_dict[param] = details["value"]
for param, value in call_args.items():
for param_type, target_dict in param_types.items():
if param_type in action_data and param in action_data[param_type].get(
"properties", {}
):
target_dict[param] = value
# Load tool (with caching)
tool = self._get_or_load_tool(
tool_data, tool_id, action_name,
headers=headers, query_params=query_params,
)
resolved_arguments = (
{"query_params": query_params, "headers": headers, "body": body}
if tool_data["name"] == "api_tool"
else parameters
)
if tool_data["name"] == "api_tool":
logger.debug(
f"Executing api: {action_name} with query_params: {query_params}, headers: {headers}, body: {body}"
)
result = tool.execute_action(action_name, **body)
else:
logger.debug(f"Executing tool: {action_name} with args: {call_args}")
result = tool.execute_action(action_name, **parameters)
get_artifact_id = (
getattr(tool, "get_artifact_id", None)
if tool_data["name"] != "api_tool"
else None
)
artifact_id = None
if callable(get_artifact_id):
try:
artifact_id = get_artifact_id(action_name, **parameters)
except Exception:
logger.exception(
"Failed to extract artifact_id from tool %s for action %s",
tool_data["name"],
action_name,
)
artifact_id = str(artifact_id).strip() if artifact_id is not None else ""
if artifact_id:
tool_call_data["artifact_id"] = artifact_id
result_full = str(result)
tool_call_data["resolved_arguments"] = resolved_arguments
tool_call_data["result_full"] = result_full
tool_call_data["result"] = (
f"{result_full[:50]}..." if len(result_full) > 50 else result_full
)
stream_tool_call_data = {
key: value
for key, value in tool_call_data.items()
if key not in {"result_full", "resolved_arguments"}
}
yield {"type": "tool_call", "data": {**stream_tool_call_data, "status": "completed"}}
self.tool_calls.append(tool_call_data)
return result, call_id
def _get_or_load_tool(
self, tool_data: Dict, tool_id: str, action_name: str,
headers: Optional[Dict] = None, query_params: Optional[Dict] = None,
):
"""Load a tool, using cache when possible."""
cache_key = f"{tool_data['name']}:{tool_id}:{self.user or ''}"
if cache_key in self._loaded_tools:
return self._loaded_tools[cache_key]
tm = ToolManager(config={})
if tool_data["name"] == "api_tool":
action_config = tool_data["config"]["actions"][action_name]
tool_config = {
"url": action_config["url"],
"method": action_config["method"],
"headers": headers or {},
"query_params": query_params or {},
}
if "body_content_type" in action_config:
tool_config["body_content_type"] = action_config.get(
"body_content_type", "application/json"
)
tool_config["body_encoding_rules"] = action_config.get(
"body_encoding_rules", {}
)
else:
tool_config = tool_data["config"].copy() if tool_data["config"] else {}
if tool_config.get("encrypted_credentials") and self.user:
decrypted = decrypt_credentials(
tool_config["encrypted_credentials"], self.user
)
tool_config.update(decrypted)
tool_config["auth_credentials"] = decrypted
tool_config.pop("encrypted_credentials", None)
tool_config["tool_id"] = str(tool_data.get("_id", tool_id))
if self.conversation_id:
tool_config["conversation_id"] = self.conversation_id
if tool_data["name"] == "mcp_tool":
tool_config["query_mode"] = True
tool = tm.load_tool(
tool_data["name"],
tool_config=tool_config,
user_id=self.user,
)
# Don't cache api_tool since config varies by action
if tool_data["name"] != "api_tool":
self._loaded_tools[cache_key] = tool
return tool
def get_truncated_tool_calls(self) -> List[Dict]:
return [
{
"tool_name": tool_call.get("tool_name"),
"call_id": tool_call.get("call_id"),
"action_name": tool_call.get("action_name"),
"arguments": tool_call.get("arguments"),
"artifact_id": tool_call.get("artifact_id"),
"result": (
f"{str(tool_call['result'])[:50]}..."
if len(str(tool_call["result"])) > 50
else tool_call["result"]
),
"status": "completed",
}
for tool_call in self.tool_calls
]

View File

@@ -2,8 +2,6 @@ from abc import ABC, abstractmethod
class Tool(ABC):
internal: bool = False
@abstractmethod
def execute_action(self, action_name: str, **kwargs):
pass

View File

@@ -1,11 +1,6 @@
import logging
import requests
from application.agents.tools.base import Tool
logger = logging.getLogger(__name__)
class BraveSearchTool(Tool):
"""
@@ -46,7 +41,7 @@ class BraveSearchTool(Tool):
"""
Performs a web search using the Brave Search API.
"""
logger.debug("Performing Brave web search for: %s", query)
print(f"Performing Brave web search for: {query}")
url = f"{self.base_url}/web/search"
@@ -99,7 +94,7 @@ class BraveSearchTool(Tool):
"""
Performs an image search using the Brave Search API.
"""
logger.debug("Performing Brave image search for: %s", query)
print(f"Performing Brave image search for: {query}")
url = f"{self.base_url}/images/search"
@@ -182,10 +177,6 @@ class BraveSearchTool(Tool):
return {
"token": {
"type": "string",
"label": "API Key",
"description": "Brave Search API key for authentication",
"required": True,
"secret": True,
"order": 1,
},
}

View File

@@ -1,14 +1,5 @@
import logging
import time
from typing import Any, Dict, Optional
from application.agents.tools.base import Tool
logger = logging.getLogger(__name__)
MAX_RETRIES = 3
RETRY_DELAY = 2.0
DEFAULT_TIMEOUT = 15
from duckduckgo_search import DDGS
class DuckDuckGoSearchTool(Tool):
@@ -19,123 +10,71 @@ class DuckDuckGoSearchTool(Tool):
def __init__(self, config):
self.config = config
self.timeout = config.get("timeout", DEFAULT_TIMEOUT)
def _get_ddgs_client(self):
from ddgs import DDGS
return DDGS(timeout=self.timeout)
def _execute_with_retry(self, operation, operation_name: str) -> Dict[str, Any]:
last_error = None
for attempt in range(1, MAX_RETRIES + 1):
try:
results = operation()
return {
"status_code": 200,
"results": list(results) if results else [],
"message": f"{operation_name} completed successfully.",
}
except Exception as e:
last_error = e
error_str = str(e).lower()
if "ratelimit" in error_str or "429" in error_str:
if attempt < MAX_RETRIES:
delay = RETRY_DELAY * attempt
logger.warning(
f"{operation_name} rate limited, retrying in {delay}s (attempt {attempt}/{MAX_RETRIES})"
)
time.sleep(delay)
continue
logger.error(f"{operation_name} failed: {e}")
break
return {
"status_code": 500,
"results": [],
"message": f"{operation_name} failed: {str(last_error)}",
}
def execute_action(self, action_name, **kwargs):
actions = {
"ddg_web_search": self._web_search,
"ddg_image_search": self._image_search,
"ddg_news_search": self._news_search,
}
if action_name not in actions:
if action_name in actions:
return actions[action_name](**kwargs)
else:
raise ValueError(f"Unknown action: {action_name}")
return actions[action_name](**kwargs)
def _web_search(
self,
query: str,
max_results: int = 5,
region: str = "wt-wt",
safesearch: str = "moderate",
timelimit: Optional[str] = None,
) -> Dict[str, Any]:
logger.info(f"DuckDuckGo web search: {query}")
query,
max_results=5,
):
print(f"Performing DuckDuckGo web search for: {query}")
def operation():
client = self._get_ddgs_client()
return client.text(
try:
results = DDGS().text(
query,
region=region,
safesearch=safesearch,
timelimit=timelimit,
max_results=min(max_results, 20),
max_results=max_results,
)
return self._execute_with_retry(operation, "Web search")
return {
"status_code": 200,
"results": results,
"message": "Web search completed successfully.",
}
except Exception as e:
return {
"status_code": 500,
"message": f"Web search failed: {str(e)}",
}
def _image_search(
self,
query: str,
max_results: int = 5,
region: str = "wt-wt",
safesearch: str = "moderate",
timelimit: Optional[str] = None,
) -> Dict[str, Any]:
logger.info(f"DuckDuckGo image search: {query}")
query,
max_results=5,
):
print(f"Performing DuckDuckGo image search for: {query}")
def operation():
client = self._get_ddgs_client()
return client.images(
query,
region=region,
safesearch=safesearch,
timelimit=timelimit,
max_results=min(max_results, 50),
try:
results = DDGS().images(
keywords=query,
max_results=max_results,
)
return self._execute_with_retry(operation, "Image search")
def _news_search(
self,
query: str,
max_results: int = 5,
region: str = "wt-wt",
safesearch: str = "moderate",
timelimit: Optional[str] = None,
) -> Dict[str, Any]:
logger.info(f"DuckDuckGo news search: {query}")
def operation():
client = self._get_ddgs_client()
return client.news(
query,
region=region,
safesearch=safesearch,
timelimit=timelimit,
max_results=min(max_results, 20),
)
return self._execute_with_retry(operation, "News search")
return {
"status_code": 200,
"results": results,
"message": "Image search completed successfully.",
}
except Exception as e:
return {
"status_code": 500,
"message": f"Image search failed: {str(e)}",
}
def get_actions_metadata(self):
return [
{
"name": "ddg_web_search",
"description": "Search the web using DuckDuckGo. Returns titles, URLs, and snippets.",
"description": "Perform a web search using DuckDuckGo.",
"parameters": {
"type": "object",
"properties": {
@@ -145,15 +84,7 @@ class DuckDuckGoSearchTool(Tool):
},
"max_results": {
"type": "integer",
"description": "Number of results (default: 5, max: 20)",
},
"region": {
"type": "string",
"description": "Region code (default: wt-wt for worldwide, us-en for US)",
},
"timelimit": {
"type": "string",
"description": "Time filter: d (day), w (week), m (month), y (year)",
"description": "Number of results to return (default: 5)",
},
},
"required": ["query"],
@@ -161,43 +92,17 @@ class DuckDuckGoSearchTool(Tool):
},
{
"name": "ddg_image_search",
"description": "Search for images using DuckDuckGo. Returns image URLs and metadata.",
"description": "Perform an image search using DuckDuckGo.",
"parameters": {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "Image search query",
"description": "Search query",
},
"max_results": {
"type": "integer",
"description": "Number of results (default: 5, max: 50)",
},
"region": {
"type": "string",
"description": "Region code (default: wt-wt for worldwide)",
},
},
"required": ["query"],
},
},
{
"name": "ddg_news_search",
"description": "Search for news articles using DuckDuckGo. Returns recent news.",
"parameters": {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "News search query",
},
"max_results": {
"type": "integer",
"description": "Number of results (default: 5, max: 20)",
},
"timelimit": {
"type": "string",
"description": "Time filter: d (day), w (week), m (month)",
"description": "Number of results to return (default: 5, max: 50)",
},
},
"required": ["query"],

View File

@@ -1,438 +0,0 @@
import json
import logging
from typing import Dict, List, Optional
from application.agents.tools.base import Tool
from application.core.settings import settings
from application.retriever.retriever_creator import RetrieverCreator
logger = logging.getLogger(__name__)
class InternalSearchTool(Tool):
"""Wraps the ClassicRAG retriever as an LLM-callable tool.
Instead of pre-fetching docs into the prompt, the LLM decides
when and what to search. Supports multiple searches per session.
Optional capabilities (enabled when sources have directory_structure):
- path_filter on search: restrict results to a specific file/folder
- list_files action: browse the file/folder structure
"""
internal = True
def __init__(self, config: Dict):
self.config = config
self.retrieved_docs: List[Dict] = []
self._retriever = None
self._directory_structure: Optional[Dict] = None
self._dir_structure_loaded = False
def _get_retriever(self):
if self._retriever is None:
self._retriever = RetrieverCreator.create_retriever(
self.config.get("retriever_name", "classic"),
source=self.config.get("source", {}),
chat_history=[],
prompt="",
chunks=int(self.config.get("chunks", 2)),
doc_token_limit=int(self.config.get("doc_token_limit", 50000)),
model_id=self.config.get("model_id", "docsgpt-local"),
user_api_key=self.config.get("user_api_key"),
agent_id=self.config.get("agent_id"),
llm_name=self.config.get("llm_name", settings.LLM_PROVIDER),
api_key=self.config.get("api_key", settings.API_KEY),
decoded_token=self.config.get("decoded_token"),
)
return self._retriever
def _get_directory_structure(self) -> Optional[Dict]:
"""Load directory structure from MongoDB for the configured sources."""
if self._dir_structure_loaded:
return self._directory_structure
self._dir_structure_loaded = True
source = self.config.get("source", {})
active_docs = source.get("active_docs", [])
if not active_docs:
return None
try:
from bson.objectid import ObjectId
from application.core.mongo_db import MongoDB
mongo = MongoDB.get_client()
db = mongo[settings.MONGO_DB_NAME]
sources_collection = db["sources"]
if isinstance(active_docs, str):
active_docs = [active_docs]
merged_structure = {}
for doc_id in active_docs:
try:
source_doc = sources_collection.find_one(
{"_id": ObjectId(doc_id)}
)
if not source_doc:
continue
dir_str = source_doc.get("directory_structure")
if dir_str:
if isinstance(dir_str, str):
dir_str = json.loads(dir_str)
source_name = source_doc.get("name", doc_id)
if len(active_docs) > 1:
merged_structure[source_name] = dir_str
else:
merged_structure = dir_str
except Exception as e:
logger.debug(f"Could not load dir structure for {doc_id}: {e}")
self._directory_structure = merged_structure if merged_structure else None
except Exception as e:
logger.debug(f"Failed to load directory structures: {e}")
return self._directory_structure
def execute_action(self, action_name: str, **kwargs):
if action_name == "search":
return self._execute_search(**kwargs)
elif action_name == "list_files":
return self._execute_list_files(**kwargs)
return f"Unknown action: {action_name}"
def _execute_search(self, **kwargs) -> str:
query = kwargs.get("query", "")
path_filter = kwargs.get("path_filter", "")
if not query:
return "Error: 'query' parameter is required."
try:
retriever = self._get_retriever()
docs = retriever.search(query)
except Exception as e:
logger.error(f"Internal search failed: {e}", exc_info=True)
return "Search failed: an internal error occurred."
if not docs:
return "No documents found matching your query."
# Apply path filter if specified
if path_filter:
path_lower = path_filter.lower()
docs = [
d
for d in docs
if path_lower in d.get("source", "").lower()
or path_lower in d.get("filename", "").lower()
or path_lower in d.get("title", "").lower()
]
if not docs:
return f"No documents found matching query '{query}' in path '{path_filter}'."
# Accumulate for source tracking
for doc in docs:
if doc not in self.retrieved_docs:
self.retrieved_docs.append(doc)
# Format results for the LLM
formatted = []
for i, doc in enumerate(docs, 1):
title = doc.get("title", "Untitled")
text = doc.get("text", "")
source = doc.get("source", "Unknown")
filename = doc.get("filename", "")
header = filename or title
formatted.append(f"[{i}] {header} (source: {source})\n{text}")
return "\n\n---\n\n".join(formatted)
def _execute_list_files(self, **kwargs) -> str:
path = kwargs.get("path", "")
dir_structure = self._get_directory_structure()
if not dir_structure:
return "No file structure available for the current sources."
# Navigate to the requested path
current = dir_structure
if path:
for part in path.strip("/").split("/"):
if not part:
continue
if isinstance(current, dict) and part in current:
current = current[part]
else:
return f"Path '{path}' not found in the file structure."
# Format the structure for the LLM
return self._format_structure(current, path or "/")
def _format_structure(self, node: Dict, current_path: str) -> str:
if not isinstance(node, dict):
return f"'{current_path}' is a file, not a directory."
lines = [f"File structure at '{current_path}':\n"]
folders = []
files = []
for name, value in sorted(node.items()):
if isinstance(value, dict):
# Check if it's a file metadata dict or a folder
if "type" in value or "size_bytes" in value or "token_count" in value:
# It's a file with metadata
size = value.get("token_count", "")
ftype = value.get("type", "")
info_parts = []
if ftype:
info_parts.append(ftype)
if size:
info_parts.append(f"{size} tokens")
info = f" ({', '.join(info_parts)})" if info_parts else ""
files.append(f" {name}{info}")
else:
# It's a folder
count = self._count_files(value)
folders.append(f" {name}/ ({count} items)")
else:
files.append(f" {name}")
if folders:
lines.append("Folders:")
lines.extend(folders)
if files:
lines.append("Files:")
lines.extend(files)
if not folders and not files:
lines.append(" (empty)")
return "\n".join(lines)
def _count_files(self, node: Dict) -> int:
count = 0
for value in node.values():
if isinstance(value, dict):
if "type" in value or "size_bytes" in value or "token_count" in value:
count += 1
else:
count += self._count_files(value)
else:
count += 1
return count
def get_actions_metadata(self):
actions = [
{
"name": "search",
"description": (
"Search the user's uploaded documents and knowledge base. "
"Use this to find relevant information before answering questions. "
"You can call this multiple times with different queries."
),
"parameters": {
"properties": {
"query": {
"type": "string",
"description": "The search query. Be specific and focused.",
"filled_by_llm": True,
"required": True,
},
}
},
}
]
# Add path_filter and list_files only if directory structure exists
has_structure = self.config.get("has_directory_structure", False)
if has_structure:
actions[0]["parameters"]["properties"]["path_filter"] = {
"type": "string",
"description": (
"Optional: filter results to a specific file or folder path. "
"Use list_files first to see available paths."
),
"filled_by_llm": True,
"required": False,
}
actions.append(
{
"name": "list_files",
"description": (
"Browse the file and folder structure of the knowledge base. "
"Use this to see what files are available before searching. "
"Optionally provide a path to browse a specific folder."
),
"parameters": {
"properties": {
"path": {
"type": "string",
"description": "Optional: folder path to browse. Leave empty for root.",
"filled_by_llm": True,
"required": False,
}
}
},
}
)
return actions
def get_config_requirements(self):
return {}
# Constants for building synthetic tools_dict entries
INTERNAL_TOOL_ID = "internal"
def build_internal_tool_entry(has_directory_structure: bool = False) -> Dict:
"""Build the tools_dict entry for InternalSearchTool.
Dynamically includes list_files and path_filter based on
whether the sources have directory structure.
"""
search_params = {
"properties": {
"query": {
"type": "string",
"description": "The search query. Be specific and focused.",
"filled_by_llm": True,
"required": True,
}
}
}
actions = [
{
"name": "search",
"description": (
"Search the user's uploaded documents and knowledge base. "
"Use this to find relevant information before answering questions. "
"You can call this multiple times with different queries."
),
"active": True,
"parameters": search_params,
}
]
if has_directory_structure:
search_params["properties"]["path_filter"] = {
"type": "string",
"description": (
"Optional: filter results to a specific file or folder path. "
"Use list_files first to see available paths."
),
"filled_by_llm": True,
"required": False,
}
actions.append(
{
"name": "list_files",
"description": (
"Browse the file and folder structure of the knowledge base. "
"Use this to see what files are available before searching. "
"Optionally provide a path to browse a specific folder."
),
"active": True,
"parameters": {
"properties": {
"path": {
"type": "string",
"description": "Optional: folder path to browse. Leave empty for root.",
"filled_by_llm": True,
"required": False,
}
}
},
}
)
return {"name": "internal_search", "actions": actions}
# Keep backward compat
INTERNAL_TOOL_ENTRY = build_internal_tool_entry(has_directory_structure=False)
def sources_have_directory_structure(source: Dict) -> bool:
"""Check if any of the active sources have directory_structure in MongoDB."""
active_docs = source.get("active_docs", [])
if not active_docs:
return False
try:
from bson.objectid import ObjectId
from application.core.mongo_db import MongoDB
mongo = MongoDB.get_client()
db = mongo[settings.MONGO_DB_NAME]
sources_collection = db["sources"]
if isinstance(active_docs, str):
active_docs = [active_docs]
for doc_id in active_docs:
try:
source_doc = sources_collection.find_one(
{"_id": ObjectId(doc_id)},
{"directory_structure": 1},
)
if source_doc and source_doc.get("directory_structure"):
return True
except Exception:
continue
except Exception as e:
logger.debug(f"Could not check directory structure: {e}")
return False
def add_internal_search_tool(tools_dict: Dict, retriever_config: Dict) -> None:
"""Add the internal search tool to tools_dict if sources are configured.
Shared by AgenticAgent and ResearchAgent to avoid duplicate setup logic.
Mutates tools_dict in place.
"""
source = retriever_config.get("source", {})
has_sources = bool(source.get("active_docs"))
if not retriever_config or not has_sources:
return
has_dir = sources_have_directory_structure(source)
internal_entry = build_internal_tool_entry(has_directory_structure=has_dir)
internal_entry["config"] = build_internal_tool_config(
**retriever_config,
has_directory_structure=has_dir,
)
tools_dict[INTERNAL_TOOL_ID] = internal_entry
def build_internal_tool_config(
source: Dict,
retriever_name: str = "classic",
chunks: int = 2,
doc_token_limit: int = 50000,
model_id: str = "docsgpt-local",
user_api_key: Optional[str] = None,
agent_id: Optional[str] = None,
llm_name: str = None,
api_key: str = None,
decoded_token: Optional[Dict] = None,
has_directory_structure: bool = False,
) -> Dict:
"""Build the config dict for InternalSearchTool."""
return {
"source": source,
"retriever_name": retriever_name,
"chunks": chunks,
"doc_token_limit": doc_token_limit,
"model_id": model_id,
"user_api_key": user_api_key,
"agent_id": agent_id,
"llm_name": llm_name or settings.LLM_PROVIDER,
"api_key": api_key or settings.API_KEY,
"decoded_token": decoded_token,
"has_directory_structure": has_directory_structure,
}

View File

@@ -1,12 +1,20 @@
import asyncio
import base64
import concurrent.futures
import json
import logging
import time
from typing import Any, Dict, List, Optional
from urllib.parse import parse_qs, urlparse
from application.agents.tools.base import Tool
from application.api.user.tasks import mcp_oauth_status_task, mcp_oauth_task
from application.cache import get_redis_instance
from application.core.mongo_db import MongoDB
from application.core.settings import settings
from application.security.encryption import decrypt_credentials
from fastmcp import Client
from fastmcp.client.auth import BearerAuth
from fastmcp.client.transports import (
@@ -16,19 +24,10 @@ from fastmcp.client.transports import (
)
from mcp.client.auth import OAuthClientProvider, TokenStorage
from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken
from pydantic import AnyHttpUrl, ValidationError
from redis import Redis
from application.agents.tools.base import Tool
from application.api.user.tasks import mcp_oauth_status_task, mcp_oauth_task
from application.cache import get_redis_instance
from application.core.mongo_db import MongoDB
from application.core.settings import settings
from application.core.url_validation import SSRFError, validate_url
from application.security.encryption import decrypt_credentials
logger = logging.getLogger(__name__)
mongo = MongoDB.get_client()
db = mongo[settings.MONGO_DB_NAME]
@@ -57,13 +56,11 @@ class MCPTool(Tool):
- args: Arguments for STDIO transport
- oauth_scopes: OAuth scopes for oauth auth type
- oauth_client_name: OAuth client name for oauth auth type
- query_mode: If True, use non-interactive OAuth (fail-fast on 401)
user_id: User ID for decrypting credentials (required if encrypted_credentials exist)
"""
self.config = config
self.user_id = user_id
raw_url = config.get("server_url", "")
self.server_url = self._validate_server_url(raw_url) if raw_url else ""
self.server_url = config.get("server_url", "")
self.transport_type = config.get("transport_type", "auto")
self.auth_type = config.get("auth_type", "none")
self.timeout = config.get("timeout", 30)
@@ -79,53 +76,23 @@ class MCPTool(Tool):
self.oauth_scopes = config.get("oauth_scopes", [])
self.oauth_task_id = config.get("oauth_task_id", None)
self.oauth_client_name = config.get("oauth_client_name", "DocsGPT-MCP")
self.redirect_uri = self._resolve_redirect_uri(config.get("redirect_uri"))
self.redirect_uri = f"{settings.API_URL}/api/mcp_server/callback"
self.available_tools = []
self._cache_key = self._generate_cache_key()
self._client = None
self.query_mode = config.get("query_mode", False)
# Only validate and setup if server_url is provided and not OAuth
if self.server_url and self.auth_type != "oauth":
self._setup_client()
@staticmethod
def _validate_server_url(server_url: str) -> str:
"""Validate server_url to prevent SSRF to internal networks.
Raises:
ValueError: If the URL points to a private/internal address.
"""
try:
return validate_url(server_url)
except SSRFError as exc:
raise ValueError(f"Invalid MCP server URL: {exc}") from exc
def _resolve_redirect_uri(self, configured_redirect_uri: Optional[str]) -> str:
if configured_redirect_uri:
return configured_redirect_uri.rstrip("/")
explicit = getattr(settings, "MCP_OAUTH_REDIRECT_URI", None)
if explicit:
return explicit.rstrip("/")
connector_base = getattr(settings, "CONNECTOR_REDIRECT_BASE_URI", None)
if connector_base:
parsed = urlparse(connector_base)
if parsed.scheme and parsed.netloc:
return f"{parsed.scheme}://{parsed.netloc}/api/mcp_server/callback"
return f"{settings.API_URL.rstrip('/')}/api/mcp_server/callback"
def _generate_cache_key(self) -> str:
"""Generate a unique cache key for this MCP server configuration."""
auth_key = ""
if self.auth_type == "oauth":
scopes_str = ",".join(self.oauth_scopes) if self.oauth_scopes else "none"
oauth_identity = self.user_id or self.oauth_task_id or "anonymous"
auth_key = (
f"oauth:{oauth_identity}:{self.oauth_client_name}:{scopes_str}:{self.redirect_uri}"
)
auth_key = f"oauth:{self.oauth_client_name}:{scopes_str}"
elif self.auth_type in ["bearer"]:
token = self.auth_credentials.get(
"bearer_token", ""
@@ -142,10 +109,11 @@ class MCPTool(Tool):
return f"{self.server_url}#{self.transport_type}#{auth_key}"
def _setup_client(self):
"""Setup FastMCP client with proper transport and authentication."""
global _mcp_clients_cache
if self._cache_key in _mcp_clients_cache:
cached_data = _mcp_clients_cache[self._cache_key]
if time.time() - cached_data["created_at"] < 300:
if time.time() - cached_data["created_at"] < 1800:
self._client = cached_data["client"]
return
else:
@@ -155,25 +123,15 @@ class MCPTool(Tool):
if self.auth_type == "oauth":
redis_client = get_redis_instance()
if self.query_mode:
auth = NonInteractiveOAuth(
mcp_url=self.server_url,
scopes=self.oauth_scopes,
redis_client=redis_client,
redirect_uri=self.redirect_uri,
db=db,
user_id=self.user_id,
)
else:
auth = DocsGPTOAuth(
mcp_url=self.server_url,
scopes=self.oauth_scopes,
redis_client=redis_client,
redirect_uri=self.redirect_uri,
task_id=self.oauth_task_id,
db=db,
user_id=self.user_id,
)
auth = DocsGPTOAuth(
mcp_url=self.server_url,
scopes=self.oauth_scopes,
redis_client=redis_client,
redirect_uri=self.redirect_uri,
task_id=self.oauth_task_id,
db=db,
user_id=self.user_id,
)
elif self.auth_type == "bearer":
token = self.auth_credentials.get(
"bearer_token", ""
@@ -275,53 +233,38 @@ class MCPTool(Tool):
else:
raise Exception(f"Unknown operation: {operation}")
_ERROR_MAP = [
(concurrent.futures.TimeoutError, lambda op, t, _: f"Timed out after {t}s"),
(ConnectionRefusedError, lambda *_: "Connection refused"),
]
_ERROR_PATTERNS = {
("403", "Forbidden"): "Access denied (403 Forbidden)",
("401", "Unauthorized"): "Authentication failed (401 Unauthorized)",
("ECONNREFUSED",): "Connection refused",
("SSL", "certificate"): "SSL/TLS error",
}
def _run_async_operation(self, operation: str, *args, **kwargs):
"""Run async operation in sync context."""
try:
try:
asyncio.get_running_loop()
loop = asyncio.get_running_loop()
import concurrent.futures
def run_in_thread():
new_loop = asyncio.new_event_loop()
asyncio.set_event_loop(new_loop)
try:
return new_loop.run_until_complete(
self._execute_with_client(operation, *args, **kwargs)
)
finally:
new_loop.close()
with concurrent.futures.ThreadPoolExecutor() as executor:
future = executor.submit(
self._run_in_new_loop, operation, *args, **kwargs
)
future = executor.submit(run_in_thread)
return future.result(timeout=self.timeout)
except RuntimeError:
return self._run_in_new_loop(operation, *args, **kwargs)
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
return loop.run_until_complete(
self._execute_with_client(operation, *args, **kwargs)
)
finally:
loop.close()
except Exception as e:
raise self._map_error(operation, e) from e
raise self._map_error(operation, e) from e
def _run_in_new_loop(self, operation, *args, **kwargs):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
return loop.run_until_complete(
self._execute_with_client(operation, *args, **kwargs)
)
finally:
loop.close()
def _map_error(self, operation: str, exc: Exception) -> Exception:
for exc_type, msg_fn in self._ERROR_MAP:
if isinstance(exc, exc_type):
return Exception(msg_fn(operation, self.timeout, exc))
error_msg = str(exc)
for patterns, friendly in self._ERROR_PATTERNS.items():
if any(p.lower() in error_msg.lower() for p in patterns):
return Exception(friendly)
logger.error("MCP %s failed: %s", operation, exc)
return exc
print(f"Error occurred while running async operation: {e}")
raise
def discover_tools(self) -> List[Dict]:
"""
@@ -342,6 +285,16 @@ class MCPTool(Tool):
raise Exception(f"Failed to discover tools from MCP server: {str(e)}")
def execute_action(self, action_name: str, **kwargs) -> Any:
"""
Execute an action on the remote MCP server using FastMCP.
Args:
action_name: Name of the action to execute
**kwargs: Parameters for the action
Returns:
Result from the MCP server
"""
if not self.server_url:
raise Exception("No MCP server configured")
if not self._client:
@@ -357,37 +310,7 @@ class MCPTool(Tool):
)
return self._format_result(result)
except Exception as e:
error_msg = str(e)
lower_msg = error_msg.lower()
is_auth_error = (
"401" in error_msg
or "unauthorized" in lower_msg
or "session expired" in lower_msg
or "re-authorize" in lower_msg
)
if is_auth_error:
if self.auth_type == "oauth":
raise Exception(
f"Action '{action_name}' failed: OAuth session expired. "
"Please re-authorize this MCP server in tool settings."
) from e
global _mcp_clients_cache
_mcp_clients_cache.pop(self._cache_key, None)
self._client = None
self._setup_client()
try:
result = self._run_async_operation(
"call_tool", action_name, **cleaned_kwargs
)
return self._format_result(result)
except Exception as retry_e:
raise Exception(
f"Action '{action_name}' failed after re-auth attempt: {retry_e}. "
"Your credentials may have expired — please re-authorize in tool settings."
) from retry_e
raise Exception(
f"Failed to execute action '{action_name}': {error_msg}"
) from e
raise Exception(f"Failed to execute action '{action_name}': {str(e)}")
def _format_result(self, result) -> Dict:
"""Format FastMCP result to match expected format."""
@@ -410,35 +333,23 @@ class MCPTool(Tool):
return result
def test_connection(self) -> Dict:
"""
Test the connection to the MCP server and validate functionality.
Returns:
Dictionary with connection test results including tool count
"""
if not self.server_url:
return {
"success": False,
"message": "No server URL configured",
"tools_count": 0,
}
try:
parsed = urlparse(self.server_url)
if parsed.scheme not in ("http", "https"):
return {
"success": False,
"message": f"Invalid URL scheme '{parsed.scheme}' — use http:// or https://",
"tools_count": 0,
}
except Exception:
return {
"success": False,
"message": "Invalid URL format",
"message": "No MCP server URL configured",
"tools_count": 0,
"transport_type": self.transport_type,
"auth_type": self.auth_type,
"error_type": "ConfigurationError",
}
if not self._client:
try:
self._setup_client()
except Exception as e:
return {
"success": False,
"message": f"Client init failed: {str(e)}",
"tools_count": 0,
}
self._setup_client()
try:
if self.auth_type == "oauth":
return self._test_oauth_connection()
@@ -449,93 +360,55 @@ class MCPTool(Tool):
"success": False,
"message": f"Connection failed: {str(e)}",
"tools_count": 0,
"transport_type": self.transport_type,
"auth_type": self.auth_type,
"error_type": type(e).__name__,
}
def _test_regular_connection(self) -> Dict:
ping_ok = False
ping_error = None
"""Test connection for non-OAuth auth types."""
try:
self._run_async_operation("ping")
ping_ok = True
except Exception as e:
ping_error = str(e)
try:
tools = self.discover_tools()
except Exception as e:
return {
"success": False,
"message": f"Connection failed: {ping_error or str(e)}",
"tools_count": 0,
}
if not tools and not ping_ok:
return {
"success": False,
"message": f"Connection failed: {ping_error or 'No tools found'}",
"tools_count": 0,
}
ping_success = True
except Exception:
ping_success = False
tools = self.discover_tools()
message = f"Successfully connected to MCP server. Found {len(tools)} tools."
if not ping_success:
message += " (Ping not supported, but tool discovery worked)"
return {
"success": True,
"message": f"Connected — found {len(tools)} tool{'s' if len(tools) != 1 else ''}.",
"message": message,
"tools_count": len(tools),
"tools": [
{
"name": tool.get("name", "unknown"),
"description": tool.get("description", ""),
}
for tool in tools
],
"transport_type": self.transport_type,
"auth_type": self.auth_type,
"ping_supported": ping_success,
"tools": [tool.get("name", "unknown") for tool in tools],
}
def _test_oauth_connection(self) -> Dict:
storage = DBTokenStorage(
server_url=self.server_url, user_id=self.user_id, db_client=db
)
loop = asyncio.new_event_loop()
"""Test connection for OAuth auth type with proper async handling."""
try:
tokens = loop.run_until_complete(storage.get_tokens())
finally:
loop.close()
if tokens and tokens.access_token:
self.query_mode = True
_mcp_clients_cache.pop(self._cache_key, None)
self._client = None
self._setup_client()
try:
tools = self.discover_tools()
return {
"success": True,
"message": f"Connected — found {len(tools)} tool{'s' if len(tools) != 1 else ''}.",
"tools_count": len(tools),
"tools": [
{
"name": t.get("name", "unknown"),
"description": t.get("description", ""),
}
for t in tools
],
}
except Exception as e:
logger.warning("OAuth token validation failed: %s", e)
_mcp_clients_cache.pop(self._cache_key, None)
self._client = None
return self._start_oauth_task()
def _start_oauth_task(self) -> Dict:
task_config = self.config.copy()
task_config.pop("query_mode", None)
result = mcp_oauth_task.delay(task_config, self.user_id)
return {
"success": False,
"requires_oauth": True,
"task_id": result.id,
"message": "OAuth authorization required.",
"tools_count": 0,
}
task = mcp_oauth_task.delay(config=self.config, user=self.user_id)
if not task:
raise Exception("Failed to start OAuth authentication")
return {
"success": True,
"requires_oauth": True,
"task_id": task.id,
"status": "pending",
"message": "OAuth flow started",
}
except Exception as e:
return {
"success": False,
"message": f"OAuth connection failed: {str(e)}",
"tools_count": 0,
"transport_type": self.transport_type,
"auth_type": self.auth_type,
"error_type": type(e).__name__,
}
def get_actions_metadata(self) -> List[Dict]:
"""
@@ -582,88 +455,110 @@ class MCPTool(Tool):
return actions
def get_config_requirements(self) -> Dict:
"""Get configuration requirements for the MCP tool."""
transport_enum = ["auto", "sse", "http"]
transport_help = {
"auto": "Automatically detect best transport",
"sse": "Server-Sent Events (for real-time streaming)",
"http": "HTTP streaming (recommended for production)",
}
return {
"server_url": {
"type": "string",
"label": "Server URL",
"description": "URL of the remote MCP server",
"description": "URL of the remote MCP server (e.g., https://api.example.com/mcp or https://docs.mcp.cloudflare.com/sse)",
"required": True,
"secret": False,
"order": 1,
},
"transport_type": {
"type": "string",
"description": "Transport type for connection",
"enum": transport_enum,
"default": "auto",
"required": False,
"help": {
**transport_help,
},
},
"auth_type": {
"type": "string",
"label": "Authentication Type",
"description": "Authentication method for the MCP server",
"description": "Authentication type",
"enum": ["none", "bearer", "oauth", "api_key", "basic"],
"default": "none",
"required": True,
"secret": False,
"order": 2,
"help": {
"none": "No authentication",
"bearer": "Bearer token authentication",
"oauth": "OAuth 2.1 authentication (with frontend integration)",
"api_key": "API key authentication",
"basic": "Basic authentication",
},
},
"api_key": {
"type": "string",
"label": "API Key",
"description": "API key for authentication",
"auth_credentials": {
"type": "object",
"description": "Authentication credentials (varies by auth_type)",
"required": False,
"secret": True,
"order": 3,
"depends_on": {"auth_type": "api_key"},
},
"api_key_header": {
"type": "string",
"label": "API Key Header",
"description": "Header name for API key (default: X-API-Key)",
"default": "X-API-Key",
"required": False,
"secret": False,
"order": 4,
"depends_on": {"auth_type": "api_key"},
},
"bearer_token": {
"type": "string",
"label": "Bearer Token",
"description": "Bearer token for authentication",
"required": False,
"secret": True,
"order": 3,
"depends_on": {"auth_type": "bearer"},
},
"username": {
"type": "string",
"label": "Username",
"description": "Username for basic authentication",
"required": False,
"secret": False,
"order": 3,
"depends_on": {"auth_type": "basic"},
},
"password": {
"type": "string",
"label": "Password",
"description": "Password for basic authentication",
"required": False,
"secret": True,
"order": 4,
"depends_on": {"auth_type": "basic"},
"properties": {
"bearer_token": {
"type": "string",
"description": "Bearer token for bearer auth",
},
"access_token": {
"type": "string",
"description": "Access token for OAuth (if pre-obtained)",
},
"api_key": {
"type": "string",
"description": "API key for api_key auth",
},
"api_key_header": {
"type": "string",
"description": "Header name for API key (default: X-API-Key)",
},
"username": {
"type": "string",
"description": "Username for basic auth",
},
"password": {
"type": "string",
"description": "Password for basic auth",
},
},
},
"oauth_scopes": {
"type": "string",
"label": "OAuth Scopes",
"description": "Comma-separated OAuth scopes to request",
"type": "array",
"description": "OAuth scopes to request (for oauth auth_type)",
"items": {"type": "string"},
"required": False,
"default": [],
},
"oauth_client_name": {
"type": "string",
"description": "Client name for OAuth registration (for oauth auth_type)",
"default": "DocsGPT-MCP",
"required": False,
},
"headers": {
"type": "object",
"description": "Custom headers to send with requests",
"required": False,
"secret": False,
"order": 3,
"depends_on": {"auth_type": "oauth"},
},
"timeout": {
"type": "number",
"label": "Timeout (seconds)",
"description": "Request timeout in seconds (1-300)",
"type": "integer",
"description": "Request timeout in seconds",
"default": 30,
"minimum": 1,
"maximum": 300,
"required": False,
},
"command": {
"type": "string",
"description": "Command to run for STDIO transport (e.g., 'python')",
"required": False,
},
"args": {
"type": "array",
"description": "Arguments for STDIO command",
"items": {"type": "string"},
"required": False,
"secret": False,
"order": 10,
},
}
@@ -685,8 +580,23 @@ class DocsGPTOAuth(OAuthClientProvider):
user_id=None,
db=None,
additional_client_metadata: dict[str, Any] | None = None,
skip_redirect_validation: bool = False,
):
"""
Initialize custom OAuth client provider for DocsGPT.
Args:
mcp_url: Full URL to the MCP endpoint
redirect_uri: Custom redirect URI for DocsGPT frontend
redis_client: Redis client for storing auth state
redis_prefix: Prefix for Redis keys
task_id: Task ID for tracking auth status
scopes: OAuth scopes to request
client_name: Name for this client during registration
user_id: User ID for token storage
db: Database instance for token storage
additional_client_metadata: Extra fields for OAuthClientMetadata
"""
self.redirect_uri = redirect_uri
self.redis_client = redis_client
self.redis_prefix = redis_prefix
@@ -709,10 +619,7 @@ class DocsGPTOAuth(OAuthClientProvider):
)
storage = DBTokenStorage(
server_url=self.server_base_url,
user_id=self.user_id,
db_client=self.db,
expected_redirect_uri=None if skip_redirect_validation else redirect_uri,
server_url=self.server_base_url, user_id=self.user_id, db_client=self.db
)
super().__init__(
@@ -744,20 +651,22 @@ class DocsGPTOAuth(OAuthClientProvider):
async def redirect_handler(self, authorization_url: str) -> None:
"""Store auth URL and state in Redis for frontend to use."""
auth_url, state = self._process_auth_url(authorization_url)
logger.info("Processed auth_url: %s, state: %s", auth_url, state)
logging.info(
"[DocsGPTOAuth] Processed auth_url: %s, state: %s", auth_url, state
)
self.auth_url = auth_url
self.extracted_state = state
if self.redis_client and self.extracted_state:
key = f"{self.redis_prefix}auth_url:{self.extracted_state}"
self.redis_client.setex(key, 600, auth_url)
logger.info("Stored auth_url in Redis: %s", key)
logging.info("[DocsGPTOAuth] Stored auth_url in Redis: %s", key)
if self.task_id:
status_key = f"mcp_oauth_status:{self.task_id}"
status_data = {
"status": "requires_redirect",
"message": "Authorization required",
"message": "OAuth authorization required",
"authorization_url": self.auth_url,
"state": self.extracted_state,
"requires_oauth": True,
@@ -777,7 +686,7 @@ class DocsGPTOAuth(OAuthClientProvider):
status_key = f"mcp_oauth_status:{self.task_id}"
status_data = {
"status": "awaiting_callback",
"message": "Waiting for authorization...",
"message": "Waiting for OAuth callback...",
"authorization_url": self.auth_url,
"state": self.extracted_state,
"requires_oauth": True,
@@ -802,7 +711,7 @@ class DocsGPTOAuth(OAuthClientProvider):
if self.task_id:
status_data = {
"status": "callback_received",
"message": "Completing authentication...",
"message": "OAuth callback received, completing authentication...",
"task_id": self.task_id,
}
self.redis_client.setex(status_key, 600, json.dumps(status_data))
@@ -822,44 +731,14 @@ class DocsGPTOAuth(OAuthClientProvider):
await asyncio.sleep(poll_interval)
self.redis_client.delete(f"{self.redis_prefix}auth_url:{self.extracted_state}")
self.redis_client.delete(f"{self.redis_prefix}state:{self.extracted_state}")
raise Exception("OAuth timeout: no code received within 5 minutes")
class NonInteractiveOAuth(DocsGPTOAuth):
"""OAuth provider that fails fast on 401 instead of starting interactive auth.
Used during query execution to prevent the streaming response from blocking
while waiting for user authorization that will never come.
"""
def __init__(self, **kwargs):
kwargs.setdefault("task_id", None)
kwargs["skip_redirect_validation"] = True
super().__init__(**kwargs)
async def redirect_handler(self, authorization_url: str) -> None:
raise Exception(
"OAuth session expired — please re-authorize this MCP server in tool settings."
)
async def callback_handler(self) -> tuple[str, str | None]:
raise Exception(
"OAuth session expired — please re-authorize this MCP server in tool settings."
)
raise Exception("OAuth callback timeout: no code received within 5 minutes")
class DBTokenStorage(TokenStorage):
def __init__(
self,
server_url: str,
user_id: str,
db_client,
expected_redirect_uri: Optional[str] = None,
):
def __init__(self, server_url: str, user_id: str, db_client):
self.server_url = server_url
self.user_id = user_id
self.db_client = db_client
self.expected_redirect_uri = expected_redirect_uri
self.collection = db_client["connector_sessions"]
@staticmethod
@@ -878,9 +757,10 @@ class DBTokenStorage(TokenStorage):
if not doc or "tokens" not in doc:
return None
try:
return OAuthToken.model_validate(doc["tokens"])
tokens = OAuthToken.model_validate(doc["tokens"])
return tokens
except ValidationError as e:
logger.error("Could not load tokens: %s", e)
logging.error(f"Could not load tokens: {e}")
return None
async def set_tokens(self, tokens: OAuthToken) -> None:
@@ -890,38 +770,28 @@ class DBTokenStorage(TokenStorage):
{"$set": {"tokens": tokens.model_dump()}},
True,
)
logger.info("Saved tokens for %s", self.get_base_url(self.server_url))
logging.info(f"Saved tokens for {self.get_base_url(self.server_url)}")
async def get_client_info(self) -> OAuthClientInformationFull | None:
doc = await asyncio.to_thread(self.collection.find_one, self.get_db_key())
if not doc or "client_info" not in doc:
logger.debug(
"No client_info in DB for %s", self.get_base_url(self.server_url)
)
return None
try:
client_info = OAuthClientInformationFull.model_validate(doc["client_info"])
if self.expected_redirect_uri:
stored_uris = [
str(uri).rstrip("/") for uri in client_info.redirect_uris
]
expected_uri = self.expected_redirect_uri.rstrip("/")
if expected_uri not in stored_uris:
logger.warning(
"Redirect URI mismatch for %s: expected=%s stored=%s — clearing.",
self.get_base_url(self.server_url),
expected_uri,
stored_uris,
)
await asyncio.to_thread(
self.collection.update_one,
self.get_db_key(),
{"$unset": {"client_info": "", "tokens": ""}},
)
return None
tokens = await self.get_tokens()
if tokens is None:
logging.debug(
"No tokens found, clearing client info to force fresh registration."
)
await asyncio.to_thread(
self.collection.update_one,
self.get_db_key(),
{"$unset": {"client_info": ""}},
)
return None
return client_info
except ValidationError as e:
logger.error("Could not load client info: %s", e)
logging.error(f"Could not load client info: {e}")
return None
def _serialize_client_info(self, info: dict) -> dict:
@@ -937,17 +807,17 @@ class DBTokenStorage(TokenStorage):
{"$set": {"client_info": serialized_info}},
True,
)
logger.info("Saved client info for %s", self.get_base_url(self.server_url))
logging.info(f"Saved client info for {self.get_base_url(self.server_url)}")
async def clear(self) -> None:
await asyncio.to_thread(self.collection.delete_one, self.get_db_key())
logger.info("Cleared OAuth cache for %s", self.get_base_url(self.server_url))
logging.info(f"Cleared OAuth cache for {self.get_base_url(self.server_url)}")
@classmethod
async def clear_all(cls, db_client) -> None:
collection = db_client["connector_sessions"]
await asyncio.to_thread(collection.delete_many, {})
logger.info("Cleared all OAuth client cache data.")
logging.info("Cleared all OAuth client cache data.")
class MCPOAuthManager:
@@ -986,7 +856,7 @@ class MCPOAuthManager:
return True
except Exception as e:
logger.error("Error handling OAuth callback: %s", e)
logging.error(f"Error handling OAuth callback: {e}")
return False
def get_oauth_status(self, task_id: str) -> Dict[str, Any]:

View File

@@ -116,13 +116,12 @@ class NtfyTool(Tool):
]
def get_config_requirements(self):
"""
Specify the configuration requirements.
Returns:
dict: Dictionary describing required config parameters.
"""
return {
"token": {
"type": "string",
"label": "Access Token",
"description": "Ntfy access token for authentication",
"required": True,
"secret": True,
"order": 1,
},
"token": {"type": "string", "description": "Access token for authentication"},
}

View File

@@ -1,12 +1,6 @@
import logging
import psycopg2
from application.agents.tools.base import Tool
logger = logging.getLogger(__name__)
class PostgresTool(Tool):
"""
PostgreSQL Database Tool
@@ -23,15 +17,17 @@ class PostgresTool(Tool):
"postgres_execute_sql": self._execute_sql,
"postgres_get_schema": self._get_schema,
}
if action_name not in actions:
if action_name in actions:
return actions[action_name](**kwargs)
else:
raise ValueError(f"Unknown action: {action_name}")
return actions[action_name](**kwargs)
def _execute_sql(self, sql_query):
"""
Executes an SQL query against the PostgreSQL database using a connection string.
"""
conn = None
conn = None # Initialize conn to None for error handling
try:
conn = psycopg2.connect(self.connection_string)
cur = conn.cursor()
@@ -39,9 +35,7 @@ class PostgresTool(Tool):
conn.commit()
if sql_query.strip().lower().startswith("select"):
column_names = (
[desc[0] for desc in cur.description] if cur.description else []
)
column_names = [desc[0] for desc in cur.description] if cur.description else []
results = []
rows = cur.fetchall()
for row in rows:
@@ -49,9 +43,7 @@ class PostgresTool(Tool):
response_data = {"data": results, "column_names": column_names}
else:
row_count = cur.rowcount
response_data = {
"message": f"Query executed successfully, {row_count} rows affected."
}
response_data = {"message": f"Query executed successfully, {row_count} rows affected."}
cur.close()
return {
@@ -62,27 +54,26 @@ class PostgresTool(Tool):
except psycopg2.Error as e:
error_message = f"Database error: {e}"
logger.error("PostgreSQL execute_sql error: %s", e)
print(f"Database error: {e}")
return {
"status_code": 500,
"message": "Failed to execute SQL query.",
"error": error_message,
}
finally:
if conn:
if conn: # Ensure connection is closed even if errors occur
conn.close()
def _get_schema(self, db_name):
"""
Retrieves the schema of the PostgreSQL database using a connection string.
"""
conn = None
conn = None # Initialize conn to None for error handling
try:
conn = psycopg2.connect(self.connection_string)
cur = conn.cursor()
cur.execute(
"""
cur.execute("""
SELECT
table_name,
column_name,
@@ -96,22 +87,19 @@ class PostgresTool(Tool):
ORDER BY
table_name,
ordinal_position;
"""
)
""")
schema_data = {}
for row in cur.fetchall():
table_name, column_name, data_type, column_default, is_nullable = row
if table_name not in schema_data:
schema_data[table_name] = []
schema_data[table_name].append(
{
"column_name": column_name,
"data_type": data_type,
"column_default": column_default,
"is_nullable": is_nullable,
}
)
schema_data[table_name].append({
"column_name": column_name,
"data_type": data_type,
"column_default": column_default,
"is_nullable": is_nullable
})
cur.close()
return {
@@ -122,14 +110,14 @@ class PostgresTool(Tool):
except psycopg2.Error as e:
error_message = f"Database error: {e}"
logger.error("PostgreSQL get_schema error: %s", e)
print(f"Database error: {e}")
return {
"status_code": 500,
"message": "Failed to retrieve database schema.",
"error": error_message,
}
finally:
if conn:
if conn: # Ensure connection is closed even if errors occur
conn.close()
def get_actions_metadata(self):
@@ -170,10 +158,6 @@ class PostgresTool(Tool):
return {
"token": {
"type": "string",
"label": "Connection String",
"description": "PostgreSQL database connection string",
"required": True,
"secret": True,
"order": 1,
"description": "PostgreSQL database connection string (e.g., 'postgresql://user:password@host:port/dbname')",
},
}
}

View File

@@ -1,11 +1,6 @@
import logging
import requests
from application.agents.tools.base import Tool
logger = logging.getLogger(__name__)
class TelegramTool(Tool):
"""
@@ -23,19 +18,21 @@ class TelegramTool(Tool):
"telegram_send_message": self._send_message,
"telegram_send_image": self._send_image,
}
if action_name not in actions:
if action_name in actions:
return actions[action_name](**kwargs)
else:
raise ValueError(f"Unknown action: {action_name}")
return actions[action_name](**kwargs)
def _send_message(self, text, chat_id):
logger.debug("Sending Telegram message to chat_id=%s", chat_id)
print(f"Sending message: {text}")
url = f"https://api.telegram.org/bot{self.token}/sendMessage"
payload = {"chat_id": chat_id, "text": text}
response = requests.post(url, data=payload)
return {"status_code": response.status_code, "message": "Message sent"}
def _send_image(self, image_url, chat_id):
logger.debug("Sending Telegram image to chat_id=%s", chat_id)
print(f"Sending image: {image_url}")
url = f"https://api.telegram.org/bot{self.token}/sendPhoto"
payload = {"chat_id": chat_id, "photo": image_url}
response = requests.post(url, data=payload)
@@ -85,12 +82,5 @@ class TelegramTool(Tool):
def get_config_requirements(self):
return {
"token": {
"type": "string",
"label": "Bot Token",
"description": "Telegram bot token for authentication",
"required": True,
"secret": True,
"order": 1,
},
"token": {"type": "string", "description": "Bot token for authentication"},
}

View File

@@ -1,70 +0,0 @@
from application.agents.tools.base import Tool
THINK_TOOL_ID = "think"
THINK_TOOL_ENTRY = {
"name": "think",
"actions": [
{
"name": "reason",
"description": (
"Use this tool to think through your reasoning step by step "
"before deciding on your next action. Always reason before "
"searching or answering."
),
"active": True,
"parameters": {
"properties": {
"reasoning": {
"type": "string",
"description": "Your step-by-step reasoning and analysis",
"filled_by_llm": True,
"required": True,
}
}
},
}
],
}
class ThinkTool(Tool):
"""Pseudo-tool that captures chain-of-thought reasoning.
Returns a short acknowledgment so the LLM can continue.
The reasoning content is captured in tool_call data for transparency.
"""
internal = True
def __init__(self, config=None):
pass
def execute_action(self, action_name: str, **kwargs):
return "Continue."
def get_actions_metadata(self):
return [
{
"name": "reason",
"description": (
"Use this tool to think through your reasoning step by step "
"before deciding on your next action. Always reason before "
"searching or answering."
),
"parameters": {
"properties": {
"reasoning": {
"type": "string",
"description": "Your step-by-step reasoning and analysis",
"filled_by_llm": True,
"required": True,
}
}
},
}
]
def get_config_requirements(self):
return {}

View File

@@ -5,9 +5,8 @@ logger = logging.getLogger(__name__)
class ToolActionParser:
def __init__(self, llm_type, name_mapping=None):
def __init__(self, llm_type):
self.llm_type = llm_type
self.name_mapping = name_mapping
self.parsers = {
"OpenAILLM": self._parse_openai_llm,
"GoogleLLM": self._parse_google_llm,
@@ -17,33 +16,22 @@ class ToolActionParser:
parser = self.parsers.get(self.llm_type, self._parse_openai_llm)
return parser(call)
def _resolve_via_mapping(self, call_name):
"""Look up (tool_id, action_name) from the name mapping if available."""
if self.name_mapping and call_name in self.name_mapping:
return self.name_mapping[call_name]
return None
def _parse_openai_llm(self, call):
try:
call_args = json.loads(call.arguments)
resolved = self._resolve_via_mapping(call.name)
if resolved:
return resolved[0], resolved[1], call_args
# Fallback: legacy split on "_" for backward compatibility
tool_parts = call.name.split("_")
# If the tool name doesn't contain an underscore, it's likely a hallucinated tool
if len(tool_parts) < 2:
logger.warning(
f"Invalid tool name format: {call.name}. "
"Could not resolve via mapping or legacy parsing."
f"Invalid tool name format: {call.name}. Expected format: action_name_tool_id"
)
return None, None, None
tool_id = tool_parts[-1]
action_name = "_".join(tool_parts[:-1])
# Validate that tool_id looks like a numerical ID
if not tool_id.isdigit():
logger.warning(
f"Tool ID '{tool_id}' is not numerical. This might be a hallucinated tool call."
@@ -57,24 +45,19 @@ class ToolActionParser:
def _parse_google_llm(self, call):
try:
call_args = call.arguments
resolved = self._resolve_via_mapping(call.name)
if resolved:
return resolved[0], resolved[1], call_args
# Fallback: legacy split on "_" for backward compatibility
tool_parts = call.name.split("_")
# If the tool name doesn't contain an underscore, it's likely a hallucinated tool
if len(tool_parts) < 2:
logger.warning(
f"Invalid tool name format: {call.name}. "
"Could not resolve via mapping or legacy parsing."
f"Invalid tool name format: {call.name}. Expected format: action_name_tool_id"
)
return None, None, None
tool_id = tool_parts[-1]
action_name = "_".join(tool_parts[:-1])
# Validate that tool_id looks like a numerical ID
if not tool_id.isdigit():
logger.warning(
f"Tool ID '{tool_id}' is not numerical. This might be a hallucinated tool call."

View File

@@ -19,7 +19,7 @@ class ToolManager:
continue
module = importlib.import_module(f"application.agents.tools.{name}")
for member_name, obj in inspect.getmembers(module, inspect.isclass):
if issubclass(obj, Tool) and obj is not Tool and not obj.internal:
if issubclass(obj, Tool) and obj is not Tool:
tool_config = self.config.get(name, {})
self.tools[name] = obj(tool_config)
@@ -36,7 +36,7 @@ class ToolManager:
def execute_action(self, tool_name, action_name, user_id=None, **kwargs):
if tool_name not in self.tools:
raise ValueError(f"Tool '{tool_name}' not loaded")
if tool_name in {"mcp_tool", "memory", "todo_list", "notes"} and user_id:
if tool_name in {"mcp_tool", "memory", "todo_list"} and user_id:
tool_config = self.config.get(tool_name, {})
tool = self.load_tool(tool_name, tool_config, user_id)
return tool.execute_action(action_name, **kwargs)

View File

@@ -2,10 +2,9 @@
from typing import Any, Dict, List, Optional, Type
from application.agents.agentic_agent import AgenticAgent
from application.agents.base import BaseAgent
from application.agents.classic_agent import ClassicAgent
from application.agents.research_agent import ResearchAgent
from application.agents.react_agent import ReActAgent
from application.agents.workflows.schemas import AgentType
@@ -37,8 +36,7 @@ class ToolFilterMixin:
return filtered_tools
class _WorkflowNodeMixin:
"""Common __init__ for all workflow node agents."""
class WorkflowNodeClassicAgent(ToolFilterMixin, ClassicAgent):
def __init__(
self,
@@ -59,25 +57,32 @@ class _WorkflowNodeMixin:
self._allowed_tool_ids = tool_ids or []
class WorkflowNodeClassicAgent(ToolFilterMixin, _WorkflowNodeMixin, ClassicAgent):
pass
class WorkflowNodeReActAgent(ToolFilterMixin, ReActAgent):
class WorkflowNodeAgenticAgent(ToolFilterMixin, _WorkflowNodeMixin, AgenticAgent):
pass
class WorkflowNodeResearchAgent(ToolFilterMixin, _WorkflowNodeMixin, ResearchAgent):
pass
def __init__(
self,
endpoint: str,
llm_name: str,
model_id: str,
api_key: str,
tool_ids: Optional[List[str]] = None,
**kwargs,
):
super().__init__(
endpoint=endpoint,
llm_name=llm_name,
model_id=model_id,
api_key=api_key,
**kwargs,
)
self._allowed_tool_ids = tool_ids or []
class WorkflowNodeAgentFactory:
_agents: Dict[AgentType, Type[BaseAgent]] = {
AgentType.CLASSIC: WorkflowNodeClassicAgent,
AgentType.REACT: WorkflowNodeClassicAgent, # backwards compat
AgentType.AGENTIC: WorkflowNodeAgenticAgent,
AgentType.RESEARCH: WorkflowNodeResearchAgent,
AgentType.REACT: WorkflowNodeReActAgent,
}
@classmethod

View File

@@ -18,8 +18,6 @@ class NodeType(str, Enum):
class AgentType(str, Enum):
CLASSIC = "classic"
REACT = "react"
AGENTIC = "agentic"
RESEARCH = "research"
class ExecutionStatus(str, Enum):

View File

@@ -7,7 +7,6 @@ from application.agents.workflows.cel_evaluator import CelEvaluationError, evalu
from application.agents.workflows.node_agent import WorkflowNodeAgentFactory
from application.agents.workflows.schemas import (
AgentNodeConfig,
AgentType,
ConditionNodeConfig,
ExecutionStatus,
NodeExecutionLog,
@@ -19,7 +18,6 @@ from application.core.json_schema_utils import (
JsonSchemaValidationError,
normalize_json_schema_payload,
)
from application.error import sanitize_api_error
from application.templates.namespaces import NamespaceManager
from application.templates.template_engine import TemplateEngine, TemplateRenderError
@@ -101,7 +99,6 @@ class WorkflowEngine:
log_entry["state_snapshot"] = dict(self.state)
self.execution_log.append(log_entry)
user_friendly_error = sanitize_api_error(e)
yield {
"type": "workflow_step",
"node_id": node.id,
@@ -109,9 +106,9 @@ class WorkflowEngine:
"node_title": node.title,
"status": "failed",
"state_snapshot": dict(self.state),
"error": user_friendly_error,
"error": str(e),
}
yield {"type": "error", "error": user_friendly_error}
yield {"type": "error", "error": str(e)}
break
log_entry["state_snapshot"] = dict(self.state)
self.execution_log.append(log_entry)
@@ -224,32 +221,18 @@ class WorkflowEngine:
f'Model "{node_model_id}" does not support structured output for node "{node.title}"'
)
factory_kwargs = {
"agent_type": node_config.agent_type,
"endpoint": self.agent.endpoint,
"llm_name": node_llm_name,
"model_id": node_model_id,
"api_key": node_api_key,
"tool_ids": node_config.tools,
"prompt": node_config.system_prompt,
"chat_history": self.agent.chat_history,
"decoded_token": self.agent.decoded_token,
"json_schema": node_json_schema,
}
# Agentic/research agents need retriever_config for on-demand search
if node_config.agent_type in (AgentType.AGENTIC, AgentType.RESEARCH):
factory_kwargs["retriever_config"] = {
"source": {"active_docs": node_config.sources} if node_config.sources else {},
"retriever_name": node_config.retriever or "classic",
"chunks": int(node_config.chunks) if node_config.chunks else 2,
"model_id": node_model_id,
"llm_name": node_llm_name,
"api_key": node_api_key,
"decoded_token": self.agent.decoded_token,
}
node_agent = WorkflowNodeAgentFactory.create(**factory_kwargs)
node_agent = WorkflowNodeAgentFactory.create(
agent_type=node_config.agent_type,
endpoint=self.agent.endpoint,
llm_name=node_llm_name,
model_id=node_model_id,
api_key=node_api_key,
tool_ids=node_config.tools,
prompt=node_config.system_prompt,
chat_history=self.agent.chat_history,
decoded_token=self.agent.decoded_token,
json_schema=node_json_schema,
)
full_response_parts: List[str] = []
structured_response_parts: List[str] = []

View File

@@ -74,76 +74,68 @@ class AnswerResource(Resource, BaseAnswerResource):
decoded_token = getattr(request, "decoded_token", None)
processor = StreamProcessor(data, decoded_token)
try:
# ---- Continuation mode ----
if data.get("tool_actions"):
(
agent,
messages,
tools_dict,
pending_tool_calls,
tool_actions,
) = processor.resume_from_tool_actions(
data["tool_actions"], data["conversation_id"]
)
if not processor.decoded_token:
return make_response({"error": "Unauthorized"}, 401)
if error := self.check_usage(processor.agent_config):
return error
stream = self.complete_stream(
question="",
agent=agent,
conversation_id=processor.conversation_id,
user_api_key=processor.agent_config.get("user_api_key"),
decoded_token=processor.decoded_token,
agent_id=processor.agent_id,
model_id=processor.model_id,
_continuation={
"messages": messages,
"tools_dict": tools_dict,
"pending_tool_calls": pending_tool_calls,
"tool_actions": tool_actions,
},
)
else:
# ---- Normal mode ----
agent = processor.build_agent(data.get("question", ""))
if not processor.decoded_token:
return make_response({"error": "Unauthorized"}, 401)
processor.initialize()
if not processor.decoded_token:
return make_response({"error": "Unauthorized"}, 401)
if error := self.check_usage(processor.agent_config):
return error
docs_together, docs_list = processor.pre_fetch_docs(
data.get("question", "")
)
tools_data = processor.pre_fetch_tools()
stream = self.complete_stream(
question=data["question"],
agent=agent,
conversation_id=processor.conversation_id,
user_api_key=processor.agent_config.get("user_api_key"),
decoded_token=processor.decoded_token,
isNoneDoc=data.get("isNoneDoc"),
index=None,
should_save_conversation=data.get("save_conversation", True),
agent_id=processor.agent_id,
is_shared_usage=processor.is_shared_usage,
shared_token=processor.shared_token,
model_id=processor.model_id,
)
agent = processor.create_agent(
docs_together=docs_together,
docs=docs_list,
tools_data=tools_data,
)
if error := self.check_usage(processor.agent_config):
return error
stream = self.complete_stream(
question=data["question"],
agent=agent,
conversation_id=processor.conversation_id,
user_api_key=processor.agent_config.get("user_api_key"),
decoded_token=processor.decoded_token,
isNoneDoc=data.get("isNoneDoc"),
index=None,
should_save_conversation=data.get("save_conversation", True),
agent_id=processor.agent_id,
is_shared_usage=processor.is_shared_usage,
shared_token=processor.shared_token,
model_id=processor.model_id,
)
stream_result = self.process_response_stream(stream)
if stream_result["error"]:
return make_response({"error": stream_result["error"]}, 400)
if len(stream_result) == 7:
(
conversation_id,
response,
sources,
tool_calls,
thought,
error,
structured_info,
) = stream_result
else:
conversation_id, response, sources, tool_calls, thought, error = (
stream_result
)
structured_info = None
if error:
return make_response({"error": error}, 400)
result = {
"conversation_id": stream_result["conversation_id"],
"answer": stream_result["answer"],
"sources": stream_result["sources"],
"tool_calls": stream_result["tool_calls"],
"thought": stream_result["thought"],
"conversation_id": conversation_id,
"answer": response,
"sources": sources,
"tool_calls": tool_calls,
"thought": thought,
}
extra_info = stream_result.get("extra")
if extra_info:
result.update(extra_info)
if structured_info:
result.update(structured_info)
except Exception as e:
logger.error(
f"/api/answer - error: {str(e)} - traceback: {traceback.format_exc()}",

View File

@@ -6,7 +6,6 @@ from typing import Any, Dict, Generator, List, Optional
from flask import jsonify, make_response, Response
from flask_restx import Namespace
from application.api.answer.services.continuation_service import ContinuationService
from application.api.answer.services.conversation_service import ConversationService
from application.core.model_utils import (
get_api_key_for_provider,
@@ -16,7 +15,6 @@ from application.core.model_utils import (
from application.core.mongo_db import MongoDB
from application.core.settings import settings
from application.error import sanitize_api_error
from application.llm.llm_creator import LLMCreator
from application.utils import check_required_fields
@@ -40,16 +38,7 @@ class BaseAnswerResource:
def validate_request(
self, data: Dict[str, Any], require_conversation_id: bool = False
) -> Optional[Response]:
"""Common request validation.
Continuation requests (``tool_actions`` present) require
``conversation_id`` but not ``question``.
"""
if data.get("tool_actions"):
# Continuation mode — question is not required
if missing := check_required_fields(data, ["conversation_id"]):
return missing
return None
"""Common request validation"""
required_fields = ["question"]
if require_conversation_id:
required_fields.append("conversation_id")
@@ -187,7 +176,6 @@ class BaseAnswerResource:
is_shared_usage: bool = False,
shared_token: Optional[str] = None,
model_id: Optional[str] = None,
_continuation: Optional[Dict] = None,
) -> Generator[str, None, None]:
"""
Generator function that streams the complete conversation response.
@@ -217,23 +205,9 @@ class BaseAnswerResource:
is_structured = False
schema_info = None
structured_chunks = []
query_metadata = {}
paused = False
if _continuation:
gen_iter = agent.gen_continuation(
messages=_continuation["messages"],
tools_dict=_continuation["tools_dict"],
pending_tool_calls=_continuation["pending_tool_calls"],
tool_actions=_continuation["tool_actions"],
)
else:
gen_iter = agent.gen(query=question)
for line in gen_iter:
if "metadata" in line:
query_metadata.update(line["metadata"])
elif "answer" in line:
for line in agent.gen(query=question):
if "answer" in line:
response_full += str(line["answer"])
if line.get("structured"):
is_structured = True
@@ -266,21 +240,8 @@ class BaseAnswerResource:
data = json.dumps({"type": "thought", "thought": line["thought"]})
yield f"data: {data}\n\n"
elif "type" in line:
if line.get("type") == "tool_calls_pending":
# Save continuation state and end the stream
paused = True
data = json.dumps(line)
yield f"data: {data}\n\n"
elif line.get("type") == "error":
sanitized_error = {
"type": "error",
"error": sanitize_api_error(line.get("error", "An error occurred"))
}
data = json.dumps(sanitized_error)
yield f"data: {data}\n\n"
else:
data = json.dumps(line)
yield f"data: {data}\n\n"
data = json.dumps(line)
yield f"data: {data}\n\n"
if is_structured and structured_chunks:
structured_data = {
"type": "structured_answer",
@@ -290,93 +251,6 @@ class BaseAnswerResource:
}
data = json.dumps(structured_data)
yield f"data: {data}\n\n"
# ---- Paused: save continuation state and end stream early ----
if paused:
continuation = getattr(agent, "_pending_continuation", None)
if continuation:
# Ensure we have a conversation_id — create a partial
# conversation if this is the first turn.
if not conversation_id and should_save_conversation:
try:
provider = (
get_provider_from_model_id(model_id)
if model_id
else settings.LLM_PROVIDER
)
sys_api_key = get_api_key_for_provider(
provider or settings.LLM_PROVIDER
)
llm = LLMCreator.create_llm(
provider or settings.LLM_PROVIDER,
api_key=sys_api_key,
user_api_key=user_api_key,
decoded_token=decoded_token,
model_id=model_id,
agent_id=agent_id,
)
conversation_id = (
self.conversation_service.save_conversation(
None,
question,
response_full,
thought,
source_log_docs,
tool_calls,
llm,
model_id or self.default_model_id,
decoded_token,
api_key=user_api_key,
agent_id=agent_id,
is_shared_usage=is_shared_usage,
shared_token=shared_token,
)
)
except Exception as e:
logger.error(
f"Failed to create conversation for continuation: {e}",
exc_info=True,
)
if conversation_id:
try:
cont_service = ContinuationService()
cont_service.save_state(
conversation_id=str(conversation_id),
user=decoded_token.get("sub", "local"),
messages=continuation["messages"],
pending_tool_calls=continuation["pending_tool_calls"],
tools_dict=continuation["tools_dict"],
tool_schemas=getattr(agent, "tools", []),
agent_config={
"model_id": model_id or self.default_model_id,
"llm_name": getattr(agent, "llm_name", settings.LLM_PROVIDER),
"api_key": getattr(agent, "api_key", None),
"user_api_key": user_api_key,
"agent_id": agent_id,
"agent_type": agent.__class__.__name__,
"prompt": getattr(agent, "prompt", ""),
"json_schema": getattr(agent, "json_schema", None),
"retriever_config": getattr(agent, "retriever_config", None),
},
client_tools=getattr(
agent.tool_executor, "client_tools", None
),
)
except Exception as e:
logger.error(
f"Failed to save continuation state: {str(e)}",
exc_info=True,
)
id_data = {"type": "id", "id": str(conversation_id)}
data = json.dumps(id_data)
yield f"data: {data}\n\n"
data = json.dumps({"type": "end"})
yield f"data: {data}\n\n"
return
if isNoneDoc:
for doc in source_log_docs:
doc["source"] = "None"
@@ -413,7 +287,6 @@ class BaseAnswerResource:
is_shared_usage=is_shared_usage,
shared_token=shared_token,
attachment_ids=attachment_ids,
metadata=query_metadata if query_metadata else None,
)
# Persist compression metadata/summary if it exists and wasn't saved mid-execution
compression_meta = getattr(agent, "compression_metadata", None)
@@ -503,7 +376,6 @@ class BaseAnswerResource:
is_shared_usage=is_shared_usage,
shared_token=shared_token,
attachment_ids=attachment_ids,
metadata=query_metadata if query_metadata else None,
)
compression_meta = getattr(agent, "compression_metadata", None)
compression_saved = getattr(agent, "compression_saved", False)
@@ -540,13 +412,8 @@ class BaseAnswerResource:
yield f"data: {data}\n\n"
return
def process_response_stream(self, stream) -> Dict[str, Any]:
"""Process the stream response for non-streaming endpoint.
Returns:
Dict with keys: conversation_id, answer, sources, tool_calls,
thought, error, and optional extra.
"""
def process_response_stream(self, stream):
"""Process the stream response for non-streaming endpoint"""
conversation_id = ""
response_full = ""
source_log_docs = []
@@ -555,7 +422,6 @@ class BaseAnswerResource:
stream_ended = False
is_structured = False
schema_info = None
pending_tool_calls = None
for line in stream:
try:
@@ -574,22 +440,11 @@ class BaseAnswerResource:
source_log_docs = event["source"]
elif event["type"] == "tool_calls":
tool_calls = event["tool_calls"]
elif event["type"] == "tool_calls_pending":
pending_tool_calls = event.get("data", {}).get(
"pending_tool_calls", []
)
elif event["type"] == "thought":
thought = event["thought"]
elif event["type"] == "error":
logger.error(f"Error from stream: {event['error']}")
return {
"conversation_id": None,
"answer": None,
"sources": None,
"tool_calls": None,
"thought": None,
"error": event["error"],
}
return None, None, None, None, event["error"], None
elif event["type"] == "end":
stream_ended = True
except (json.JSONDecodeError, KeyError) as e:
@@ -597,30 +452,18 @@ class BaseAnswerResource:
continue
if not stream_ended:
logger.error("Stream ended unexpectedly without an 'end' event.")
return {
"conversation_id": None,
"answer": None,
"sources": None,
"tool_calls": None,
"thought": None,
"error": "Stream ended unexpectedly",
}
result: Dict[str, Any] = {
"conversation_id": conversation_id,
"answer": response_full,
"sources": source_log_docs,
"tool_calls": tool_calls,
"thought": thought,
"error": None,
}
if pending_tool_calls is not None:
result["extra"] = {"pending_tool_calls": pending_tool_calls}
return None, None, None, None, "Stream ended unexpectedly", None
result = (
conversation_id,
response_full,
source_log_docs,
tool_calls,
thought,
None,
)
if is_structured:
result["extra"] = {"structured": True, "schema": schema_info}
result = result + ({"structured": True, "schema": schema_info},)
return result
def error_stream_generate(self, err_response):

View File

@@ -79,48 +79,8 @@ class StreamResource(Resource, BaseAnswerResource):
return error
decoded_token = getattr(request, "decoded_token", None)
processor = StreamProcessor(data, decoded_token)
try:
# ---- Continuation mode ----
if data.get("tool_actions"):
(
agent,
messages,
tools_dict,
pending_tool_calls,
tool_actions,
) = processor.resume_from_tool_actions(
data["tool_actions"], data["conversation_id"]
)
if not processor.decoded_token:
return Response(
self.error_stream_generate("Unauthorized"),
status=401,
mimetype="text/event-stream",
)
if error := self.check_usage(processor.agent_config):
return error
return Response(
self.complete_stream(
question="",
agent=agent,
conversation_id=processor.conversation_id,
user_api_key=processor.agent_config.get("user_api_key"),
decoded_token=processor.decoded_token,
agent_id=processor.agent_id,
model_id=processor.model_id,
_continuation={
"messages": messages,
"tools_dict": tools_dict,
"pending_tool_calls": pending_tool_calls,
"tool_actions": tool_actions,
},
),
mimetype="text/event-stream",
)
# ---- Normal mode ----
agent = processor.build_agent(data["question"])
processor.initialize()
if not processor.decoded_token:
return Response(
self.error_stream_generate("Unauthorized"),
@@ -128,6 +88,13 @@ class StreamResource(Resource, BaseAnswerResource):
mimetype="text/event-stream",
)
docs_together, docs_list = processor.pre_fetch_docs(data["question"])
tools_data = processor.pre_fetch_tools()
agent = processor.create_agent(
docs_together=docs_together, docs=docs_list, tools_data=tools_data
)
if error := self.check_usage(processor.agent_config):
return error
return Response(

View File

@@ -1,6 +1,5 @@
"""Message reconstruction utilities for compression."""
import json
import logging
import uuid
from typing import Dict, List, Optional
@@ -50,35 +49,28 @@ class MessageBuilder:
if include_tool_calls and "tool_calls" in query:
for tool_call in query["tool_calls"]:
call_id = tool_call.get("call_id") or str(uuid.uuid4())
args = tool_call.get("arguments")
args_str = (
json.dumps(args)
if isinstance(args, dict)
else (args or "{}")
function_call_dict = {
"function_call": {
"name": tool_call.get("action_name"),
"args": tool_call.get("arguments"),
"call_id": call_id,
}
}
function_response_dict = {
"function_response": {
"name": tool_call.get("action_name"),
"response": {"result": tool_call.get("result")},
"call_id": call_id,
}
}
messages.append(
{"role": "assistant", "content": [function_call_dict]}
)
messages.append({
"role": "assistant",
"content": None,
"tool_calls": [{
"id": call_id,
"type": "function",
"function": {
"name": tool_call.get("action_name", ""),
"arguments": args_str,
},
}],
})
result = tool_call.get("result")
result_str = (
json.dumps(result)
if not isinstance(result, str)
else (result or "")
messages.append(
{"role": "tool", "content": [function_response_dict]}
)
messages.append({
"role": "tool",
"tool_call_id": call_id,
"content": result_str,
})
# If no recent queries (everything was compressed), add a continuation user message
if len(recent_queries) == 0 and compressed_summary:
@@ -188,35 +180,28 @@ class MessageBuilder:
if include_tool_calls and "tool_calls" in query:
for tool_call in query["tool_calls"]:
call_id = tool_call.get("call_id") or str(uuid.uuid4())
args = tool_call.get("arguments")
args_str = (
json.dumps(args)
if isinstance(args, dict)
else (args or "{}")
function_call_dict = {
"function_call": {
"name": tool_call.get("action_name"),
"args": tool_call.get("arguments"),
"call_id": call_id,
}
}
function_response_dict = {
"function_response": {
"name": tool_call.get("action_name"),
"response": {"result": tool_call.get("result")},
"call_id": call_id,
}
}
rebuilt_messages.append(
{"role": "assistant", "content": [function_call_dict]}
)
rebuilt_messages.append({
"role": "assistant",
"content": None,
"tool_calls": [{
"id": call_id,
"type": "function",
"function": {
"name": tool_call.get("action_name", ""),
"arguments": args_str,
},
}],
})
result = tool_call.get("result")
result_str = (
json.dumps(result)
if not isinstance(result, str)
else (result or "")
rebuilt_messages.append(
{"role": "tool", "content": [function_response_dict]}
)
rebuilt_messages.append({
"role": "tool",
"tool_call_id": call_id,
"content": result_str,
})
# If no recent queries (everything was compressed), add a continuation user message
if len(recent_queries) == 0 and compressed_summary:

View File

@@ -1,141 +0,0 @@
"""Service for saving and restoring tool-call continuation state.
When a stream pauses (tool needs approval or client-side execution),
the full execution state is persisted to MongoDB so the client can
resume later by sending tool_actions.
"""
import datetime
import logging
from typing import Any, Dict, List, Optional
from bson import ObjectId
from application.core.mongo_db import MongoDB
from application.core.settings import settings
logger = logging.getLogger(__name__)
# TTL for pending states — auto-cleaned after this period
PENDING_STATE_TTL_SECONDS = 30 * 60 # 30 minutes
def _make_serializable(obj: Any) -> Any:
"""Recursively convert MongoDB ObjectIds and other non-JSON types."""
if isinstance(obj, ObjectId):
return str(obj)
if isinstance(obj, dict):
return {str(k): _make_serializable(v) for k, v in obj.items()}
if isinstance(obj, list):
return [_make_serializable(v) for v in obj]
if isinstance(obj, bytes):
return obj.decode("utf-8", errors="replace")
return obj
class ContinuationService:
"""Manages pending tool-call state in MongoDB."""
def __init__(self):
mongo = MongoDB.get_client()
db = mongo[settings.MONGO_DB_NAME]
self.collection = db["pending_tool_state"]
self._ensure_indexes()
def _ensure_indexes(self):
try:
self.collection.create_index(
"expires_at", expireAfterSeconds=0
)
self.collection.create_index(
[("conversation_id", 1), ("user", 1)], unique=True
)
except Exception:
# Indexes may already exist or mongomock doesn't support TTL
pass
def save_state(
self,
conversation_id: str,
user: str,
messages: List[Dict],
pending_tool_calls: List[Dict],
tools_dict: Dict,
tool_schemas: List[Dict],
agent_config: Dict,
client_tools: Optional[List[Dict]] = None,
) -> str:
"""Save execution state for later continuation.
Args:
conversation_id: The conversation this state belongs to.
user: Owner user ID.
messages: Full messages array at the pause point.
pending_tool_calls: Tool calls awaiting client action.
tools_dict: Serializable tools configuration dict.
tool_schemas: LLM-formatted tool schemas (agent.tools).
agent_config: Config needed to recreate the agent on resume.
client_tools: Client-provided tool schemas for client-side execution.
Returns:
The string ID of the saved state document.
"""
now = datetime.datetime.now(datetime.timezone.utc)
expires_at = now + datetime.timedelta(seconds=PENDING_STATE_TTL_SECONDS)
doc = {
"conversation_id": conversation_id,
"user": user,
"messages": _make_serializable(messages),
"pending_tool_calls": _make_serializable(pending_tool_calls),
"tools_dict": _make_serializable(tools_dict),
"tool_schemas": _make_serializable(tool_schemas),
"agent_config": _make_serializable(agent_config),
"client_tools": _make_serializable(client_tools) if client_tools else None,
"created_at": now,
"expires_at": expires_at,
}
# Upsert — only one pending state per conversation per user
result = self.collection.replace_one(
{"conversation_id": conversation_id, "user": user},
doc,
upsert=True,
)
state_id = str(result.upserted_id) if result.upserted_id else conversation_id
logger.info(
f"Saved continuation state for conversation {conversation_id} "
f"with {len(pending_tool_calls)} pending tool call(s)"
)
return state_id
def load_state(
self, conversation_id: str, user: str
) -> Optional[Dict[str, Any]]:
"""Load pending continuation state.
Returns:
The state dict, or None if no pending state exists.
"""
doc = self.collection.find_one(
{"conversation_id": conversation_id, "user": user}
)
if not doc:
return None
doc["_id"] = str(doc["_id"])
return doc
def delete_state(self, conversation_id: str, user: str) -> bool:
"""Delete pending state after successful resumption.
Returns:
True if a document was deleted.
"""
result = self.collection.delete_one(
{"conversation_id": conversation_id, "user": user}
)
if result.deleted_count:
logger.info(
f"Deleted continuation state for conversation {conversation_id}"
)
return result.deleted_count > 0

View File

@@ -60,7 +60,6 @@ class ConversationService:
is_shared_usage: bool = False,
shared_token: Optional[str] = None,
attachment_ids: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
) -> str:
"""Save or update a conversation in the database"""
if decoded_token is None:
@@ -94,11 +93,6 @@ class ConversationService:
f"queries.{index}.timestamp": current_time,
f"queries.{index}.attachments": attachment_ids,
f"queries.{index}.model_id": model_id,
**(
{f"queries.{index}.metadata": metadata}
if metadata
else {}
),
}
},
)
@@ -130,7 +124,6 @@ class ConversationService:
"timestamp": current_time,
"attachments": attachment_ids,
"model_id": model_id,
**({"metadata": metadata} if metadata else {}),
}
}
},
@@ -163,24 +156,22 @@ class ConversationService:
if not completion or not completion.strip():
completion = question[:50] if question else "New Conversation"
query_doc = {
"prompt": question,
"response": response,
"thought": thought,
"sources": sources,
"tool_calls": tool_calls,
"timestamp": current_time,
"attachments": attachment_ids,
"model_id": model_id,
}
if metadata:
query_doc["metadata"] = metadata
conversation_data = {
"user": user_id,
"date": current_time,
"name": completion,
"queries": [query_doc],
"queries": [
{
"prompt": question,
"response": response,
"thought": thought,
"sources": sources,
"tool_calls": tool_calls,
"timestamp": current_time,
"attachments": attachment_ids,
"model_id": model_id,
}
],
}
if api_key:

View File

@@ -38,23 +38,13 @@ def get_prompt(prompt_id: str, prompts_collection=None) -> str:
current_dir = Path(__file__).resolve().parents[3]
prompts_dir = current_dir / "prompts"
# Maps for classic agent types
CLASSIC_PRESETS = {
preset_mapping = {
"default": "chat_combine_default.txt",
"creative": "chat_combine_creative.txt",
"strict": "chat_combine_strict.txt",
"reduce": "chat_reduce_prompt.txt",
}
# Agentic counterparts — same styles, but with search tool instructions
AGENTIC_PRESETS = {
"default": "agentic/default.txt",
"creative": "agentic/creative.txt",
"strict": "agentic/strict.txt",
}
preset_mapping = {**CLASSIC_PRESETS, **{f"agentic_{k}": v for k, v in AGENTIC_PRESETS.items()}}
if prompt_id in preset_mapping:
file_path = os.path.join(prompts_dir, preset_mapping[prompt_id])
try:
@@ -101,7 +91,6 @@ class StreamProcessor:
self.is_shared_usage = False
self.shared_token = None
self.agent_id = self.data.get("agent_id")
self.agent_key = None
self.model_id: Optional[str] = None
self.conversation_service = ConversationService()
self.compression_orchestrator = CompressionOrchestrator(
@@ -112,7 +101,6 @@ class StreamProcessor:
self._required_tool_actions: Optional[Dict[str, Set[Optional[str]]]] = None
self.compressed_summary: Optional[str] = None
self.compressed_summary_tokens: int = 0
self._agent_data: Optional[Dict[str, Any]] = None
def initialize(self):
"""Initialize all required components for processing"""
@@ -123,29 +111,6 @@ class StreamProcessor:
self._load_conversation_history()
self._process_attachments()
def build_agent(self, question: str):
"""One call to go from request data to a ready-to-run agent.
Combines initialize(), pre_fetch_docs(), pre_fetch_tools(), and
create_agent() into a single convenience method.
"""
self.initialize()
agent_type = self.agent_config.get("agent_type", "classic")
# Agentic/research agents skip pre-fetch — the LLM searches on-demand via tools
if agent_type in ("agentic", "research"):
tools_data = self.pre_fetch_tools()
return self.create_agent(tools_data=tools_data)
docs_together, docs_list = self.pre_fetch_docs(question)
tools_data = self.pre_fetch_tools()
return self.create_agent(
docs_together=docs_together,
docs=docs_list,
tools_data=tools_data,
)
def _load_conversation_history(self):
"""Load conversation history either from DB or request"""
if self.conversation_id and self.initial_user_id:
@@ -159,17 +124,9 @@ class StreamProcessor:
if settings.ENABLE_CONVERSATION_COMPRESSION:
self._handle_compression(conversation)
else:
# Original behavior - load all history (include metadata if present)
# Original behavior - load all history
self.history = [
{
"prompt": query["prompt"],
"response": query["response"],
**(
{"metadata": query["metadata"]}
if "metadata" in query
else {}
),
}
{"prompt": query["prompt"], "response": query["response"]}
for query in conversation.get("queries", [])
]
else:
@@ -178,8 +135,14 @@ class StreamProcessor:
)
def _handle_compression(self, conversation: Dict[str, Any]):
"""Handle conversation compression logic using orchestrator."""
"""
Handle conversation compression logic using orchestrator.
Args:
conversation: Full conversation document
"""
try:
# Use orchestrator to handle all compression logic
result = self.compression_orchestrator.compress_if_needed(
conversation_id=self.conversation_id,
user_id=self.initial_user_id,
@@ -190,15 +153,12 @@ class StreamProcessor:
if not result.success:
logger.error(f"Compression failed: {result.error}, using full history")
self.history = [
{
"prompt": query["prompt"],
"response": query["response"],
**({"metadata": query["metadata"]} if "metadata" in query else {}),
}
{"prompt": query["prompt"], "response": query["response"]}
for query in conversation.get("queries", [])
]
return
# Set compressed summary if compression was performed
if result.compression_performed and result.compressed_summary:
self.compressed_summary = result.compressed_summary
self.compressed_summary_tokens = TokenCounter.count_message_tokens(
@@ -209,27 +169,17 @@ class StreamProcessor:
f"+ {len(result.recent_queries)} recent messages"
)
# Build history from recent queries
self.history = result.as_history()
# Preserve metadata from recent queries (as_history only has prompt/response)
recent = result.recent_queries if result.recent_queries else conversation.get("queries", [])
for i, entry in enumerate(self.history):
# Match by index from the end of recent queries
offset = len(recent) - len(self.history)
qi = offset + i
if 0 <= qi < len(recent) and "metadata" in recent[qi]:
entry["metadata"] = recent[qi]["metadata"]
except Exception as e:
logger.error(
f"Error handling compression, falling back to standard history: {str(e)}",
exc_info=True,
)
# Fallback to original behavior
self.history = [
{
"prompt": query["prompt"],
"response": query["response"],
**({"metadata": query["metadata"]} if "metadata" in query else {}),
}
{"prompt": query["prompt"], "response": query["response"]}
for query in conversation.get("queries", [])
]
@@ -241,6 +191,9 @@ class StreamProcessor:
)
def _get_attachments_content(self, attachment_ids, user_id):
"""
Retrieve content from attachment documents based on their IDs.
"""
if not attachment_ids:
return []
attachments = []
@@ -249,6 +202,7 @@ class StreamProcessor:
attachment_doc = self.attachments_collection.find_one(
{"_id": ObjectId(attachment_id), "user": user_id}
)
if attachment_doc:
attachments.append(attachment_doc)
except Exception as e:
@@ -278,6 +232,7 @@ class StreamProcessor:
)
self.model_id = requested_model
else:
# Check if agent has a default model configured
agent_default_model = self.agent_config.get("default_model_id", "")
if agent_default_model and validate_model_id(agent_default_model):
self.model_id = agent_default_model
@@ -330,6 +285,7 @@ class StreamProcessor:
data["source"] = "default"
else:
data["source"] = None
# Handle multiple sources
sources = data.get("sources", [])
if sources and isinstance(sources, list):
@@ -355,34 +311,28 @@ class StreamProcessor:
else:
data["sources"] = []
# Preserve model configuration from agent
data["default_model_id"] = data.get("default_model_id", "")
return data
def _configure_source(self):
"""Configure the source based on agent data.
"""Configure the source based on agent data"""
api_key = self.data.get("api_key") or self.agent_key
The literal string ``"default"`` is a placeholder meaning "no
ingested source" and is normalized to an empty source so that no
retrieval is attempted.
"""
if self._agent_data:
agent_data = self._agent_data
if api_key:
agent_data = self._get_data_from_api_key(api_key)
if agent_data.get("sources") and len(agent_data["sources"]) > 0:
source_ids = [
source["id"]
for source in agent_data["sources"]
if source.get("id") and source["id"] != "default"
source["id"] for source in agent_data["sources"] if source.get("id")
]
if source_ids:
self.source = {"active_docs": source_ids}
else:
self.source = {}
self.all_sources = [
s for s in agent_data["sources"] if s.get("id") != "default"
]
elif agent_data.get("source") and agent_data["source"] != "default":
self.all_sources = agent_data["sources"]
elif agent_data.get("source"):
self.source = {"active_docs": agent_data["source"]}
self.all_sources = [
{
@@ -395,100 +345,84 @@ class StreamProcessor:
self.all_sources = []
return
if "active_docs" in self.data:
active_docs = self.data["active_docs"]
if active_docs and active_docs != "default":
self.source = {"active_docs": active_docs}
else:
self.source = {}
self.source = {"active_docs": self.data["active_docs"]}
return
self.source = {}
self.all_sources = []
def _has_active_docs(self) -> bool:
"""Return True if a real document source is configured for retrieval."""
active_docs = self.source.get("active_docs") if self.source else None
if not active_docs:
return False
if active_docs == "default":
return False
return True
def _resolve_agent_id(self) -> Optional[str]:
"""Resolve agent_id from request, then fall back to conversation context."""
request_agent_id = self.data.get("agent_id")
if request_agent_id:
return str(request_agent_id)
if not self.conversation_id or not self.initial_user_id:
return None
try:
conversation = self.conversation_service.get_conversation(
self.conversation_id, self.initial_user_id
)
except Exception:
return None
if not conversation:
return None
conversation_agent_id = conversation.get("agent_id")
if conversation_agent_id:
return str(conversation_agent_id)
return None
def _configure_agent(self):
"""Configure the agent based on request data.
Unified flow: resolve the effective API key, then extract config once.
"""
agent_id = self._resolve_agent_id()
"""Configure the agent based on request data"""
agent_id = self.data.get("agent_id")
self.agent_key, self.is_shared_usage, self.shared_token = self._get_agent_key(
agent_id, self.initial_user_id
)
self.agent_id = str(agent_id) if agent_id else None
# Determine the effective API key (explicit > agent-derived)
effective_key = self.data.get("api_key") or self.agent_key
if effective_key:
self._agent_data = self._get_data_from_api_key(effective_key)
if self._agent_data.get("_id"):
self.agent_id = str(self._agent_data.get("_id"))
api_key = self.data.get("api_key")
if api_key:
data_key = self._get_data_from_api_key(api_key)
if data_key.get("_id"):
self.agent_id = str(data_key.get("_id"))
self.agent_config.update(
{
"prompt_id": self._agent_data.get("prompt_id", "default"),
"agent_type": self._agent_data.get("agent_type", settings.AGENT_NAME),
"user_api_key": effective_key,
"json_schema": self._agent_data.get("json_schema"),
"default_model_id": self._agent_data.get("default_model_id", ""),
"models": self._agent_data.get("models", []),
"allow_system_prompt_override": self._agent_data.get(
"allow_system_prompt_override", False
),
"prompt_id": data_key.get("prompt_id", "default"),
"agent_type": data_key.get("agent_type", settings.AGENT_NAME),
"user_api_key": api_key,
"json_schema": data_key.get("json_schema"),
"default_model_id": data_key.get("default_model_id", ""),
}
)
# Set identity context
if self.data.get("api_key"):
# External API key: use the key owner's identity
self.initial_user_id = self._agent_data.get("user")
self.decoded_token = {"sub": self._agent_data.get("user")}
elif self.is_shared_usage:
# Shared agent: keep the caller's identity
pass
else:
# Owner using their own agent
self.decoded_token = {"sub": self._agent_data.get("user")}
if self._agent_data.get("workflow"):
self.agent_config["workflow"] = self._agent_data["workflow"]
self.agent_config["workflow_owner"] = self._agent_data.get("user")
self.initial_user_id = data_key.get("user")
self.decoded_token = {"sub": data_key.get("user")}
if data_key.get("source"):
self.source = {"active_docs": data_key["source"]}
if data_key.get("workflow"):
self.agent_config["workflow"] = data_key["workflow"]
self.agent_config["workflow_owner"] = data_key.get("user")
if data_key.get("retriever"):
self.retriever_config["retriever_name"] = data_key["retriever"]
if data_key.get("chunks") is not None:
try:
self.retriever_config["chunks"] = int(data_key["chunks"])
except (ValueError, TypeError):
logger.warning(
f"Invalid chunks value: {data_key['chunks']}, using default value 2"
)
self.retriever_config["chunks"] = 2
elif self.agent_key:
data_key = self._get_data_from_api_key(self.agent_key)
if data_key.get("_id"):
self.agent_id = str(data_key.get("_id"))
self.agent_config.update(
{
"prompt_id": data_key.get("prompt_id", "default"),
"agent_type": data_key.get("agent_type", settings.AGENT_NAME),
"user_api_key": self.agent_key,
"json_schema": data_key.get("json_schema"),
"default_model_id": data_key.get("default_model_id", ""),
}
)
self.decoded_token = (
self.decoded_token
if self.is_shared_usage
else {"sub": data_key.get("user")}
)
if data_key.get("source"):
self.source = {"active_docs": data_key["source"]}
if data_key.get("workflow"):
self.agent_config["workflow"] = data_key["workflow"]
self.agent_config["workflow_owner"] = data_key.get("user")
if data_key.get("retriever"):
self.retriever_config["retriever_name"] = data_key["retriever"]
if data_key.get("chunks") is not None:
try:
self.retriever_config["chunks"] = int(data_key["chunks"])
except (ValueError, TypeError):
logger.warning(
f"Invalid chunks value: {data_key['chunks']}, using default value 2"
)
self.retriever_config["chunks"] = 2
else:
# No API key — default/workflow configuration
agent_type = settings.AGENT_NAME
if self.data.get("workflow") and isinstance(
self.data.get("workflow"), dict
@@ -509,45 +443,14 @@ class StreamProcessor:
)
def _configure_retriever(self):
"""Assemble retriever config with precedence: request > agent > default."""
doc_token_limit = calculate_doc_token_budget(model_id=self.model_id)
# Start with defaults
retriever_name = "classic"
chunks = 2
# Layer agent-level config (if present)
if self._agent_data:
if self._agent_data.get("retriever"):
retriever_name = self._agent_data["retriever"]
if self._agent_data.get("chunks") is not None:
try:
chunks = int(self._agent_data["chunks"])
except (ValueError, TypeError):
logger.warning(
f"Invalid agent chunks value: {self._agent_data['chunks']}, "
"using default value 2"
)
# Explicit request values win over agent config
if "retriever" in self.data:
retriever_name = self.data["retriever"]
if "chunks" in self.data:
try:
chunks = int(self.data["chunks"])
except (ValueError, TypeError):
logger.warning(
f"Invalid request chunks value: {self.data['chunks']}, "
"using default value 2"
)
self.retriever_config = {
"retriever_name": retriever_name,
"chunks": chunks,
"retriever_name": self.data.get("retriever", "classic"),
"chunks": int(self.data.get("chunks", 2)),
"doc_token_limit": doc_token_limit,
}
# isNoneDoc without an API key forces no retrieval
api_key = self.data.get("api_key") or self.agent_key
if not api_key and "isNoneDoc" in self.data and self.data["isNoneDoc"]:
self.retriever_config["chunks"] = 0
@@ -568,12 +471,9 @@ class StreamProcessor:
def pre_fetch_docs(self, question: str) -> tuple[Optional[str], Optional[list]]:
"""Pre-fetch documents for template rendering before agent creation"""
if self.data.get("isNoneDoc", False) and not self.agent_id:
if self.data.get("isNoneDoc", False):
logger.info("Pre-fetch skipped: isNoneDoc=True")
return None, None
if not self._has_active_docs():
logger.info("Pre-fetch skipped: no active docs configured")
return None, None
try:
retriever = self.create_retriever()
logger.info(
@@ -605,7 +505,12 @@ class StreamProcessor:
return None, None
def pre_fetch_tools(self) -> Optional[Dict[str, Any]]:
"""Pre-fetch tool data for template rendering before agent creation"""
"""Pre-fetch tool data for template rendering before agent creation
Can be controlled via:
1. Global setting: ENABLE_TOOL_PREFETCH in .env
2. Per-request: disable_tool_prefetch in request data
"""
if not settings.ENABLE_TOOL_PREFETCH:
logger.info(
"Tool pre-fetching disabled globally via ENABLE_TOOL_PREFETCH setting"
@@ -817,121 +722,6 @@ class StreamProcessor:
logger.warning(f"Failed to fetch memory tool data: {str(e)}")
return None
def resume_from_tool_actions(
self,
tool_actions: list,
conversation_id: str,
):
"""Resume a paused agent from saved continuation state.
Loads the pending state from MongoDB, recreates the agent with
the saved configuration, and returns an agent ready to call
``gen_continuation()``.
Args:
tool_actions: Client-provided actions (approvals / results).
conversation_id: The conversation being resumed.
Returns:
Tuple of (agent, messages, tools_dict, pending_tool_calls, tool_actions).
"""
from application.api.answer.services.continuation_service import (
ContinuationService,
)
from application.agents.agent_creator import AgentCreator
from application.agents.tool_executor import ToolExecutor
from application.llm.handlers.handler_creator import LLMHandlerCreator
from application.llm.llm_creator import LLMCreator
cont_service = ContinuationService()
state = cont_service.load_state(conversation_id, self.initial_user_id)
if not state:
raise ValueError("No pending tool state found for this conversation")
messages = state["messages"]
pending_tool_calls = state["pending_tool_calls"]
tools_dict = state["tools_dict"]
tool_schemas = state.get("tool_schemas", [])
agent_config = state["agent_config"]
model_id = agent_config.get("model_id")
llm_name = agent_config.get("llm_name", settings.LLM_PROVIDER)
api_key = agent_config.get("api_key")
user_api_key = agent_config.get("user_api_key")
agent_id = agent_config.get("agent_id")
prompt = agent_config.get("prompt", "")
json_schema = agent_config.get("json_schema")
retriever_config = agent_config.get("retriever_config")
# Recreate dependencies
system_api_key = api_key or get_api_key_for_provider(llm_name)
llm = LLMCreator.create_llm(
llm_name,
api_key=system_api_key,
user_api_key=user_api_key,
decoded_token=self.decoded_token,
model_id=model_id,
agent_id=agent_id,
)
llm_handler = LLMHandlerCreator.create_handler(llm_name or "default")
tool_executor = ToolExecutor(
user_api_key=user_api_key,
user=self.initial_user_id,
decoded_token=self.decoded_token,
)
tool_executor.conversation_id = conversation_id
# Restore client tools so they stay available for subsequent LLM calls
saved_client_tools = state.get("client_tools")
if saved_client_tools:
tool_executor.client_tools = saved_client_tools
# Re-merge into tools_dict (they may have been stripped during serialization)
tool_executor.merge_client_tools(tools_dict, saved_client_tools)
agent_type = agent_config.get("agent_type", "ClassicAgent")
# Map class names back to agent creator keys
type_map = {
"ClassicAgent": "classic",
"AgenticAgent": "agentic",
"ResearchAgent": "research",
"WorkflowAgent": "workflow",
}
agent_key = type_map.get(agent_type, "classic")
agent_kwargs = {
"endpoint": "stream",
"llm_name": llm_name,
"model_id": model_id,
"api_key": system_api_key,
"agent_id": agent_id,
"user_api_key": user_api_key,
"prompt": prompt,
"chat_history": [],
"decoded_token": self.decoded_token,
"json_schema": json_schema,
"llm": llm,
"llm_handler": llm_handler,
"tool_executor": tool_executor,
}
if agent_key in ("agentic", "research") and retriever_config:
agent_kwargs["retriever_config"] = retriever_config
agent = AgentCreator.create_agent(agent_key, **agent_kwargs)
agent.conversation_id = conversation_id
agent.initial_user_id = self.initial_user_id
agent.tools = tool_schemas
# Store config for the route layer
self.model_id = model_id
self.agent_id = agent_id
self.agent_config["user_api_key"] = user_api_key
self.conversation_id = conversation_id
# Delete state so it can't be replayed
cont_service.delete_state(conversation_id, self.initial_user_id)
return agent, messages, tools_dict, pending_tool_calls, tool_actions
def create_agent(
self,
docs_together: Optional[str] = None,
@@ -939,40 +729,22 @@ class StreamProcessor:
tools_data: Optional[Dict[str, Any]] = None,
):
"""Create and return the configured agent with rendered prompt"""
agent_type = self.agent_config["agent_type"]
# For agentic agents, swap standard presets for their agentic
# counterparts (which include search tool instructions instead of
# {summaries}). Custom / user-provided prompts pass through as-is.
raw_prompt = self._get_prompt_content()
if raw_prompt is None:
prompt_id = self.agent_config.get("prompt_id", "default")
agentic_presets = {"default", "creative", "strict"}
if agent_type in ("agentic", "research") and prompt_id in agentic_presets:
raw_prompt = get_prompt(
f"agentic_{prompt_id}", self.prompts_collection
)
else:
raw_prompt = get_prompt(prompt_id, self.prompts_collection)
raw_prompt = get_prompt(
self.agent_config["prompt_id"], self.prompts_collection
)
self._prompt_content = raw_prompt
# Allow API callers to override the system prompt when the agent
# has opted in via allow_system_prompt_override.
if (
self.agent_config.get("allow_system_prompt_override", False)
and self.data.get("system_prompt_override")
):
rendered_prompt = self.data["system_prompt_override"]
else:
rendered_prompt = self.prompt_renderer.render_prompt(
prompt_content=raw_prompt,
user_id=self.initial_user_id,
request_id=self.data.get("request_id"),
passthrough_data=self.data.get("passthrough"),
docs=docs,
docs_together=docs_together,
tools_data=tools_data,
)
rendered_prompt = self.prompt_renderer.render_prompt(
prompt_content=raw_prompt,
user_id=self.initial_user_id,
request_id=self.data.get("request_id"),
passthrough_data=self.data.get("passthrough"),
docs=docs,
docs_together=docs_together,
tools_data=tools_data,
)
provider = (
get_provider_from_model_id(self.model_id)
@@ -981,39 +753,7 @@ class StreamProcessor:
)
system_api_key = get_api_key_for_provider(provider or settings.LLM_PROVIDER)
# Create LLM and handler (dependency injection)
from application.llm.llm_creator import LLMCreator
from application.llm.handlers.handler_creator import LLMHandlerCreator
from application.agents.tool_executor import ToolExecutor
# Compute backup models: agent's configured models minus the active one
agent_models = self.agent_config.get("models", [])
backup_models = [m for m in agent_models if m != self.model_id]
llm = LLMCreator.create_llm(
provider or settings.LLM_PROVIDER,
api_key=system_api_key,
user_api_key=self.agent_config["user_api_key"],
decoded_token=self.decoded_token,
model_id=self.model_id,
agent_id=self.agent_id,
backup_models=backup_models,
)
llm_handler = LLMHandlerCreator.create_handler(
provider if provider else "default"
)
user = self.decoded_token.get("sub") if self.decoded_token else None
tool_executor = ToolExecutor(
user_api_key=self.agent_config["user_api_key"],
user=user,
decoded_token=self.decoded_token,
)
tool_executor.conversation_id = self.conversation_id
# Pass client-side tools so they get merged in get_tools()
client_tools = self.data.get("client_tools")
if client_tools:
tool_executor.client_tools = client_tools
agent_type = self.agent_config["agent_type"]
# Base agent kwargs
agent_kwargs = {
@@ -1030,31 +770,10 @@ class StreamProcessor:
"attachments": self.attachments,
"json_schema": self.agent_config.get("json_schema"),
"compressed_summary": self.compressed_summary,
"llm": llm,
"llm_handler": llm_handler,
"tool_executor": tool_executor,
}
# Type-specific kwargs
if agent_type in ("agentic", "research"):
agent_kwargs["retriever_config"] = {
"source": self.source,
"retriever_name": self.retriever_config.get(
"retriever_name", "classic"
),
"chunks": self.retriever_config.get("chunks", 2),
"doc_token_limit": self.retriever_config.get(
"doc_token_limit", 50000
),
"model_id": self.model_id,
"user_api_key": self.agent_config["user_api_key"],
"agent_id": self.agent_id,
"llm_name": provider or settings.LLM_PROVIDER,
"api_key": system_api_key,
"decoded_token": self.decoded_token,
}
elif agent_type == "workflow":
# Workflow-specific kwargs for workflow agents
if agent_type == "workflow":
workflow_config = self.agent_config.get("workflow")
if isinstance(workflow_config, str):
agent_kwargs["workflow_id"] = workflow_config

View File

@@ -146,19 +146,20 @@ class ConnectorsCallback(Resource):
session_token = str(uuid.uuid4())
try:
if provider == "google_drive":
credentials = auth.create_credentials_from_token_info(token_info)
service = auth.build_drive_service(credentials)
user_info = service.about().get(fields="user").execute()
user_email = user_info.get('user', {}).get('emailAddress', 'Connected User')
else:
user_email = token_info.get('user_info', {}).get('email', 'Connected User')
credentials = auth.create_credentials_from_token_info(token_info)
service = auth.build_drive_service(credentials)
user_info = service.about().get(fields="user").execute()
user_email = user_info.get('user', {}).get('emailAddress', 'Connected User')
except Exception as e:
current_app.logger.warning(f"Could not get user info: {e}")
user_email = 'Connected User'
sanitized_token_info = auth.sanitize_token_info(token_info)
sanitized_token_info = {
"access_token": token_info.get("access_token"),
"refresh_token": token_info.get("refresh_token"),
"token_uri": token_info.get("token_uri"),
"expiry": token_info.get("expiry")
}
sessions_collection.find_one_and_update(
{"_id": ObjectId(state_object_id), "provider": provider},
@@ -200,12 +201,12 @@ class ConnectorsCallback(Resource):
@connectors_ns.route("/api/connectors/files")
class ConnectorFiles(Resource):
@api.expect(api.model("ConnectorFilesModel", {
"provider": fields.String(required=True),
"session_token": fields.String(required=True),
"folder_id": fields.String(required=False),
"limit": fields.Integer(required=False),
"provider": fields.String(required=True),
"session_token": fields.String(required=True),
"folder_id": fields.String(required=False),
"limit": fields.Integer(required=False),
"page_token": fields.String(required=False),
"search_query": fields.String(required=False),
"search_query": fields.String(required=False)
}))
@api.doc(description="List files from a connector provider (supports pagination and search)")
def post(self):
@@ -213,8 +214,11 @@ class ConnectorFiles(Resource):
data = request.get_json()
provider = data.get('provider')
session_token = data.get('session_token')
folder_id = data.get('folder_id')
limit = data.get('limit', 10)
page_token = data.get('page_token')
search_query = data.get('search_query')
if not provider or not session_token:
return make_response(jsonify({"success": False, "error": "provider and session_token are required"}), 400)
@@ -227,12 +231,15 @@ class ConnectorFiles(Resource):
return make_response(jsonify({"success": False, "error": "Invalid or unauthorized session"}), 401)
loader = ConnectorCreator.create_connector(provider, session_token)
generic_keys = {'provider', 'session_token'}
input_config = {
k: v for k, v in data.items() if k not in generic_keys
'limit': limit,
'list_only': True,
'session_token': session_token,
'folder_id': folder_id,
'page_token': page_token
}
input_config['list_only'] = True
if search_query:
input_config['search_query'] = search_query
documents = loader.load_data(input_config)
@@ -299,7 +306,12 @@ class ConnectorValidateSession(Resource):
if is_expired and token_info.get('refresh_token'):
try:
refreshed_token_info = auth.refresh_access_token(token_info.get('refresh_token'))
sanitized_token_info = auth.sanitize_token_info(refreshed_token_info)
sanitized_token_info = {
"access_token": refreshed_token_info.get("access_token"),
"refresh_token": refreshed_token_info.get("refresh_token"),
"token_uri": refreshed_token_info.get("token_uri"),
"expiry": refreshed_token_info.get("expiry")
}
sessions_collection.update_one(
{"session_token": session_token},
{"$set": {"token_info": sanitized_token_info}}
@@ -316,18 +328,12 @@ class ConnectorValidateSession(Resource):
"error": "Session token has expired. Please reconnect."
}), 401)
_base_fields = {"access_token", "refresh_token", "token_uri", "expiry"}
provider_extras = {k: v for k, v in token_info.items() if k not in _base_fields}
response_data = {
return make_response(jsonify({
"success": True,
"expired": False,
"user_email": session.get('user_email', 'Connected User'),
"access_token": token_info.get('access_token'),
**provider_extras,
}
return make_response(jsonify(response_data), 200)
"access_token": token_info.get('access_token')
}), 200)
except Exception as e:
current_app.logger.error(f"Error validating connector session: {e}", exc_info=True)
return make_response(jsonify({"success": False, "error": "Failed to validate session"}), 500)

View File

@@ -26,20 +26,12 @@ internal = Blueprint("internal", __name__)
@internal.before_request
def verify_internal_key():
"""Verify INTERNAL_KEY for all internal endpoint requests.
Deny by default: if INTERNAL_KEY is not configured, reject all requests.
"""
if not settings.INTERNAL_KEY:
logger.warning(
f"Internal API request rejected from {request.remote_addr}: "
"INTERNAL_KEY is not configured"
)
return jsonify({"error": "Unauthorized", "message": "Internal API is not configured"}), 401
internal_key = request.headers.get("X-Internal-Key")
if not internal_key or internal_key != settings.INTERNAL_KEY:
logger.warning(f"Unauthorized internal API access attempt from {request.remote_addr}")
return jsonify({"error": "Unauthorized", "message": "Invalid or missing internal key"}), 401
"""Verify INTERNAL_KEY for all internal endpoint requests."""
if settings.INTERNAL_KEY:
internal_key = request.headers.get("X-Internal-Key")
if not internal_key or internal_key != settings.INTERNAL_KEY:
logger.warning(f"Unauthorized internal API access attempt from {request.remote_addr}")
return jsonify({"error": "Unauthorized", "message": "Invalid or missing internal key"}), 401
@internal.route("/api/download", methods=["get"])

View File

@@ -5,7 +5,7 @@ Provides virtual folder organization for agents (Google Drive-like structure).
import datetime
from bson.objectid import ObjectId
from flask import current_app, jsonify, make_response, request
from flask import jsonify, make_response, request
from flask_restx import Namespace, Resource, fields
from application.api import api
@@ -19,11 +19,6 @@ agents_folders_ns = Namespace(
)
def _folder_error_response(message: str, err: Exception):
current_app.logger.error(f"{message}: {err}", exc_info=True)
return make_response(jsonify({"success": False, "message": message}), 400)
@agents_folders_ns.route("/")
class AgentFolders(Resource):
@api.doc(description="Get all folders for the user")
@@ -45,8 +40,8 @@ class AgentFolders(Resource):
for f in folders
]
return make_response(jsonify({"folders": result}), 200)
except Exception as err:
return _folder_error_response("Failed to fetch folders", err)
except Exception as e:
return make_response(jsonify({"success": False, "message": str(e)}), 400)
@api.doc(description="Create a new folder")
@api.expect(
@@ -87,8 +82,8 @@ class AgentFolders(Resource):
jsonify({"id": str(result.inserted_id), "name": data["name"], "parent_id": parent_id}),
201,
)
except Exception as err:
return _folder_error_response("Failed to create folder", err)
except Exception as e:
return make_response(jsonify({"success": False, "message": str(e)}), 400)
@agents_folders_ns.route("/<string:folder_id>")
@@ -122,8 +117,8 @@ class AgentFolder(Resource):
}),
200,
)
except Exception as err:
return _folder_error_response("Failed to fetch folder", err)
except Exception as e:
return make_response(jsonify({"success": False, "message": str(e)}), 400)
@api.doc(description="Update a folder")
def put(self, folder_id):
@@ -150,8 +145,8 @@ class AgentFolder(Resource):
if result.matched_count == 0:
return make_response(jsonify({"success": False, "message": "Folder not found"}), 404)
return make_response(jsonify({"success": True}), 200)
except Exception as err:
return _folder_error_response("Failed to update folder", err)
except Exception as e:
return make_response(jsonify({"success": False, "message": str(e)}), 400)
@api.doc(description="Delete a folder")
def delete(self, folder_id):
@@ -170,8 +165,8 @@ class AgentFolder(Resource):
if result.deleted_count == 0:
return make_response(jsonify({"success": False, "message": "Folder not found"}), 404)
return make_response(jsonify({"success": True}), 200)
except Exception as err:
return _folder_error_response("Failed to delete folder", err)
except Exception as e:
return make_response(jsonify({"success": False, "message": str(e)}), 400)
@agents_folders_ns.route("/move_agent")
@@ -216,8 +211,8 @@ class MoveAgentToFolder(Resource):
)
return make_response(jsonify({"success": True}), 200)
except Exception as err:
return _folder_error_response("Failed to move agent", err)
except Exception as e:
return make_response(jsonify({"success": False, "message": str(e)}), 400)
@agents_folders_ns.route("/bulk_move")
@@ -262,5 +257,5 @@ class BulkMoveAgents(Resource):
{"$unset": {"folder_id": ""}},
)
return make_response(jsonify({"success": True}), 200)
except Exception as err:
return _folder_error_response("Failed to move agents", err)
except Exception as e:
return make_response(jsonify({"success": False, "message": str(e)}), 400)

View File

@@ -73,7 +73,6 @@ AGENT_TYPE_SCHEMAS = {
"token_limit",
"limited_request_mode",
"request_limit",
"allow_system_prompt_override",
"createdAt",
"updatedAt",
"lastUsedAt",
@@ -97,7 +96,6 @@ AGENT_TYPE_SCHEMAS = {
"token_limit",
"limited_request_mode",
"request_limit",
"allow_system_prompt_override",
"createdAt",
"updatedAt",
"lastUsedAt",
@@ -106,8 +104,6 @@ AGENT_TYPE_SCHEMAS = {
}
AGENT_TYPE_SCHEMAS["react"] = AGENT_TYPE_SCHEMAS["classic"]
AGENT_TYPE_SCHEMAS["agentic"] = AGENT_TYPE_SCHEMAS["classic"]
AGENT_TYPE_SCHEMAS["research"] = AGENT_TYPE_SCHEMAS["classic"]
AGENT_TYPE_SCHEMAS["openai"] = AGENT_TYPE_SCHEMAS["classic"]
@@ -222,12 +218,6 @@ def build_agent_document(
base_doc["request_limit"] = int(
data.get("request_limit", settings.DEFAULT_AGENT_LIMITS["request_limit"])
)
if "allow_system_prompt_override" in allowed_fields:
base_doc["allow_system_prompt_override"] = (
data.get("allow_system_prompt_override") == "True"
if isinstance(data.get("allow_system_prompt_override"), str)
else bool(data.get("allow_system_prompt_override", False))
)
return {k: v for k, v in base_doc.items() if k in allowed_fields}
@@ -300,9 +290,6 @@ class GetAgent(Resource):
"default_model_id": agent.get("default_model_id", ""),
"folder_id": agent.get("folder_id"),
"workflow": agent.get("workflow"),
"allow_system_prompt_override": agent.get(
"allow_system_prompt_override", False
),
}
return make_response(jsonify(data), 200)
except Exception as e:
@@ -384,9 +371,6 @@ class GetAgents(Resource):
"default_model_id": agent.get("default_model_id", ""),
"folder_id": agent.get("folder_id"),
"workflow": agent.get("workflow"),
"allow_system_prompt_override": agent.get(
"allow_system_prompt_override", False
),
}
for agent in agents
if "source" in agent
@@ -464,10 +448,6 @@ class CreateAgent(Resource):
"folder_id": fields.String(
required=False, description="Folder ID to organize the agent"
),
"allow_system_prompt_override": fields.Boolean(
required=False,
description="Allow API callers to override the system prompt via the v1 endpoint",
),
},
)
@@ -509,9 +489,9 @@ class CreateAgent(Resource):
data["json_schema"] = normalize_json_schema_payload(
data.get("json_schema")
)
except JsonSchemaValidationError:
except JsonSchemaValidationError as exc:
return make_response(
jsonify({"success": False, "message": "Invalid JSON schema"}),
jsonify({"success": False, "message": f"JSON schema {exc}"}),
400,
)
if data.get("status") not in ["draft", "published"]:
@@ -692,10 +672,6 @@ class UpdateAgent(Resource):
"folder_id": fields.String(
required=False, description="Folder ID to organize the agent"
),
"allow_system_prompt_override": fields.Boolean(
required=False,
description="Allow API callers to override the system prompt via the v1 endpoint",
),
},
)
@@ -764,7 +740,13 @@ class UpdateAgent(Resource):
request, existing_agent.get("image", ""), user, storage
)
if error:
return error
current_app.logger.error(
f"Image upload error for agent {agent_id}: {error}"
)
return make_response(
jsonify({"success": False, "message": f"Image upload failed: {error}"}),
400,
)
update_fields = {}
allowed_fields = [
"name",
@@ -787,7 +769,6 @@ class UpdateAgent(Resource):
"default_model_id",
"folder_id",
"workflow",
"allow_system_prompt_override",
]
for field in allowed_fields:
@@ -895,9 +876,9 @@ class UpdateAgent(Resource):
update_fields[field] = normalize_json_schema_payload(
json_schema
)
except JsonSchemaValidationError:
except JsonSchemaValidationError as exc:
return make_response(
jsonify({"success": False, "message": "Invalid JSON schema"}),
jsonify({"success": False, "message": f"JSON schema {exc}"}),
400,
)
else:
@@ -1006,13 +987,6 @@ class UpdateAgent(Resource):
if workflow_error:
return workflow_error
update_fields[field] = workflow_id
elif field == "allow_system_prompt_override":
raw_value = data.get("allow_system_prompt_override", False)
update_fields[field] = (
raw_value == "True"
if isinstance(raw_value, str)
else bool(raw_value)
)
else:
value = data[field]
if field in ["name", "description", "prompt_id", "agent_type"]:

View File

@@ -1,36 +1,15 @@
"""File attachments and media routes."""
import os
import tempfile
from pathlib import Path
from bson.objectid import ObjectId
from flask import current_app, jsonify, make_response, request
from flask_restx import fields, Namespace, Resource
from application.api import api
from application.cache import get_redis_instance
from application.api.user.base import agents_collection, storage
from application.api.user.tasks import store_attachment
from application.core.settings import settings
from application.stt.constants import (
SUPPORTED_AUDIO_EXTENSIONS,
SUPPORTED_AUDIO_MIME_TYPES,
)
from application.stt.upload_limits import (
AudioFileTooLargeError,
build_stt_file_size_limit_message,
enforce_audio_file_size_limit,
is_audio_filename,
)
from application.stt.live_session import (
apply_live_stt_hypothesis,
create_live_stt_session,
delete_live_stt_session,
finalize_live_stt_session,
get_live_stt_transcript_text,
load_live_stt_session,
save_live_stt_session,
)
from application.stt.stt_creator import STTCreator
from application.tts.tts_creator import TTSCreator
from application.utils import safe_filename
@@ -40,74 +19,6 @@ attachments_ns = Namespace(
)
def _resolve_authenticated_user():
decoded_token = getattr(request, "decoded_token", None)
api_key = request.form.get("api_key") or request.args.get("api_key")
if decoded_token:
return safe_filename(decoded_token.get("sub"))
if api_key:
from application.api.user.base import agents_collection
agent = agents_collection.find_one({"key": api_key})
if not agent:
return make_response(
jsonify({"success": False, "message": "Invalid API key"}), 401
)
return safe_filename(agent.get("user"))
return None
def _get_uploaded_file_size(file) -> int:
try:
current_position = file.stream.tell()
file.stream.seek(0, os.SEEK_END)
size_bytes = file.stream.tell()
file.stream.seek(current_position)
return size_bytes
except Exception:
return 0
def _is_supported_audio_mimetype(mimetype: str) -> bool:
if not mimetype:
return True
normalized = mimetype.split(";")[0].strip().lower()
return normalized.startswith("audio/") or normalized in SUPPORTED_AUDIO_MIME_TYPES
def _enforce_uploaded_audio_size_limit(file, filename: str) -> None:
if not is_audio_filename(filename):
return
size_bytes = _get_uploaded_file_size(file)
if size_bytes:
enforce_audio_file_size_limit(size_bytes)
def _get_store_attachment_user_error(exc: Exception) -> str:
if isinstance(exc, AudioFileTooLargeError):
return build_stt_file_size_limit_message()
return "Failed to process file"
def _require_live_stt_redis():
redis_client = get_redis_instance()
if redis_client:
return redis_client
return make_response(
jsonify({"success": False, "message": "Live transcription is unavailable"}),
503,
)
def _parse_bool_form_value(value: str | None) -> bool:
if value is None:
return False
return value.strip().lower() in {"1", "true", "yes", "on"}
@attachments_ns.route("/store_attachment")
class StoreAttachment(Resource):
@api.expect(
@@ -125,9 +36,8 @@ class StoreAttachment(Resource):
description="Stores one or multiple attachments without vectorization or training. Supports user or API key authentication."
)
def post(self):
auth_user = _resolve_authenticated_user()
if hasattr(auth_user, "status_code"):
return auth_user
decoded_token = getattr(request, "decoded_token", None)
api_key = request.form.get("api_key") or request.args.get("api_key")
files = request.files.getlist("file")
if not files:
@@ -141,16 +51,22 @@ class StoreAttachment(Resource):
400,
)
user = auth_user
if not user:
user = None
if decoded_token:
user = safe_filename(decoded_token.get("sub"))
elif api_key:
agent = agents_collection.find_one({"key": api_key})
if not agent:
return make_response(
jsonify({"success": False, "message": "Invalid API key"}), 401
)
user = safe_filename(agent.get("user"))
else:
return make_response(
jsonify({"success": False, "message": "Authentication required"}), 401
)
try:
from application.api.user.tasks import store_attachment
from application.api.user.base import storage
tasks = []
errors = []
original_file_count = len(files)
@@ -159,7 +75,6 @@ class StoreAttachment(Resource):
try:
attachment_id = ObjectId()
original_filename = safe_filename(os.path.basename(file.filename))
_enforce_uploaded_audio_size_limit(file, original_filename)
relative_path = f"{settings.UPLOAD_FOLDER}/{user}/attachments/{str(attachment_id)}/{original_filename}"
metadata = storage.save_file(file, relative_path)
@@ -175,31 +90,15 @@ class StoreAttachment(Resource):
"task_id": task.id,
"filename": original_filename,
"attachment_id": str(attachment_id),
"upload_index": idx,
})
except Exception as file_err:
current_app.logger.error(f"Error processing file {idx} ({file.filename}): {file_err}", exc_info=True)
errors.append({
"upload_index": idx,
"filename": file.filename,
"error": _get_store_attachment_user_error(file_err),
"error": str(file_err)
})
if not tasks:
if errors and all(
error.get("error") == build_stt_file_size_limit_message()
for error in errors
):
return make_response(
jsonify(
{
"success": False,
"message": build_stt_file_size_limit_message(),
"errors": errors,
}
),
413,
)
return make_response(
jsonify({"status": "error", "message": "No valid files to upload"}),
400,
@@ -236,389 +135,11 @@ class StoreAttachment(Resource):
return make_response(jsonify({"success": False, "error": "Failed to store attachment"}), 400)
@attachments_ns.route("/stt")
class SpeechToText(Resource):
@api.expect(
api.model(
"SpeechToTextModel",
{
"file": fields.Raw(required=True, description="Audio file"),
"language": fields.String(
required=False, description="Optional transcription language hint"
),
},
)
)
@api.doc(description="Transcribe an uploaded audio file")
def post(self):
auth_user = _resolve_authenticated_user()
if hasattr(auth_user, "status_code"):
return auth_user
if not auth_user:
return make_response(
jsonify({"success": False, "message": "Authentication required"}),
401,
)
file = request.files.get("file")
if not file or file.filename == "":
return make_response(
jsonify({"success": False, "message": "Missing file"}),
400,
)
filename = safe_filename(os.path.basename(file.filename))
suffix = Path(filename).suffix.lower()
if suffix not in SUPPORTED_AUDIO_EXTENSIONS:
return make_response(
jsonify({"success": False, "message": "Unsupported audio format"}),
400,
)
if not _is_supported_audio_mimetype(file.mimetype or ""):
return make_response(
jsonify({"success": False, "message": "Unsupported audio MIME type"}),
400,
)
try:
_enforce_uploaded_audio_size_limit(file, filename)
except AudioFileTooLargeError:
return make_response(
jsonify(
{
"success": False,
"message": build_stt_file_size_limit_message(),
}
),
413,
)
temp_path = None
try:
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as temp_file:
file.save(temp_file.name)
temp_path = Path(temp_file.name)
stt_instance = STTCreator.create_stt(settings.STT_PROVIDER)
transcript = stt_instance.transcribe(
temp_path,
language=request.form.get("language") or settings.STT_LANGUAGE,
timestamps=settings.STT_ENABLE_TIMESTAMPS,
diarize=settings.STT_ENABLE_DIARIZATION,
)
return make_response(jsonify({"success": True, **transcript}), 200)
except Exception as err:
current_app.logger.error(f"Error transcribing audio: {err}", exc_info=True)
return make_response(
jsonify({"success": False, "message": "Failed to transcribe audio"}),
400,
)
finally:
if temp_path and temp_path.exists():
temp_path.unlink()
@attachments_ns.route("/stt/live/start")
class LiveSpeechToTextStart(Resource):
@api.doc(description="Start a live speech-to-text session")
def post(self):
auth_user = _resolve_authenticated_user()
if hasattr(auth_user, "status_code"):
return auth_user
if not auth_user:
return make_response(
jsonify({"success": False, "message": "Authentication required"}),
401,
)
redis_client = _require_live_stt_redis()
if hasattr(redis_client, "status_code"):
return redis_client
payload = request.get_json(silent=True) or {}
session_state = create_live_stt_session(
user=auth_user,
language=payload.get("language") or settings.STT_LANGUAGE,
)
save_live_stt_session(redis_client, session_state)
return make_response(
jsonify(
{
"success": True,
"session_id": session_state["session_id"],
"language": session_state.get("language"),
"committed_text": "",
"mutable_text": "",
"previous_hypothesis": "",
"latest_hypothesis": "",
"finalized_text": "",
"pending_text": "",
"transcript_text": "",
}
),
200,
)
@attachments_ns.route("/stt/live/chunk")
class LiveSpeechToTextChunk(Resource):
@api.expect(
api.model(
"LiveSpeechToTextChunkModel",
{
"session_id": fields.String(
required=True, description="Live transcription session ID"
),
"chunk_index": fields.Integer(
required=True, description="Sequential chunk index"
),
"is_silence": fields.Boolean(
required=False,
description="Whether the latest capture window was mostly silence",
),
"file": fields.Raw(required=True, description="Audio chunk"),
},
)
)
@api.doc(description="Transcribe a chunk for a live speech-to-text session")
def post(self):
auth_user = _resolve_authenticated_user()
if hasattr(auth_user, "status_code"):
return auth_user
if not auth_user:
return make_response(
jsonify({"success": False, "message": "Authentication required"}),
401,
)
redis_client = _require_live_stt_redis()
if hasattr(redis_client, "status_code"):
return redis_client
session_id = request.form.get("session_id", "").strip()
if not session_id:
return make_response(
jsonify({"success": False, "message": "Missing session_id"}),
400,
)
session_state = load_live_stt_session(redis_client, session_id)
if not session_state:
return make_response(
jsonify(
{
"success": False,
"message": "Live transcription session not found",
}
),
404,
)
if safe_filename(str(session_state.get("user", ""))) != auth_user:
return make_response(
jsonify({"success": False, "message": "Forbidden"}),
403,
)
chunk_index_raw = request.form.get("chunk_index", "").strip()
if chunk_index_raw == "":
return make_response(
jsonify({"success": False, "message": "Missing chunk_index"}),
400,
)
try:
chunk_index = int(chunk_index_raw)
except ValueError:
return make_response(
jsonify({"success": False, "message": "Invalid chunk_index"}),
400,
)
is_silence = _parse_bool_form_value(request.form.get("is_silence"))
file = request.files.get("file")
if not file or file.filename == "":
return make_response(
jsonify({"success": False, "message": "Missing file"}),
400,
)
filename = safe_filename(os.path.basename(file.filename))
suffix = Path(filename).suffix.lower()
if suffix not in SUPPORTED_AUDIO_EXTENSIONS:
return make_response(
jsonify({"success": False, "message": "Unsupported audio format"}),
400,
)
if not _is_supported_audio_mimetype(file.mimetype or ""):
return make_response(
jsonify({"success": False, "message": "Unsupported audio MIME type"}),
400,
)
try:
_enforce_uploaded_audio_size_limit(file, filename)
except AudioFileTooLargeError:
return make_response(
jsonify(
{
"success": False,
"message": build_stt_file_size_limit_message(),
}
),
413,
)
temp_path = None
try:
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as temp_file:
file.save(temp_file.name)
temp_path = Path(temp_file.name)
session_language = session_state.get("language") or settings.STT_LANGUAGE
stt_instance = STTCreator.create_stt(settings.STT_PROVIDER)
transcript = stt_instance.transcribe(
temp_path,
language=session_language,
timestamps=False,
diarize=False,
)
if not session_state.get("language") and transcript.get("language"):
session_state["language"] = transcript["language"]
try:
apply_live_stt_hypothesis(
session_state,
str(transcript.get("text", "")),
chunk_index,
is_silence=is_silence,
)
except ValueError:
current_app.logger.warning(
"Invalid live transcription chunk",
exc_info=True,
)
return make_response(
jsonify(
{
"success": False,
"message": "Invalid live transcription chunk",
}
),
409,
)
save_live_stt_session(redis_client, session_state)
return make_response(
jsonify(
{
"success": True,
"session_id": session_id,
"chunk_index": chunk_index,
"chunk_text": transcript.get("text", ""),
"is_silence": is_silence,
"language": session_state.get("language"),
"committed_text": session_state.get("committed_text", ""),
"mutable_text": session_state.get("mutable_text", ""),
"previous_hypothesis": session_state.get(
"previous_hypothesis", ""
),
"latest_hypothesis": session_state.get(
"latest_hypothesis", ""
),
"finalized_text": session_state.get("committed_text", ""),
"pending_text": session_state.get("mutable_text", ""),
"transcript_text": get_live_stt_transcript_text(session_state),
}
),
200,
)
except Exception as err:
current_app.logger.error(
f"Error transcribing live audio chunk: {err}", exc_info=True
)
return make_response(
jsonify({"success": False, "message": "Failed to transcribe audio"}),
400,
)
finally:
if temp_path and temp_path.exists():
temp_path.unlink()
@attachments_ns.route("/stt/live/finish")
class LiveSpeechToTextFinish(Resource):
@api.doc(description="Finish a live speech-to-text session")
def post(self):
auth_user = _resolve_authenticated_user()
if hasattr(auth_user, "status_code"):
return auth_user
if not auth_user:
return make_response(
jsonify({"success": False, "message": "Authentication required"}),
401,
)
redis_client = _require_live_stt_redis()
if hasattr(redis_client, "status_code"):
return redis_client
payload = request.get_json(silent=True) or {}
session_id = str(payload.get("session_id", "")).strip()
if not session_id:
return make_response(
jsonify({"success": False, "message": "Missing session_id"}),
400,
)
session_state = load_live_stt_session(redis_client, session_id)
if not session_state:
return make_response(
jsonify(
{
"success": False,
"message": "Live transcription session not found",
}
),
404,
)
if safe_filename(str(session_state.get("user", ""))) != auth_user:
return make_response(
jsonify({"success": False, "message": "Forbidden"}),
403,
)
final_text = finalize_live_stt_session(session_state)
delete_live_stt_session(redis_client, session_id)
return make_response(
jsonify(
{
"success": True,
"session_id": session_id,
"language": session_state.get("language"),
"text": final_text,
}
),
200,
)
@attachments_ns.route("/images/<path:image_path>")
class ServeImage(Resource):
@api.doc(description="Serve an image from storage")
def get(self, image_path):
if ".." in image_path or image_path.startswith("/") or "\x00" in image_path:
return make_response(
jsonify({"success": False, "message": "Invalid image path"}), 400
)
try:
from application.api.user.base import storage
file_obj = storage.get_file(image_path)
extension = image_path.split(".")[-1].lower()
content_type = f"image/{extension}"
@@ -633,10 +154,6 @@ class ServeImage(Resource):
return make_response(
jsonify({"success": False, "message": "Image not found"}), 404
)
except ValueError:
return make_response(
jsonify({"success": False, "message": "Invalid image path"}), 400
)
except Exception as e:
current_app.logger.error(f"Error serving image: {e}")
return make_response(

View File

@@ -145,22 +145,14 @@ def resolve_tool_details(tool_ids):
Returns:
List of tool details with id, name, and display_name
"""
valid_ids = []
for tid in tool_ids:
try:
valid_ids.append(ObjectId(tid))
except Exception:
continue
tools = user_tools_collection.find(
{"_id": {"$in": valid_ids}}
) if valid_ids else []
{"_id": {"$in": [ObjectId(tid) for tid in tool_ids]}}
)
return [
{
"id": str(tool["_id"]),
"name": tool.get("name", ""),
"display_name": tool.get("customName")
or tool.get("displayName")
or tool.get("name", ""),
"display_name": tool.get("displayName", tool.get("name", "")),
}
for tool in tools
]

View File

@@ -57,7 +57,7 @@ class ShareConversation(Resource):
try:
conversation = conversations_collection.find_one(
{"_id": ObjectId(conversation_id), "user": user}
{"_id": ObjectId(conversation_id)}
)
if conversation is None:
return make_response(

View File

@@ -14,14 +14,7 @@ from application.api.user.base import sources_collection
from application.api.user.tasks import ingest, ingest_connector_task, ingest_remote
from application.core.settings import settings
from application.parser.connectors.connector_creator import ConnectorCreator
from application.parser.file.constants import SUPPORTED_SOURCE_EXTENSIONS
from application.storage.storage_creator import StorageCreator
from application.stt.upload_limits import (
AudioFileTooLargeError,
build_stt_file_size_limit_message,
enforce_audio_file_size_limit,
is_audio_filename,
)
from application.utils import check_required_fields, safe_filename
@@ -30,12 +23,6 @@ sources_upload_ns = Namespace(
)
def _enforce_audio_path_size_limit(file_path: str, filename: str) -> None:
if not is_audio_filename(filename):
return
enforce_audio_file_size_limit(os.path.getsize(file_path))
@sources_upload_ns.route("/upload")
class UploadFile(Resource):
@api.expect(
@@ -91,7 +78,6 @@ class UploadFile(Resource):
with tempfile.TemporaryDirectory() as temp_dir:
temp_file_path = os.path.join(temp_dir, safe_file)
file.save(temp_file_path)
_enforce_audio_path_size_limit(temp_file_path, safe_file)
# Only extract actual .zip files, not Office formats (.docx, .xlsx, .pptx)
# which are technically zip archives but should be processed as-is
@@ -116,10 +102,6 @@ class UploadFile(Resource):
os.path.join(root, extracted_file), temp_dir
)
storage_path = f"{base_path}/{rel_path}"
_enforce_audio_path_size_limit(
os.path.join(root, extracted_file),
extracted_file,
)
with open(
os.path.join(root, extracted_file), "rb"
@@ -142,23 +124,29 @@ class UploadFile(Resource):
storage.save_file(f, file_path)
task = ingest.delay(
settings.UPLOAD_FOLDER,
list(SUPPORTED_SOURCE_EXTENSIONS),
[
".rst",
".md",
".pdf",
".txt",
".docx",
".csv",
".epub",
".html",
".mdx",
".json",
".xlsx",
".pptx",
".png",
".jpg",
".jpeg",
],
job_name,
user,
file_path=base_path,
filename=dir_name,
file_name_map=file_name_map,
)
except AudioFileTooLargeError:
return make_response(
jsonify(
{
"success": False,
"message": build_stt_file_size_limit_message(),
}
),
413,
)
except Exception as err:
current_app.logger.error(f"Error uploading file: {err}", exc_info=True)
return make_response(jsonify({"success": False}), 400)
@@ -463,16 +451,6 @@ class ManageSourceFiles(Resource):
removed_files = []
map_updated = False
for file_path in file_paths:
if ".." in str(file_path) or str(file_path).startswith("/"):
return make_response(
jsonify(
{
"success": False,
"message": "Invalid file path",
}
),
400,
)
full_path = f"{source_file_path}/{file_path}"
# Remove from storage

View File

@@ -1,83 +1,21 @@
"""Tool management MCP server integration."""
import json
from urllib.parse import urlencode, urlparse
from urllib.parse import unquote, urlencode
from bson.objectid import ObjectId
from flask import current_app, jsonify, make_response, redirect, request
from flask_restx import Namespace, Resource, fields
from flask_restx import fields, Namespace, Resource
from application.agents.tools.mcp_tool import MCPOAuthManager, MCPTool
from application.api import api
from application.api.user.base import user_tools_collection
from application.api.user.tools.routes import transform_actions
from application.cache import get_redis_instance
from application.core.mongo_db import MongoDB
from application.core.settings import settings
from application.core.url_validation import SSRFError, validate_url
from application.security.encryption import decrypt_credentials, encrypt_credentials
from application.security.encryption import encrypt_credentials
from application.utils import check_required_fields
tools_mcp_ns = Namespace("tools", description="Tool management operations", path="/api")
_mongo = MongoDB.get_client()
_db = _mongo[settings.MONGO_DB_NAME]
_connector_sessions = _db["connector_sessions"]
_ALLOWED_TRANSPORTS = {"auto", "sse", "http"}
def _sanitize_mcp_transport(config):
"""Normalise and validate the transport_type field.
Strips ``command`` / ``args`` keys that are only valid for local STDIO
transports and returns the cleaned transport type string.
"""
transport_type = (config.get("transport_type") or "auto").lower()
if transport_type not in _ALLOWED_TRANSPORTS:
raise ValueError(f"Unsupported transport_type: {transport_type}")
config.pop("command", None)
config.pop("args", None)
config["transport_type"] = transport_type
return transport_type
def _extract_auth_credentials(config):
"""Build an ``auth_credentials`` dict from the raw MCP config."""
auth_credentials = {}
auth_type = config.get("auth_type", "none")
if auth_type == "api_key":
if config.get("api_key"):
auth_credentials["api_key"] = config["api_key"]
if config.get("api_key_header"):
auth_credentials["api_key_header"] = config["api_key_header"]
elif auth_type == "bearer":
if config.get("bearer_token"):
auth_credentials["bearer_token"] = config["bearer_token"]
elif auth_type == "basic":
if config.get("username"):
auth_credentials["username"] = config["username"]
if config.get("password"):
auth_credentials["password"] = config["password"]
return auth_credentials
def _validate_mcp_server_url(config: dict) -> None:
"""Validate the server_url in an MCP config to prevent SSRF.
Raises:
ValueError: If the URL is missing or points to a blocked address.
"""
server_url = (config.get("server_url") or "").strip()
if not server_url:
raise ValueError("server_url is required")
try:
validate_url(server_url)
except SSRFError as exc:
raise ValueError(f"Invalid server URL: {exc}") from exc
@tools_mcp_ns.route("/mcp_server/test")
class TestMCPServerConfig(Resource):
@@ -105,63 +43,49 @@ class TestMCPServerConfig(Resource):
return missing_fields
try:
config = data["config"]
try:
_sanitize_mcp_transport(config)
except ValueError:
transport_type = (config.get("transport_type") or "auto").lower()
allowed_transports = {"auto", "sse", "http"}
if transport_type not in allowed_transports:
return make_response(
jsonify({"success": False, "error": "Unsupported transport_type"}),
400,
)
config.pop("command", None)
config.pop("args", None)
config["transport_type"] = transport_type
_validate_mcp_server_url(config)
auth_credentials = {}
auth_type = config.get("auth_type", "none")
auth_credentials = _extract_auth_credentials(config)
if auth_type == "api_key" and "api_key" in config:
auth_credentials["api_key"] = config["api_key"]
if "api_key_header" in config:
auth_credentials["api_key_header"] = config["api_key_header"]
elif auth_type == "bearer" and "bearer_token" in config:
auth_credentials["bearer_token"] = config["bearer_token"]
elif auth_type == "basic":
if "username" in config:
auth_credentials["username"] = config["username"]
if "password" in config:
auth_credentials["password"] = config["password"]
test_config = config.copy()
test_config["auth_credentials"] = auth_credentials
mcp_tool = MCPTool(config=test_config, user_id=user)
result = mcp_tool.test_connection()
if result.get("requires_oauth"):
safe_result = {
k: v
for k, v in result.items()
if k in ("success", "requires_oauth", "auth_url")
}
return make_response(jsonify(safe_result), 200)
# Sanitize the response to avoid exposing internal error details
if not result.get("success") and "message" in result:
current_app.logger.error(f"MCP connection test failed: {result.get('message')}")
result["message"] = "Connection test failed"
if not result.get("success"):
current_app.logger.error(
f"MCP connection test failed: {result.get('message')}"
)
return make_response(
jsonify(
{
"success": False,
"message": "Connection test failed",
"tools_count": 0,
}
),
200,
)
safe_result = {
"success": True,
"message": result.get("message", "Connection successful"),
"tools_count": result.get("tools_count", 0),
"tools": result.get("tools", []),
}
return make_response(jsonify(safe_result), 200)
except ValueError as e:
current_app.logger.warning(f"Invalid MCP server test request: {e}")
return make_response(
jsonify({"success": False, "error": "Invalid MCP server configuration"}),
400,
)
return make_response(jsonify(result), 200)
except Exception as e:
current_app.logger.error(f"Error testing MCP server: {e}", exc_info=True)
return make_response(
jsonify({"success": False, "error": "Connection test failed"}),
jsonify(
{"success": False, "error": "Connection test failed"}
),
500,
)
@@ -201,18 +125,32 @@ class MCPServerSave(Resource):
return missing_fields
try:
config = data["config"]
try:
_sanitize_mcp_transport(config)
except ValueError:
transport_type = (config.get("transport_type") or "auto").lower()
allowed_transports = {"auto", "sse", "http"}
if transport_type not in allowed_transports:
return make_response(
jsonify({"success": False, "error": "Unsupported transport_type"}),
400,
)
config.pop("command", None)
config.pop("args", None)
config["transport_type"] = transport_type
_validate_mcp_server_url(config)
auth_credentials = _extract_auth_credentials(config)
auth_credentials = {}
auth_type = config.get("auth_type", "none")
if auth_type == "api_key":
if "api_key" in config and config["api_key"]:
auth_credentials["api_key"] = config["api_key"]
if "api_key_header" in config:
auth_credentials["api_key_header"] = config["api_key_header"]
elif auth_type == "bearer":
if "bearer_token" in config and config["bearer_token"]:
auth_credentials["bearer_token"] = config["bearer_token"]
elif auth_type == "basic":
if "username" in config and config["username"]:
auth_credentials["username"] = config["username"]
if "password" in config and config["password"]:
auth_credentials["password"] = config["password"]
mcp_config = config.copy()
mcp_config["auth_credentials"] = auth_credentials
@@ -250,39 +188,30 @@ class MCPServerSave(Resource):
"No valid credentials provided for the selected authentication type"
)
storage_config = config.copy()
tool_id = data.get("id")
existing_encrypted = None
if tool_id:
existing_doc = user_tools_collection.find_one(
{"_id": ObjectId(tool_id), "user": user, "name": "mcp_tool"}
)
if existing_doc:
existing_encrypted = existing_doc.get("config", {}).get(
"encrypted_credentials"
)
if auth_credentials:
if existing_encrypted:
existing_secrets = decrypt_credentials(existing_encrypted, user)
existing_secrets.update(auth_credentials)
auth_credentials = existing_secrets
storage_config["encrypted_credentials"] = encrypt_credentials(
encrypted_credentials_string = encrypt_credentials(
auth_credentials, user
)
elif existing_encrypted:
storage_config["encrypted_credentials"] = existing_encrypted
storage_config["encrypted_credentials"] = encrypted_credentials_string
for field in [
"api_key",
"bearer_token",
"username",
"password",
"api_key_header",
"redirect_uri",
]:
storage_config.pop(field, None)
transformed_actions = transform_actions(actions_metadata)
transformed_actions = []
for action in actions_metadata:
action["active"] = True
if "parameters" in action:
if "properties" in action["parameters"]:
for param_name, param_details in action["parameters"][
"properties"
].items():
param_details["filled_by_llm"] = True
param_details["value"] = ""
transformed_actions.append(action)
tool_data = {
"name": "mcp_tool",
"displayName": data["displayName"],
@@ -294,6 +223,7 @@ class MCPServerSave(Resource):
"user": user,
}
tool_id = data.get("id")
if tool_id:
result = user_tools_collection.update_one(
{"_id": ObjectId(tool_id), "user": user, "name": "mcp_tool"},
@@ -325,16 +255,12 @@ class MCPServerSave(Resource):
"tools_count": len(transformed_actions),
}
return make_response(jsonify(response_data), 200)
except ValueError as e:
current_app.logger.warning(f"Invalid MCP server save request: {e}")
return make_response(
jsonify({"success": False, "error": "Invalid MCP server configuration"}),
400,
)
except Exception as e:
current_app.logger.error(f"Error saving MCP server: {e}", exc_info=True)
return make_response(
jsonify({"success": False, "error": "Failed to save MCP server"}),
jsonify(
{"success": False, "error": "Failed to save MCP server"}
),
500,
)
@@ -365,7 +291,7 @@ class MCPOAuthCallback(Resource):
params = {
"status": "error",
"message": f"OAuth error: {error}. Please try again and make sure to grant all requested permissions, including offline access.",
"provider": "mcp_tool",
"provider": "mcp_tool"
}
return redirect(f"/api/connectors/callback-status?{urlencode(params)}")
if not code or not state:
@@ -378,6 +304,7 @@ class MCPOAuthCallback(Resource):
return redirect(
"/api/connectors/callback-status?status=error&message=Internal+server+error:+Redis+not+available.&provider=mcp_tool"
)
code = unquote(code)
manager = MCPOAuthManager(redis_client)
success = manager.handle_oauth_callback(state, code, error)
if success:
@@ -400,6 +327,10 @@ class MCPOAuthCallback(Resource):
@tools_mcp_ns.route("/mcp_server/oauth_status/<string:task_id>")
class MCPOAuthStatus(Resource):
def get(self, task_id):
"""
Get current status of OAuth flow.
Frontend should poll this endpoint periodically.
"""
try:
redis_client = get_redis_instance()
status_key = f"mcp_oauth_status:{task_id}"
@@ -407,14 +338,6 @@ class MCPOAuthStatus(Resource):
if status_data:
status = json.loads(status_data)
if "tools" in status and isinstance(status["tools"], list):
status["tools"] = [
{
"name": t.get("name", "unknown"),
"description": t.get("description", ""),
}
for t in status["tools"]
]
return make_response(
jsonify({"success": True, "task_id": task_id, **status})
)
@@ -422,93 +345,17 @@ class MCPOAuthStatus(Resource):
return make_response(
jsonify(
{
"success": True,
"success": False,
"error": "Task not found or expired",
"task_id": task_id,
"status": "pending",
"message": "Waiting for OAuth to start...",
}
),
200,
404,
)
except Exception as e:
current_app.logger.error(
f"Error getting OAuth status for task {task_id}: {str(e)}",
exc_info=True,
f"Error getting OAuth status for task {task_id}: {str(e)}", exc_info=True
)
return make_response(
jsonify(
{
"success": False,
"error": "Failed to get OAuth status",
"task_id": task_id,
}
),
500,
)
@tools_mcp_ns.route("/mcp_server/auth_status")
class MCPAuthStatus(Resource):
@api.doc(
description="Batch check auth status for all MCP tools. "
"Lightweight DB-only check — no network calls to MCP servers."
)
def get(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
try:
mcp_tools = list(
user_tools_collection.find(
{"user": user, "name": "mcp_tool"},
{"_id": 1, "config": 1},
)
)
if not mcp_tools:
return make_response(jsonify({"success": True, "statuses": {}}), 200)
oauth_server_urls = {}
statuses = {}
for tool in mcp_tools:
tool_id = str(tool["_id"])
config = tool.get("config", {})
auth_type = config.get("auth_type", "none")
if auth_type == "oauth":
server_url = config.get("server_url", "")
if server_url:
parsed = urlparse(server_url)
base_url = f"{parsed.scheme}://{parsed.netloc}"
oauth_server_urls[tool_id] = base_url
else:
statuses[tool_id] = "needs_auth"
else:
statuses[tool_id] = "configured"
if oauth_server_urls:
unique_urls = list(set(oauth_server_urls.values()))
sessions = list(
_connector_sessions.find(
{"user_id": user, "server_url": {"$in": unique_urls}},
{"server_url": 1, "tokens": 1},
)
)
url_has_tokens = {
doc["server_url"]: bool(doc.get("tokens", {}).get("access_token"))
for doc in sessions
}
for tool_id, base_url in oauth_server_urls.items():
if url_has_tokens.get(base_url):
statuses[tool_id] = "connected"
else:
statuses[tool_id] = "needs_auth"
return make_response(jsonify({"success": True, "statuses": statuses}), 200)
except Exception as e:
current_app.logger.error(
"Error checking MCP auth status: %s", e, exc_info=True
)
return make_response(
jsonify({"success": False, "error": "Failed to check auth status"}),
500,
jsonify({"success": False, "error": "Failed to get OAuth status", "task_id": task_id}), 500
)

View File

@@ -8,7 +8,6 @@ from application.agents.tools.spec_parser import parse_spec
from application.agents.tools.tool_manager import ToolManager
from application.api import api
from application.api.user.base import user_tools_collection
from application.core.url_validation import SSRFError, validate_url
from application.security.encryption import decrypt_credentials, encrypt_credentials
from application.utils import check_required_fields, validate_function_name
@@ -16,114 +15,6 @@ tool_config = {}
tool_manager = ToolManager(config=tool_config)
def _encrypt_secret_fields(config, config_requirements, user_id):
secret_keys = [
key for key, spec in config_requirements.items()
if spec.get("secret") and key in config and config[key]
]
if not secret_keys:
return config
storage_config = config.copy()
secret_values = {k: config[k] for k in secret_keys}
storage_config["encrypted_credentials"] = encrypt_credentials(secret_values, user_id)
for key in secret_keys:
storage_config.pop(key, None)
return storage_config
def _validate_config(config, config_requirements, has_existing_secrets=False):
errors = {}
for key, spec in config_requirements.items():
depends_on = spec.get("depends_on")
if depends_on:
if not all(config.get(dk) == dv for dk, dv in depends_on.items()):
continue
if spec.get("required") and not config.get(key):
if has_existing_secrets and spec.get("secret"):
continue
errors[key] = f"{spec.get('label', key)} is required"
value = config.get(key)
if value is not None and value != "":
if spec.get("type") == "number":
try:
num = float(value)
if key == "timeout" and (num < 1 or num > 300):
errors[key] = "Timeout must be between 1 and 300"
except (ValueError, TypeError):
errors[key] = f"{spec.get('label', key)} must be a number"
if spec.get("enum") and value not in spec["enum"]:
errors[key] = f"Invalid value for {spec.get('label', key)}"
return errors
def _merge_secrets_on_update(new_config, existing_config, config_requirements, user_id):
"""Merge incoming config with existing encrypted secrets and re-encrypt.
For updates, the client may omit unchanged secret values. This helper
decrypts any previously stored secrets, overlays whatever the client *did*
send, strips plain-text secrets from the stored config, and re-encrypts
the merged result.
Returns the final ``config`` dict ready for persistence.
"""
secret_keys = [
key for key, spec in config_requirements.items()
if spec.get("secret")
]
if not secret_keys:
return new_config
existing_secrets = {}
if "encrypted_credentials" in existing_config:
existing_secrets = decrypt_credentials(
existing_config["encrypted_credentials"], user_id
)
merged_secrets = existing_secrets.copy()
for key in secret_keys:
if key in new_config and new_config[key]:
merged_secrets[key] = new_config[key]
# Start from existing non-secret values, then overlay incoming non-secrets
storage_config = {
k: v for k, v in existing_config.items()
if k not in secret_keys and k != "encrypted_credentials"
}
storage_config.update(
{k: v for k, v in new_config.items() if k not in secret_keys}
)
if merged_secrets:
storage_config["encrypted_credentials"] = encrypt_credentials(
merged_secrets, user_id
)
else:
storage_config.pop("encrypted_credentials", None)
storage_config.pop("has_encrypted_credentials", None)
return storage_config
def transform_actions(actions_metadata):
"""Set default flags on action metadata for storage.
Marks each action as active, sets ``filled_by_llm`` and ``value`` on every
parameter property. Used by both the generic create_tool and MCP save routes.
"""
transformed = []
for action in actions_metadata:
action["active"] = True
if "parameters" in action:
props = action["parameters"].get("properties", {})
for param_details in props.values():
param_details["filled_by_llm"] = True
param_details["value"] = ""
transformed.append(action)
return transformed
tools_ns = Namespace("tools", description="Tool management operations", path="/api")
@@ -131,8 +22,6 @@ tools_ns = Namespace("tools", description="Tool management operations", path="/a
class AvailableTools(Resource):
@api.doc(description="Get available tools for a user")
def get(self):
if not request.decoded_token:
return make_response(jsonify({"success": False}), 401)
try:
tools_metadata = []
for tool_name, tool_instance in tool_manager.tools.items():
@@ -140,15 +29,12 @@ class AvailableTools(Resource):
lines = doc.split("\n", 1)
name = lines[0].strip()
description = lines[1].strip() if len(lines) > 1 else ""
config_req = tool_instance.get_config_requirements()
actions = tool_instance.get_actions_metadata()
tools_metadata.append(
{
"name": tool_name,
"displayName": name,
"description": description,
"configRequirements": config_req,
"actions": actions,
"configRequirements": tool_instance.get_config_requirements(),
}
)
except Exception as err:
@@ -174,21 +60,6 @@ class GetTools(Resource):
tool_copy = {**tool}
tool_copy["id"] = str(tool["_id"])
tool_copy.pop("_id", None)
config_req = tool_copy.get("configRequirements", {})
if not config_req:
tool_instance = tool_manager.tools.get(tool_copy.get("name"))
if tool_instance:
config_req = tool_instance.get_config_requirements()
tool_copy["configRequirements"] = config_req
has_secrets = any(
spec.get("secret") for spec in config_req.values()
) if config_req else False
if has_secrets and "encrypted_credentials" in tool_copy.get("config", {}):
tool_copy["config"]["has_encrypted_credentials"] = True
tool_copy["config"].pop("encrypted_credentials", None)
user_tools.append(tool_copy)
except Exception as err:
current_app.logger.error(f"Error getting user tools: {err}", exc_info=True)
@@ -239,48 +110,29 @@ class CreateTool(Resource):
if missing_fields:
return missing_fields
try:
if data["name"] == "mcp_tool":
server_url = (data.get("config", {}).get("server_url") or "").strip()
if server_url:
try:
validate_url(server_url)
except SSRFError:
return make_response(
jsonify({"success": False, "message": "Invalid server URL"}),
400,
)
tool_instance = tool_manager.tools.get(data["name"])
if not tool_instance:
return make_response(
jsonify({"success": False, "message": "Tool not found"}), 404
)
actions_metadata = tool_instance.get_actions_metadata()
transformed_actions = transform_actions(actions_metadata)
transformed_actions = []
for action in actions_metadata:
action["active"] = True
if "parameters" in action:
if "properties" in action["parameters"]:
for param_name, param_details in action["parameters"][
"properties"
].items():
param_details["filled_by_llm"] = True
param_details["value"] = ""
transformed_actions.append(action)
except Exception as err:
current_app.logger.error(
f"Error getting tool actions: {err}", exc_info=True
)
return make_response(jsonify({"success": False}), 400)
try:
config_requirements = tool_instance.get_config_requirements()
if config_requirements:
validation_errors = _validate_config(
data["config"], config_requirements
)
if validation_errors:
return make_response(
jsonify(
{
"success": False,
"message": "Validation failed",
"errors": validation_errors,
}
),
400,
)
storage_config = _encrypt_secret_fields(
data["config"], config_requirements, user
)
new_tool = {
"user": user,
"name": data["name"],
@@ -288,8 +140,7 @@ class CreateTool(Resource):
"description": data["description"],
"customName": data.get("customName", ""),
"actions": transformed_actions,
"config": storage_config,
"configRequirements": config_requirements,
"config": data["config"],
"status": data["status"],
}
resp = user_tools_collection.insert_one(new_tool)
@@ -359,37 +210,57 @@ class UpdateTool(Resource):
tool_doc = user_tools_collection.find_one(
{"_id": ObjectId(data["id"]), "user": user}
)
if not tool_doc:
return make_response(
jsonify({"success": False, "message": "Tool not found"}),
404,
)
tool_name = tool_doc.get("name", data.get("name"))
tool_instance = tool_manager.tools.get(tool_name)
config_requirements = (
tool_instance.get_config_requirements() if tool_instance else {}
)
existing_config = tool_doc.get("config", {})
has_existing_secrets = "encrypted_credentials" in existing_config
if tool_doc and tool_doc.get("name") == "mcp_tool":
config = data["config"]
existing_config = tool_doc.get("config", {})
storage_config = existing_config.copy()
if config_requirements:
validation_errors = _validate_config(
data["config"], config_requirements,
has_existing_secrets=has_existing_secrets,
)
if validation_errors:
return make_response(
jsonify({
"success": False,
"message": "Validation failed",
"errors": validation_errors,
}),
400,
storage_config.update(config)
existing_credentials = {}
if "encrypted_credentials" in existing_config:
existing_credentials = decrypt_credentials(
existing_config["encrypted_credentials"], user
)
update_data["config"] = _merge_secrets_on_update(
data["config"], existing_config, config_requirements, user
)
auth_credentials = existing_credentials.copy()
auth_type = storage_config.get("auth_type", "none")
if auth_type == "api_key":
if "api_key" in config and config["api_key"]:
auth_credentials["api_key"] = config["api_key"]
if "api_key_header" in config:
auth_credentials["api_key_header"] = config[
"api_key_header"
]
elif auth_type == "bearer":
if "bearer_token" in config and config["bearer_token"]:
auth_credentials["bearer_token"] = config["bearer_token"]
elif "encrypted_token" in config and config["encrypted_token"]:
auth_credentials["bearer_token"] = config["encrypted_token"]
elif auth_type == "basic":
if "username" in config and config["username"]:
auth_credentials["username"] = config["username"]
if "password" in config and config["password"]:
auth_credentials["password"] = config["password"]
if auth_type != "none" and auth_credentials:
encrypted_credentials_string = encrypt_credentials(
auth_credentials, user
)
storage_config["encrypted_credentials"] = (
encrypted_credentials_string
)
elif auth_type == "none":
storage_config.pop("encrypted_credentials", None)
for field in [
"api_key",
"bearer_token",
"encrypted_token",
"username",
"password",
"api_key_header",
]:
storage_config.pop(field, None)
update_data["config"] = storage_config
else:
update_data["config"] = data["config"]
if "status" in data:
update_data["status"] = data["status"]
user_tools_collection.update_one(
@@ -427,52 +298,9 @@ class UpdateToolConfig(Resource):
if missing_fields:
return missing_fields
try:
tool_doc = user_tools_collection.find_one(
{"_id": ObjectId(data["id"]), "user": user}
)
if not tool_doc:
return make_response(jsonify({"success": False}), 404)
tool_name = tool_doc.get("name")
if tool_name == "mcp_tool":
server_url = (data["config"].get("server_url") or "").strip()
if server_url:
try:
validate_url(server_url)
except SSRFError:
return make_response(
jsonify({"success": False, "message": "Invalid server URL"}),
400,
)
tool_instance = tool_manager.tools.get(tool_name)
config_requirements = (
tool_instance.get_config_requirements() if tool_instance else {}
)
existing_config = tool_doc.get("config", {})
has_existing_secrets = "encrypted_credentials" in existing_config
if config_requirements:
validation_errors = _validate_config(
data["config"], config_requirements,
has_existing_secrets=has_existing_secrets,
)
if validation_errors:
return make_response(
jsonify({
"success": False,
"message": "Validation failed",
"errors": validation_errors,
}),
400,
)
final_config = _merge_secrets_on_update(
data["config"], existing_config, config_requirements, user
)
user_tools_collection.update_one(
{"_id": ObjectId(data["id"]), "user": user},
{"$set": {"config": final_config}},
{"$set": {"config": data["config"]}},
)
except Exception as err:
current_app.logger.error(
@@ -582,13 +410,11 @@ class DeleteTool(Resource):
{"_id": ObjectId(data["id"]), "user": user}
)
if result.deleted_count == 0:
return make_response(
jsonify({"success": False, "message": "Tool not found"}), 404
)
return {"success": False, "message": "Tool not found"}, 404
except Exception as err:
current_app.logger.error(f"Error deleting tool: {err}", exc_info=True)
return make_response(jsonify({"success": False}), 400)
return make_response(jsonify({"success": True}), 200)
return {"success": False}, 400
return {"success": True}, 200
@tools_ns.route("/parse_spec")
@@ -685,6 +511,7 @@ class GetArtifact(Resource):
todo_doc = db["todos"].find_one({"_id": obj_id, "user_id": user_id})
if todo_doc:
tool_id = todo_doc.get("tool_id")
# Return all todos for the tool
query = {"user_id": user_id, "tool_id": tool_id}
all_todos = list(db["todos"].find(query))
items = []

View File

@@ -5,14 +5,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple
from bson.errors import InvalidId
from bson.objectid import ObjectId
from flask import (
Response,
current_app,
has_app_context,
jsonify,
make_response,
request,
)
from flask import jsonify, make_response, request, Response
from pymongo.collection import Collection
@@ -326,10 +319,8 @@ def safe_db_operation(
try:
result = operation()
return result, None
except Exception as err:
if has_app_context():
current_app.logger.error(f"{error_message}: {err}", exc_info=True)
return None, error_response(error_message)
except Exception as e:
return None, error_response(f"{error_message}: {str(e)}")
def validate_enum(

View File

@@ -30,11 +30,6 @@ from application.api.user.utils import (
workflows_ns = Namespace("workflows", path="/api")
def _workflow_error_response(message: str, err: Exception):
current_app.logger.error(f"{message}: {err}", exc_info=True)
return error_response(message)
def serialize_workflow(w: Dict) -> Dict:
"""Serialize workflow document to API response format."""
return {
@@ -401,11 +396,11 @@ class WorkflowList(Resource):
try:
create_workflow_nodes(workflow_id, nodes_data, 1)
create_workflow_edges(workflow_id, edges_data, 1)
except Exception as err:
except Exception as e:
workflow_nodes_collection.delete_many({"workflow_id": workflow_id})
workflow_edges_collection.delete_many({"workflow_id": workflow_id})
workflows_collection.delete_one({"_id": result.inserted_id})
return _workflow_error_response("Failed to create workflow structure", err)
return error_response(f"Failed to create workflow structure: {str(e)}")
return success_response({"id": workflow_id}, 201)
@@ -475,14 +470,14 @@ class WorkflowDetail(Resource):
try:
create_workflow_nodes(workflow_id, nodes_data, next_graph_version)
create_workflow_edges(workflow_id, edges_data, next_graph_version)
except Exception as err:
except Exception as e:
workflow_nodes_collection.delete_many(
{"workflow_id": workflow_id, "graph_version": next_graph_version}
)
workflow_edges_collection.delete_many(
{"workflow_id": workflow_id, "graph_version": next_graph_version}
)
return _workflow_error_response("Failed to update workflow structure", err)
return error_response(f"Failed to update workflow structure: {str(e)}")
now = datetime.now(timezone.utc)
_, error = safe_db_operation(
@@ -540,7 +535,7 @@ class WorkflowDetail(Resource):
workflow_nodes_collection.delete_many({"workflow_id": workflow_id})
workflow_edges_collection.delete_many({"workflow_id": workflow_id})
workflows_collection.delete_one({"_id": workflow["_id"], "user": user_id})
except Exception as err:
return _workflow_error_response("Failed to delete workflow", err)
except Exception as e:
return error_response(f"Failed to delete workflow: {str(e)}")
return success_response()

View File

@@ -1,3 +0,0 @@
from application.api.v1.routes import v1_bp
__all__ = ["v1_bp"]

View File

@@ -1,333 +0,0 @@
"""Standard chat completions API routes.
Exposes ``/v1/chat/completions`` and ``/v1/models`` endpoints that
follow the widely-adopted chat completions protocol so external tools
(opencode, continue, etc.) can connect to DocsGPT agents.
"""
import json
import logging
import time
import traceback
from typing import Any, Dict, Generator, Optional
from flask import Blueprint, jsonify, make_response, request, Response
from application.api.answer.routes.base import BaseAnswerResource
from application.api.answer.services.stream_processor import StreamProcessor
from application.api.v1.translator import (
translate_request,
translate_response,
translate_stream_event,
)
from application.core.mongo_db import MongoDB
from application.core.settings import settings
logger = logging.getLogger(__name__)
v1_bp = Blueprint("v1", __name__, url_prefix="/v1")
def _extract_bearer_token() -> Optional[str]:
"""Extract API key from Authorization: Bearer header."""
auth = request.headers.get("Authorization", "")
if auth.startswith("Bearer "):
return auth[7:].strip()
return None
def _lookup_agent(api_key: str) -> Optional[Dict]:
"""Look up the agent document for this API key."""
try:
mongo = MongoDB.get_client()
db = mongo[settings.MONGO_DB_NAME]
return db["agents"].find_one({"key": api_key})
except Exception:
logger.warning("Failed to look up agent for API key", exc_info=True)
return None
def _get_model_name(agent: Optional[Dict], api_key: str) -> str:
"""Return agent name for display as model name."""
if agent:
return agent.get("name", api_key)
return api_key
class _V1AnswerHelper(BaseAnswerResource):
"""Thin wrapper to access complete_stream / process_response_stream."""
pass
@v1_bp.route("/chat/completions", methods=["POST"])
def chat_completions():
"""Handle POST /v1/chat/completions."""
api_key = _extract_bearer_token()
if not api_key:
return make_response(
jsonify({"error": {"message": "Missing Authorization header", "type": "auth_error"}}),
401,
)
data = request.get_json()
if not data or not data.get("messages"):
return make_response(
jsonify({"error": {"message": "messages field is required", "type": "invalid_request"}}),
400,
)
is_stream = data.get("stream", False)
agent_doc = _lookup_agent(api_key)
model_name = _get_model_name(agent_doc, api_key)
try:
internal_data = translate_request(data, api_key)
except Exception as e:
logger.error(f"/v1/chat/completions translate error: {e}", exc_info=True)
return make_response(
jsonify({"error": {"message": "Failed to process request", "type": "invalid_request"}}),
400,
)
# Link decoded_token to the agent's owner so continuation state,
# logs, and tool execution use the correct user identity.
agent_user = agent_doc.get("user") if agent_doc else None
decoded_token = {"sub": agent_user or "api_key_user"}
try:
processor = StreamProcessor(internal_data, decoded_token)
if internal_data.get("tool_actions"):
# Continuation mode
conversation_id = internal_data.get("conversation_id")
if not conversation_id:
return make_response(
jsonify({"error": {"message": "conversation_id required for tool continuation", "type": "invalid_request"}}),
400,
)
(
agent,
messages,
tools_dict,
pending_tool_calls,
tool_actions,
) = processor.resume_from_tool_actions(
internal_data["tool_actions"], conversation_id
)
continuation = {
"messages": messages,
"tools_dict": tools_dict,
"pending_tool_calls": pending_tool_calls,
"tool_actions": tool_actions,
}
question = ""
else:
# Normal mode
question = internal_data.get("question", "")
agent = processor.build_agent(question)
continuation = None
if not processor.decoded_token:
return make_response(
jsonify({"error": {"message": "Unauthorized", "type": "auth_error"}}),
401,
)
helper = _V1AnswerHelper()
usage_error = helper.check_usage(processor.agent_config)
if usage_error:
return usage_error
should_save_conversation = bool(internal_data.get("save_conversation", False))
if is_stream:
return Response(
_stream_response(
helper,
question,
agent,
processor,
model_name,
continuation,
should_save_conversation,
),
mimetype="text/event-stream",
headers={
"Cache-Control": "no-cache",
"X-Accel-Buffering": "no",
},
)
else:
return _non_stream_response(
helper,
question,
agent,
processor,
model_name,
continuation,
should_save_conversation,
)
except ValueError as e:
logger.error(
f"/v1/chat/completions error: {e} - {traceback.format_exc()}",
extra={"error": str(e)},
)
return make_response(
jsonify({"error": {"message": "Failed to process request", "type": "invalid_request"}}),
400,
)
except Exception as e:
logger.error(
f"/v1/chat/completions error: {e} - {traceback.format_exc()}",
extra={"error": str(e)},
)
return make_response(
jsonify({"error": {"message": "Internal server error", "type": "server_error"}}),
500,
)
def _stream_response(
helper: _V1AnswerHelper,
question: str,
agent: Any,
processor: StreamProcessor,
model_name: str,
continuation: Optional[Dict],
should_save_conversation: bool,
) -> Generator[str, None, None]:
"""Generate translated SSE chunks for streaming response."""
completion_id = f"chatcmpl-{int(time.time())}"
internal_stream = helper.complete_stream(
question=question,
agent=agent,
conversation_id=processor.conversation_id,
user_api_key=processor.agent_config.get("user_api_key"),
decoded_token=processor.decoded_token,
agent_id=processor.agent_id,
model_id=processor.model_id,
should_save_conversation=should_save_conversation,
_continuation=continuation,
)
for line in internal_stream:
if not line.strip():
continue
# Parse the internal SSE event
event_str = line.replace("data: ", "").strip()
try:
event_data = json.loads(event_str)
except (json.JSONDecodeError, TypeError):
continue
# Update completion_id when we get the conversation id
if event_data.get("type") == "id":
conv_id = event_data.get("id", "")
if conv_id:
completion_id = f"chatcmpl-{conv_id}"
# Translate to standard format
translated = translate_stream_event(event_data, completion_id, model_name)
for chunk in translated:
yield chunk
def _non_stream_response(
helper: _V1AnswerHelper,
question: str,
agent: Any,
processor: StreamProcessor,
model_name: str,
continuation: Optional[Dict],
should_save_conversation: bool,
) -> Response:
"""Collect full response and return as single JSON."""
stream = helper.complete_stream(
question=question,
agent=agent,
conversation_id=processor.conversation_id,
user_api_key=processor.agent_config.get("user_api_key"),
decoded_token=processor.decoded_token,
agent_id=processor.agent_id,
model_id=processor.model_id,
should_save_conversation=should_save_conversation,
_continuation=continuation,
)
result = helper.process_response_stream(stream)
if result["error"]:
return make_response(
jsonify({"error": {"message": result["error"], "type": "server_error"}}),
500,
)
extra = result.get("extra")
pending = extra.get("pending_tool_calls") if isinstance(extra, dict) else None
response = translate_response(
conversation_id=result["conversation_id"],
answer=result["answer"] or "",
sources=result["sources"],
tool_calls=result["tool_calls"],
thought=result["thought"] or "",
model_name=model_name,
pending_tool_calls=pending,
)
return make_response(jsonify(response), 200)
@v1_bp.route("/models", methods=["GET"])
def list_models():
"""Handle GET /v1/models — return agents as models."""
api_key = _extract_bearer_token()
if not api_key:
return make_response(
jsonify({"error": {"message": "Missing Authorization header", "type": "auth_error"}}),
401,
)
try:
mongo = MongoDB.get_client()
db = mongo[settings.MONGO_DB_NAME]
agents_collection = db["agents"]
# Find the agent for this api_key
agent = agents_collection.find_one({"key": api_key})
if not agent:
return make_response(
jsonify({"error": {"message": "Invalid API key", "type": "auth_error"}}),
401,
)
user = agent.get("user")
# Return all agents belonging to this user
user_agents = list(agents_collection.find({"user": user}))
models = []
for ag in user_agents:
created = ag.get("createdAt")
created_ts = int(created.timestamp()) if created else int(time.time())
model_id = str(ag.get("_id") or ag.get("id") or "")
models.append({
"id": model_id,
"object": "model",
"created": created_ts,
"owned_by": "docsgpt",
"name": ag.get("name", ""),
"description": ag.get("description", ""),
})
return make_response(
jsonify({"object": "list", "data": models}),
200,
)
except Exception as e:
logger.error(f"/v1/models error: {e}", exc_info=True)
return make_response(
jsonify({"error": {"message": "Internal server error", "type": "server_error"}}),
500,
)

View File

@@ -1,433 +0,0 @@
"""Translate between standard chat completions format and DocsGPT internals.
This module handles:
- Request translation (chat completions -> DocsGPT internal format)
- Response translation (DocsGPT response -> chat completions format)
- Streaming event translation (DocsGPT SSE -> standard SSE chunks)
"""
import json
import time
from typing import Any, Dict, List, Optional
def _get_client_tool_name(tc: Dict) -> str:
"""Return the original tool name for client-facing responses.
For client-side tools the ``tool_name`` field carries the name the
client originally registered. Fall back to ``action_name`` (which
is now the clean LLM-visible name) or ``name``.
"""
return tc.get("tool_name", tc.get("action_name", tc.get("name", "")))
# ---------------------------------------------------------------------------
# Request translation
# ---------------------------------------------------------------------------
def is_continuation(messages: List[Dict]) -> bool:
"""Check if messages represent a tool-call continuation.
A continuation is detected when the last message(s) have ``role: "tool"``
immediately after an assistant message with ``tool_calls``.
"""
if not messages:
return False
# Walk backwards: if we see tool messages before hitting a non-tool, non-assistant message
# and there's an assistant message with tool_calls, it's a continuation.
i = len(messages) - 1
while i >= 0 and messages[i].get("role") == "tool":
i -= 1
if i < 0:
return False
return (
messages[i].get("role") == "assistant"
and bool(messages[i].get("tool_calls"))
)
def extract_tool_results(messages: List[Dict]) -> List[Dict]:
"""Extract tool results from trailing tool messages for continuation.
Returns a list of ``tool_actions`` dicts with ``call_id`` and ``result``.
"""
results = []
for msg in reversed(messages):
if msg.get("role") != "tool":
break
call_id = msg.get("tool_call_id", "")
content = msg.get("content", "")
if isinstance(content, str):
try:
content = json.loads(content)
except (json.JSONDecodeError, TypeError):
pass
results.append({"call_id": call_id, "result": content})
results.reverse()
return results
def extract_conversation_id(messages: List[Dict]) -> Optional[str]:
"""Try to extract conversation_id from the assistant message before tool results.
The conversation_id may be stored in a custom field on the assistant message
from a previous response cycle.
"""
for msg in reversed(messages):
if msg.get("role") == "assistant":
# Check docsgpt extension
return msg.get("docsgpt", {}).get("conversation_id")
return None
def extract_system_prompt(messages: List[Dict]) -> Optional[str]:
"""Extract the first system message content from the messages array.
Returns None if no system message is present.
"""
for msg in messages:
if msg.get("role") == "system":
return msg.get("content", "")
return None
def convert_history(messages: List[Dict]) -> List[Dict]:
"""Convert chat completions messages array to DocsGPT history format.
DocsGPT history is a list of ``{prompt, response}`` dicts.
Excludes the last user message (that becomes the ``question``).
"""
history = []
i = 0
while i < len(messages):
msg = messages[i]
if msg.get("role") == "system":
i += 1
continue
if msg.get("role") == "user":
# Look ahead for assistant response
if i + 1 < len(messages) and messages[i + 1].get("role") == "assistant":
content = messages[i + 1].get("content") or ""
history.append({
"prompt": msg.get("content", ""),
"response": content,
})
i += 2
continue
# Last user message without response — skip (it's the question)
i += 1
continue
i += 1
return history
def translate_request(
data: Dict[str, Any], api_key: str
) -> Dict[str, Any]:
"""Translate a chat completions request to DocsGPT internal format.
Args:
data: The incoming request body.
api_key: Agent API key from the Authorization header.
Returns:
Dict suitable for passing to ``StreamProcessor``.
"""
messages = data.get("messages", [])
# Check for continuation (tool results after assistant tool_calls)
if is_continuation(messages):
tool_actions = extract_tool_results(messages)
conversation_id = extract_conversation_id(messages)
if not conversation_id:
conversation_id = data.get("conversation_id")
result = {
"conversation_id": conversation_id,
"tool_actions": tool_actions,
"api_key": api_key,
}
# Carry tools forward for next iteration
if data.get("tools"):
result["client_tools"] = data["tools"]
return result
# Normal request — extract question from last user message
question = ""
for msg in reversed(messages):
if msg.get("role") == "user":
question = msg.get("content", "")
break
history = convert_history(messages)
system_prompt_override = extract_system_prompt(messages)
docsgpt = data.get("docsgpt", {})
result = {
"question": question,
"api_key": api_key,
"history": json.dumps(history),
# Conversations are NOT persisted by default on the v1 endpoint.
# Callers opt in via ``docsgpt.save_conversation: true``.
"save_conversation": bool(docsgpt.get("save_conversation", False)),
}
if system_prompt_override is not None:
result["system_prompt_override"] = system_prompt_override
# Client tools
if data.get("tools"):
result["client_tools"] = data["tools"]
# DocsGPT extensions
if docsgpt.get("attachments"):
result["attachments"] = docsgpt["attachments"]
return result
# ---------------------------------------------------------------------------
# Response translation (non-streaming)
# ---------------------------------------------------------------------------
def translate_response(
conversation_id: str,
answer: str,
sources: Optional[List[Dict]],
tool_calls: Optional[List[Dict]],
thought: str,
model_name: str,
pending_tool_calls: Optional[List[Dict]] = None,
) -> Dict[str, Any]:
"""Translate DocsGPT response to chat completions format.
Args:
conversation_id: The DocsGPT conversation ID.
answer: The assistant's text response.
sources: RAG retrieval sources.
tool_calls: Completed tool call results.
thought: Reasoning/thinking tokens.
model_name: Model/agent identifier.
pending_tool_calls: Pending client-side tool calls (if paused).
Returns:
Dict in the standard chat completions response format.
"""
created = int(time.time())
completion_id = f"chatcmpl-{conversation_id}" if conversation_id else f"chatcmpl-{created}"
# Build message
message: Dict[str, Any] = {"role": "assistant"}
if pending_tool_calls:
# Tool calls pending — return them for client execution
message["content"] = None
message["tool_calls"] = [
{
"id": tc.get("call_id", ""),
"type": "function",
"function": {
"name": _get_client_tool_name(tc),
"arguments": (
json.dumps(tc["arguments"])
if isinstance(tc.get("arguments"), dict)
else tc.get("arguments", "{}")
),
},
}
for tc in pending_tool_calls
]
finish_reason = "tool_calls"
else:
message["content"] = answer
if thought:
message["reasoning_content"] = thought
finish_reason = "stop"
result: Dict[str, Any] = {
"id": completion_id,
"object": "chat.completion",
"created": created,
"model": model_name,
"choices": [
{
"index": 0,
"message": message,
"finish_reason": finish_reason,
}
],
"usage": {
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0,
},
}
# DocsGPT extensions
docsgpt: Dict[str, Any] = {}
if conversation_id:
docsgpt["conversation_id"] = conversation_id
if sources:
docsgpt["sources"] = sources
if tool_calls:
docsgpt["tool_calls"] = tool_calls
if docsgpt:
result["docsgpt"] = docsgpt
return result
# ---------------------------------------------------------------------------
# Streaming event translation
# ---------------------------------------------------------------------------
def _make_chunk(
completion_id: str,
model_name: str,
delta: Dict[str, Any],
finish_reason: Optional[str] = None,
) -> str:
"""Build a single SSE chunk in the standard streaming format."""
chunk = {
"id": completion_id,
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": model_name,
"choices": [
{
"index": 0,
"delta": delta,
"finish_reason": finish_reason,
}
],
}
return f"data: {json.dumps(chunk)}\n\n"
def _make_docsgpt_chunk(data: Dict[str, Any]) -> str:
"""Build a DocsGPT extension SSE chunk."""
return f"data: {json.dumps({'docsgpt': data})}\n\n"
def translate_stream_event(
event_data: Dict[str, Any],
completion_id: str,
model_name: str,
) -> List[str]:
"""Translate a DocsGPT SSE event dict to standard streaming chunks.
May return 0, 1, or 2 chunks per input event. For example, a completed
tool call produces both a docsgpt extension chunk and nothing on the
standard side (since server-side tool calls aren't surfaced in standard
format).
Args:
event_data: Parsed DocsGPT event dict.
completion_id: The completion ID for this response.
model_name: Model/agent identifier.
Returns:
List of SSE-formatted strings to send to the client.
"""
event_type = event_data.get("type")
chunks: List[str] = []
if event_type == "answer":
chunks.append(
_make_chunk(completion_id, model_name, {"content": event_data.get("answer", "")})
)
elif event_type == "thought":
chunks.append(
_make_chunk(
completion_id, model_name,
{"reasoning_content": event_data.get("thought", "")},
)
)
elif event_type == "source":
chunks.append(
_make_docsgpt_chunk({
"type": "source",
"sources": event_data.get("source", []),
})
)
elif event_type == "tool_call":
tc_data = event_data.get("data", {})
status = tc_data.get("status")
if status == "requires_client_execution":
# Standard: stream as tool_calls delta
args = tc_data.get("arguments", {})
args_str = json.dumps(args) if isinstance(args, dict) else str(args)
chunks.append(
_make_chunk(completion_id, model_name, {
"tool_calls": [{
"index": 0,
"id": tc_data.get("call_id", ""),
"type": "function",
"function": {
"name": _get_client_tool_name(tc_data),
"arguments": args_str,
},
}],
})
)
elif status == "awaiting_approval":
# Extension: approval needed
chunks.append(_make_docsgpt_chunk({"type": "tool_call", "data": tc_data}))
elif status in ("completed", "pending", "error", "denied", "skipped"):
# Extension: tool call progress
chunks.append(_make_docsgpt_chunk({"type": "tool_call", "data": tc_data}))
elif event_type == "tool_calls_pending":
# Standard: finish_reason = tool_calls
chunks.append(
_make_chunk(completion_id, model_name, {}, finish_reason="tool_calls")
)
# Also emit as docsgpt extension
chunks.append(
_make_docsgpt_chunk({
"type": "tool_calls_pending",
"pending_tool_calls": event_data.get("data", {}).get("pending_tool_calls", []),
})
)
elif event_type == "end":
chunks.append(
_make_chunk(completion_id, model_name, {}, finish_reason="stop")
)
chunks.append("data: [DONE]\n\n")
elif event_type == "id":
chunks.append(
_make_docsgpt_chunk({
"type": "id",
"conversation_id": event_data.get("id", ""),
})
)
elif event_type == "error":
# Emit as standard error (non-standard but widely supported)
error_data = {
"error": {
"message": event_data.get("error", "An error occurred"),
"type": "server_error",
}
}
chunks.append(f"data: {json.dumps(error_data)}\n\n")
elif event_type == "structured_answer":
chunks.append(
_make_chunk(
completion_id, model_name,
{"content": event_data.get("answer", "")},
)
)
# Skip: tool_calls (redundant), research_plan, research_progress
return chunks

View File

@@ -17,13 +17,8 @@ from application.api.answer import answer # noqa: E402
from application.api.internal.routes import internal # noqa: E402
from application.api.user.routes import user # noqa: E402
from application.api.connector.routes import connector # noqa: E402
from application.api.v1 import v1_bp # noqa: E402
from application.celery_init import celery # noqa: E402
from application.core.settings import settings # noqa: E402
from application.stt.upload_limits import ( # noqa: E402
build_stt_file_size_limit_message,
should_reject_stt_request,
)
if platform.system() == "Windows":
@@ -37,7 +32,6 @@ app.register_blueprint(user)
app.register_blueprint(answer)
app.register_blueprint(internal)
app.register_blueprint(connector)
app.register_blueprint(v1_bp)
app.config.update(
UPLOAD_FOLDER="inputs",
CELERY_BROKER_URL=settings.CELERY_BROKER_URL,
@@ -74,11 +68,6 @@ def home():
return "Welcome to DocsGPT Backend!"
@app.route("/api/health")
def health():
return jsonify({"status": "ok"})
@app.route("/api/config")
def get_config():
response = {
@@ -99,23 +88,6 @@ def generate_token():
return jsonify({"error": "Token generation not allowed in current auth mode"}), 400
@app.before_request
def enforce_stt_request_size_limits():
if request.method == "OPTIONS":
return None
if should_reject_stt_request(request.path, request.content_length):
return (
jsonify(
{
"success": False,
"message": build_stt_file_size_limit_message(),
}
),
413,
)
return None
@app.before_request
def authenticate_request():
if request.method == "OPTIONS":

View File

@@ -27,8 +27,6 @@ ANTHROPIC_ATTACHMENTS = IMAGE_ATTACHMENTS
OPENROUTER_ATTACHMENTS = IMAGE_ATTACHMENTS
NOVITA_ATTACHMENTS = IMAGE_ATTACHMENTS
OPENAI_MODELS = [
AvailableModel(
@@ -195,46 +193,6 @@ OPENROUTER_MODELS = [
),
]
NOVITA_MODELS = [
AvailableModel(
id="moonshotai/kimi-k2.5",
provider=ModelProvider.NOVITA,
display_name="Kimi K2.5",
description="MoE model with function calling, structured output, reasoning, and vision",
capabilities=ModelCapabilities(
supports_tools=True,
supports_structured_output=True,
supported_attachment_types=NOVITA_ATTACHMENTS,
context_window=262144,
),
),
AvailableModel(
id="zai-org/glm-5",
provider=ModelProvider.NOVITA,
display_name="GLM-5",
description="MoE model with function calling, structured output, and reasoning",
capabilities=ModelCapabilities(
supports_tools=True,
supports_structured_output=True,
supported_attachment_types=[],
context_window=202800,
),
),
AvailableModel(
id="minimax/minimax-m2.5",
provider=ModelProvider.NOVITA,
display_name="MiniMax M2.5",
description="MoE model with function calling, structured output, and reasoning",
capabilities=ModelCapabilities(
supports_tools=True,
supports_structured_output=True,
supported_attachment_types=[],
context_window=204800,
),
),
]
AZURE_OPENAI_MODELS = [
AvailableModel(
id="azure-gpt-4",

View File

@@ -114,10 +114,6 @@ class ModelRegistry:
settings.LLM_PROVIDER == "openrouter" and settings.API_KEY
):
self._add_openrouter_models(settings)
if settings.NOVITA_API_KEY or (
settings.LLM_PROVIDER == "novita" and settings.API_KEY
):
self._add_novita_models(settings)
if settings.HUGGINGFACE_API_KEY or (
settings.LLM_PROVIDER == "huggingface" and settings.API_KEY
):
@@ -249,21 +245,6 @@ class ModelRegistry:
for model in OPENROUTER_MODELS:
self.models[model.id] = model
def _add_novita_models(self, settings):
from application.core.model_configs import NOVITA_MODELS
if settings.NOVITA_API_KEY:
for model in NOVITA_MODELS:
self.models[model.id] = model
return
if settings.LLM_PROVIDER == "novita" and settings.LLM_NAME:
for model in NOVITA_MODELS:
if model.id == settings.LLM_NAME:
self.models[model.id] = model
return
for model in NOVITA_MODELS:
self.models[model.id] = model
def _add_docsgpt_models(self, settings):
model_id = "docsgpt-local"
model = AvailableModel(

View File

@@ -10,7 +10,6 @@ def get_api_key_for_provider(provider: str) -> Optional[str]:
provider_key_map = {
"openai": settings.OPENAI_API_KEY,
"openrouter": settings.OPEN_ROUTER_API_KEY,
"novita": settings.NOVITA_API_KEY,
"anthropic": settings.ANTHROPIC_API_KEY,
"google": settings.GOOGLE_API_KEY,
"groq": settings.GROQ_API_KEY,

View File

@@ -5,7 +5,9 @@ from typing import Optional
from pydantic import field_validator
from pydantic_settings import BaseSettings, SettingsConfigDict
current_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
current_dir = os.path.dirname(
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
)
class Settings(BaseSettings):
@@ -13,11 +15,15 @@ class Settings(BaseSettings):
AUTH_TYPE: Optional[str] = None # simple_jwt, session_jwt, or None
LLM_PROVIDER: str = "docsgpt"
LLM_NAME: Optional[str] = None # if LLM_PROVIDER is openai, LLM_NAME can be gpt-4 or gpt-3.5-turbo
LLM_NAME: Optional[str] = (
None # if LLM_PROVIDER is openai, LLM_NAME can be gpt-4 or gpt-3.5-turbo
)
EMBEDDINGS_NAME: str = "huggingface_sentence-transformers/all-mpnet-base-v2"
EMBEDDINGS_BASE_URL: Optional[str] = None # Remote embeddings API URL (OpenAI-compatible)
EMBEDDINGS_KEY: Optional[str] = None # api key for embeddings (if using openai, just copy API_KEY)
EMBEDDINGS_KEY: Optional[str] = (
None # api key for embeddings (if using openai, just copy API_KEY)
)
CELERY_BROKER_URL: str = "redis://localhost:6379/0"
CELERY_RESULT_BACKEND: str = "redis://localhost:6379/1"
MONGO_URI: str = "mongodb://localhost:27017/docsgpt"
@@ -39,7 +45,9 @@ class Settings(BaseSettings):
PARSE_IMAGE_REMOTE: bool = False
DOCLING_OCR_ENABLED: bool = False # Enable OCR for docling parsers (PDF, images)
DOCLING_OCR_ATTACHMENTS_ENABLED: bool = False # Enable OCR for docling when parsing attachments
VECTOR_STORE: str = "faiss" # "faiss" or "elasticsearch" or "qdrant" or "milvus" or "lancedb" or "pgvector"
VECTOR_STORE: str = (
"faiss" # "faiss" or "elasticsearch" or "qdrant" or "milvus" or "lancedb" or "pgvector"
)
RETRIEVERS_ENABLED: list = ["classic_rag"]
AGENT_NAME: str = "classic"
FALLBACK_LLM_PROVIDER: Optional[str] = None # provider for fallback llm
@@ -47,18 +55,16 @@ class Settings(BaseSettings):
FALLBACK_LLM_API_KEY: Optional[str] = None # api key for fallback llm
# Google Drive integration
GOOGLE_CLIENT_ID: Optional[str] = None # Replace with your actual Google OAuth client ID
GOOGLE_CLIENT_SECRET: Optional[str] = None # Replace with your actual Google OAuth client secret
GOOGLE_CLIENT_ID: Optional[str] = (
None # Replace with your actual Google OAuth client ID
)
GOOGLE_CLIENT_SECRET: Optional[str] = (
None # Replace with your actual Google OAuth client secret
)
CONNECTOR_REDIRECT_BASE_URI: Optional[str] = (
"http://127.0.0.1:7091/api/connectors/callback" ##add redirect url as it is to your provider's console(gcp)
)
# Microsoft Entra ID (Azure AD) integration
MICROSOFT_CLIENT_ID: Optional[str] = None # Azure AD Application (client) ID
MICROSOFT_CLIENT_SECRET: Optional[str] = None # Azure AD Application client secret
MICROSOFT_TENANT_ID: Optional[str] = "common" # Azure AD Tenant ID (or 'common' for multi-tenant)
MICROSOFT_AUTHORITY: Optional[str] = None # e.g., "https://login.microsoftonline.com/{tenant_id}"
# GitHub source
GITHUB_ACCESS_TOKEN: Optional[str] = None # PAT token with read repo access
@@ -66,7 +72,6 @@ class Settings(BaseSettings):
CACHE_REDIS_URL: str = "redis://localhost:6379/2"
API_URL: str = "http://localhost:7091" # backend url for celery worker
MCP_OAUTH_REDIRECT_URI: Optional[str] = None # public callback URL for MCP OAuth
INTERNAL_KEY: Optional[str] = None # internal api key for worker-to-backend auth
API_KEY: Optional[str] = None # LLM api key (used by LLM_PROVIDER)
@@ -78,13 +83,16 @@ class Settings(BaseSettings):
GROQ_API_KEY: Optional[str] = None
HUGGINGFACE_API_KEY: Optional[str] = None
OPEN_ROUTER_API_KEY: Optional[str] = None
NOVITA_API_KEY: Optional[str] = None
OPENAI_API_BASE: Optional[str] = None # azure openai api base url
OPENAI_API_VERSION: Optional[str] = None # azure openai api version
AZURE_DEPLOYMENT_NAME: Optional[str] = None # azure deployment name for answering
AZURE_EMBEDDINGS_DEPLOYMENT_NAME: Optional[str] = None # azure deployment name for embeddings
OPENAI_BASE_URL: Optional[str] = None # openai base url for open ai compatable models
AZURE_EMBEDDINGS_DEPLOYMENT_NAME: Optional[str] = (
None # azure deployment name for embeddings
)
OPENAI_BASE_URL: Optional[str] = (
None # openai base url for open ai compatable models
)
# elasticsearch
ELASTIC_CLOUD_ID: Optional[str] = None # cloud id for elasticsearch
@@ -126,7 +134,9 @@ class Settings(BaseSettings):
# LanceDB vectorstore config
LANCEDB_PATH: str = "./data/lancedb" # Path where LanceDB stores its local data
LANCEDB_TABLE_NAME: Optional[str] = "docsgpts" # Name of the table to use for storing vectors
LANCEDB_TABLE_NAME: Optional[str] = (
"docsgpts" # Name of the table to use for storing vectors
)
FLASK_DEBUG_MODE: bool = False
STORAGE_TYPE: str = "local" # local or s3
@@ -139,12 +149,6 @@ class Settings(BaseSettings):
TTS_PROVIDER: str = "google_tts" # google_tts or elevenlabs
ELEVENLABS_API_KEY: Optional[str] = None
STT_PROVIDER: str = "openai" # openai or faster_whisper
OPENAI_STT_MODEL: str = "gpt-4o-mini-transcribe"
STT_LANGUAGE: Optional[str] = None
STT_MAX_FILE_SIZE_MB: int = 50
STT_ENABLE_TIMESTAMPS: bool = False
STT_ENABLE_DIARIZATION: bool = False
# Tool pre-fetch settings
ENABLE_TOOL_PREFETCH: bool = True
@@ -163,7 +167,6 @@ class Settings(BaseSettings):
"GOOGLE_API_KEY",
"GROQ_API_KEY",
"HUGGINGFACE_API_KEY",
"NOVITA_API_KEY",
"EMBEDDINGS_KEY",
"FALLBACK_LLM_API_KEY",
"QDRANT_API_KEY",

View File

@@ -13,25 +13,3 @@ def response_error(code_status, message=None):
def bad_request(status_code=400, message=''):
return response_error(code_status=status_code, message=message)
def sanitize_api_error(error) -> str:
"""
Convert technical API errors to user-friendly messages.
Works with both Exception objects and error message strings.
"""
error_str = str(error).lower()
if "503" in error_str or "unavailable" in error_str or "high demand" in error_str:
return "The AI service is temporarily unavailable due to high demand. Please try again in a moment."
if "429" in error_str or "rate limit" in error_str or "quota" in error_str:
return "Rate limit exceeded. Please wait a moment before trying again."
if "401" in error_str or "unauthorized" in error_str or "invalid api key" in error_str:
return "Authentication error. Please check your API configuration."
if "timeout" in error_str or "timed out" in error_str:
return "The request timed out. Please try again."
if "connection" in error_str or "network" in error_str:
return "Network error. Please check your connection and try again."
original = str(error)
if len(original) > 200 or "{" in original or "traceback" in error_str:
return "An error occurred while processing your request. Please try again later."
return original

View File

@@ -16,61 +16,22 @@ class BaseLLM(ABC):
agent_id=None,
model_id=None,
base_url=None,
backup_models=None,
):
self.decoded_token = decoded_token
self.agent_id = str(agent_id) if agent_id else None
self.model_id = model_id
self.base_url = base_url
self.token_usage = {"prompt_tokens": 0, "generated_tokens": 0}
self._backup_models = backup_models or []
self._fallback_llm = None
self._fallback_sequence_index = 0
@property
def fallback_llm(self):
"""Lazy-loaded fallback LLM: tries per-agent backup models first,
then the global FALLBACK_* settings."""
if self._fallback_llm is not None:
return self._fallback_llm
from application.llm.llm_creator import LLMCreator
from application.core.model_utils import (
get_provider_from_model_id,
get_api_key_for_provider,
)
# Try per-agent backup models first
for backup_model_id in self._backup_models:
"""Lazy-loaded fallback LLM from FALLBACK_* settings."""
if self._fallback_llm is None and settings.FALLBACK_LLM_PROVIDER:
try:
provider = get_provider_from_model_id(backup_model_id)
if not provider:
logger.warning(
f"Could not resolve provider for backup model: {backup_model_id}"
)
continue
api_key = get_api_key_for_provider(provider)
self._fallback_llm = LLMCreator.create_llm(
provider,
api_key=api_key,
user_api_key=getattr(self, "user_api_key", None),
decoded_token=self.decoded_token,
model_id=backup_model_id,
agent_id=self.agent_id,
)
logger.info(
f"Fallback LLM initialized from agent backup model: "
f"{provider}/{backup_model_id}"
)
return self._fallback_llm
except Exception as e:
logger.warning(
f"Failed to initialize backup model {backup_model_id}: {str(e)}"
)
continue
from application.llm.llm_creator import LLMCreator
# Fall back to global FALLBACK_* settings
if settings.FALLBACK_LLM_PROVIDER:
try:
self._fallback_llm = LLMCreator.create_llm(
settings.FALLBACK_LLM_PROVIDER,
api_key=settings.FALLBACK_LLM_API_KEY or settings.API_KEY,
@@ -80,14 +41,12 @@ class BaseLLM(ABC):
agent_id=self.agent_id,
)
logger.info(
f"Fallback LLM initialized from global settings: "
f"{settings.FALLBACK_LLM_PROVIDER}/{settings.FALLBACK_LLM_NAME}"
f"Fallback LLM initialized: {settings.FALLBACK_LLM_PROVIDER}/{settings.FALLBACK_LLM_NAME}"
)
except Exception as e:
logger.error(
f"Failed to initialize fallback LLM: {str(e)}", exc_info=True
)
return self._fallback_llm
@staticmethod
@@ -115,60 +74,20 @@ class BaseLLM(ABC):
method = decorator(method)
return method(self, *args, **kwargs)
is_stream = "stream" in method_name
if is_stream:
return self._stream_with_fallback(
decorated_method, method_name, *args, **kwargs
)
try:
return decorated_method()
except Exception as e:
if not self.fallback_llm:
logger.error(f"Primary LLM failed and no fallback configured: {str(e)}")
raise
fallback = self.fallback_llm
logger.warning(
f"Primary LLM failed. Falling back to "
f"{fallback.model_id}. Error: {str(e)}"
f"Primary LLM failed. Falling back to {settings.FALLBACK_LLM_PROVIDER}/{settings.FALLBACK_LLM_NAME}. Error: {str(e)}"
)
fallback_method = getattr(
fallback, method_name.replace("_raw_", "")
self.fallback_llm, method_name.replace("_raw_", "")
)
fallback_kwargs = {**kwargs, "model": fallback.model_id}
return fallback_method(*args, **fallback_kwargs)
def _stream_with_fallback(
self, decorated_method, method_name, *args, **kwargs
):
"""
Wrapper generator that catches mid-stream errors and falls back.
Unlike non-streaming calls where exceptions are raised immediately,
streaming generators raise exceptions during iteration. This wrapper
ensures that if the primary LLM fails at any point during streaming
(creation or mid-stream), we fall back to the backup model.
"""
try:
yield from decorated_method()
except Exception as e:
if not self.fallback_llm:
logger.error(
f"Primary LLM failed and no fallback configured: {str(e)}"
)
raise
fallback = self.fallback_llm
logger.warning(
f"Primary LLM failed mid-stream. Falling back to "
f"{fallback.model_id}. Error: {str(e)}"
)
fallback_method = getattr(
fallback, method_name.replace("_raw_", "")
)
fallback_kwargs = {**kwargs, "model": fallback.model_id}
yield from fallback_method(*args, **fallback_kwargs)
return fallback_method(*args, **kwargs)
def gen(self, model, messages, stream=False, tools=None, *args, **kwargs):
decorators = [gen_token_usage, gen_cache]

View File

@@ -158,17 +158,11 @@ class GoogleLLM(BaseLLM):
if isinstance(content, list):
parts = []
for item in content:
if (
isinstance(item, dict)
and "text" in item
and item["text"] is not None
):
if isinstance(item, dict) and "text" in item and item["text"] is not None:
parts.append(item["text"])
return "\n".join(parts)
return ""
import json as _json
for message in messages:
role = message.get("role")
content = message.get("content")
@@ -182,66 +176,9 @@ class GoogleLLM(BaseLLM):
if role == "assistant":
role = "model"
parts = []
# Standard format: assistant message with tool_calls array
msg_tool_calls = message.get("tool_calls")
if msg_tool_calls and role == "model":
for tc in msg_tool_calls:
func = tc.get("function", {})
args = func.get("arguments", "{}")
if isinstance(args, str):
try:
args = _json.loads(args)
except (_json.JSONDecodeError, TypeError):
args = {}
cleaned_args = self._remove_null_values(args)
thought_sig = tc.get("thought_signature")
if thought_sig:
parts.append(
types.Part(
functionCall=types.FunctionCall(
name=func.get("name", ""),
args=cleaned_args,
),
thoughtSignature=thought_sig,
)
)
else:
parts.append(
types.Part.from_function_call(
name=func.get("name", ""),
args=cleaned_args,
)
)
if parts:
cleaned_messages.append(types.Content(role=role, parts=parts))
continue
# Standard format: tool message with tool_call_id
tool_call_id = message.get("tool_call_id")
if role == "tool" and tool_call_id is not None:
result_content = content
if isinstance(result_content, str):
try:
result_content = _json.loads(result_content)
except (_json.JSONDecodeError, TypeError):
pass
# Google expects function_response name — extract from tool_call_id context
# We use a placeholder name since Google API doesn't require exact match
parts.append(
types.Part.from_function_response(
name="tool_result",
response={"result": result_content},
)
)
cleaned_messages.append(types.Content(role="model", parts=parts))
continue
if role == "tool":
elif role == "tool":
role = "model"
parts = []
if role and content is not None:
if isinstance(content, str):
parts = [types.Part.from_text(text=content)]
@@ -250,11 +187,15 @@ class GoogleLLM(BaseLLM):
if "text" in item:
parts.append(types.Part.from_text(text=item["text"]))
elif "function_call" in item:
# Legacy format support
# Remove null values from args to avoid API errors
cleaned_args = self._remove_null_values(
item["function_call"]["args"]
)
# Create function call part with thought_signature if present
# For Gemini 3 models, we need to include thought_signature
if "thought_signature" in item:
# Use Part constructor with functionCall and thoughtSignature
parts.append(
types.Part(
functionCall=types.FunctionCall(
@@ -265,6 +206,7 @@ class GoogleLLM(BaseLLM):
)
)
else:
# Use helper method when no thought_signature
parts.append(
types.Part.from_function_call(
name=item["function_call"]["name"],
@@ -294,9 +236,7 @@ class GoogleLLM(BaseLLM):
raise ValueError(f"Unexpected content type: {type(content)}")
if parts:
cleaned_messages.append(types.Content(role=role, parts=parts))
system_instruction = (
"\n\n".join(system_instructions) if system_instructions else None
)
system_instruction = "\n\n".join(system_instructions) if system_instructions else None
return cleaned_messages, system_instruction
def _clean_schema(self, schema_obj):
@@ -396,10 +336,7 @@ class GoogleLLM(BaseLLM):
return f"function_call:{name}"
function_response = getattr(part, "function_response", None)
if function_response:
name = (
getattr(function_response, "name", "")
or "function_response"
)
name = getattr(function_response, "name", "") or "function_response"
return f"function_response:{name}"
if isinstance(message, dict):
content = message.get("content")
@@ -570,9 +507,6 @@ class GoogleLLM(BaseLLM):
yield {"type": "thought", "thought": chunk_text}
else:
yield chunk_text
except Exception as e:
logging.error(f"GoogleLLM: Stream error: {e}", exc_info=True)
raise
finally:
if hasattr(response, "close"):
response.close()

View File

@@ -1,4 +1,3 @@
import json
import logging
import uuid
from abc import ABC, abstractmethod
@@ -316,34 +315,10 @@ class LLMHandler(ABC):
current_prompt = self._extract_text_from_content(content)
elif role in {"assistant", "model"}:
# Standard format: tool_calls array on assistant message
msg_tool_calls = message.get("tool_calls")
if msg_tool_calls:
for tc in msg_tool_calls:
call_id = tc.get("id") or str(uuid.uuid4())
func = tc.get("function", {})
args = func.get("arguments")
if isinstance(args, str):
try:
args = json.loads(args)
except (json.JSONDecodeError, TypeError):
pass
current_tool_calls[call_id] = {
"tool_name": "unknown_tool",
"action_name": func.get("name"),
"arguments": args,
"result": None,
"status": "called",
"call_id": call_id,
}
continue
# Legacy format: function_call/function_response in content list
# If this assistant turn contains tool calls, collect them; otherwise commit a response.
if isinstance(content, list):
has_fc = False
for item in content:
if "function_call" in item:
has_fc = True
fc = item["function_call"]
call_id = fc.get("call_id") or str(uuid.uuid4())
current_tool_calls[call_id] = {
@@ -354,30 +329,37 @@ class LLMHandler(ABC):
"status": "called",
"call_id": call_id,
}
if has_fc:
continue
elif "function_response" in item:
fr = item["function_response"]
call_id = fr.get("call_id") or str(uuid.uuid4())
current_tool_calls[call_id] = {
"tool_name": "unknown_tool",
"action_name": fr.get("name"),
"arguments": None,
"result": fr.get("response", {}).get("result"),
"status": "completed",
"call_id": call_id,
}
# No direct assistant text here; continue to next message
continue
response_text = self._extract_text_from_content(content)
_commit_query(response_text)
elif role == "tool":
# Standard format: tool_call_id on tool message
call_id = message.get("tool_call_id")
# Attach tool outputs to the latest pending tool call if possible
tool_text = self._extract_text_from_content(content)
# Attempt to parse function_response style
call_id = None
if isinstance(content, list):
for item in content:
if "function_response" in item and item["function_response"].get("call_id"):
call_id = item["function_response"]["call_id"]
break
if call_id and call_id in current_tool_calls:
current_tool_calls[call_id]["result"] = tool_text
current_tool_calls[call_id]["status"] = "completed"
# Legacy: function_response in content list
elif isinstance(content, list):
for item in content:
if "function_response" in item:
legacy_id = item["function_response"].get("call_id")
if legacy_id and legacy_id in current_tool_calls:
current_tool_calls[legacy_id]["result"] = tool_text
current_tool_calls[legacy_id]["status"] = "completed"
break
elif call_id is None and queries:
elif queries:
queries[-1].setdefault("tool_calls", []).append(
{
"tool_name": "unknown_tool",
@@ -666,13 +648,6 @@ class LLMHandler(ABC):
"""
Execute tool calls and update conversation history.
When a tool requires approval or client-side execution, it is
collected as a pending action instead of being executed. The
generator returns ``(updated_messages, pending_actions)`` where
*pending_actions* is ``None`` when every tool was executed
normally, or a list of dicts describing actions the client must
resolve before the LLM loop can continue.
Args:
agent: The agent instance
tool_calls: List of tool calls to execute
@@ -680,11 +655,9 @@ class LLMHandler(ABC):
messages: Current conversation history
Returns:
Tuple of (updated_messages, pending_actions).
pending_actions is None if all tools executed, otherwise a list.
Updated messages list
"""
updated_messages = messages.copy()
pending_actions: List[Dict] = []
for i, call in enumerate(tool_calls):
# Check context limit before executing tool call
@@ -790,29 +763,6 @@ class LLMHandler(ABC):
# Set flag on agent
agent.context_limit_reached = True
break
# ---- Pause check: approval / client-side execution ----
llm_class = agent.llm.__class__.__name__
pause_info = agent.tool_executor.check_pause(
tools_dict, call, llm_class
)
if pause_info:
# Yield pause event so the client knows this tool is waiting
yield {
"type": "tool_call",
"data": {
"tool_name": pause_info["tool_name"],
"call_id": pause_info["call_id"],
"action_name": pause_info.get("llm_name", pause_info["name"]),
"arguments": pause_info["arguments"],
"status": pause_info["pause_type"],
},
}
pending_actions.append(pause_info)
# Do NOT add messages for pending tools here.
# They will be added on resume to keep call/result pairs together.
continue
try:
self.tool_calls.append(call)
tool_executor_gen = agent._execute_tool_action(tools_dict, call)
@@ -822,30 +772,25 @@ class LLMHandler(ABC):
except StopIteration as e:
tool_response, call_id = e.value
break
# Standard internal format: assistant message with tool_calls array
args_str = (
json.dumps(call.arguments)
if isinstance(call.arguments, dict)
else call.arguments
)
tool_call_obj = {
"id": call_id,
"type": "function",
"function": {
function_call_content = {
"function_call": {
"name": call.name,
"arguments": args_str,
},
"args": call.arguments,
"call_id": call_id,
}
}
# Preserve thought_signature for Google Gemini 3 models
# Include thought_signature for Google Gemini 3 models
# It should be at the same level as function_call, not inside it
if call.thought_signature:
tool_call_obj["thought_signature"] = call.thought_signature
function_call_content["thought_signature"] = call.thought_signature
updated_messages.append(
{
"role": "assistant",
"content": [function_call_content],
}
)
updated_messages.append({
"role": "assistant",
"content": None,
"tool_calls": [tool_call_obj],
})
updated_messages.append(self.create_tool_message(call, tool_response))
except Exception as e:
@@ -857,15 +802,16 @@ class LLMHandler(ABC):
error_message = self.create_tool_message(error_call, error_response)
updated_messages.append(error_message)
mapping = agent.tool_executor._name_to_tool
if call.name in mapping:
resolved_tool_id, _ = mapping[call.name]
tool_name = tools_dict.get(resolved_tool_id, {}).get(
"name", "unknown_tool"
)
call_parts = call.name.split("_")
if len(call_parts) >= 2:
tool_id = call_parts[-1] # Last part is tool ID (e.g., "1")
action_name = "_".join(call_parts[:-1])
tool_name = tools_dict.get(tool_id, {}).get("name", "unknown_tool")
full_action_name = f"{action_name}_{tool_id}"
else:
tool_name = "unknown_tool"
full_action_name = call.name
action_name = call.name
full_action_name = call.name
yield {
"type": "tool_call",
"data": {
@@ -877,7 +823,7 @@ class LLMHandler(ABC):
"status": "error",
},
}
return updated_messages, pending_actions if pending_actions else None
return updated_messages
def handle_non_streaming(
self, agent, response: Any, tools_dict: Dict, messages: List[Dict]
@@ -905,22 +851,8 @@ class LLMHandler(ABC):
try:
yield next(tool_handler_gen)
except StopIteration as e:
messages, pending_actions = e.value
messages = e.value
break
# If tools need approval or client execution, pause the loop
if pending_actions:
agent._pending_continuation = {
"messages": messages,
"pending_tool_calls": pending_actions,
"tools_dict": tools_dict,
}
yield {
"type": "tool_calls_pending",
"data": {"pending_tool_calls": pending_actions},
}
return ""
response = agent.llm.gen(
model=agent.model_id, messages=messages, tools=agent.tools
)
@@ -981,23 +913,10 @@ class LLMHandler(ABC):
try:
yield next(tool_handler_gen)
except StopIteration as e:
messages, pending_actions = e.value
messages = e.value
break
tool_calls = {}
# If tools need approval or client execution, pause the loop
if pending_actions:
agent._pending_continuation = {
"messages": messages,
"pending_tool_calls": pending_actions,
"tools_dict": tools_dict,
}
yield {
"type": "tool_calls_pending",
"data": {"pending_tool_calls": pending_actions},
}
return
# Check if context limit was reached during tool execution
if hasattr(agent, 'context_limit_reached') and agent.context_limit_reached:
# Add system message warning about context limit

View File

@@ -67,18 +67,18 @@ class GoogleLLMHandler(LLMHandler):
)
def create_tool_message(self, tool_call: ToolCall, result: Any) -> Dict:
"""Create a tool result message in the standard internal format."""
import json as _json
"""Create Google-style tool message."""
content = (
_json.dumps(result)
if not isinstance(result, str)
else result
)
return {
"role": "tool",
"tool_call_id": tool_call.id,
"content": content,
"role": "model",
"content": [
{
"function_response": {
"name": tool_call.name,
"response": {"result": result},
}
}
],
}
def _iterate_stream(self, response: Any) -> Generator:

View File

@@ -7,7 +7,6 @@ class LLMHandlerCreator:
handlers = {
"openai": OpenAILLMHandler,
"google": GoogleLLMHandler,
"novita": OpenAILLMHandler, # Novita uses OpenAI-compatible API
"default": OpenAILLMHandler,
}

View File

@@ -37,18 +37,18 @@ class OpenAILLMHandler(LLMHandler):
)
def create_tool_message(self, tool_call: ToolCall, result: Any) -> Dict:
"""Create a tool result message in the standard internal format."""
import json as _json
content = (
_json.dumps(result)
if not isinstance(result, str)
else result
)
"""Create OpenAI-style tool message."""
return {
"role": "tool",
"tool_call_id": tool_call.id,
"content": content,
"content": [
{
"function_response": {
"name": tool_call.name,
"response": {"result": result},
"call_id": tool_call.id,
}
}
],
}
def _iterate_stream(self, response: Any) -> Generator:

View File

@@ -38,7 +38,6 @@ class LLMCreator:
decoded_token,
model_id=None,
agent_id=None,
backup_models=None,
*args,
**kwargs,
):
@@ -60,7 +59,6 @@ class LLMCreator:
model_id=model_id,
agent_id=agent_id,
base_url=base_url,
backup_models=backup_models,
*args,
**kwargs,
)

View File

@@ -1,13 +1,13 @@
from application.core.settings import settings
from application.llm.openai import OpenAILLM
NOVITA_BASE_URL = "https://api.novita.ai/openai"
NOVITA_BASE_URL = "https://api.novita.ai/v3/openai"
class NovitaLLM(OpenAILLM):
def __init__(self, api_key=None, user_api_key=None, base_url=None, *args, **kwargs):
super().__init__(
api_key=api_key or settings.NOVITA_API_KEY or settings.API_KEY,
api_key=api_key or settings.API_KEY,
user_api_key=user_api_key,
base_url=base_url or NOVITA_BASE_URL,
*args,

View File

@@ -91,59 +91,19 @@ class OpenAILLM(BaseLLM):
if role == "model":
role = "assistant"
# Standard format: assistant message with tool_calls (passthrough)
tool_calls = message.get("tool_calls")
if tool_calls and role == "assistant":
cleaned_tcs = []
for tc in tool_calls:
func = tc.get("function", {})
args = func.get("arguments", "{}")
if isinstance(args, dict):
args = json.dumps(self._remove_null_values(args))
elif isinstance(args, str):
try:
parsed = json.loads(args)
args = json.dumps(self._remove_null_values(parsed))
except (json.JSONDecodeError, TypeError):
pass
cleaned_tcs.append({
"id": tc.get("id", ""),
"type": "function",
"function": {"name": func.get("name", ""), "arguments": args},
})
cleaned_messages.append({
"role": "assistant",
"content": None,
"tool_calls": cleaned_tcs,
})
continue
# Standard format: tool message with tool_call_id (passthrough)
tool_call_id = message.get("tool_call_id")
if role == "tool" and tool_call_id is not None:
cleaned_messages.append({
"role": "tool",
"tool_call_id": tool_call_id,
"content": content if isinstance(content, str) else json.dumps(content),
})
continue
if role and content is not None:
if isinstance(content, str):
cleaned_messages.append({"role": role, "content": content})
elif isinstance(content, list):
# Collect all content parts into a single message
content_parts = []
for item in content:
# Legacy format support: function_call / function_response
if "function_call" in item:
args = item["function_call"]["args"]
if isinstance(args, str):
try:
args = json.loads(args)
except (json.JSONDecodeError, TypeError):
pass
cleaned_args = self._remove_null_values(args)
# Function calls need their own message
cleaned_args = self._remove_null_values(
item["function_call"]["args"]
)
tool_call = {
"id": item["function_call"]["call_id"],
"type": "function",
@@ -152,20 +112,28 @@ class OpenAILLM(BaseLLM):
"arguments": json.dumps(cleaned_args),
},
}
cleaned_messages.append({
"role": "assistant",
"content": None,
"tool_calls": [tool_call],
})
cleaned_messages.append(
{
"role": "assistant",
"content": None,
"tool_calls": [tool_call],
}
)
elif "function_response" in item:
cleaned_messages.append({
"role": "tool",
"tool_call_id": item["function_response"]["call_id"],
"content": json.dumps(
item["function_response"]["response"]["result"]
),
})
# Function responses need their own message
cleaned_messages.append(
{
"role": "tool",
"tool_call_id": item["function_response"][
"call_id"
],
"content": json.dumps(
item["function_response"]["response"]["result"]
),
}
)
elif isinstance(item, dict):
# Collect content parts (text, images, files) into a single message
if "type" in item and item["type"] == "text" and "text" in item:
content_parts.append(item)
elif "type" in item and item["type"] == "file" and "file" in item:
@@ -173,7 +141,10 @@ class OpenAILLM(BaseLLM):
elif "type" in item and item["type"] == "image_url" and "image_url" in item:
content_parts.append(item)
elif "text" in item and "type" not in item:
# Legacy format: {"text": "..."} without type
content_parts.append({"type": "text", "text": item["text"]})
# Add the collected content parts as a single message
if content_parts:
cleaned_messages.append({"role": role, "content": content_parts})
else:

View File

@@ -62,26 +62,15 @@ class BaseConnectorAuth(ABC):
def is_token_expired(self, token_info: Dict[str, Any]) -> bool:
"""
Check if a token is expired.
Args:
token_info: Token information dictionary
Returns:
True if token is expired, False otherwise
"""
pass
def sanitize_token_info(self, token_info: Dict[str, Any], **extra_fields) -> Dict[str, Any]:
"""Extract the fields safe to persist in the session store.
"""
return {
"access_token": token_info.get("access_token"),
"refresh_token": token_info.get("refresh_token"),
"token_uri": token_info.get("token_uri"),
"expiry": token_info.get("expiry"),
**extra_fields,
}
class BaseConnectorLoader(ABC):
"""

View File

@@ -1,7 +1,5 @@
from application.parser.connectors.google_drive.loader import GoogleDriveLoader
from application.parser.connectors.google_drive.auth import GoogleDriveAuth
from application.parser.connectors.share_point.auth import SharePointAuth
from application.parser.connectors.share_point.loader import SharePointLoader
class ConnectorCreator:
@@ -14,12 +12,10 @@ class ConnectorCreator:
connectors = {
"google_drive": GoogleDriveLoader,
"share_point": SharePointLoader,
}
auth_providers = {
"google_drive": GoogleDriveAuth,
"share_point": SharePointAuth,
}
@classmethod

View File

@@ -232,6 +232,10 @@ class GoogleDriveAuth(BaseConnectorAuth):
if missing_fields:
raise ValueError(f"Missing required token fields: {missing_fields}")
if 'client_id' not in token_info:
token_info['client_id'] = settings.GOOGLE_CLIENT_ID
if 'client_secret' not in token_info:
token_info['client_secret'] = settings.GOOGLE_CLIENT_SECRET
if 'token_uri' not in token_info:
token_info['token_uri'] = 'https://oauth2.googleapis.com/token'

View File

@@ -327,10 +327,15 @@ class GoogleDriveLoader(BaseConnectorLoader):
content_bytes = file_io.getvalue()
try:
return content_bytes.decode('utf-8')
content = content_bytes.decode('utf-8')
except UnicodeDecodeError:
logging.error(f"Could not decode file {file_id} as text")
return None
try:
content = content_bytes.decode('latin-1')
except UnicodeDecodeError:
logging.error(f"Could not decode file {file_id} as text")
return None
return content
except HttpError as e:
logging.error(f"HTTP error downloading file {file_id}: {e.resp.status} - {e.content}")

View File

@@ -1,10 +0,0 @@
"""
Share Point connector package for DocsGPT.
This module provides authentication and document loading capabilities for Share Point.
"""
from .auth import SharePointAuth
from .loader import SharePointLoader
__all__ = ['SharePointAuth', 'SharePointLoader']

View File

@@ -1,152 +0,0 @@
import datetime
import logging
from typing import Optional, Dict, Any
from msal import ConfidentialClientApplication
from application.core.settings import settings
from application.parser.connectors.base import BaseConnectorAuth
logger = logging.getLogger(__name__)
class SharePointAuth(BaseConnectorAuth):
"""
Handles Microsoft OAuth 2.0 authentication for SharePoint/OneDrive.
Note: Files.Read scope allows access to files the user has granted access to,
similar to Google Drive's drive.file scope.
"""
SCOPES = [
"Files.Read",
"Sites.Read.All",
"User.Read",
]
def __init__(self):
self.client_id = settings.MICROSOFT_CLIENT_ID
self.client_secret = settings.MICROSOFT_CLIENT_SECRET
if not self.client_id:
raise ValueError(
"Microsoft OAuth credentials not configured. Please set MICROSOFT_CLIENT_ID in settings."
)
if not self.client_secret:
raise ValueError(
"Microsoft OAuth credentials not configured. Please set MICROSOFT_CLIENT_SECRET in settings."
)
self.redirect_uri = settings.CONNECTOR_REDIRECT_BASE_URI
self.tenant_id = settings.MICROSOFT_TENANT_ID
self.authority = getattr(settings, "MICROSOFT_AUTHORITY", f"https://login.microsoftonline.com/{self.tenant_id}")
self.auth_app = ConfidentialClientApplication(
client_id=self.client_id,
client_credential=self.client_secret,
authority=self.authority
)
def get_authorization_url(self, state: Optional[str] = None) -> str:
return self.auth_app.get_authorization_request_url(
scopes=self.SCOPES, state=state, redirect_uri=self.redirect_uri
)
def exchange_code_for_tokens(self, authorization_code: str) -> Dict[str, Any]:
result = self.auth_app.acquire_token_by_authorization_code(
code=authorization_code,
scopes=self.SCOPES,
redirect_uri=self.redirect_uri
)
if "error" in result:
logger.error("Token exchange failed: %s", result.get("error_description"))
raise ValueError(f"Error acquiring token: {result.get('error_description')}")
return self.map_token_response(result)
def refresh_access_token(self, refresh_token: str) -> Dict[str, Any]:
result = self.auth_app.acquire_token_by_refresh_token(refresh_token=refresh_token, scopes=self.SCOPES)
if "error" in result:
logger.error("Token refresh failed: %s", result.get("error_description"))
raise ValueError(f"Error refreshing token: {result.get('error_description')}")
return self.map_token_response(result)
def get_token_info_from_session(self, session_token: str) -> Dict[str, Any]:
try:
from application.core.mongo_db import MongoDB
from application.core.settings import settings
mongo = MongoDB.get_client()
db = mongo[settings.MONGO_DB_NAME]
sessions_collection = db["connector_sessions"]
session = sessions_collection.find_one({"session_token": session_token})
if not session:
raise ValueError(f"Invalid session token: {session_token}")
if "token_info" not in session:
raise ValueError("Session missing token information")
token_info = session["token_info"]
if not token_info:
raise ValueError("Invalid token information")
required_fields = ["access_token", "refresh_token"]
missing_fields = [field for field in required_fields if field not in token_info or not token_info.get(field)]
if missing_fields:
raise ValueError(f"Missing required token fields: {missing_fields}")
if 'token_uri' not in token_info:
token_info['token_uri'] = f"https://login.microsoftonline.com/{settings.MICROSOFT_TENANT_ID}/oauth2/v2.0/token"
return token_info
except Exception as e:
logger.error("Failed to retrieve token from session: %s", e)
raise ValueError(f"Failed to retrieve SharePoint token information: {str(e)}")
def is_token_expired(self, token_info: Dict[str, Any]) -> bool:
if not token_info:
return True
expiry_timestamp = token_info.get("expiry")
if expiry_timestamp is None:
return True
current_timestamp = int(datetime.datetime.now().timestamp())
return (expiry_timestamp - current_timestamp) < 60
def sanitize_token_info(self, token_info: Dict[str, Any], **extra_fields) -> Dict[str, Any]:
return super().sanitize_token_info(
token_info,
allows_shared_content=token_info.get("allows_shared_content", False),
**extra_fields,
)
PERSONAL_ACCOUNT_TENANT_ID = "9188040d-6c67-4c5b-b112-36a304b66dad"
def _allows_shared_content(self, id_token_claims: Dict[str, Any]) -> bool:
"""Return True when the account is a work/school tenant that can access SharePoint shared content."""
tid = id_token_claims.get("tid", "")
return bool(tid) and tid != self.PERSONAL_ACCOUNT_TENANT_ID
def map_token_response(self, result) -> Dict[str, Any]:
claims = result.get("id_token_claims", {})
return {
"access_token": result.get("access_token"),
"refresh_token": result.get("refresh_token"),
"token_uri": claims.get("iss"),
"scopes": result.get("scope"),
"expiry": claims.get("exp"),
"allows_shared_content": self._allows_shared_content(claims),
"user_info": {
"name": claims.get("name"),
"email": claims.get("preferred_username"),
},
}

View File

@@ -1,649 +0,0 @@
"""
SharePoint/OneDrive loader for DocsGPT.
Loads documents from SharePoint/OneDrive using Microsoft Graph API.
"""
import functools
import logging
import os
from typing import List, Dict, Any, Optional, Tuple
from urllib.parse import quote
import requests
from application.parser.connectors.base import BaseConnectorLoader
from application.parser.connectors.share_point.auth import SharePointAuth
from application.parser.schema.base import Document
def _retry_on_auth_failure(func):
"""Retry once after refreshing the access token on 401/403 responses."""
@functools.wraps(func)
def wrapper(self, *args, **kwargs):
try:
return func(self, *args, **kwargs)
except requests.exceptions.HTTPError as e:
if e.response is not None and e.response.status_code in (401, 403):
logging.info(f"Auth failure in {func.__name__}, refreshing token and retrying")
try:
new_token_info = self.auth.refresh_access_token(self.refresh_token)
self.access_token = new_token_info.get('access_token')
except Exception as refresh_error:
raise ValueError(
f"Authentication failed and could not be refreshed: {refresh_error}"
) from e
return func(self, *args, **kwargs)
raise
return wrapper
class SharePointLoader(BaseConnectorLoader):
SUPPORTED_MIME_TYPES = {
'application/pdf': '.pdf',
'application/vnd.openxmlformats-officedocument.wordprocessingml.document': '.docx',
'application/vnd.openxmlformats-officedocument.presentationml.presentation': '.pptx',
'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet': '.xlsx',
'application/msword': '.doc',
'application/vnd.ms-powerpoint': '.ppt',
'application/vnd.ms-excel': '.xls',
'text/plain': '.txt',
'text/csv': '.csv',
'text/html': '.html',
'text/markdown': '.md',
'text/x-rst': '.rst',
'application/json': '.json',
'application/epub+zip': '.epub',
'application/rtf': '.rtf',
'image/jpeg': '.jpg',
'image/png': '.png',
}
EXTENSION_TO_MIME = {v: k for k, v in SUPPORTED_MIME_TYPES.items()}
GRAPH_API_BASE = "https://graph.microsoft.com/v1.0"
def __init__(self, session_token: str):
self.auth = SharePointAuth()
self.session_token = session_token
token_info = self.auth.get_token_info_from_session(session_token)
self.access_token = token_info.get('access_token')
self.refresh_token = token_info.get('refresh_token')
self.allows_shared_content = token_info.get('allows_shared_content', False)
if not self.access_token:
raise ValueError("No access token found in session")
self.next_page_token = None
def _get_headers(self) -> Dict[str, str]:
return {
'Authorization': f'Bearer {self.access_token}',
'Accept': 'application/json'
}
def _ensure_valid_token(self):
if not self.access_token:
raise ValueError("No access token available")
token_info = {'access_token': self.access_token, 'expiry': None}
if self.auth.is_token_expired(token_info):
logging.info("Token expired, attempting refresh")
try:
new_token_info = self.auth.refresh_access_token(self.refresh_token)
self.access_token = new_token_info.get('access_token')
except Exception:
raise ValueError("Failed to refresh access token")
def _get_item_url(self, item_ref: str) -> str:
if ':' in item_ref:
drive_id, item_id = item_ref.split(':', 1)
return f"{self.GRAPH_API_BASE}/drives/{drive_id}/items/{item_id}"
return f"{self.GRAPH_API_BASE}/me/drive/items/{item_ref}"
def _process_file(self, file_metadata: Dict[str, Any], load_content: bool = True) -> Optional[Document]:
try:
drive_item_id = file_metadata.get('id')
file_name = file_metadata.get('name', 'Unknown')
file_data = file_metadata.get('file', {})
mime_type = file_data.get('mimeType', 'application/octet-stream')
if mime_type not in self.SUPPORTED_MIME_TYPES:
logging.info(f"Skipping unsupported file type: {mime_type} for file {file_name}")
return None
doc_metadata = {
'file_name': file_name,
'mime_type': mime_type,
'size': file_metadata.get('size'),
'created_time': file_metadata.get('createdDateTime'),
'modified_time': file_metadata.get('lastModifiedDateTime'),
'source': 'share_point'
}
if not load_content:
return Document(
text="",
doc_id=drive_item_id,
extra_info=doc_metadata
)
content = self._download_file_content(drive_item_id)
if content is None:
logging.warning(f"Could not load content for file {file_name} ({drive_item_id})")
return None
return Document(
text=content,
doc_id=drive_item_id,
extra_info=doc_metadata
)
except Exception as e:
logging.error(f"Error processing file: {e}")
return None
def load_data(self, inputs: Dict[str, Any]) -> List[Document]:
try:
documents: List[Document] = []
folder_id = inputs.get('folder_id')
file_ids = inputs.get('file_ids', [])
limit = inputs.get('limit', 100)
list_only = inputs.get('list_only', False)
load_content = not list_only
page_token = inputs.get('page_token')
search_query = inputs.get('search_query')
self.next_page_token = None
shared = inputs.get('shared', False)
if file_ids:
for file_id in file_ids:
try:
doc = self._load_file_by_id(file_id, load_content=load_content)
if doc:
if not search_query or (
search_query.lower() in doc.extra_info.get('file_name', '').lower()
):
documents.append(doc)
except Exception as e:
logging.error(f"Error loading file {file_id}: {e}")
continue
elif shared:
if not self.allows_shared_content:
logging.warning("Shared content is only available for work/school Microsoft accounts")
return []
documents = self._list_shared_items(
limit=limit,
load_content=load_content,
page_token=page_token,
search_query=search_query
)
else:
parent_id = folder_id if folder_id else 'root'
documents = self._list_items_in_parent(
parent_id,
limit=limit,
load_content=load_content,
page_token=page_token,
search_query=search_query
)
logging.info(f"Loaded {len(documents)} documents from SharePoint/OneDrive")
return documents
except Exception as e:
logging.error(f"Error loading data from SharePoint/OneDrive: {e}", exc_info=True)
raise
@_retry_on_auth_failure
def _load_file_by_id(self, file_id: str, load_content: bool = True) -> Optional[Document]:
self._ensure_valid_token()
try:
url = self._get_item_url(file_id)
params = {'$select': 'id,name,file,createdDateTime,lastModifiedDateTime,size'}
response = requests.get(url, headers=self._get_headers(), params=params)
response.raise_for_status()
file_metadata = response.json()
return self._process_file(file_metadata, load_content=load_content)
except requests.exceptions.HTTPError:
raise
except Exception as e:
logging.error(f"Error loading file {file_id}: {e}")
return None
@_retry_on_auth_failure
def _list_items_in_parent(self, parent_id: str, limit: int = 100, load_content: bool = False, page_token: Optional[str] = None, search_query: Optional[str] = None) -> List[Document]:
self._ensure_valid_token()
documents: List[Document] = []
try:
url = f"{self._get_item_url(parent_id)}/children"
params = {'$top': min(100, limit) if limit else 100, '$select': 'id,name,file,folder,createdDateTime,lastModifiedDateTime,size'}
if page_token:
params['$skipToken'] = page_token
if search_query:
encoded_query = quote(search_query, safe='')
if ':' in parent_id:
drive_id = parent_id.split(':', 1)[0]
search_url = f"{self.GRAPH_API_BASE}/drives/{drive_id}/root/search(q='{encoded_query}')"
else:
search_url = f"{self.GRAPH_API_BASE}/me/drive/search(q='{encoded_query}')"
response = requests.get(search_url, headers=self._get_headers(), params=params)
else:
response = requests.get(url, headers=self._get_headers(), params=params)
response.raise_for_status()
results = response.json()
items = results.get('value', [])
for item in items:
if 'folder' in item:
doc_metadata = {
'file_name': item.get('name', 'Unknown'),
'mime_type': 'folder',
'size': item.get('size'),
'created_time': item.get('createdDateTime'),
'modified_time': item.get('lastModifiedDateTime'),
'source': 'share_point',
'is_folder': True
}
documents.append(Document(text="", doc_id=item.get('id'), extra_info=doc_metadata))
else:
doc = self._process_file(item, load_content=load_content)
if doc:
documents.append(doc)
if limit and len(documents) >= limit:
break
next_link = results.get('@odata.nextLink')
if next_link:
from urllib.parse import urlparse, parse_qs
parsed = urlparse(next_link)
query_params = parse_qs(parsed.query)
skiptoken_list = query_params.get('$skiptoken')
if skiptoken_list:
self.next_page_token = skiptoken_list[0]
else:
self.next_page_token = None
else:
self.next_page_token = None
return documents
except Exception as e:
logging.error(f"Error listing items under parent {parent_id}: {e}")
return documents
def _resolve_mime_type(self, resource: Dict[str, Any]) -> Tuple[str, bool]:
"""Resolve mime type from resource, falling back to file extension."""
file_data = resource.get('file', {})
mime_type = file_data.get('mimeType') if file_data else None
if mime_type and mime_type in self.SUPPORTED_MIME_TYPES:
return mime_type, True
name = resource.get('name', '')
ext = os.path.splitext(name)[1].lower()
if ext in self.EXTENSION_TO_MIME:
return self.EXTENSION_TO_MIME[ext], True
return mime_type or 'application/octet-stream', False
def _get_user_drive_web_url(self) -> Optional[str]:
"""Fetch the current user's OneDrive web URL for KQL path exclusion."""
try:
response = requests.get(
f"{self.GRAPH_API_BASE}/me/drive",
headers=self._get_headers(),
params={'$select': 'webUrl'}
)
response.raise_for_status()
return response.json().get('webUrl')
except Exception as e:
logging.warning(f"Could not fetch user drive web URL: {e}")
return None
def _build_shared_kql_query(self, search_query: Optional[str], user_drive_url: Optional[str]) -> str:
"""Build KQL query string that excludes the user's own drive items."""
base_query = search_query if search_query else "*"
if user_drive_url:
return f'{base_query} AND -path:"{user_drive_url}"'
return base_query
def _list_shared_items(self, limit: int = 100, load_content: bool = False, page_token: Optional[str] = None, search_query: Optional[str] = None) -> List[Document]:
"""Fetch shared drive items using Microsoft Graph Search API with local offset paging.
We always fetch up to a fixed maximum number of hits from Graph (single request),
then page through that array locally using `page_token` as a simple integer offset.
This avoids relying on buggy or inconsistent remote `from`/`size` semantics.
"""
self._ensure_valid_token()
documents: List[Document] = []
try:
user_drive_url = self._get_user_drive_web_url()
query_text = self._build_shared_kql_query(search_query, user_drive_url)
url = f"{self.GRAPH_API_BASE}/search/query"
page_size = 500 # maximum number of hits we care about for selection
body = {
"requests": [
{
"entityTypes": ["driveItem"],
"query": {"queryString": query_text},
"from": 0,
"size": page_size,
}
]
}
headers = self._get_headers()
headers["Content-Type"] = "application/json"
response = requests.post(url, headers=headers, json=body)
response.raise_for_status()
results = response.json()
search_response = results.get("value", [])
if not search_response:
logging.warning("Search API returned empty value array")
self.next_page_token = None
return documents
hits_containers = search_response[0].get("hitsContainers", [])
if not hits_containers:
logging.warning("Search API returned no hitsContainers")
self.next_page_token = None
return documents
container = hits_containers[0]
total = container.get("total", 0)
raw_hits = container.get("hits", [])
# Deduplicate by effective item ID (driveId:itemId) to avoid the same
# resource appearing multiple times across the result set.
deduped_hits = []
seen_ids = set()
for hit in raw_hits:
resource = hit.get("resource", {})
item_id = resource.get("id")
drive_id = resource.get("parentReference", {}).get("driveId")
effective_id = f"{drive_id}:{item_id}" if drive_id and item_id else item_id
if not effective_id or effective_id in seen_ids:
continue
seen_ids.add(effective_id)
deduped_hits.append(hit)
hits = deduped_hits
logging.info(
f"Search API returned {total} total results, {len(raw_hits)} raw hits, {len(hits)} unique hits in this batch"
)
try:
offset = int(page_token) if page_token is not None else 0
except (TypeError, ValueError):
logging.warning(
f"Invalid page_token '{page_token}' for shared items search, defaulting to 0"
)
offset = 0
if offset < 0:
offset = 0
if offset >= len(hits):
self.next_page_token = None
return documents
end_index = offset + limit if limit else len(hits)
end_index = min(end_index, len(hits))
for hit in hits[offset:end_index]:
resource = hit.get("resource", {})
item_name = resource.get("name", "Unknown")
item_id = resource.get("id")
drive_id = resource.get("parentReference", {}).get("driveId")
effective_id = f"{drive_id}:{item_id}" if drive_id and item_id else item_id
is_folder = "folder" in resource
if is_folder:
doc_metadata = {
"file_name": item_name,
"mime_type": "folder",
"size": resource.get("size"),
"created_time": resource.get("createdDateTime"),
"modified_time": resource.get("lastModifiedDateTime"),
"source": "share_point",
"is_folder": True,
}
documents.append(
Document(text="", doc_id=effective_id, extra_info=doc_metadata)
)
else:
mime_type, supported = self._resolve_mime_type(resource)
if not supported:
logging.info(
f"Skipping unsupported shared file: {item_name} (mime: {mime_type})"
)
continue
doc_metadata = {
"file_name": item_name,
"mime_type": mime_type,
"size": resource.get("size"),
"created_time": resource.get("createdDateTime"),
"modified_time": resource.get("lastModifiedDateTime"),
"source": "share_point",
}
content = ""
if load_content:
content = self._download_file_content(effective_id) or ""
documents.append(
Document(text=content, doc_id=effective_id, extra_info=doc_metadata)
)
if limit and end_index < len(hits):
self.next_page_token = str(end_index)
else:
self.next_page_token = None
return documents
except Exception as e:
logging.error(f"Error listing shared items via search API: {e}", exc_info=True)
return documents
@_retry_on_auth_failure
def _download_file_content(self, file_id: str) -> Optional[str]:
self._ensure_valid_token()
try:
url = f"{self._get_item_url(file_id)}/content"
response = requests.get(url, headers=self._get_headers())
response.raise_for_status()
try:
return response.content.decode('utf-8')
except UnicodeDecodeError:
logging.error(f"Could not decode file {file_id} as text")
return None
except requests.exceptions.HTTPError:
raise
except Exception as e:
logging.error(f"Error downloading file {file_id}: {e}")
return None
def _download_single_file(self, file_id: str, local_dir: str) -> bool:
try:
url = self._get_item_url(file_id)
params = {'$select': 'id,name,file'}
response = requests.get(url, headers=self._get_headers(), params=params)
response.raise_for_status()
metadata = response.json()
file_name = metadata.get('name', 'unknown')
file_data = metadata.get('file', {})
mime_type = file_data.get('mimeType', 'application/octet-stream')
if mime_type not in self.SUPPORTED_MIME_TYPES:
logging.info(f"Skipping unsupported file type: {mime_type}")
return False
os.makedirs(local_dir, exist_ok=True)
full_path = os.path.join(local_dir, file_name)
download_url = f"{self._get_item_url(file_id)}/content"
download_response = requests.get(download_url, headers=self._get_headers())
download_response.raise_for_status()
with open(full_path, 'wb') as f:
f.write(download_response.content)
return True
except Exception as e:
logging.error(f"Error in _download_single_file: {e}")
return False
def _download_folder_recursive(self, folder_id: str, local_dir: str, recursive: bool = True) -> int:
files_downloaded = 0
try:
os.makedirs(local_dir, exist_ok=True)
url = f"{self._get_item_url(folder_id)}/children"
params = {'$top': 1000}
while url:
response = requests.get(url, headers=self._get_headers(), params=params)
response.raise_for_status()
results = response.json()
items = results.get('value', [])
logging.info(f"Found {len(items)} items in folder {folder_id}")
for item in items:
item_name = item.get('name', 'unknown')
item_id = item.get('id')
if 'folder' in item:
if recursive:
subfolder_path = os.path.join(local_dir, item_name)
os.makedirs(subfolder_path, exist_ok=True)
subfolder_files = self._download_folder_recursive(
item_id,
subfolder_path,
recursive
)
files_downloaded += subfolder_files
logging.info(f"Downloaded {subfolder_files} files from subfolder {item_name}")
else:
success = self._download_single_file(item_id, local_dir)
if success:
files_downloaded += 1
logging.info(f"Downloaded file: {item_name}")
else:
logging.warning(f"Failed to download file: {item_name}")
url = results.get('@odata.nextLink')
return files_downloaded
except Exception as e:
logging.error(f"Error in _download_folder_recursive for folder {folder_id}: {e}", exc_info=True)
return files_downloaded
def _download_folder_contents(self, folder_id: str, local_dir: str, recursive: bool = True) -> int:
try:
self._ensure_valid_token()
return self._download_folder_recursive(folder_id, local_dir, recursive)
except Exception as e:
logging.error(f"Error downloading folder {folder_id}: {e}", exc_info=True)
return 0
def _download_file_to_directory(self, file_id: str, local_dir: str) -> bool:
try:
self._ensure_valid_token()
return self._download_single_file(file_id, local_dir)
except Exception as e:
logging.error(f"Error downloading file {file_id}: {e}", exc_info=True)
return False
def download_to_directory(self, local_dir: str, source_config: Dict[str, Any] = None) -> Dict[str, Any]:
if source_config is None:
source_config = {}
config = source_config if source_config else getattr(self, 'config', {})
files_downloaded = 0
try:
folder_ids = config.get('folder_ids', [])
file_ids = config.get('file_ids', [])
recursive = config.get('recursive', True)
if file_ids:
if isinstance(file_ids, str):
file_ids = [file_ids]
for file_id in file_ids:
if self._download_file_to_directory(file_id, local_dir):
files_downloaded += 1
if folder_ids:
if isinstance(folder_ids, str):
folder_ids = [folder_ids]
for folder_id in folder_ids:
try:
url = self._get_item_url(folder_id)
params = {'$select': 'id,name'}
response = requests.get(url, headers=self._get_headers(), params=params)
response.raise_for_status()
folder_metadata = response.json()
folder_name = folder_metadata.get('name', '')
folder_path = os.path.join(local_dir, folder_name)
os.makedirs(folder_path, exist_ok=True)
folder_files = self._download_folder_recursive(
folder_id,
folder_path,
recursive
)
files_downloaded += folder_files
logging.info(f"Downloaded {folder_files} files from folder {folder_name}")
except Exception as e:
logging.error(f"Error downloading folder {folder_id}: {e}", exc_info=True)
if not file_ids and not folder_ids:
raise ValueError("No folder_ids or file_ids provided for download")
return {
"files_downloaded": files_downloaded,
"directory_path": local_dir,
"empty_result": files_downloaded == 0,
"source_type": "share_point",
"config_used": config
}
except Exception as e:
return {
"files_downloaded": files_downloaded,
"directory_path": local_dir,
"empty_result": True,
"source_type": "share_point",
"config_used": config,
"error": str(e)
}

View File

@@ -1,48 +0,0 @@
from pathlib import Path
from typing import Dict, Union
from application.core.settings import settings
from application.parser.file.base_parser import BaseParser
from application.stt.stt_creator import STTCreator
from application.stt.upload_limits import enforce_audio_file_size_limit
class AudioParser(BaseParser):
def __init__(self, parser_config=None):
super().__init__(parser_config=parser_config)
self._transcript_metadata: Dict[str, Dict] = {}
def _init_parser(self) -> Dict:
return {}
def parse_file(self, file: Path, errors: str = "ignore") -> Union[str, list[str]]:
_ = errors
try:
enforce_audio_file_size_limit(file.stat().st_size)
except OSError:
pass
stt = STTCreator.create_stt(settings.STT_PROVIDER)
result = stt.transcribe(
file,
language=settings.STT_LANGUAGE,
timestamps=settings.STT_ENABLE_TIMESTAMPS,
diarize=settings.STT_ENABLE_DIARIZATION,
)
transcript_metadata = {
"transcript_duration_s": result.get("duration_s"),
"transcript_language": result.get("language"),
"transcript_provider": result.get("provider"),
}
if result.get("segments"):
transcript_metadata["transcript_segments"] = result["segments"]
self._transcript_metadata[str(file)] = {
key: value
for key, value in transcript_metadata.items()
if value not in (None, [], {})
}
return result.get("text", "")
def get_file_metadata(self, file: Path) -> Dict:
return self._transcript_metadata.get(str(file), {})

View File

@@ -36,8 +36,3 @@ class BaseParser:
@abstractmethod
def parse_file(self, file: Path, errors: str = "ignore") -> Union[str, List[str]]:
"""Parse file."""
def get_file_metadata(self, file: Path) -> Dict:
"""Return parser-specific metadata for the most recently parsed file."""
_ = file
return {}

View File

@@ -14,17 +14,11 @@ from application.parser.file.tabular_parser import PandasCSVParser, ExcelParser
from application.parser.file.json_parser import JSONParser
from application.parser.file.pptx_parser import PPTXParser
from application.parser.file.image_parser import ImageParser
from application.parser.file.audio_parser import AudioParser
from application.parser.schema.base import Document
from application.stt.constants import SUPPORTED_AUDIO_EXTENSIONS
from application.utils import num_tokens_from_string
from application.core.settings import settings
def _build_audio_parser_mapping() -> Dict[str, BaseParser]:
return {extension: AudioParser() for extension in SUPPORTED_AUDIO_EXTENSIONS}
def get_default_file_extractor(
ocr_enabled: Optional[bool] = None,
) -> Dict[str, BaseParser]:
@@ -76,7 +70,6 @@ def get_default_file_extractor(
".webp": DoclingImageParser(ocr_enabled=ocr_enabled) if ocr_enabled else ImageParser(),
# Media/subtitles
".vtt": DoclingVTTParser(),
**_build_audio_parser_mapping(),
# Specialized XML formats
".xml": DoclingXMLParser(),
# Formats docling doesn't support - use standard parsers
@@ -103,7 +96,6 @@ def get_default_file_extractor(
".png": ImageParser(),
".jpg": ImageParser(),
".jpeg": ImageParser(),
**_build_audio_parser_mapping(),
}
@@ -229,13 +221,11 @@ class SimpleDirectoryReader(BaseReader):
for input_file in self.input_files:
suffix_lower = input_file.suffix.lower()
parser_metadata = {}
if suffix_lower in self.file_extractor:
parser = self.file_extractor[suffix_lower]
if not parser.parser_config_set:
parser.init_parser()
data = parser.parse_file(input_file, errors=self.errors)
parser_metadata = parser.get_file_metadata(input_file)
else:
# do standard read
with open(input_file, "r", errors=self.errors) as f:
@@ -254,8 +244,6 @@ class SimpleDirectoryReader(BaseReader):
'title': input_file.name,
'token_count': file_tokens,
}
if parser_metadata:
base_metadata.update(parser_metadata)
if hasattr(self, 'input_dir'):
try:

View File

@@ -1,27 +0,0 @@
"""Shared file-extension constants for parsing and ingestion flows."""
from application.stt.constants import SUPPORTED_AUDIO_EXTENSIONS
SUPPORTED_SOURCE_DOCUMENT_EXTENSIONS = (
".rst",
".md",
".pdf",
".txt",
".docx",
".csv",
".epub",
".html",
".mdx",
".json",
".xlsx",
".pptx",
)
SUPPORTED_SOURCE_IMAGE_EXTENSIONS = (".png", ".jpg", ".jpeg")
SUPPORTED_SOURCE_EXTENSIONS = (
*SUPPORTED_SOURCE_DOCUMENT_EXTENSIONS,
*SUPPORTED_SOURCE_IMAGE_EXTENSIONS,
*SUPPORTED_AUDIO_EXTENSIONS,
)

View File

@@ -1,16 +0,0 @@
You are a helpful AI assistant, DocsGPT. You are proactive and helpful. Try to use tools, if they are available to you,
be proactive and fill in missing information.
Users can Upload documents for your context as attachments or sources via UI using the Conversation input box.
If appropriate, your answers can include code examples, formatted as follows:
```(language)
(code)
```
Users are also able to see charts and diagrams if you use them with valid mermaid syntax in your responses.
Try to respond with mermaid charts if visualization helps with users queries.
You effectively utilize chat history, ensuring relevant and tailored responses.
You have access to a search tool that searches the user's uploaded documents and knowledge base.
Use the search_internal tool to find relevant information before answering questions.
You may search multiple times with different queries if needed.
Do not guess when documents are available — search first, then answer based on what you find.
If no relevant documents are found, use your general knowledge and tool capabilities.
Allow yourself to be very creative and use your imagination.

View File

@@ -1,15 +0,0 @@
You are a helpful AI assistant, DocsGPT. You are proactive and helpful. Try to use tools, if they are available to you,
be proactive and fill in missing information.
Users can Upload documents for your context as attachments or sources via UI using the Conversation input box.
If appropriate, your answers can include code examples, formatted as follows:
```(language)
(code)
```
Users are also able to see charts and diagrams if you use them with valid mermaid syntax in your responses.
Try to respond with mermaid charts if visualization helps with users queries.
You effectively utilize chat history, ensuring relevant and tailored responses.
You have access to a search tool that searches the user's uploaded documents and knowledge base.
Use the search_internal tool to find relevant information before answering questions.
You may search multiple times with different queries if needed.
Do not guess when documents are available — search first, then answer based on what you find.
If no relevant documents are found, use your general knowledge and tool capabilities.

View File

@@ -1,16 +0,0 @@
You are a helpful AI assistant, DocsGPT. You are proactive and helpful. Try to use tools, if they are available to you,
be proactive and fill in missing information.
Users can Upload documents for your context as attachments or sources via UI using the Conversation input box.
If appropriate, your answers can include code examples, formatted as follows:
```(language)
(code)
```
Users are also able to see charts and diagrams if you use them with valid mermaid syntax in your responses.
Try to respond with mermaid charts if visualization helps with users queries.
You effectively utilize chat history, ensuring relevant and tailored responses.
You have access to a search tool that searches the user's uploaded documents and knowledge base.
Use the search_internal tool to find relevant information before answering questions.
You may search multiple times with different queries if needed.
You MUST search before answering any factual question. Do not guess or use general knowledge when documents are available.
If you dont have enough information from the search results or tools, answer "I don't know" or "I don't have enough information".
Never make up information or provide false information!

View File

@@ -0,0 +1,3 @@
Query: {query}
Observations: {observations}
Now, using the insights from the observations, formulate a well-structured and precise final answer.

View File

@@ -0,0 +1,13 @@
You are an AI assistant and talk like you're thinking out loud. Given the following query, outline a concise thought process that includes key steps and considerations necessary for effective analysis and response. Avoid pointwise formatting. The goal is to break down the query into manageable components without excessive detail, focusing on clarity and logical progression.
Include the following elements in your thought and execution process:
1. Identify the main objective of the query.
2. Determine any relevant context or background information needed to understand the query.
3. List potential approaches or methods to address the query.
4. Highlight any critical factors or constraints that may influence the outcome.
5. Plan with available tools to help you with the analysis but dont execute them. Tools will be executed by another AI.
Query: {query}
Summaries: {summaries}
Prompt: {prompt}
Observations(potentially previous tool calls): {observations}

View File

@@ -1,23 +0,0 @@
You are a research assistant evaluating whether a user's question is clear enough to begin in-depth research.
Decide whether the question can be researched as-is, or whether you need to ask clarifying questions first.
Proceed WITHOUT clarification if:
- The question has a clear topic and intent
- There is enough context to form a research plan
- The question is broad but can be broken into sub-topics
- Documents/sources are available to search
Ask for clarification ONLY if:
- The question is critically ambiguous (e.g., "review the thing" with no context about what "thing" means)
- Key information is missing that would make research impossible (not just imperfect)
- The user seems to reference something specific but hasn't provided it
Err on the side of proceeding. Most questions are clear enough to start researching.
You MUST respond with ONLY a valid JSON object (no markdown, no code fences):
{
"needs_clarification": true or false,
"questions": ["question 1", "question 2"] (only if needs_clarification is true, max 3 questions),
"reason": "brief explanation of why clarification is or isn't needed"
}

View File

@@ -1,23 +0,0 @@
You are a research planner. Your job is to analyze the user's question and create a focused research plan.
IMPORTANT: Every step must be a concrete research action — something you can search for and find information about. Never generate steps that ask the user for more information or request documents. Work with what you have.
Assess the question's complexity:
- "simple": Can be answered with 1-2 focused searches (e.g., "What is our refund policy?")
- "moderate": Needs 3-4 searches across different topics (e.g., "Compare our pricing across product lines")
- "complex": Requires 5-6 searches with synthesis (e.g., "Audit our compliance docs against regulation X")
Break the question into sub-questions appropriate to its complexity. Don't over-decompose simple questions.
Consider:
- What distinct aspects of the question need separate investigation?
- What order should they be investigated in?
- Are there any dependencies between sub-questions?
You MUST respond with ONLY a valid JSON object in this exact format (no markdown, no code fences):
{
"complexity": "simple|moderate|complex",
"steps": [
{"query": "specific sub-question to investigate", "rationale": "why this step is needed"}
]
}

View File

@@ -1,13 +0,0 @@
You are a research agent investigating a specific topic. Your goal is to find comprehensive, accurate information.
Your current research task: {step_query}
Instructions:
1. Use the available tools to search for and gather relevant information.
2. If you have a search_internal tool, use it to search the knowledge base. You may search multiple times with different queries.
3. If you have a reason_think tool, use it to analyze what you've found before drawing conclusions.
4. After gathering sufficient information, provide a detailed summary of your findings.
5. Cite specific documents and passages you found. Reference sources by their titles or filenames.
6. If you cannot find relevant information through tools, use your general knowledge but clearly indicate this.
Be thorough — prefer completeness over brevity. Include all relevant details you find.

View File

@@ -1,22 +0,0 @@
You are compiling a comprehensive research report based on multiple research steps.
Original question: {question}
Research plan:
{plan_summary}
Intermediate findings:
{findings}
Write a well-structured, thorough report that:
1. Directly addresses the user's original question
2. Synthesizes findings from all research steps into a coherent narrative
3. Uses inline citations [N] for every factual claim, where N maps to the source list below
4. Highlights key insights and connections across different research steps
5. Notes any gaps or areas where information was insufficient
6. Includes a "References" section at the end listing all cited sources
Available sources for citation:
{references}
Format the report with clear headings and sections. Be comprehensive but well-organized.

View File

@@ -1,94 +1,98 @@
anthropic==0.88.0
boto3==1.42.83
anthropic==0.75.0
boto3==1.42.17
beautifulsoup4==4.14.3
cel-python==0.5.0
celery==5.6.3
cryptography==46.0.6
celery==5.6.0
cryptography==46.0.3
dataclasses-json==0.6.7
defusedxml==0.7.1
docling>=2.16.0
rapidocr>=1.4.0
onnxruntime>=1.19.0
docx2txt==0.9
ddgs>=8.0.0
duckduckgo-search==8.1.1
ebooklib==0.20
elevenlabs==2.41.0
Flask==3.1.3
escodegen==1.0.11
esprima==4.0.1
esutils==1.0.1
elevenlabs==2.27.0
Flask==3.1.2
faiss-cpu==1.13.2
fastmcp==3.2.0
fastmcp==2.14.1
flask-restx==1.3.2
google-genai==1.69.0
google-api-python-client==2.193.0
google-auth-httplib2==0.3.1
google-auth-oauthlib==1.3.1
google-genai==1.54.0
google-api-python-client==2.187.0
google-auth-httplib2==0.3.0
google-auth-oauthlib==1.2.3
gTTS==2.5.4
gunicorn==25.3.0
gunicorn==23.0.0
html2text==2025.4.15
javalang==0.13.0
jinja2==3.1.6
jiter==0.13.0
jmespath==1.1.0
jiter==0.12.0
jmespath==1.0.1
joblib==1.5.3
jsonpatch==1.33
jsonpointer==3.0.0
kombu==5.6.2
langchain==1.2.3
kombu==5.6.1
langchain==1.2.0
langchain-community==0.4.1
langchain-core==1.2.23
langchain-openai==1.1.12
langchain-text-splitters==1.1.1
langsmith==0.7.23
langchain-core==1.2.5
langchain-openai==1.1.6
langchain-text-splitters==1.1.0
langsmith==0.5.1
lazy-object-proxy==1.12.0
lxml==6.0.2
markupsafe==3.0.3
marshmallow>=3.18.0,<5.0.0
mpmath==1.3.0
multidict==6.7.1
msal==1.35.1
multidict==6.7.0
mypy-extensions==1.1.0
networkx==3.6.1
numpy==2.4.4
openai==2.30.0
numpy==2.4.0
openai==2.14.0
openapi3-parser==1.1.22
orjson==3.11.7
packaging==26.0
pandas==3.0.2
orjson==3.11.5
packaging==24.2
pandas==2.3.3
openpyxl==3.1.5
pathable==0.5.0
pathable==0.4.4
pdf2image>=1.17.0
pillow
portalocker>=2.7.0,<4.0.0
portalocker>=2.7.0,<3.0.0
prance==25.4.8.0
prompt-toolkit==3.0.52
protobuf==7.34.1
protobuf==6.33.2
psycopg2-binary==2.9.11
py==1.11.0
pydantic
pydantic-core
pydantic-settings
pymongo==4.16.0
pypdf==6.9.2
pymongo==4.15.5
pypdf==6.5.0
python-dateutil==2.9.0.post0
python-dotenv
python-jose==3.5.0
python-pptx==1.0.2
redis==7.4.0
redis==7.1.0
referencing>=0.28.0,<0.38.0
regex==2026.4.4
requests==2.33.1
regex==2025.11.3
requests==2.32.5
retry==0.9.2
sentence-transformers==5.3.0
sentence-transformers==5.2.0
tiktoken==0.12.0
tokenizers==0.22.2
torch==2.11.0
tqdm==4.67.3
transformers==5.4.0
tokenizers==0.22.1
torch==2.9.1
tqdm==4.67.1
transformers==4.57.3
typing-extensions==4.15.0
typing-inspect==0.9.0
tzdata==2025.3
urllib3==2.6.3
vine==5.1.0
wcwidth==0.6.0
wcwidth==0.2.14
werkzeug>=3.1.0
yarl==1.23.0
yarl==1.22.0
markdownify==1.2.2
tldextract==5.3.1
websockets==16.0
tldextract==5.3.0
websockets==15.0.1

View File

@@ -18,7 +18,7 @@ agents:
- name: "Researcher"
description: "A specialized research agent that performs deep dives into subjects."
image: "https://d3dg1063dc54p9.cloudfront.net/imgs/agents/agent-researcher.svg"
agent_type: "classic"
agent_type: "react"
prompt:
name: "Researcher-Agent"
content: |

View File

@@ -21,19 +21,10 @@ class LocalStorage(BaseStorage):
)
def _get_full_path(self, path: str) -> str:
"""Get absolute path by combining base_dir and path.
Raises:
ValueError: If the resolved path escapes base_dir (path traversal).
"""
"""Get absolute path by combining base_dir and path."""
if os.path.isabs(path):
resolved = os.path.realpath(path)
else:
resolved = os.path.realpath(os.path.join(self.base_dir, path))
base = os.path.realpath(self.base_dir)
if not resolved.startswith(base + os.sep) and resolved != base:
raise ValueError(f"Path traversal detected: {path}")
return resolved
return path
return os.path.join(self.base_dir, path)
def save_file(self, file_data: BinaryIO, path: str, **kwargs) -> dict:
"""Save a file to local storage."""

View File

@@ -2,7 +2,6 @@
import io
import os
import posixpath
from typing import BinaryIO, Callable, List
import boto3
@@ -15,20 +14,6 @@ from botocore.exceptions import ClientError
class S3Storage(BaseStorage):
"""AWS S3 storage implementation."""
@staticmethod
def _validate_path(path: str) -> str:
"""Validate and normalize an S3 key to prevent path traversal.
Raises:
ValueError: If the path contains traversal sequences or is absolute.
"""
if "\x00" in path:
raise ValueError(f"Null byte in path: {path}")
normalized = posixpath.normpath(path)
if normalized.startswith("/") or normalized.startswith(".."):
raise ValueError(f"Path traversal detected: {path}")
return normalized
def __init__(self, bucket_name=None):
"""
Initialize S3 storage.
@@ -61,7 +46,6 @@ class S3Storage(BaseStorage):
**kwargs,
) -> dict:
"""Save a file to S3 storage."""
path = self._validate_path(path)
self.s3.upload_fileobj(
file_data, self.bucket_name, path, ExtraArgs={"StorageClass": storage_class}
)
@@ -77,7 +61,6 @@ class S3Storage(BaseStorage):
def get_file(self, path: str) -> BinaryIO:
"""Get a file from S3 storage."""
path = self._validate_path(path)
if not self.file_exists(path):
raise FileNotFoundError(f"File not found: {path}")
file_obj = io.BytesIO()
@@ -87,7 +70,6 @@ class S3Storage(BaseStorage):
def delete_file(self, path: str) -> bool:
"""Delete a file from S3 storage."""
path = self._validate_path(path)
try:
self.s3.delete_object(Bucket=self.bucket_name, Key=path)
return True
@@ -96,7 +78,6 @@ class S3Storage(BaseStorage):
def file_exists(self, path: str) -> bool:
"""Check if a file exists in S3 storage."""
path = self._validate_path(path)
try:
self.s3.head_object(Bucket=self.bucket_name, Key=path)
return True
@@ -134,7 +115,6 @@ class S3Storage(BaseStorage):
import logging
import tempfile
path = self._validate_path(path)
if not self.file_exists(path):
raise FileNotFoundError(f"File not found in S3: {path}")
with tempfile.NamedTemporaryFile(

View File

@@ -1 +0,0 @@
"""Speech-to-text providers."""

View File

@@ -1,15 +0,0 @@
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any, Dict, Optional
class BaseSTT(ABC):
@abstractmethod
def transcribe(
self,
file_path: Path,
language: Optional[str] = None,
timestamps: bool = False,
diarize: bool = False,
) -> Dict[str, Any]:
pass

View File

@@ -1,15 +0,0 @@
SUPPORTED_AUDIO_EXTENSIONS = (".wav", ".mp3", ".m4a", ".ogg", ".webm")
SUPPORTED_AUDIO_MIME_TYPES = {
"application/ogg",
"audio/aac",
"audio/mp3",
"audio/mp4",
"audio/mpeg",
"audio/ogg",
"audio/wav",
"audio/webm",
"audio/x-m4a",
"audio/x-wav",
"video/webm",
}

View File

@@ -1,70 +0,0 @@
from pathlib import Path
from typing import Dict, Optional
from application.stt.base import BaseSTT
class FasterWhisperSTT(BaseSTT):
def __init__(
self,
model_size: str = "base",
device: str = "auto",
compute_type: str = "int8",
):
self.model_size = model_size
self.device = device
self.compute_type = compute_type
self._model = None
def _get_model(self):
if self._model is None:
try:
from faster_whisper import WhisperModel
except ImportError as exc:
raise ImportError(
"faster-whisper is required to use the faster_whisper STT provider."
) from exc
self._model = WhisperModel(
self.model_size,
device=self.device,
compute_type=self.compute_type,
)
return self._model
def transcribe(
self,
file_path: Path,
language: Optional[str] = None,
timestamps: bool = False,
diarize: bool = False,
) -> Dict[str, object]:
_ = diarize
model = self._get_model()
segments_iter, info = model.transcribe(
str(file_path),
language=language,
word_timestamps=timestamps,
)
segments = []
text_parts = []
for segment in segments_iter:
segment_text = getattr(segment, "text", "").strip()
if segment_text:
text_parts.append(segment_text)
segments.append(
{
"start": getattr(segment, "start", None),
"end": getattr(segment, "end", None),
"text": segment_text,
}
)
return {
"text": " ".join(text_parts).strip(),
"language": getattr(info, "language", language),
"duration_s": getattr(info, "duration", None),
"segments": segments if timestamps else [],
"provider": "faster_whisper",
}

View File

@@ -1,201 +0,0 @@
import json
import re
import uuid
from typing import Dict, Optional
LIVE_STT_SESSION_PREFIX = "stt_live_session:"
LIVE_STT_SESSION_TTL_SECONDS = 15 * 60
LIVE_STT_MUTABLE_TAIL_WORDS = 8
LIVE_STT_SILENCE_MUTABLE_TAIL_WORDS = 2
LIVE_STT_MIN_COMMITTED_OVERLAP_WORDS = 2
def normalize_transcript_text(text: str) -> str:
return " ".join((text or "").split()).strip()
def join_transcript_parts(*parts: str) -> str:
return " ".join(part for part in map(normalize_transcript_text, parts) if part)
def _normalize_word(word: str) -> str:
normalized = re.sub(r"[^\w]+", "", word.casefold(), flags=re.UNICODE)
return normalized or word.casefold()
def _split_words(text: str) -> list[str]:
normalized = normalize_transcript_text(text)
return normalized.split() if normalized else []
def _common_prefix_length(left_words: list[str], right_words: list[str]) -> int:
max_index = min(len(left_words), len(right_words))
prefix_length = 0
for index in range(max_index):
if _normalize_word(left_words[index]) != _normalize_word(right_words[index]):
break
prefix_length += 1
return prefix_length
def _find_suffix_prefix_overlap(
left_words: list[str], right_words: list[str], min_overlap: int
) -> int:
max_overlap = min(len(left_words), len(right_words))
if max_overlap < min_overlap:
return 0
left_keys = [_normalize_word(word) for word in left_words]
right_keys = [_normalize_word(word) for word in right_words]
for overlap_size in range(max_overlap, min_overlap - 1, -1):
if left_keys[-overlap_size:] == right_keys[:overlap_size]:
return overlap_size
return 0
def strip_committed_prefix(committed_text: str, hypothesis_text: str) -> str:
committed_words = _split_words(committed_text)
hypothesis_words = _split_words(hypothesis_text)
if not committed_words or not hypothesis_words:
return normalize_transcript_text(hypothesis_text)
full_prefix_length = _common_prefix_length(committed_words, hypothesis_words)
if full_prefix_length == len(committed_words):
return " ".join(hypothesis_words[full_prefix_length:])
overlap_size = _find_suffix_prefix_overlap(
committed_words,
hypothesis_words,
LIVE_STT_MIN_COMMITTED_OVERLAP_WORDS,
)
if overlap_size:
return " ".join(hypothesis_words[overlap_size:])
return " ".join(hypothesis_words)
def _calculate_commit_count(
previous_hypothesis: str, current_hypothesis: str, is_silence: bool
) -> int:
previous_words = _split_words(previous_hypothesis)
current_words = _split_words(current_hypothesis)
if not current_words:
return 0
if not previous_words:
if is_silence:
return max(0, len(current_words) - LIVE_STT_SILENCE_MUTABLE_TAIL_WORDS)
return 0
stable_prefix_length = _common_prefix_length(previous_words, current_words)
if not stable_prefix_length:
return 0
mutable_tail_words = (
LIVE_STT_SILENCE_MUTABLE_TAIL_WORDS
if is_silence
else LIVE_STT_MUTABLE_TAIL_WORDS
)
max_committable_by_tail = max(0, len(current_words) - mutable_tail_words)
return min(stable_prefix_length, max_committable_by_tail)
def create_live_stt_session(
user: str, language: Optional[str] = None
) -> Dict[str, object]:
return {
"session_id": str(uuid.uuid4()),
"user": user,
"language": language,
"committed_text": "",
"mutable_text": "",
"previous_hypothesis": "",
"latest_hypothesis": "",
"last_chunk_index": -1,
}
def get_live_stt_session_key(session_id: str) -> str:
return f"{LIVE_STT_SESSION_PREFIX}{session_id}"
def save_live_stt_session(redis_client, session_state: Dict[str, object]) -> None:
redis_client.setex(
get_live_stt_session_key(str(session_state["session_id"])),
LIVE_STT_SESSION_TTL_SECONDS,
json.dumps(session_state),
)
def load_live_stt_session(redis_client, session_id: str) -> Optional[Dict[str, object]]:
raw_session = redis_client.get(get_live_stt_session_key(session_id))
if not raw_session:
return None
if isinstance(raw_session, bytes):
raw_session = raw_session.decode("utf-8")
return json.loads(raw_session)
def delete_live_stt_session(redis_client, session_id: str) -> None:
redis_client.delete(get_live_stt_session_key(session_id))
def apply_live_stt_hypothesis(
session_state: Dict[str, object],
hypothesis_text: str,
chunk_index: int,
is_silence: bool = False,
) -> Dict[str, object]:
last_chunk_index = int(session_state.get("last_chunk_index", -1))
if chunk_index < 0:
raise ValueError("chunk_index must be non-negative")
if chunk_index < last_chunk_index:
raise ValueError("chunk_index is older than the last processed chunk")
if chunk_index == last_chunk_index:
return session_state
committed_text = normalize_transcript_text(str(session_state.get("committed_text", "")))
previous_hypothesis = normalize_transcript_text(
str(session_state.get("latest_hypothesis", ""))
)
current_hypothesis = strip_committed_prefix(committed_text, hypothesis_text)
if not current_hypothesis and is_silence and previous_hypothesis:
committed_text = join_transcript_parts(committed_text, previous_hypothesis)
previous_hypothesis = ""
commit_count = _calculate_commit_count(
previous_hypothesis,
current_hypothesis,
is_silence=is_silence,
)
current_words = _split_words(current_hypothesis)
if commit_count:
committed_text = join_transcript_parts(
committed_text,
" ".join(current_words[:commit_count]),
)
current_hypothesis = " ".join(current_words[commit_count:])
session_state["committed_text"] = committed_text
session_state["mutable_text"] = normalize_transcript_text(current_hypothesis)
session_state["previous_hypothesis"] = previous_hypothesis
session_state["latest_hypothesis"] = normalize_transcript_text(current_hypothesis)
session_state["last_chunk_index"] = chunk_index
return session_state
def get_live_stt_transcript_text(session_state: Dict[str, object]) -> str:
return join_transcript_parts(
str(session_state.get("committed_text", "")),
str(session_state.get("mutable_text", "")),
)
def finalize_live_stt_session(session_state: Dict[str, object]) -> str:
return join_transcript_parts(
str(session_state.get("committed_text", "")),
str(session_state.get("latest_hypothesis", "")),
)

View File

@@ -1,60 +0,0 @@
from pathlib import Path
from typing import Any, Dict, Optional
from openai import OpenAI
from application.core.settings import settings
from application.stt.base import BaseSTT
class OpenAISTT(BaseSTT):
def __init__(
self,
api_key: Optional[str] = None,
base_url: Optional[str] = None,
model: Optional[str] = None,
):
self.api_key = api_key or settings.OPENAI_API_KEY or settings.API_KEY
self.base_url = base_url or settings.OPENAI_BASE_URL or "https://api.openai.com/v1"
self.model = model or settings.OPENAI_STT_MODEL
self.client = OpenAI(api_key=self.api_key, base_url=self.base_url)
def transcribe(
self,
file_path: Path,
language: Optional[str] = None,
timestamps: bool = False,
diarize: bool = False,
) -> Dict[str, Any]:
_ = diarize
request: Dict[str, Any] = {
"file": file_path,
"model": self.model,
"response_format": "verbose_json",
}
if language:
request["language"] = language
if timestamps:
request["timestamp_granularities"] = ["segment"]
with open(file_path, "rb") as audio_file:
request["file"] = audio_file
response = self.client.audio.transcriptions.create(**request)
response_dict = self._to_dict(response)
segments = response_dict.get("segments") or []
return {
"text": response_dict.get("text", ""),
"language": response_dict.get("language") or language,
"duration_s": response_dict.get("duration"),
"segments": [self._to_dict(segment) for segment in segments],
"provider": "openai",
}
@staticmethod
def _to_dict(value: Any) -> Dict[str, Any]:
if hasattr(value, "model_dump"):
return value.model_dump()
if isinstance(value, dict):
return value
return {}

View File

@@ -1,17 +0,0 @@
from application.stt.base import BaseSTT
from application.stt.faster_whisper_stt import FasterWhisperSTT
from application.stt.openai_stt import OpenAISTT
class STTCreator:
stt_providers = {
"openai": OpenAISTT,
"faster_whisper": FasterWhisperSTT,
}
@classmethod
def create_stt(cls, stt_type, *args, **kwargs) -> BaseSTT:
stt_class = cls.stt_providers.get(stt_type.lower())
if not stt_class:
raise ValueError(f"No stt class found for type {stt_type}")
return stt_class(*args, **kwargs)

Some files were not shown because too many files have changed in this diff Show More