mirror of
https://github.com/arc53/DocsGPT.git
synced 2026-05-07 14:34:32 +00:00
Compare commits
1 Commits
tests-util
...
dependabot
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a5db3d2019 |
@@ -1,36 +1,9 @@
|
||||
API_KEY=<LLM api key (for example, open ai key)>
|
||||
LLM_NAME=docsgpt
|
||||
VITE_API_STREAMING=true
|
||||
INTERNAL_KEY=<internal key for worker-to-backend authentication>
|
||||
|
||||
# Provider-specific API keys (optional - use these to enable multiple providers)
|
||||
# OPENAI_API_KEY=<your-openai-api-key>
|
||||
# ANTHROPIC_API_KEY=<your-anthropic-api-key>
|
||||
# GOOGLE_API_KEY=<your-google-api-key>
|
||||
# GROQ_API_KEY=<your-groq-api-key>
|
||||
# NOVITA_API_KEY=<your-novita-api-key>
|
||||
# OPEN_ROUTER_API_KEY=<your-openrouter-api-key>
|
||||
|
||||
# Remote Embeddings (Optional - for using a remote embeddings API instead of local SentenceTransformer)
|
||||
# When set, the app will use the remote API and won't load SentenceTransformer (saves RAM)
|
||||
EMBEDDINGS_BASE_URL=
|
||||
EMBEDDINGS_KEY=
|
||||
|
||||
#For Azure (you can delete it if you don't use Azure)
|
||||
OPENAI_API_BASE=
|
||||
OPENAI_API_VERSION=
|
||||
AZURE_DEPLOYMENT_NAME=
|
||||
AZURE_EMBEDDINGS_DEPLOYMENT_NAME=
|
||||
|
||||
#Azure AD Application (client) ID
|
||||
MICROSOFT_CLIENT_ID=your-azure-ad-client-id
|
||||
#Azure AD Application client secret
|
||||
MICROSOFT_CLIENT_SECRET=your-azure-ad-client-secret
|
||||
#Azure AD Tenant ID (or 'common' for multi-tenant)
|
||||
MICROSOFT_TENANT_ID=your-azure-ad-tenant-id
|
||||
#If you are using a Microsoft Entra ID tenant,
|
||||
#configure the AUTHORITY variable as
|
||||
#"https://login.microsoftonline.com/TENANT_GUID"
|
||||
#or "https://login.microsoftonline.com/contoso.onmicrosoft.com".
|
||||
#Alternatively, use "https://login.microsoftonline.com/common" for multi-tenant app.
|
||||
MICROSOFT_AUTHORITY=https://{tenantId}.ciamlogin.com/{tenantId}
|
||||
AZURE_EMBEDDINGS_DEPLOYMENT_NAME=
|
||||
6
.github/dependabot.yml
vendored
6
.github/dependabot.yml
vendored
@@ -13,11 +13,7 @@ updates:
|
||||
directory: "/frontend" # Location of package manifests
|
||||
schedule:
|
||||
interval: "daily"
|
||||
- package-ecosystem: "npm"
|
||||
directory: "/extensions/react-widget"
|
||||
schedule:
|
||||
interval: "daily"
|
||||
- package-ecosystem: "github-actions"
|
||||
directory: "/"
|
||||
schedule:
|
||||
interval: "daily"
|
||||
interval: "daily"
|
||||
|
||||
11
.github/styles/DocsGPT/Spelling.yml
vendored
11
.github/styles/DocsGPT/Spelling.yml
vendored
@@ -1,11 +0,0 @@
|
||||
extends: spelling
|
||||
level: warning
|
||||
message: "Did you really mean '%s'?"
|
||||
ignore:
|
||||
- "**/node_modules/**"
|
||||
- "**/dist/**"
|
||||
- "**/build/**"
|
||||
- "**/coverage/**"
|
||||
- "**/public/**"
|
||||
- "**/static/**"
|
||||
vocab: DocsGPT
|
||||
@@ -1,46 +0,0 @@
|
||||
Ollama
|
||||
Qdrant
|
||||
Milvus
|
||||
Chatwoot
|
||||
Nextra
|
||||
VSCode
|
||||
npm
|
||||
LLMs
|
||||
APIs
|
||||
Groq
|
||||
SGLang
|
||||
LMDeploy
|
||||
OAuth
|
||||
Vite
|
||||
LLM
|
||||
JSONPath
|
||||
UIs
|
||||
configs
|
||||
uncomment
|
||||
qdrant
|
||||
vectorstore
|
||||
docsgpt
|
||||
llm
|
||||
GPUs
|
||||
kubectl
|
||||
Lightsail
|
||||
enqueues
|
||||
chatbot
|
||||
VSCode's
|
||||
Shareability
|
||||
feedbacks
|
||||
automations
|
||||
Premade
|
||||
Signup
|
||||
Repo
|
||||
repo
|
||||
env
|
||||
URl
|
||||
agentic
|
||||
llama_cpp
|
||||
parsable
|
||||
SDKs
|
||||
boolean
|
||||
bool
|
||||
hardcode
|
||||
EOL
|
||||
2
.github/workflows/labeler.yml
vendored
2
.github/workflows/labeler.yml
vendored
@@ -10,7 +10,7 @@ jobs:
|
||||
pull-requests: write
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/labeler@v5
|
||||
- uses: actions/labeler@v6
|
||||
with:
|
||||
repo-token: "${{ secrets.GITHUB_TOKEN }}"
|
||||
sync-labels: true
|
||||
|
||||
3
.github/workflows/lint.yml
vendored
3
.github/workflows/lint.yml
vendored
@@ -7,9 +7,6 @@ on:
|
||||
pull_request:
|
||||
types: [ opened, synchronize ]
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
ruff:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
114
.github/workflows/npm-publish.yml
vendored
114
.github/workflows/npm-publish.yml
vendored
@@ -1,114 +0,0 @@
|
||||
name: Publish npm libraries
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
version:
|
||||
description: >
|
||||
Version bump type (patch | minor | major) or explicit semver (e.g. 1.2.3).
|
||||
Applies to both docsgpt and docsgpt-react.
|
||||
required: true
|
||||
default: patch
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
pull-requests: write
|
||||
|
||||
jobs:
|
||||
publish:
|
||||
runs-on: ubuntu-latest
|
||||
environment: npm-release
|
||||
defaults:
|
||||
run:
|
||||
working-directory: extensions/react-widget
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: 20
|
||||
registry-url: https://registry.npmjs.org
|
||||
|
||||
- name: Install dependencies
|
||||
run: npm ci
|
||||
|
||||
# ── docsgpt (HTML embedding bundle) ──────────────────────────────────
|
||||
# Uses the `build` script (parcel build src/browser.tsx) and keeps
|
||||
# the `targets` field so Parcel produces browser-optimised bundles.
|
||||
|
||||
- name: Set package name → docsgpt
|
||||
run: jq --arg n "docsgpt" '.name=$n' package.json > _tmp.json && mv _tmp.json package.json
|
||||
|
||||
- name: Bump version (docsgpt)
|
||||
id: version_docsgpt
|
||||
run: |
|
||||
VERSION="${{ github.event.inputs.version }}"
|
||||
NEW_VER=$(npm version "${VERSION:-patch}" --no-git-tag-version)
|
||||
echo "version=${NEW_VER#v}" >> "$GITHUB_OUTPUT"
|
||||
|
||||
- name: Build docsgpt
|
||||
run: npm run build
|
||||
|
||||
- name: Publish docsgpt
|
||||
run: npm publish --verbose
|
||||
env:
|
||||
NODE_AUTH_TOKEN: ${{ secrets.NPM_TOKEN }}
|
||||
|
||||
# ── docsgpt-react (React library bundle) ─────────────────────────────
|
||||
# Uses `build:react` script (parcel build src/index.ts) and strips
|
||||
# the `targets` field so Parcel treats the output as a plain library
|
||||
# without browser-specific target resolution, producing a smaller bundle.
|
||||
|
||||
- name: Reset package.json from source control
|
||||
run: git checkout -- package.json
|
||||
|
||||
- name: Set package name → docsgpt-react
|
||||
run: jq --arg n "docsgpt-react" '.name=$n' package.json > _tmp.json && mv _tmp.json package.json
|
||||
|
||||
- name: Remove targets field (react library build)
|
||||
run: jq 'del(.targets)' package.json > _tmp.json && mv _tmp.json package.json
|
||||
|
||||
- name: Bump version (docsgpt-react) to match docsgpt
|
||||
run: npm version "${{ steps.version_docsgpt.outputs.version }}" --no-git-tag-version
|
||||
|
||||
- name: Clean dist before react build
|
||||
run: rm -rf dist
|
||||
|
||||
- name: Build docsgpt-react
|
||||
run: npm run build:react
|
||||
|
||||
- name: Publish docsgpt-react
|
||||
run: npm publish --verbose
|
||||
env:
|
||||
NODE_AUTH_TOKEN: ${{ secrets.NPM_TOKEN }}
|
||||
|
||||
# ── Commit the bumped version back to the repository ─────────────────
|
||||
|
||||
- name: Reset package.json and write final version
|
||||
run: |
|
||||
git checkout -- package.json
|
||||
jq --arg v "${{ steps.version_docsgpt.outputs.version }}" '.version=$v' \
|
||||
package.json > _tmp.json && mv _tmp.json package.json
|
||||
npm install --package-lock-only
|
||||
|
||||
- name: Commit version bump and create PR
|
||||
run: |
|
||||
git config user.name "github-actions[bot]"
|
||||
git config user.email "github-actions[bot]@users.noreply.github.com"
|
||||
BRANCH="chore/bump-npm-v${{ steps.version_docsgpt.outputs.version }}"
|
||||
git checkout -b "$BRANCH"
|
||||
git add package.json package-lock.json
|
||||
git commit -m "chore: bump npm libraries to v${{ steps.version_docsgpt.outputs.version }}"
|
||||
git push origin "$BRANCH"
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Create PR
|
||||
run: |
|
||||
gh pr create \
|
||||
--title "chore: bump npm libraries to v${{ steps.version_docsgpt.outputs.version }}" \
|
||||
--body "Automated version bump after npm publish." \
|
||||
--base main
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
10
.github/workflows/pytest.yml
vendored
10
.github/workflows/pytest.yml
vendored
@@ -1,9 +1,5 @@
|
||||
name: Run python tests with pytest
|
||||
on: [push, pull_request]
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
pytest_and_coverage:
|
||||
name: Run tests and count coverage
|
||||
@@ -20,15 +16,15 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install pytest pytest-cov
|
||||
cd application
|
||||
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
|
||||
cd ../tests
|
||||
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
|
||||
- name: Test with pytest and generate coverage report
|
||||
run: |
|
||||
python -m pytest --cov=application --cov-report=xml --cov-report=term-missing
|
||||
python -m pytest --cov=application --cov-report=xml
|
||||
- name: Upload coverage reports to Codecov
|
||||
if: github.event_name == 'pull_request' && matrix.python-version == '3.12'
|
||||
uses: codecov/codecov-action@v5
|
||||
env:
|
||||
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
|
||||
|
||||
|
||||
34
.github/workflows/react-widget-build.yml
vendored
34
.github/workflows/react-widget-build.yml
vendored
@@ -1,34 +0,0 @@
|
||||
name: React Widget Build
|
||||
|
||||
on:
|
||||
push:
|
||||
paths:
|
||||
- 'extensions/react-widget/**'
|
||||
pull_request:
|
||||
paths:
|
||||
- 'extensions/react-widget/**'
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
build:
|
||||
runs-on: ubuntu-latest
|
||||
defaults:
|
||||
run:
|
||||
working-directory: extensions/react-widget
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: 20
|
||||
cache: npm
|
||||
cache-dependency-path: extensions/react-widget/package-lock.json
|
||||
|
||||
- name: Install dependencies
|
||||
run: npm ci
|
||||
|
||||
- name: Build
|
||||
run: npm run build
|
||||
30
.github/workflows/vale.yml
vendored
30
.github/workflows/vale.yml
vendored
@@ -1,30 +0,0 @@
|
||||
name: Vale Documentation Linter
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
paths:
|
||||
- 'docs/**/*.md'
|
||||
- 'docs/**/*.mdx'
|
||||
- '**/*.md'
|
||||
- '.vale.ini'
|
||||
- '.github/styles/**'
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
pull-requests: write
|
||||
|
||||
jobs:
|
||||
vale:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Vale linter
|
||||
uses: errata-ai/vale-action@v2
|
||||
with:
|
||||
files: docs
|
||||
fail_on_error: false
|
||||
version: 3.0.5
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
8
.gitignore
vendored
8
.gitignore
vendored
@@ -2,10 +2,7 @@
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
results.txt
|
||||
experiments/
|
||||
|
||||
experiments
|
||||
# C extensions
|
||||
*.so
|
||||
*.next
|
||||
@@ -72,7 +69,6 @@ instance/
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
docs/public/_pagefind/
|
||||
|
||||
# PyBuilder
|
||||
target/
|
||||
@@ -149,10 +145,6 @@ frontend/yarn-error.log*
|
||||
frontend/pnpm-debug.log*
|
||||
frontend/lerna-debug.log*
|
||||
|
||||
# Keep frontend utility helpers tracked (overrides global lib/ ignore)
|
||||
!frontend/src/lib/
|
||||
!frontend/src/lib/**
|
||||
|
||||
frontend/node_modules
|
||||
frontend/dist
|
||||
frontend/dist-ssr
|
||||
|
||||
@@ -1,6 +1,2 @@
|
||||
# Allow lines to be as long as 120 characters.
|
||||
line-length = 120
|
||||
|
||||
[lint.per-file-ignores]
|
||||
# Integration tests use sys.path.insert() before imports for standalone execution
|
||||
"tests/integration/*.py" = ["E402"]
|
||||
line-length = 120
|
||||
@@ -1,5 +0,0 @@
|
||||
MinAlertLevel = warning
|
||||
StylesPath = .github/styles
|
||||
|
||||
[*.{md,mdx}]
|
||||
BasedOnStyles = DocsGPT
|
||||
33
.vscode/launch.json
vendored
33
.vscode/launch.json
vendored
@@ -2,11 +2,15 @@
|
||||
"version": "0.2.0",
|
||||
"configurations": [
|
||||
{
|
||||
"name": "Frontend Debug (npm)",
|
||||
"type": "node-terminal",
|
||||
"name": "Docker Debug Frontend",
|
||||
"request": "launch",
|
||||
"command": "npm run dev",
|
||||
"cwd": "${workspaceFolder}/frontend"
|
||||
"type": "chrome",
|
||||
"preLaunchTask": "docker-compose: debug:frontend",
|
||||
"url": "http://127.0.0.1:5173",
|
||||
"webRoot": "${workspaceFolder}/frontend",
|
||||
"skipFiles": [
|
||||
"<node_internals>/**"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "Flask Debugger",
|
||||
@@ -45,27 +49,6 @@
|
||||
"--pool=solo"
|
||||
],
|
||||
"cwd": "${workspaceFolder}"
|
||||
},
|
||||
{
|
||||
"name": "Dev Containers (Mongo + Redis)",
|
||||
"type": "node-terminal",
|
||||
"request": "launch",
|
||||
"command": "docker compose -f deployment/docker-compose-dev.yaml up --build",
|
||||
"cwd": "${workspaceFolder}"
|
||||
}
|
||||
],
|
||||
"compounds": [
|
||||
{
|
||||
"name": "DocsGPT: Full Stack",
|
||||
"configurations": [
|
||||
"Frontend Debug (npm)",
|
||||
"Flask Debugger",
|
||||
"Celery Debugger"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "DocsGPT",
|
||||
"order": 1
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
21
.vscode/tasks.json
vendored
Normal file
21
.vscode/tasks.json
vendored
Normal file
@@ -0,0 +1,21 @@
|
||||
{
|
||||
"version": "2.0.0",
|
||||
"tasks": [
|
||||
{
|
||||
"type": "docker-compose",
|
||||
"label": "docker-compose: debug:frontend",
|
||||
"dockerCompose": {
|
||||
"up": {
|
||||
"detached": true,
|
||||
"services": [
|
||||
"frontend"
|
||||
],
|
||||
"build": true
|
||||
},
|
||||
"files": [
|
||||
"${workspaceFolder}/docker-compose.yaml"
|
||||
]
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
134
AGENTS.md
134
AGENTS.md
@@ -1,134 +0,0 @@
|
||||
# AGENTS.md
|
||||
|
||||
- Read `CONTRIBUTING.md` before making non-trivial changes.
|
||||
- For day-to-day development and feature work, follow the development-environment workflow rather than defaulting to `setup.sh` / `setup.ps1`.
|
||||
- Avoid using the setup scripts during normal feature work unless the user explicitly asks for them. Users configure `.env` usually.
|
||||
- Try to follow red/green TDD
|
||||
|
||||
### Check existing dev prerequisites first
|
||||
|
||||
For feature work, do **not** assume the environment needs to be recreated.
|
||||
|
||||
- Check whether the user already has a Python virtual environment such as `venv/` or `.venv/`.
|
||||
- Check whether MongoDB is already running.
|
||||
- Check whether Redis is already running.
|
||||
- Reuse what is already working. Do not stop or recreate MongoDB, Redis, or the Python environment unless the task is environment setup or troubleshooting.
|
||||
|
||||
## Normal local development commands
|
||||
|
||||
Use these commands once the dev prerequisites above are satisfied.
|
||||
|
||||
### Backend
|
||||
|
||||
```bash
|
||||
source .venv/bin/activate # macOS/Linux
|
||||
uv pip install -r application/requirements.txt # or: pip install -r application/requirements.txt
|
||||
```
|
||||
|
||||
Run the Flask API (if needed):
|
||||
|
||||
```bash
|
||||
flask --app application/app.py run --host=0.0.0.0 --port=7091
|
||||
```
|
||||
|
||||
Run the Celery worker in a separate terminal (if needed):
|
||||
|
||||
```bash
|
||||
celery -A application.app.celery worker -l INFO
|
||||
```
|
||||
|
||||
On macOS, prefer the solo pool for Celery:
|
||||
|
||||
```bash
|
||||
python -m celery -A application.app.celery worker -l INFO --pool=solo
|
||||
```
|
||||
|
||||
### Frontend
|
||||
|
||||
Install dependencies only when needed, then run the dev server:
|
||||
|
||||
```bash
|
||||
cd frontend
|
||||
npm install --include=dev
|
||||
npm run dev
|
||||
```
|
||||
|
||||
### Docs site
|
||||
|
||||
```bash
|
||||
cd docs
|
||||
npm install
|
||||
```
|
||||
|
||||
### Python / backend changes validation
|
||||
|
||||
```bash
|
||||
ruff check .
|
||||
python -m pytest
|
||||
```
|
||||
|
||||
### Frontend changes
|
||||
|
||||
```bash
|
||||
cd frontend && npm run lint
|
||||
cd frontend && npm run build
|
||||
```
|
||||
|
||||
### Documentation changes
|
||||
|
||||
```bash
|
||||
cd docs && npm run build
|
||||
```
|
||||
|
||||
If Vale is installed locally and you edited prose, also run:
|
||||
|
||||
```bash
|
||||
vale .
|
||||
```
|
||||
|
||||
## Repository map
|
||||
|
||||
- `application/`: Flask backend, API routes, agent logic, retrieval, parsing, security, storage, Celery worker, and WSGI entrypoints.
|
||||
- `tests/`: backend unit/integration tests and test-only Python dependencies.
|
||||
- `frontend/`: Vite + React + TypeScript application.
|
||||
- `frontend/src/`: main UI code, including `components`, `conversation`, `hooks`, `locale`, `settings`, `upload`, and Redux store wiring in `store.ts`.
|
||||
- `docs/`: separate documentation site built with Next.js/Nextra.
|
||||
- `extensions/`: integrations and widgets such as Chatwoot, Chrome, Discord, React widget, Slack bot, and web widget.
|
||||
- `deployment/`: Docker Compose variants and Kubernetes manifests.
|
||||
|
||||
## Coding rules
|
||||
|
||||
### Backend
|
||||
|
||||
- Follow PEP 8 and keep Python line length at or under 120 characters.
|
||||
- Use type hints for function arguments and return values.
|
||||
- Add Google-style docstrings to new or substantially changed functions and classes.
|
||||
- Add or update tests under `tests/` for backend behavior changes.
|
||||
- Keep changes narrow in `api`, `auth`, `security`, `parser`, `retriever`, and `storage` areas.
|
||||
|
||||
### Backend Abstractions
|
||||
|
||||
- LLM providers implement a common interface in `application/llm/` (add new providers by extending the base class).
|
||||
- Vector stores are abstracted in `application/vectorstore/`.
|
||||
- Parsers live in `application/parser/` and handle different document formats in the ingestion stage.
|
||||
- Agents and tools are in `application/agents/` and `application/agents/tools/`.
|
||||
- Celery setup/config lives in `application/celery_init.py` and `application/celeryconfig.py`.
|
||||
- Settings and env vars are managed via Pydantic in `application/core/settings.py`.
|
||||
|
||||
### Frontend
|
||||
|
||||
- Follow the existing ESLint + Prettier setup.
|
||||
- Prefer small, reusable functional components and hooks.
|
||||
- If shared state must be added, use Redux rather than introducing a new global state library.
|
||||
- Avoid broad UI refactors unless the task explicitly asks for them.
|
||||
- Do not re-create components if we already have some in the app.
|
||||
|
||||
## PR readiness
|
||||
|
||||
Before opening a PR:
|
||||
|
||||
- run the relevant validation commands above
|
||||
- confirm backend changes still work end-to-end after ingesting sample data when applicable
|
||||
- clearly summarize user-visible behavior changes
|
||||
- mention any config, dependency, or deployment implications
|
||||
- Ask your user to attach a screenshot or a video to it
|
||||
@@ -22,11 +22,6 @@ Thank you for choosing to contribute to DocsGPT! We are all very grateful!
|
||||
|
||||
- We have a frontend built on React (Vite) and a backend in Python.
|
||||
|
||||
> **Required for every PR:** Please attach screenshots or a short screen
|
||||
> recording that shows the working version of your changes. This makes the
|
||||
> requirement visible to reviewers and helps them quickly verify what you are
|
||||
> submitting.
|
||||
|
||||
|
||||
Before creating issues, please check out how the latest version of our app looks and works by launching it via [Quickstart](https://github.com/arc53/DocsGPT#quickstart) the version on our live demo is slightly modified with login. Your issues should relate to the version you can launch via [Quickstart](https://github.com/arc53/DocsGPT#quickstart).
|
||||
|
||||
@@ -130,7 +125,7 @@ Here's a step-by-step guide on how to contribute to DocsGPT:
|
||||
```
|
||||
|
||||
9. **Submit a Pull Request (PR):**
|
||||
- Create a Pull Request from your branch to the main repository. Make sure to include a detailed description of your changes, reference any related issues, and attach screenshots or a screen recording showing the working version.
|
||||
- Create a Pull Request from your branch to the main repository. Make sure to include a detailed description of your changes and reference any related issues.
|
||||
|
||||
10. **Collaborate:**
|
||||
- Be responsive to comments and feedback on your PR.
|
||||
@@ -152,5 +147,5 @@ Here's a step-by-step guide on how to contribute to DocsGPT:
|
||||
Thank you for considering contributing to DocsGPT! 🙏
|
||||
|
||||
## Questions/collaboration
|
||||
Feel free to join our [Discord](https://discord.gg/vN7YFfdMpj). We're very friendly and welcoming to new contributors, so don't hesitate to reach out.
|
||||
Feel free to join our [Discord](https://discord.gg/n5BX8dh8rU). We're very friendly and welcoming to new contributors, so don't hesitate to reach out.
|
||||
# Thank you so much for considering to contributing DocsGPT!🙏
|
||||
|
||||
@@ -1,39 +0,0 @@
|
||||
# **🎉 Join the Hacktoberfest with DocsGPT and win a Free T-shirt for a meaningful PR! 🎉**
|
||||
|
||||
Welcome, contributors! We're excited to announce that DocsGPT is participating in Hacktoberfest. Get involved by submitting meaningful pull requests.
|
||||
|
||||
All Meaningful contributors with accepted PRs that were created for issues with the `hacktoberfest` label (set by our maintainer team: dartpain, siiddhantt, pabik, ManishMadan2882) will receive a cool T-shirt! 🤩.
|
||||
<img width="1331" height="678" alt="hacktoberfest-mocks-preview" src="https://github.com/user-attachments/assets/633f6377-38db-48f5-b519-a8b3855a9eb4" />
|
||||
|
||||
Fill in [this form](https://forms.gle/Npaba4n9Epfyx56S8
|
||||
) after your PR was merged please
|
||||
|
||||
If you are in doubt don't hesitate to ping us on discord, ping me - Alex (dartpain).
|
||||
|
||||
## 📜 Here's How to Contribute:
|
||||
```text
|
||||
🛠️ Code: This is the golden ticket! Make meaningful contributions through PRs.
|
||||
|
||||
🧩 API extension: Build an app utilising DocsGPT API. We prefer submissions that showcase original ideas and turn the API into an AI agent.
|
||||
They can be a completely separate repos.
|
||||
For example:
|
||||
https://github.com/arc53/tg-bot-docsgpt-extenstion or
|
||||
https://github.com/arc53/DocsGPT-cli
|
||||
|
||||
Non-Code Contributions:
|
||||
|
||||
📚 Wiki: Improve our documentation, create a guide.
|
||||
|
||||
🖥️ Design: Improve the UI/UX or design a new feature.
|
||||
```
|
||||
|
||||
### 📝 Guidelines for Pull Requests:
|
||||
- Familiarize yourself with the current contributions and our [Roadmap](https://github.com/orgs/arc53/projects/2).
|
||||
- Before contributing check existing [issues](https://github.com/arc53/DocsGPT/issues) or [create](https://github.com/arc53/DocsGPT/issues/new/choose) an issue and wait to get assigned.
|
||||
- Once you are finished with your contribution, please fill in this [form](https://forms.gle/Npaba4n9Epfyx56S8).
|
||||
- Refer to the [Documentation](https://docs.docsgpt.cloud/).
|
||||
- Feel free to join our [Discord](https://discord.gg/vN7YFfdMpj) server. We're here to help newcomers, so don't hesitate to jump in! Join us [here](https://discord.gg/vN7YFfdMpj).
|
||||
|
||||
Thank you very much for considering contributing to DocsGPT during Hacktoberfest! 🙏 Your contributions (not just simple typos) could earn you a stylish new t-shirt.
|
||||
|
||||
We will publish a t-shirt design later into the October.
|
||||
48
README.md
48
README.md
@@ -7,7 +7,7 @@
|
||||
</p>
|
||||
|
||||
<p align="left">
|
||||
<strong><a href="https://www.docsgpt.cloud/">DocsGPT</a></strong> is an open-source AI platform for building intelligent agents and assistants. Features Agent Builder, deep research tools, document analysis (PDF, Office, web content, and audio), Multi-model support (choose your provider or run locally), and rich API connectivity for agents with actionable tools and integrations. Deploy anywhere with complete privacy control.
|
||||
<strong><a href="https://www.docsgpt.cloud/">DocsGPT</a></strong> is an open-source AI platform for building intelligent agents and assistants. Features Agent Builder, deep research tools, document analysis (PDF, Office, web content), Multi-model support (choose your provider or run locally), and rich API connectivity for agents with actionable tools and integrations. Deploy anywhere with complete privacy control.
|
||||
</p>
|
||||
|
||||
<div align="center">
|
||||
@@ -16,27 +16,23 @@
|
||||
<a href="https://github.com/arc53/DocsGPT"></a>
|
||||
<a href="https://github.com/arc53/DocsGPT/blob/main/LICENSE"></a>
|
||||
<a href="https://www.bestpractices.dev/projects/9907"><img src="https://www.bestpractices.dev/projects/9907/badge"></a>
|
||||
<a href="https://discord.gg/vN7YFfdMpj"></a>
|
||||
<a href="https://x.com/docsgptai"></a>
|
||||
<a href="https://discord.gg/n5BX8dh8rU"></a>
|
||||
<a href="https://twitter.com/docsgptai"></a>
|
||||
|
||||
<a href="https://docs.docsgpt.cloud/quickstart">⚡️ Quickstart</a> • <a href="https://app.docsgpt.cloud/">☁️ Cloud Version</a> • <a href="https://discord.gg/vN7YFfdMpj">💬 Discord</a>
|
||||
<a href="https://docs.docsgpt.cloud/quickstart">⚡️ Quickstart</a> • <a href="https://app.docsgpt.cloud/">☁️ Cloud Version</a> • <a href="https://discord.gg/n5BX8dh8rU">💬 Discord</a>
|
||||
<br>
|
||||
<a href="https://docs.docsgpt.cloud/">📖 Documentation</a> • <a href="https://github.com/arc53/DocsGPT/blob/main/CONTRIBUTING.md">👫 Contribute</a> • <a href="https://blog.docsgpt.cloud/">🗞 Blog</a>
|
||||
<br>
|
||||
|
||||
</div>
|
||||
|
||||
|
||||
<div align="center">
|
||||
<br>
|
||||
<img src="https://d3dg1063dc54p9.cloudfront.net/videos/demo-26.gif" alt="video-example-of-docs-gpt" width="800" height="480">
|
||||
<img src="https://d3dg1063dc54p9.cloudfront.net/videos/demov7.gif" alt="video-example-of-docs-gpt" width="800" height="450">
|
||||
</div>
|
||||
<h3 align="left">
|
||||
<strong>Key Features:</strong>
|
||||
</h3>
|
||||
<ul align="left">
|
||||
<li><strong>🗂️ Wide Format Support:</strong> Reads PDF, DOCX, CSV, XLSX, EPUB, MD, RST, HTML, MDX, JSON, PPTX, images, and audio files such as MP3, WAV, M4A, OGG, and WebM.</li>
|
||||
<li><strong>🎙️ Speech Workflows:</strong> Record voice input into chat, transcribe audio on the backend, and ingest meeting recordings or voice notes as searchable knowledge.</li>
|
||||
<li><strong>🗂️ Wide Format Support:</strong> Reads PDF, DOCX, CSV, XLSX, EPUB, MD, RST, HTML, MDX, JSON, PPTX, and images.</li>
|
||||
<li><strong>🌐 Web & Data Integration:</strong> Ingests from URLs, sitemaps, Reddit, GitHub and web crawlers.</li>
|
||||
<li><strong>✅ Reliable Answers:</strong> Get accurate, hallucination-free responses with source citations viewable in a clean UI.</li>
|
||||
<li><strong>🔑 Streamlined API Keys:</strong> Generate keys linked to your settings, documents, and models, simplifying chatbot and integration setup.</li>
|
||||
@@ -47,11 +43,22 @@
|
||||
</ul>
|
||||
|
||||
## Roadmap
|
||||
- [x] Add OAuth 2.0 authentication for MCP ( September 2025 )
|
||||
- [x] Deep Agents ( October 2025 )
|
||||
- [x] Prompt Templating ( October 2025 )
|
||||
- [x] Full api tooling ( Dec 2025 )
|
||||
- [ ] Agent scheduling ( Jan 2026 )
|
||||
|
||||
- [x] Full GoogleAI compatibility (Jan 2025)
|
||||
- [x] Add tools (Jan 2025)
|
||||
- [x] Manually updating chunks in the app UI (Feb 2025)
|
||||
- [x] Devcontainer for easy development (Feb 2025)
|
||||
- [x] ReACT agent (March 2025)
|
||||
- [x] Chatbots menu re-design to handle tools, agent types, and more (April 2025)
|
||||
- [x] New input box in the conversation menu (April 2025)
|
||||
- [x] Add triggerable actions / tools (webhook) (April 2025)
|
||||
- [x] Agent optimisations (May 2025)
|
||||
- [x] Filesystem sources update (July 2025)
|
||||
- [x] Json Responses (August 2025)
|
||||
- [ ] Sharepoint integration (August 2025)
|
||||
- [ ] MCP support (August 2025)
|
||||
- [ ] Add OAuth 2.0 authentication for tools and sources (August 2025)
|
||||
- [ ] Agent scheduling
|
||||
|
||||
You can find our full roadmap [here](https://github.com/orgs/arc53/projects/2). Please don't hesitate to contribute or create issues, it helps us improve DocsGPT!
|
||||
|
||||
@@ -99,7 +106,7 @@ A more detailed [Quickstart](https://docs.docsgpt.cloud/quickstart) is available
|
||||
PowerShell -ExecutionPolicy Bypass -File .\setup.ps1
|
||||
```
|
||||
|
||||
Either script will guide you through setting up DocsGPT. Five options available: using the public API, running locally, connecting to a local inference engine, using a cloud API provider, or build the docker image locally. Scripts will automatically configure your `.env` file and handle necessary downloads and installations based on your chosen option.
|
||||
Either script will guide you through setting up DocsGPT. Four options available: using the public API, running locally, connecting to a local inference engine, or using a cloud API provider. Scripts will automatically configure your `.env` file and handle necessary downloads and installations based on your chosen option.
|
||||
|
||||
**Navigate to http://localhost:5173/**
|
||||
|
||||
@@ -146,16 +153,9 @@ We as members, contributors, and leaders, pledge to make participation in our co
|
||||
|
||||
The source code license is [MIT](https://opensource.org/license/mit/), as described in the [LICENSE](LICENSE) file.
|
||||
|
||||
## This project is supported by:
|
||||
|
||||
<p>This project is supported by:</p>
|
||||
<p>
|
||||
<a href="https://www.digitalocean.com/?utm_medium=opensource&utm_source=DocsGPT">
|
||||
<img src="https://opensource.nyc3.cdn.digitaloceanspaces.com/attribution/assets/SVG/DO_Logo_horizontal_blue.svg" width="201px">
|
||||
</a>
|
||||
</p>
|
||||
<p>
|
||||
<a href="https://get.neon.com/docsgpt">
|
||||
<img width="201" alt="color" src="https://github.com/user-attachments/assets/7d9813b7-0e6d-403f-b5af-68af066b326f" />
|
||||
</a>
|
||||
|
||||
</p>
|
||||
|
||||
11
application/.env_sample
Normal file
11
application/.env_sample
Normal file
@@ -0,0 +1,11 @@
|
||||
API_KEY=your_api_key
|
||||
EMBEDDINGS_KEY=your_api_key
|
||||
API_URL=http://localhost:7091
|
||||
FLASK_APP=application/app.py
|
||||
FLASK_DEBUG=true
|
||||
|
||||
#For OPENAI on Azure
|
||||
OPENAI_API_BASE=
|
||||
OPENAI_API_VERSION=
|
||||
AZURE_DEPLOYMENT_NAME=
|
||||
AZURE_EMBEDDINGS_DEPLOYMENT_NAME=
|
||||
@@ -7,7 +7,7 @@ RUN apt-get update && \
|
||||
apt-get install -y software-properties-common && \
|
||||
add-apt-repository ppa:deadsnakes/ppa && \
|
||||
apt-get update && \
|
||||
apt-get install -y --no-install-recommends gcc g++ wget unzip libc6-dev python3.12 python3.12-venv python3.12-dev && \
|
||||
apt-get install -y --no-install-recommends gcc wget unzip libc6-dev python3.12 python3.12-venv && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Verify Python installation and setup symlink
|
||||
@@ -48,12 +48,7 @@ FROM ubuntu:24.04 as final
|
||||
RUN apt-get update && \
|
||||
apt-get install -y software-properties-common && \
|
||||
add-apt-repository ppa:deadsnakes/ppa && \
|
||||
apt-get update && apt-get install -y --no-install-recommends \
|
||||
python3.12 \
|
||||
libgl1 \
|
||||
libglib2.0-0 \
|
||||
poppler-utils \
|
||||
&& \
|
||||
apt-get update && apt-get install -y --no-install-recommends python3.12 && \
|
||||
ln -s /usr/bin/python3.12 /usr/bin/python && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
|
||||
@@ -1,20 +1,11 @@
|
||||
import logging
|
||||
|
||||
from application.agents.agentic_agent import AgenticAgent
|
||||
from application.agents.classic_agent import ClassicAgent
|
||||
from application.agents.research_agent import ResearchAgent
|
||||
from application.agents.workflow_agent import WorkflowAgent
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
from application.agents.react_agent import ReActAgent
|
||||
|
||||
|
||||
class AgentCreator:
|
||||
agents = {
|
||||
"classic": ClassicAgent,
|
||||
"react": ClassicAgent, # backwards compat: react falls back to classic
|
||||
"agentic": AgenticAgent,
|
||||
"research": ResearchAgent,
|
||||
"workflow": WorkflowAgent,
|
||||
"react": ReActAgent,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -1,63 +0,0 @@
|
||||
import logging
|
||||
from typing import Dict, Generator, Optional
|
||||
|
||||
from application.agents.base import BaseAgent
|
||||
from application.agents.tools.internal_search import (
|
||||
INTERNAL_TOOL_ID,
|
||||
add_internal_search_tool,
|
||||
)
|
||||
from application.logging import LogContext
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AgenticAgent(BaseAgent):
|
||||
"""Agent where the LLM controls retrieval via tools.
|
||||
|
||||
Unlike ClassicAgent which pre-fetches docs into the prompt,
|
||||
AgenticAgent gives the LLM an internal_search tool so it can
|
||||
decide when, what, and whether to search.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
retriever_config: Optional[Dict] = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.retriever_config = retriever_config or {}
|
||||
|
||||
def _gen_inner(
|
||||
self, query: str, log_context: LogContext
|
||||
) -> Generator[Dict, None, None]:
|
||||
tools_dict = self.tool_executor.get_tools()
|
||||
add_internal_search_tool(tools_dict, self.retriever_config)
|
||||
self._prepare_tools(tools_dict)
|
||||
|
||||
# 4. Build messages (prompt has NO pre-fetched docs)
|
||||
messages = self._build_messages(self.prompt, query)
|
||||
|
||||
# 5. Call LLM — the handler manages the tool loop
|
||||
llm_response = self._llm_gen(messages, log_context)
|
||||
|
||||
yield from self._handle_response(
|
||||
llm_response, tools_dict, messages, log_context
|
||||
)
|
||||
|
||||
# 6. Collect sources from internal search tool results
|
||||
self._collect_internal_sources()
|
||||
|
||||
yield {"sources": self.retrieved_docs}
|
||||
yield {"tool_calls": self._get_truncated_tool_calls()}
|
||||
|
||||
log_context.stacks.append(
|
||||
{"component": "agent", "data": {"tool_calls": self.tool_calls.copy()}}
|
||||
)
|
||||
|
||||
def _collect_internal_sources(self):
|
||||
"""Collect retrieved docs from the cached InternalSearchTool instance."""
|
||||
cache_key = f"internal_search:{INTERNAL_TOOL_ID}:{self.user or ''}"
|
||||
tool = self.tool_executor._loaded_tools.get(cache_key)
|
||||
if tool and hasattr(tool, "retrieved_docs") and tool.retrieved_docs:
|
||||
self.retrieved_docs = tool.retrieved_docs
|
||||
@@ -3,15 +3,16 @@ import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, Generator, List, Optional
|
||||
|
||||
from application.agents.tool_executor import ToolExecutor
|
||||
from application.core.json_schema_utils import (
|
||||
JsonSchemaValidationError,
|
||||
normalize_json_schema_payload,
|
||||
)
|
||||
from bson.objectid import ObjectId
|
||||
|
||||
from application.agents.tools.tool_action_parser import ToolActionParser
|
||||
from application.agents.tools.tool_manager import ToolManager
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.core.settings import settings
|
||||
from application.llm.handlers.handler_creator import LLMHandlerCreator
|
||||
from application.llm.llm_creator import LLMCreator
|
||||
from application.logging import build_stack_data, log_activity, LogContext
|
||||
from application.retriever.base import BaseRetriever
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -21,249 +22,255 @@ class BaseAgent(ABC):
|
||||
self,
|
||||
endpoint: str,
|
||||
llm_name: str,
|
||||
model_id: str,
|
||||
gpt_model: str,
|
||||
api_key: str,
|
||||
agent_id: Optional[str] = None,
|
||||
user_api_key: Optional[str] = None,
|
||||
prompt: str = "",
|
||||
chat_history: Optional[List[Dict]] = None,
|
||||
retrieved_docs: Optional[List[Dict]] = None,
|
||||
decoded_token: Optional[Dict] = None,
|
||||
attachments: Optional[List[Dict]] = None,
|
||||
json_schema: Optional[Dict] = None,
|
||||
limited_token_mode: Optional[bool] = False,
|
||||
token_limit: Optional[int] = settings.DEFAULT_AGENT_LIMITS["token_limit"],
|
||||
limited_request_mode: Optional[bool] = False,
|
||||
request_limit: Optional[int] = settings.DEFAULT_AGENT_LIMITS["request_limit"],
|
||||
compressed_summary: Optional[str] = None,
|
||||
llm=None,
|
||||
llm_handler=None,
|
||||
tool_executor: Optional[ToolExecutor] = None,
|
||||
backup_models: Optional[List[str]] = None,
|
||||
):
|
||||
self.endpoint = endpoint
|
||||
self.llm_name = llm_name
|
||||
self.model_id = model_id
|
||||
self.gpt_model = gpt_model
|
||||
self.api_key = api_key
|
||||
self.agent_id = agent_id
|
||||
self.user_api_key = user_api_key
|
||||
self.prompt = prompt
|
||||
self.decoded_token = decoded_token or {}
|
||||
self.user: str = self.decoded_token.get("sub")
|
||||
self.user: str = decoded_token.get("sub")
|
||||
self.tool_config: Dict = {}
|
||||
self.tools: List[Dict] = []
|
||||
self.tool_calls: List[Dict] = []
|
||||
self.chat_history: List[Dict] = chat_history if chat_history is not None else []
|
||||
|
||||
# Dependency injection for LLM — fall back to creating if not provided
|
||||
if llm is not None:
|
||||
self.llm = llm
|
||||
else:
|
||||
self.llm = LLMCreator.create_llm(
|
||||
llm_name,
|
||||
api_key=api_key,
|
||||
user_api_key=user_api_key,
|
||||
decoded_token=decoded_token,
|
||||
model_id=model_id,
|
||||
agent_id=agent_id,
|
||||
backup_models=backup_models,
|
||||
)
|
||||
|
||||
self.retrieved_docs = retrieved_docs or []
|
||||
|
||||
if llm_handler is not None:
|
||||
self.llm_handler = llm_handler
|
||||
else:
|
||||
self.llm_handler = LLMHandlerCreator.create_handler(
|
||||
llm_name if llm_name else "default"
|
||||
)
|
||||
|
||||
# Tool executor — injected or created
|
||||
if tool_executor is not None:
|
||||
self.tool_executor = tool_executor
|
||||
else:
|
||||
self.tool_executor = ToolExecutor(
|
||||
user_api_key=user_api_key,
|
||||
user=self.user,
|
||||
decoded_token=decoded_token,
|
||||
)
|
||||
|
||||
self.llm = LLMCreator.create_llm(
|
||||
llm_name,
|
||||
api_key=api_key,
|
||||
user_api_key=user_api_key,
|
||||
decoded_token=decoded_token,
|
||||
)
|
||||
self.llm_handler = LLMHandlerCreator.create_handler(
|
||||
llm_name if llm_name else "default"
|
||||
)
|
||||
self.attachments = attachments or []
|
||||
self.json_schema = None
|
||||
if json_schema is not None:
|
||||
try:
|
||||
self.json_schema = normalize_json_schema_payload(json_schema)
|
||||
except JsonSchemaValidationError as exc:
|
||||
logger.warning("Ignoring invalid JSON schema payload: %s", exc)
|
||||
self.limited_token_mode = limited_token_mode
|
||||
self.token_limit = token_limit
|
||||
self.limited_request_mode = limited_request_mode
|
||||
self.request_limit = request_limit
|
||||
self.compressed_summary = compressed_summary
|
||||
self.current_token_count = 0
|
||||
self.context_limit_reached = False
|
||||
self.json_schema = json_schema
|
||||
|
||||
@log_activity()
|
||||
def gen(
|
||||
self, query: str, log_context: LogContext = None
|
||||
self, query: str, retriever: BaseRetriever, log_context: LogContext = None
|
||||
) -> Generator[Dict, None, None]:
|
||||
yield from self._gen_inner(query, log_context)
|
||||
yield from self._gen_inner(query, retriever, log_context)
|
||||
|
||||
@abstractmethod
|
||||
def _gen_inner(
|
||||
self, query: str, log_context: LogContext
|
||||
self, query: str, retriever: BaseRetriever, log_context: LogContext
|
||||
) -> Generator[Dict, None, None]:
|
||||
pass
|
||||
|
||||
# ---- Tool delegation (thin wrappers around ToolExecutor) ----
|
||||
|
||||
@property
|
||||
def tool_calls(self) -> List[Dict]:
|
||||
return self.tool_executor.tool_calls
|
||||
|
||||
@tool_calls.setter
|
||||
def tool_calls(self, value: List[Dict]):
|
||||
self.tool_executor.tool_calls = value
|
||||
|
||||
def _get_tools(self, api_key: str = None) -> Dict[str, Dict]:
|
||||
return self.tool_executor._get_tools_by_api_key(api_key or self.user_api_key)
|
||||
mongo = MongoDB.get_client()
|
||||
db = mongo[settings.MONGO_DB_NAME]
|
||||
agents_collection = db["agents"]
|
||||
tools_collection = db["user_tools"]
|
||||
|
||||
agent_data = agents_collection.find_one({"key": api_key or self.user_api_key})
|
||||
tool_ids = agent_data.get("tools", []) if agent_data else []
|
||||
|
||||
tools = (
|
||||
tools_collection.find(
|
||||
{"_id": {"$in": [ObjectId(tool_id) for tool_id in tool_ids]}}
|
||||
)
|
||||
if tool_ids
|
||||
else []
|
||||
)
|
||||
tools = list(tools)
|
||||
tools_by_id = {str(tool["_id"]): tool for tool in tools} if tools else {}
|
||||
|
||||
return tools_by_id
|
||||
|
||||
def _get_user_tools(self, user="local"):
|
||||
return self.tool_executor._get_user_tools(user)
|
||||
mongo = MongoDB.get_client()
|
||||
db = mongo[settings.MONGO_DB_NAME]
|
||||
user_tools_collection = db["user_tools"]
|
||||
user_tools = user_tools_collection.find({"user": user, "status": True})
|
||||
user_tools = list(user_tools)
|
||||
|
||||
return {str(i): tool for i, tool in enumerate(user_tools)}
|
||||
|
||||
def _build_tool_parameters(self, action):
|
||||
return self.tool_executor._build_tool_parameters(action)
|
||||
params = {"type": "object", "properties": {}, "required": []}
|
||||
for param_type in ["query_params", "headers", "body", "parameters"]:
|
||||
if param_type in action and action[param_type].get("properties"):
|
||||
for k, v in action[param_type]["properties"].items():
|
||||
if v.get("filled_by_llm", True):
|
||||
params["properties"][k] = {
|
||||
key: value
|
||||
for key, value in v.items()
|
||||
if key != "filled_by_llm" and key != "value"
|
||||
}
|
||||
|
||||
params["required"].append(k)
|
||||
return params
|
||||
|
||||
def _prepare_tools(self, tools_dict):
|
||||
self.tools = self.tool_executor.prepare_tools_for_llm(tools_dict)
|
||||
self.tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": f"{action['name']}_{tool_id}",
|
||||
"description": action["description"],
|
||||
"parameters": self._build_tool_parameters(action),
|
||||
},
|
||||
}
|
||||
for tool_id, tool in tools_dict.items()
|
||||
if (
|
||||
(tool["name"] == "api_tool" and "actions" in tool.get("config", {}))
|
||||
or (tool["name"] != "api_tool" and "actions" in tool)
|
||||
)
|
||||
for action in (
|
||||
tool["config"]["actions"].values()
|
||||
if tool["name"] == "api_tool"
|
||||
else tool["actions"]
|
||||
)
|
||||
if action.get("active", True)
|
||||
]
|
||||
|
||||
def _execute_tool_action(self, tools_dict, call):
|
||||
return self.tool_executor.execute(
|
||||
tools_dict, call, self.llm.__class__.__name__
|
||||
parser = ToolActionParser(self.llm.__class__.__name__)
|
||||
tool_id, action_name, call_args = parser.parse_args(call)
|
||||
|
||||
call_id = getattr(call, "id", None) or str(uuid.uuid4())
|
||||
|
||||
# Check if parsing failed
|
||||
if tool_id is None or action_name is None:
|
||||
error_message = f"Error: Failed to parse LLM tool call. Tool name: {getattr(call, 'name', 'unknown')}"
|
||||
logger.error(error_message)
|
||||
|
||||
tool_call_data = {
|
||||
"tool_name": "unknown",
|
||||
"call_id": call_id,
|
||||
"action_name": getattr(call, 'name', 'unknown'),
|
||||
"arguments": call_args or {},
|
||||
"result": f"Failed to parse tool call. Invalid tool name format: {getattr(call, 'name', 'unknown')}",
|
||||
}
|
||||
yield {"type": "tool_call", "data": {**tool_call_data, "status": "error"}}
|
||||
self.tool_calls.append(tool_call_data)
|
||||
return "Failed to parse tool call.", call_id
|
||||
|
||||
# Check if tool_id exists in available tools
|
||||
if tool_id not in tools_dict:
|
||||
error_message = f"Error: Tool ID '{tool_id}' extracted from LLM call not found in available tools_dict. Available IDs: {list(tools_dict.keys())}"
|
||||
logger.error(error_message)
|
||||
|
||||
# Return error result
|
||||
tool_call_data = {
|
||||
"tool_name": "unknown",
|
||||
"call_id": call_id,
|
||||
"action_name": f"{action_name}_{tool_id}",
|
||||
"arguments": call_args,
|
||||
"result": f"Tool with ID {tool_id} not found. Available tools: {list(tools_dict.keys())}",
|
||||
}
|
||||
yield {"type": "tool_call", "data": {**tool_call_data, "status": "error"}}
|
||||
self.tool_calls.append(tool_call_data)
|
||||
return f"Tool with ID {tool_id} not found.", call_id
|
||||
|
||||
tool_call_data = {
|
||||
"tool_name": tools_dict[tool_id]["name"],
|
||||
"call_id": call_id,
|
||||
"action_name": f"{action_name}_{tool_id}",
|
||||
"arguments": call_args,
|
||||
}
|
||||
yield {"type": "tool_call", "data": {**tool_call_data, "status": "pending"}}
|
||||
|
||||
tool_data = tools_dict[tool_id]
|
||||
action_data = (
|
||||
tool_data["config"]["actions"][action_name]
|
||||
if tool_data["name"] == "api_tool"
|
||||
else next(
|
||||
action
|
||||
for action in tool_data["actions"]
|
||||
if action["name"] == action_name
|
||||
)
|
||||
)
|
||||
|
||||
query_params, headers, body, parameters = {}, {}, {}, {}
|
||||
param_types = {
|
||||
"query_params": query_params,
|
||||
"headers": headers,
|
||||
"body": body,
|
||||
"parameters": parameters,
|
||||
}
|
||||
|
||||
for param_type, target_dict in param_types.items():
|
||||
if param_type in action_data and action_data[param_type].get("properties"):
|
||||
for param, details in action_data[param_type]["properties"].items():
|
||||
if param not in call_args and "value" in details:
|
||||
target_dict[param] = details["value"]
|
||||
for param, value in call_args.items():
|
||||
for param_type, target_dict in param_types.items():
|
||||
if param_type in action_data and param in action_data[param_type].get(
|
||||
"properties", {}
|
||||
):
|
||||
target_dict[param] = value
|
||||
tm = ToolManager(config={})
|
||||
tool = tm.load_tool(
|
||||
tool_data["name"],
|
||||
tool_config=(
|
||||
{
|
||||
"url": tool_data["config"]["actions"][action_name]["url"],
|
||||
"method": tool_data["config"]["actions"][action_name]["method"],
|
||||
"headers": headers,
|
||||
"query_params": query_params,
|
||||
}
|
||||
if tool_data["name"] == "api_tool"
|
||||
else tool_data["config"]
|
||||
),
|
||||
)
|
||||
if tool_data["name"] == "api_tool":
|
||||
print(
|
||||
f"Executing api: {action_name} with query_params: {query_params}, headers: {headers}, body: {body}"
|
||||
)
|
||||
result = tool.execute_action(action_name, **body)
|
||||
else:
|
||||
print(f"Executing tool: {action_name} with args: {call_args}")
|
||||
result = tool.execute_action(action_name, **parameters)
|
||||
tool_call_data["result"] = (
|
||||
f"{str(result)[:50]}..." if len(str(result)) > 50 else result
|
||||
)
|
||||
|
||||
yield {"type": "tool_call", "data": {**tool_call_data, "status": "completed"}}
|
||||
self.tool_calls.append(tool_call_data)
|
||||
|
||||
return result, call_id
|
||||
|
||||
def _get_truncated_tool_calls(self):
|
||||
return self.tool_executor.get_truncated_tool_calls()
|
||||
|
||||
# ---- Context / token management ----
|
||||
|
||||
def _calculate_current_context_tokens(self, messages: List[Dict]) -> int:
|
||||
from application.api.answer.services.compression.token_counter import (
|
||||
TokenCounter,
|
||||
)
|
||||
return TokenCounter.count_message_tokens(messages)
|
||||
|
||||
def _check_context_limit(self, messages: List[Dict]) -> bool:
|
||||
from application.core.model_utils import get_token_limit
|
||||
|
||||
try:
|
||||
current_tokens = self._calculate_current_context_tokens(messages)
|
||||
self.current_token_count = current_tokens
|
||||
context_limit = get_token_limit(self.model_id)
|
||||
threshold = int(context_limit * settings.COMPRESSION_THRESHOLD_PERCENTAGE)
|
||||
|
||||
if current_tokens >= threshold:
|
||||
logger.warning(
|
||||
f"Context limit approaching: {current_tokens}/{context_limit} tokens "
|
||||
f"({(current_tokens/context_limit)*100:.1f}%)"
|
||||
)
|
||||
return True
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking context limit: {str(e)}", exc_info=True)
|
||||
return False
|
||||
|
||||
def _validate_context_size(self, messages: List[Dict]) -> None:
|
||||
from application.core.model_utils import get_token_limit
|
||||
|
||||
current_tokens = self._calculate_current_context_tokens(messages)
|
||||
self.current_token_count = current_tokens
|
||||
context_limit = get_token_limit(self.model_id)
|
||||
percentage = (current_tokens / context_limit) * 100
|
||||
|
||||
if current_tokens >= context_limit:
|
||||
logger.warning(
|
||||
f"Context at limit: {current_tokens:,}/{context_limit:,} tokens "
|
||||
f"({percentage:.1f}%). Model: {self.model_id}"
|
||||
)
|
||||
elif current_tokens >= int(
|
||||
context_limit * settings.COMPRESSION_THRESHOLD_PERCENTAGE
|
||||
):
|
||||
logger.info(
|
||||
f"Context approaching limit: {current_tokens:,}/{context_limit:,} tokens "
|
||||
f"({percentage:.1f}%)"
|
||||
)
|
||||
|
||||
def _truncate_text_middle(self, text: str, max_tokens: int) -> str:
|
||||
from application.utils import num_tokens_from_string
|
||||
|
||||
current_tokens = num_tokens_from_string(text)
|
||||
if current_tokens <= max_tokens:
|
||||
return text
|
||||
|
||||
chars_per_token = len(text) / current_tokens if current_tokens > 0 else 4
|
||||
target_chars = int(max_tokens * chars_per_token * 0.95)
|
||||
|
||||
if target_chars <= 0:
|
||||
return ""
|
||||
|
||||
start_chars = int(target_chars * 0.4)
|
||||
end_chars = int(target_chars * 0.4)
|
||||
|
||||
truncation_marker = "\n\n[... content truncated to fit context limit ...]\n\n"
|
||||
truncated = text[:start_chars] + truncation_marker + text[-end_chars:]
|
||||
|
||||
logger.info(
|
||||
f"Truncated text from {current_tokens:,} to ~{max_tokens:,} tokens "
|
||||
f"(removed middle section)"
|
||||
)
|
||||
return truncated
|
||||
|
||||
# ---- Message building ----
|
||||
return [
|
||||
{
|
||||
**tool_call,
|
||||
"result": (
|
||||
f"{str(tool_call['result'])[:50]}..."
|
||||
if len(str(tool_call["result"])) > 50
|
||||
else tool_call["result"]
|
||||
),
|
||||
"status": "completed",
|
||||
}
|
||||
for tool_call in self.tool_calls
|
||||
]
|
||||
|
||||
def _build_messages(
|
||||
self,
|
||||
system_prompt: str,
|
||||
query: str,
|
||||
retrieved_data: List[Dict],
|
||||
) -> List[Dict]:
|
||||
"""Build messages using pre-rendered system prompt"""
|
||||
from application.core.model_utils import get_token_limit
|
||||
from application.utils import num_tokens_from_string
|
||||
docs_together = "\n".join([doc["text"] for doc in retrieved_data])
|
||||
p_chat_combine = system_prompt.replace("{summaries}", docs_together)
|
||||
messages_combine = [{"role": "system", "content": p_chat_combine}]
|
||||
|
||||
if self.compressed_summary:
|
||||
compression_context = (
|
||||
"\n\n---\n\n"
|
||||
"This session is being continued from a previous conversation that "
|
||||
"has been compressed to fit within context limits. "
|
||||
"The conversation is summarized below:\n\n"
|
||||
f"{self.compressed_summary}"
|
||||
)
|
||||
system_prompt = system_prompt + compression_context
|
||||
|
||||
context_limit = get_token_limit(self.model_id)
|
||||
system_tokens = num_tokens_from_string(system_prompt)
|
||||
|
||||
safety_buffer = int(context_limit * 0.1)
|
||||
available_after_system = context_limit - system_tokens - safety_buffer
|
||||
|
||||
max_query_tokens = int(available_after_system * 0.8)
|
||||
query_tokens = num_tokens_from_string(query)
|
||||
|
||||
if query_tokens > max_query_tokens:
|
||||
query = self._truncate_text_middle(query, max_query_tokens)
|
||||
query_tokens = num_tokens_from_string(query)
|
||||
|
||||
available_for_history = max(available_after_system - query_tokens, 0)
|
||||
|
||||
working_history = self._truncate_history_to_fit(
|
||||
self.chat_history,
|
||||
available_for_history,
|
||||
)
|
||||
|
||||
messages = [{"role": "system", "content": system_prompt}]
|
||||
|
||||
for i in working_history:
|
||||
for i in self.chat_history:
|
||||
if "prompt" in i and "response" in i:
|
||||
messages.append({"role": "user", "content": i["prompt"]})
|
||||
messages.append({"role": "assistant", "content": i["response"]})
|
||||
messages_combine.append({"role": "user", "content": i["prompt"]})
|
||||
messages_combine.append({"role": "assistant", "content": i["response"]})
|
||||
if "tool_calls" in i:
|
||||
for tool_call in i["tool_calls"]:
|
||||
call_id = tool_call.get("call_id") or str(uuid.uuid4())
|
||||
@@ -283,67 +290,29 @@ class BaseAgent(ABC):
|
||||
}
|
||||
}
|
||||
|
||||
messages.append(
|
||||
messages_combine.append(
|
||||
{"role": "assistant", "content": [function_call_dict]}
|
||||
)
|
||||
messages.append(
|
||||
messages_combine.append(
|
||||
{"role": "tool", "content": [function_response_dict]}
|
||||
)
|
||||
messages.append({"role": "user", "content": query})
|
||||
return messages
|
||||
messages_combine.append({"role": "user", "content": query})
|
||||
return messages_combine
|
||||
|
||||
def _truncate_history_to_fit(
|
||||
def _retriever_search(
|
||||
self,
|
||||
history: List[Dict],
|
||||
max_tokens: int,
|
||||
retriever: BaseRetriever,
|
||||
query: str,
|
||||
log_context: Optional[LogContext] = None,
|
||||
) -> List[Dict]:
|
||||
from application.utils import num_tokens_from_string
|
||||
|
||||
if not history or max_tokens <= 0:
|
||||
return []
|
||||
|
||||
truncated = []
|
||||
current_tokens = 0
|
||||
|
||||
for message in reversed(history):
|
||||
message_tokens = 0
|
||||
|
||||
if "prompt" in message and "response" in message:
|
||||
message_tokens += num_tokens_from_string(message["prompt"])
|
||||
message_tokens += num_tokens_from_string(message["response"])
|
||||
|
||||
if "tool_calls" in message:
|
||||
for tool_call in message["tool_calls"]:
|
||||
tool_str = (
|
||||
f"Tool: {tool_call.get('tool_name')} | "
|
||||
f"Action: {tool_call.get('action_name')} | "
|
||||
f"Args: {tool_call.get('arguments')} | "
|
||||
f"Response: {tool_call.get('result')}"
|
||||
)
|
||||
message_tokens += num_tokens_from_string(tool_str)
|
||||
|
||||
if current_tokens + message_tokens <= max_tokens:
|
||||
current_tokens += message_tokens
|
||||
truncated.insert(0, message)
|
||||
else:
|
||||
break
|
||||
|
||||
if len(truncated) < len(history):
|
||||
logger.info(
|
||||
f"Truncated chat history from {len(history)} to {len(truncated)} messages "
|
||||
f"to fit within {max_tokens:,} token budget"
|
||||
)
|
||||
|
||||
return truncated
|
||||
|
||||
# ---- LLM generation ----
|
||||
retrieved_data = retriever.search(query)
|
||||
if log_context:
|
||||
data = build_stack_data(retriever, exclude_attributes=["llm"])
|
||||
log_context.stacks.append({"component": "retriever", "data": data})
|
||||
return retrieved_data
|
||||
|
||||
def _llm_gen(self, messages: List[Dict], log_context: Optional[LogContext] = None):
|
||||
self._validate_context_size(messages)
|
||||
|
||||
gen_kwargs = {"model": self.model_id, "messages": messages}
|
||||
if self.attachments:
|
||||
gen_kwargs["_usage_attachments"] = self.attachments
|
||||
gen_kwargs = {"model": self.gpt_model, "messages": messages}
|
||||
|
||||
if (
|
||||
hasattr(self.llm, "_supports_tools")
|
||||
@@ -351,6 +320,7 @@ class BaseAgent(ABC):
|
||||
and self.tools
|
||||
):
|
||||
gen_kwargs["tools"] = self.tools
|
||||
|
||||
if (
|
||||
self.json_schema
|
||||
and hasattr(self.llm, "_supports_structured_output")
|
||||
@@ -364,6 +334,7 @@ class BaseAgent(ABC):
|
||||
gen_kwargs["response_format"] = structured_format
|
||||
elif self.llm_name == "google":
|
||||
gen_kwargs["response_schema"] = structured_format
|
||||
|
||||
resp = self.llm.gen_stream(**gen_kwargs)
|
||||
|
||||
if log_context:
|
||||
|
||||
@@ -1,33 +1,53 @@
|
||||
import logging
|
||||
from typing import Dict, Generator
|
||||
|
||||
from application.agents.base import BaseAgent
|
||||
from application.logging import LogContext
|
||||
from application.retriever.base import BaseRetriever
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ClassicAgent(BaseAgent):
|
||||
"""A simplified agent with clear execution flow"""
|
||||
"""A simplified agent with clear execution flow.
|
||||
|
||||
Usage:
|
||||
1. Processes a query through retrieval
|
||||
2. Sets up available tools
|
||||
3. Generates responses using LLM
|
||||
4. Handles tool interactions if needed
|
||||
5. Returns standardized outputs
|
||||
|
||||
Easy to extend by overriding specific steps.
|
||||
"""
|
||||
|
||||
def _gen_inner(
|
||||
self, query: str, log_context: LogContext
|
||||
self, query: str, retriever: BaseRetriever, log_context: LogContext
|
||||
) -> Generator[Dict, None, None]:
|
||||
"""Core generator function for ClassicAgent execution flow"""
|
||||
# Step 1: Retrieve relevant data
|
||||
retrieved_data = self._retriever_search(retriever, query, log_context)
|
||||
|
||||
tools_dict = self.tool_executor.get_tools()
|
||||
# Step 2: Prepare tools
|
||||
tools_dict = (
|
||||
self._get_user_tools(self.user)
|
||||
if not self.user_api_key
|
||||
else self._get_tools(self.user_api_key)
|
||||
)
|
||||
self._prepare_tools(tools_dict)
|
||||
|
||||
messages = self._build_messages(self.prompt, query)
|
||||
# Step 3: Build and process messages
|
||||
messages = self._build_messages(self.prompt, query, retrieved_data)
|
||||
llm_response = self._llm_gen(messages, log_context)
|
||||
|
||||
# Step 4: Handle the response
|
||||
yield from self._handle_response(
|
||||
llm_response, tools_dict, messages, log_context
|
||||
)
|
||||
|
||||
yield {"sources": self.retrieved_docs}
|
||||
# Step 5: Return metadata
|
||||
yield {"sources": retrieved_data}
|
||||
yield {"tool_calls": self._get_truncated_tool_calls()}
|
||||
|
||||
# Log tool calls for debugging
|
||||
log_context.stacks.append(
|
||||
{"component": "agent", "data": {"tool_calls": self.tool_calls.copy()}}
|
||||
)
|
||||
|
||||
229
application/agents/react_agent.py
Normal file
229
application/agents/react_agent.py
Normal file
@@ -0,0 +1,229 @@
|
||||
import os
|
||||
from typing import Dict, Generator, List, Any
|
||||
import logging
|
||||
|
||||
from application.agents.base import BaseAgent
|
||||
from application.logging import build_stack_data, LogContext
|
||||
from application.retriever.base import BaseRetriever
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
current_dir = os.path.dirname(
|
||||
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
)
|
||||
with open(
|
||||
os.path.join(current_dir, "application/prompts", "react_planning_prompt.txt"), "r"
|
||||
) as f:
|
||||
planning_prompt_template = f.read()
|
||||
with open(
|
||||
os.path.join(current_dir, "application/prompts", "react_final_prompt.txt"),
|
||||
"r",
|
||||
) as f:
|
||||
final_prompt_template = f.read()
|
||||
|
||||
MAX_ITERATIONS_REASONING = 10
|
||||
|
||||
class ReActAgent(BaseAgent):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.plan: str = ""
|
||||
self.observations: List[str] = []
|
||||
|
||||
def _extract_content_from_llm_response(self, resp: Any) -> str:
|
||||
"""
|
||||
Helper to extract string content from various LLM response types.
|
||||
Handles strings, message objects (OpenAI-like), and streams.
|
||||
Adapt stream handling for your specific LLM client if not OpenAI.
|
||||
"""
|
||||
collected_content = []
|
||||
if isinstance(resp, str):
|
||||
collected_content.append(resp)
|
||||
elif ( # OpenAI non-streaming or Anthropic non-streaming (older SDK style)
|
||||
hasattr(resp, "message")
|
||||
and hasattr(resp.message, "content")
|
||||
and resp.message.content is not None
|
||||
):
|
||||
collected_content.append(resp.message.content)
|
||||
elif ( # OpenAI non-streaming (Pydantic model), Anthropic new SDK non-streaming
|
||||
hasattr(resp, "choices") and resp.choices and
|
||||
hasattr(resp.choices[0], "message") and
|
||||
hasattr(resp.choices[0].message, "content") and
|
||||
resp.choices[0].message.content is not None
|
||||
):
|
||||
collected_content.append(resp.choices[0].message.content) # OpenAI
|
||||
elif ( # Anthropic new SDK non-streaming content block
|
||||
hasattr(resp, "content") and isinstance(resp.content, list) and resp.content and
|
||||
hasattr(resp.content[0], "text")
|
||||
):
|
||||
collected_content.append(resp.content[0].text) # Anthropic
|
||||
else:
|
||||
# Assume resp is a stream if not a recognized object
|
||||
try:
|
||||
for chunk in resp: # This will fail if resp is not iterable (e.g. a non-streaming response object)
|
||||
content_piece = ""
|
||||
# OpenAI-like stream
|
||||
if hasattr(chunk, 'choices') and len(chunk.choices) > 0 and \
|
||||
hasattr(chunk.choices[0], 'delta') and \
|
||||
hasattr(chunk.choices[0].delta, 'content') and \
|
||||
chunk.choices[0].delta.content is not None:
|
||||
content_piece = chunk.choices[0].delta.content
|
||||
# Anthropic-like stream (ContentBlockDelta)
|
||||
elif hasattr(chunk, 'type') and chunk.type == 'content_block_delta' and \
|
||||
hasattr(chunk, 'delta') and hasattr(chunk.delta, 'text'):
|
||||
content_piece = chunk.delta.text
|
||||
elif isinstance(chunk, str): # Simplest case: stream of strings
|
||||
content_piece = chunk
|
||||
|
||||
if content_piece:
|
||||
collected_content.append(content_piece)
|
||||
except TypeError: # If resp is not iterable (e.g. a final response object that wasn't caught above)
|
||||
logger.debug(f"Response type {type(resp)} could not be iterated as a stream. It might be a non-streaming object not handled by specific checks.")
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing potential stream chunk: {e}, chunk was: {getattr(chunk, '__dict__', chunk)}")
|
||||
|
||||
|
||||
return "".join(collected_content)
|
||||
|
||||
def _gen_inner(
|
||||
self, query: str, retriever: BaseRetriever, log_context: LogContext
|
||||
) -> Generator[Dict, None, None]:
|
||||
# Reset state for this generation call
|
||||
self.plan = ""
|
||||
self.observations = []
|
||||
retrieved_data = self._retriever_search(retriever, query, log_context)
|
||||
|
||||
if self.user_api_key:
|
||||
tools_dict = self._get_tools(self.user_api_key)
|
||||
else:
|
||||
tools_dict = self._get_user_tools(self.user)
|
||||
self._prepare_tools(tools_dict)
|
||||
|
||||
docs_together = "\n".join([doc["text"] for doc in retrieved_data])
|
||||
iterating_reasoning = 0
|
||||
while iterating_reasoning < MAX_ITERATIONS_REASONING:
|
||||
iterating_reasoning += 1
|
||||
# 1. Create Plan
|
||||
logger.info("ReActAgent: Creating plan...")
|
||||
plan_stream = self._create_plan(query, docs_together, log_context)
|
||||
current_plan_parts = []
|
||||
yield {"thought": f"Reasoning... (iteration {iterating_reasoning})\n\n"}
|
||||
for line_chunk in plan_stream:
|
||||
current_plan_parts.append(line_chunk)
|
||||
yield {"thought": line_chunk}
|
||||
self.plan = "".join(current_plan_parts)
|
||||
if self.plan:
|
||||
self.observations.append(f"Plan: {self.plan} Iteration: {iterating_reasoning}")
|
||||
|
||||
|
||||
max_obs_len = 20000
|
||||
obs_str = "\n".join(self.observations)
|
||||
if len(obs_str) > max_obs_len:
|
||||
obs_str = obs_str[:max_obs_len] + "\n...[observations truncated]"
|
||||
execution_prompt_str = (
|
||||
(self.prompt or "")
|
||||
+ f"\n\nFollow this plan:\n{self.plan}"
|
||||
+ f"\n\nObservations:\n{obs_str}"
|
||||
+ f"\n\nIf there is enough data to complete user query '{query}', Respond with 'SATISFIED' only. Otherwise, continue. Dont Menstion 'SATISFIED' in your response if you are not ready. "
|
||||
)
|
||||
|
||||
messages = self._build_messages(execution_prompt_str, query, retrieved_data)
|
||||
|
||||
resp_from_llm_gen = self._llm_gen(messages, log_context)
|
||||
|
||||
initial_llm_thought_content = self._extract_content_from_llm_response(resp_from_llm_gen)
|
||||
if initial_llm_thought_content:
|
||||
self.observations.append(f"Initial thought/response: {initial_llm_thought_content}")
|
||||
else:
|
||||
logger.info("ReActAgent: Initial LLM response (before handler) had no textual content (might be only tool calls).")
|
||||
resp_after_handler = self._llm_handler(resp_from_llm_gen, tools_dict, messages, log_context)
|
||||
|
||||
for tool_call_info in self.tool_calls: # Iterate over self.tool_calls populated by _llm_handler
|
||||
observation_string = (
|
||||
f"Executed Action: Tool '{tool_call_info.get('tool_name', 'N/A')}' "
|
||||
f"with arguments '{tool_call_info.get('arguments', '{}')}'. Result: '{str(tool_call_info.get('result', ''))[:200]}...'"
|
||||
)
|
||||
self.observations.append(observation_string)
|
||||
|
||||
content_after_handler = self._extract_content_from_llm_response(resp_after_handler)
|
||||
if content_after_handler:
|
||||
self.observations.append(f"Response after tool execution: {content_after_handler}")
|
||||
else:
|
||||
logger.info("ReActAgent: LLM response after handler had no textual content.")
|
||||
|
||||
if log_context:
|
||||
log_context.stacks.append(
|
||||
{"component": "agent_tool_calls", "data": {"tool_calls": self.tool_calls.copy()}}
|
||||
)
|
||||
|
||||
yield {"sources": retrieved_data}
|
||||
|
||||
display_tool_calls = []
|
||||
for tc in self.tool_calls:
|
||||
cleaned_tc = tc.copy()
|
||||
if len(str(cleaned_tc.get("result", ""))) > 50:
|
||||
cleaned_tc["result"] = str(cleaned_tc["result"])[:50] + "..."
|
||||
display_tool_calls.append(cleaned_tc)
|
||||
if display_tool_calls:
|
||||
yield {"tool_calls": display_tool_calls}
|
||||
|
||||
if "SATISFIED" in content_after_handler:
|
||||
logger.info("ReActAgent: LLM satisfied with the plan and data. Stopping reasoning.")
|
||||
break
|
||||
|
||||
# 3. Create Final Answer based on all observations
|
||||
final_answer_stream = self._create_final_answer(query, self.observations, log_context)
|
||||
for answer_chunk in final_answer_stream:
|
||||
yield {"answer": answer_chunk}
|
||||
logger.info("ReActAgent: Finished generating final answer.")
|
||||
|
||||
def _create_plan(
|
||||
self, query: str, docs_data: str, log_context: LogContext = None
|
||||
) -> Generator[str, None, None]:
|
||||
plan_prompt_filled = planning_prompt_template.replace("{query}", query)
|
||||
if "{summaries}" in plan_prompt_filled:
|
||||
summaries = docs_data if docs_data else "No documents retrieved."
|
||||
plan_prompt_filled = plan_prompt_filled.replace("{summaries}", summaries)
|
||||
plan_prompt_filled = plan_prompt_filled.replace("{prompt}", self.prompt or "")
|
||||
plan_prompt_filled = plan_prompt_filled.replace("{observations}", "\n".join(self.observations))
|
||||
|
||||
messages = [{"role": "user", "content": plan_prompt_filled}]
|
||||
|
||||
plan_stream_from_llm = self.llm.gen_stream(
|
||||
model=self.gpt_model, messages=messages, tools=getattr(self, 'tools', None) # Use self.tools
|
||||
)
|
||||
if log_context:
|
||||
data = build_stack_data(self.llm)
|
||||
log_context.stacks.append({"component": "planning_llm", "data": data})
|
||||
|
||||
for chunk in plan_stream_from_llm:
|
||||
content_piece = self._extract_content_from_llm_response(chunk)
|
||||
if content_piece:
|
||||
yield content_piece
|
||||
|
||||
def _create_final_answer(
|
||||
self, query: str, observations: List[str], log_context: LogContext = None
|
||||
) -> Generator[str, None, None]:
|
||||
observation_string = "\n".join(observations)
|
||||
max_obs_len = 10000
|
||||
if len(observation_string) > max_obs_len:
|
||||
observation_string = observation_string[:max_obs_len] + "\n...[observations truncated]"
|
||||
logger.warning("ReActAgent: Truncated observations for final answer prompt due to length.")
|
||||
|
||||
final_answer_prompt_filled = final_prompt_template.format(
|
||||
query=query, observations=observation_string
|
||||
)
|
||||
|
||||
messages = [{"role": "user", "content": final_answer_prompt_filled}]
|
||||
|
||||
# Final answer should synthesize, not call tools.
|
||||
final_answer_stream_from_llm = self.llm.gen_stream(
|
||||
model=self.gpt_model, messages=messages, tools=None
|
||||
)
|
||||
if log_context:
|
||||
data = build_stack_data(self.llm)
|
||||
log_context.stacks.append({"component": "final_answer_llm", "data": data})
|
||||
|
||||
for chunk in final_answer_stream_from_llm:
|
||||
content_piece = self._extract_content_from_llm_response(chunk)
|
||||
if content_piece:
|
||||
yield content_piece
|
||||
@@ -1,692 +0,0 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from typing import Dict, Generator, List, Optional
|
||||
|
||||
from application.agents.base import BaseAgent
|
||||
from application.agents.tool_executor import ToolExecutor
|
||||
from application.agents.tools.internal_search import (
|
||||
INTERNAL_TOOL_ID,
|
||||
add_internal_search_tool,
|
||||
)
|
||||
from application.agents.tools.think import THINK_TOOL_ENTRY, THINK_TOOL_ID
|
||||
from application.logging import LogContext
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Defaults (can be overridden via constructor)
|
||||
DEFAULT_MAX_STEPS = 6
|
||||
DEFAULT_MAX_SUB_ITERATIONS = 5
|
||||
DEFAULT_TIMEOUT_SECONDS = 300 # 5 minutes
|
||||
DEFAULT_TOKEN_BUDGET = 100_000
|
||||
DEFAULT_PARALLEL_WORKERS = 3
|
||||
|
||||
# Adaptive depth caps per complexity level
|
||||
COMPLEXITY_CAPS = {
|
||||
"simple": 2,
|
||||
"moderate": 4,
|
||||
"complex": 6,
|
||||
}
|
||||
|
||||
_PROMPTS_DIR = os.path.join(
|
||||
os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
|
||||
"prompts",
|
||||
"research",
|
||||
)
|
||||
|
||||
|
||||
def _load_prompt(name: str) -> str:
|
||||
with open(os.path.join(_PROMPTS_DIR, name), "r") as f:
|
||||
return f.read()
|
||||
|
||||
|
||||
CLARIFICATION_PROMPT = _load_prompt("clarification.txt")
|
||||
PLANNING_PROMPT = _load_prompt("planning.txt")
|
||||
STEP_PROMPT = _load_prompt("step.txt")
|
||||
SYNTHESIS_PROMPT = _load_prompt("synthesis.txt")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CitationManager
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class CitationManager:
|
||||
"""Tracks and deduplicates citations across research steps."""
|
||||
|
||||
def __init__(self):
|
||||
self.citations: Dict[int, Dict] = {}
|
||||
self._counter = 0
|
||||
|
||||
def add(self, doc: Dict) -> int:
|
||||
"""Register a source, return its citation number. Deduplicates by source."""
|
||||
source = doc.get("source", "")
|
||||
title = doc.get("title", "")
|
||||
for num, existing in self.citations.items():
|
||||
if existing.get("source") == source and existing.get("title") == title:
|
||||
return num
|
||||
self._counter += 1
|
||||
self.citations[self._counter] = doc
|
||||
return self._counter
|
||||
|
||||
def add_docs(self, docs: List[Dict]) -> str:
|
||||
"""Register multiple docs, return formatted citation mapping text."""
|
||||
mapping_lines = []
|
||||
for doc in docs:
|
||||
num = self.add(doc)
|
||||
title = doc.get("title", "Untitled")
|
||||
mapping_lines.append(f"[{num}] {title}")
|
||||
return "\n".join(mapping_lines)
|
||||
|
||||
def format_references(self) -> str:
|
||||
"""Generate [N] -> source mapping for report footer."""
|
||||
if not self.citations:
|
||||
return "No sources found."
|
||||
lines = []
|
||||
for num, doc in sorted(self.citations.items()):
|
||||
title = doc.get("title", "Untitled")
|
||||
source = doc.get("source", "Unknown")
|
||||
filename = doc.get("filename", "")
|
||||
display = filename or title
|
||||
lines.append(f"[{num}] {display} — {source}")
|
||||
return "\n".join(lines)
|
||||
|
||||
def get_all_docs(self) -> List[Dict]:
|
||||
return list(self.citations.values())
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ResearchAgent
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class ResearchAgent(BaseAgent):
|
||||
"""Multi-step research agent with parallel execution and budget controls.
|
||||
|
||||
Orchestrates: Plan -> Research (per step, optionally parallel) -> Synthesize.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
retriever_config: Optional[Dict] = None,
|
||||
max_steps: int = DEFAULT_MAX_STEPS,
|
||||
max_sub_iterations: int = DEFAULT_MAX_SUB_ITERATIONS,
|
||||
timeout_seconds: int = DEFAULT_TIMEOUT_SECONDS,
|
||||
token_budget: int = DEFAULT_TOKEN_BUDGET,
|
||||
parallel_workers: int = DEFAULT_PARALLEL_WORKERS,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.retriever_config = retriever_config or {}
|
||||
self.max_steps = max_steps
|
||||
self.max_sub_iterations = max_sub_iterations
|
||||
self.timeout_seconds = timeout_seconds
|
||||
self.token_budget = token_budget
|
||||
self.parallel_workers = parallel_workers
|
||||
self.citations = CitationManager()
|
||||
self._start_time: float = 0
|
||||
self._tokens_used: int = 0
|
||||
self._last_token_snapshot: int = 0
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Budget & timeout helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _is_timed_out(self) -> bool:
|
||||
return (time.monotonic() - self._start_time) >= self.timeout_seconds
|
||||
|
||||
def _elapsed(self) -> float:
|
||||
return round(time.monotonic() - self._start_time, 1)
|
||||
|
||||
def _track_tokens(self, count: int):
|
||||
self._tokens_used += count
|
||||
|
||||
def _budget_remaining(self) -> int:
|
||||
return max(self.token_budget - self._tokens_used, 0)
|
||||
|
||||
def _is_over_budget(self) -> bool:
|
||||
return self._tokens_used >= self.token_budget
|
||||
|
||||
def _snapshot_llm_tokens(self) -> int:
|
||||
"""Read current token usage from LLM and return delta since last snapshot."""
|
||||
current = self.llm.token_usage.get("prompt_tokens", 0) + self.llm.token_usage.get("generated_tokens", 0)
|
||||
delta = current - self._last_token_snapshot
|
||||
self._last_token_snapshot = current
|
||||
return delta
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Main orchestration
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _gen_inner(
|
||||
self, query: str, log_context: LogContext
|
||||
) -> Generator[Dict, None, None]:
|
||||
self._start_time = time.monotonic()
|
||||
tools_dict = self._setup_tools()
|
||||
|
||||
# Phase 0: Clarification (skip if user is responding to a prior clarification)
|
||||
if not self._is_follow_up():
|
||||
clarification = self._clarification_phase(query)
|
||||
if clarification:
|
||||
yield {"metadata": {"is_clarification": True}}
|
||||
yield {"answer": clarification}
|
||||
yield {"sources": []}
|
||||
yield {"tool_calls": []}
|
||||
log_context.stacks.append(
|
||||
{"component": "agent", "data": {"clarification": True}}
|
||||
)
|
||||
return
|
||||
|
||||
# Phase 1: Planning (with adaptive depth)
|
||||
yield {"type": "research_progress", "data": {"status": "planning"}}
|
||||
plan, complexity = self._planning_phase(query)
|
||||
|
||||
if not plan:
|
||||
logger.warning("ResearchAgent: Planning produced no steps, falling back")
|
||||
plan = [{"query": query, "rationale": "Direct investigation"}]
|
||||
complexity = "simple"
|
||||
|
||||
yield {
|
||||
"type": "research_plan",
|
||||
"data": {"steps": plan, "complexity": complexity},
|
||||
}
|
||||
|
||||
# Phase 2: Research each step (yields progress events in real-time)
|
||||
intermediate_reports = []
|
||||
for i, step in enumerate(plan):
|
||||
step_num = i + 1
|
||||
step_query = step.get("query", query)
|
||||
|
||||
if self._is_timed_out():
|
||||
logger.warning(
|
||||
f"ResearchAgent: Timeout at step {step_num}/{len(plan)} "
|
||||
f"({self._elapsed()}s)"
|
||||
)
|
||||
break
|
||||
if self._is_over_budget():
|
||||
logger.warning(
|
||||
f"ResearchAgent: Token budget exhausted at step {step_num}/{len(plan)}"
|
||||
)
|
||||
break
|
||||
|
||||
yield {
|
||||
"type": "research_progress",
|
||||
"data": {
|
||||
"step": step_num,
|
||||
"total": len(plan),
|
||||
"query": step_query,
|
||||
"status": "researching",
|
||||
},
|
||||
}
|
||||
|
||||
report = self._research_step(step_query, tools_dict)
|
||||
intermediate_reports.append({"step": step, "content": report})
|
||||
|
||||
yield {
|
||||
"type": "research_progress",
|
||||
"data": {
|
||||
"step": step_num,
|
||||
"total": len(plan),
|
||||
"query": step_query,
|
||||
"status": "complete",
|
||||
},
|
||||
}
|
||||
|
||||
# Phase 3: Synthesis (streaming)
|
||||
if self._is_timed_out():
|
||||
logger.warning(
|
||||
f"ResearchAgent: Timeout ({self._elapsed()}s) before synthesis, "
|
||||
f"synthesizing with {len(intermediate_reports)} reports"
|
||||
)
|
||||
yield {
|
||||
"type": "research_progress",
|
||||
"data": {
|
||||
"status": "synthesizing",
|
||||
"elapsed_seconds": self._elapsed(),
|
||||
"tokens_used": self._tokens_used,
|
||||
},
|
||||
}
|
||||
yield from self._synthesis_phase(
|
||||
query, plan, intermediate_reports, tools_dict, log_context
|
||||
)
|
||||
|
||||
# Sources and tool calls
|
||||
self.retrieved_docs = self.citations.get_all_docs()
|
||||
yield {"sources": self.retrieved_docs}
|
||||
yield {"tool_calls": self._get_truncated_tool_calls()}
|
||||
|
||||
logger.info(
|
||||
f"ResearchAgent completed: {len(intermediate_reports)}/{len(plan)} steps, "
|
||||
f"{self._elapsed()}s, ~{self._tokens_used} tokens"
|
||||
)
|
||||
log_context.stacks.append(
|
||||
{"component": "agent", "data": {"tool_calls": self.tool_calls.copy()}}
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Tool setup
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _setup_tools(self) -> Dict:
|
||||
"""Build tools_dict with user tools + internal search + think."""
|
||||
tools_dict = self.tool_executor.get_tools()
|
||||
|
||||
add_internal_search_tool(tools_dict, self.retriever_config)
|
||||
|
||||
think_entry = dict(THINK_TOOL_ENTRY)
|
||||
think_entry["config"] = {}
|
||||
tools_dict[THINK_TOOL_ID] = think_entry
|
||||
|
||||
self._prepare_tools(tools_dict)
|
||||
return tools_dict
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Phase 0: Clarification
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _is_follow_up(self) -> bool:
|
||||
"""Check if the user is responding to a prior clarification.
|
||||
|
||||
Uses the metadata flag stored in the conversation DB — no string matching.
|
||||
Only skip clarification when the last query was explicitly flagged
|
||||
as a clarification by this agent.
|
||||
"""
|
||||
if not self.chat_history:
|
||||
return False
|
||||
last = self.chat_history[-1]
|
||||
meta = last.get("metadata", {})
|
||||
return bool(meta.get("is_clarification"))
|
||||
|
||||
def _clarification_phase(self, question: str) -> Optional[str]:
|
||||
"""Ask the LLM whether the question needs clarification.
|
||||
|
||||
Returns formatted clarification text if needed, or None to proceed.
|
||||
Uses response_format to force valid JSON output.
|
||||
"""
|
||||
messages = [
|
||||
{"role": "system", "content": CLARIFICATION_PROMPT},
|
||||
{"role": "user", "content": question},
|
||||
]
|
||||
|
||||
try:
|
||||
response = self.llm.gen(
|
||||
model=self.model_id,
|
||||
messages=messages,
|
||||
tools=None,
|
||||
response_format={"type": "json_object"},
|
||||
)
|
||||
text = self._extract_text(response)
|
||||
self._track_tokens(self._snapshot_llm_tokens())
|
||||
logger.info(f"ResearchAgent clarification response: {text[:300]}")
|
||||
|
||||
data = self._parse_clarification_json(text)
|
||||
if not data or not data.get("needs_clarification"):
|
||||
return None
|
||||
|
||||
questions = data.get("questions", [])
|
||||
if not questions:
|
||||
return None
|
||||
|
||||
# Format as a friendly response
|
||||
lines = [
|
||||
"Before I begin researching, I'd like to clarify a few things:\n"
|
||||
]
|
||||
for i, q in enumerate(questions[:3], 1):
|
||||
lines.append(f"{i}. {q}")
|
||||
lines.append(
|
||||
"\nPlease provide these details and I'll start the research."
|
||||
)
|
||||
return "\n".join(lines)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Clarification phase failed: {e}", exc_info=True)
|
||||
return None # proceed with research on failure
|
||||
|
||||
def _parse_clarification_json(self, text: str) -> Optional[Dict]:
|
||||
"""Parse clarification JSON from LLM response."""
|
||||
try:
|
||||
return json.loads(text)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# Try extracting from code fences
|
||||
for marker in ["```json", "```"]:
|
||||
if marker in text:
|
||||
start = text.index(marker) + len(marker)
|
||||
end = text.index("```", start) if "```" in text[start:] else len(text)
|
||||
try:
|
||||
return json.loads(text[start:end].strip())
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
pass
|
||||
|
||||
# Try finding JSON object
|
||||
for i, ch in enumerate(text):
|
||||
if ch == "{":
|
||||
for j in range(len(text) - 1, i, -1):
|
||||
if text[j] == "}":
|
||||
try:
|
||||
return json.loads(text[i : j + 1])
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
break
|
||||
|
||||
return None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Phase 1: Planning (with adaptive depth)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _planning_phase(self, question: str) -> tuple[List[Dict], str]:
|
||||
"""Decompose the question into research steps via LLM.
|
||||
|
||||
Returns (steps, complexity) where complexity is simple/moderate/complex.
|
||||
"""
|
||||
messages = [
|
||||
{"role": "system", "content": PLANNING_PROMPT},
|
||||
{"role": "user", "content": question},
|
||||
]
|
||||
|
||||
try:
|
||||
response = self.llm.gen(
|
||||
model=self.model_id,
|
||||
messages=messages,
|
||||
tools=None,
|
||||
response_format={"type": "json_object"},
|
||||
)
|
||||
text = self._extract_text(response)
|
||||
self._track_tokens(self._snapshot_llm_tokens())
|
||||
logger.info(f"ResearchAgent planning LLM response: {text[:500]}")
|
||||
|
||||
plan_data = self._parse_plan_json(text)
|
||||
if isinstance(plan_data, dict):
|
||||
complexity = plan_data.get("complexity", "moderate")
|
||||
steps = plan_data.get("steps", [])
|
||||
else:
|
||||
complexity = "moderate"
|
||||
steps = plan_data
|
||||
|
||||
# Adaptive depth: cap steps based on assessed complexity
|
||||
cap = COMPLEXITY_CAPS.get(complexity, self.max_steps)
|
||||
cap = min(cap, self.max_steps)
|
||||
steps = steps[:cap]
|
||||
|
||||
logger.info(
|
||||
f"ResearchAgent plan: complexity={complexity}, "
|
||||
f"steps={len(steps)} (cap={cap})"
|
||||
)
|
||||
return steps, complexity
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Planning phase failed: {e}", exc_info=True)
|
||||
return (
|
||||
[{"query": question, "rationale": "Direct investigation (planning failed)"}],
|
||||
"simple",
|
||||
)
|
||||
|
||||
def _parse_plan_json(self, text: str):
|
||||
"""Extract JSON plan from LLM response. Returns dict or list."""
|
||||
# Try direct parse
|
||||
try:
|
||||
data = json.loads(text)
|
||||
if isinstance(data, dict) and "steps" in data:
|
||||
return data
|
||||
if isinstance(data, list):
|
||||
return data
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# Try extracting from markdown code fences
|
||||
for marker in ["```json", "```"]:
|
||||
if marker in text:
|
||||
start = text.index(marker) + len(marker)
|
||||
end = text.index("```", start) if "```" in text[start:] else len(text)
|
||||
try:
|
||||
data = json.loads(text[start:end].strip())
|
||||
if isinstance(data, dict) and "steps" in data:
|
||||
return data
|
||||
if isinstance(data, list):
|
||||
return data
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
pass
|
||||
|
||||
# Try finding JSON object in text
|
||||
for i, ch in enumerate(text):
|
||||
if ch == "{":
|
||||
for j in range(len(text) - 1, i, -1):
|
||||
if text[j] == "}":
|
||||
try:
|
||||
data = json.loads(text[i : j + 1])
|
||||
if isinstance(data, dict) and "steps" in data:
|
||||
return data
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
break
|
||||
|
||||
logger.warning(f"Could not parse plan JSON from: {text[:200]}")
|
||||
return []
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Phase 2: Research step (core loop)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _research_step(self, step_query: str, tools_dict: Dict) -> str:
|
||||
"""Run a focused research loop for one sub-question (sequential path)."""
|
||||
report = self._research_step_with_executor(
|
||||
step_query, tools_dict, self.tool_executor
|
||||
)
|
||||
self._collect_step_sources()
|
||||
return report
|
||||
|
||||
def _research_step_with_executor(
|
||||
self, step_query: str, tools_dict: Dict, executor: ToolExecutor
|
||||
) -> str:
|
||||
"""Core research loop. Works with any ToolExecutor instance."""
|
||||
system_prompt = STEP_PROMPT.replace("{step_query}", step_query)
|
||||
messages = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": step_query},
|
||||
]
|
||||
|
||||
last_search_empty = False
|
||||
|
||||
for iteration in range(self.max_sub_iterations):
|
||||
# Check timeout and budget
|
||||
if self._is_timed_out():
|
||||
logger.info(
|
||||
f"Research step '{step_query[:50]}' timed out at iteration {iteration}"
|
||||
)
|
||||
break
|
||||
if self._is_over_budget():
|
||||
logger.info(
|
||||
f"Research step '{step_query[:50]}' hit token budget at iteration {iteration}"
|
||||
)
|
||||
break
|
||||
|
||||
try:
|
||||
response = self.llm.gen(
|
||||
model=self.model_id,
|
||||
messages=messages,
|
||||
tools=self.tools if self.tools else None,
|
||||
)
|
||||
self._track_tokens(self._snapshot_llm_tokens())
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Research step LLM call failed (iteration {iteration}): {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
break
|
||||
|
||||
parsed = self.llm_handler.parse_response(response)
|
||||
|
||||
if not parsed.requires_tool_call:
|
||||
return parsed.content or "No findings for this step."
|
||||
|
||||
# Execute tool calls
|
||||
messages, last_search_empty = self._execute_step_tools_with_refinement(
|
||||
parsed.tool_calls, tools_dict, messages, executor, last_search_empty
|
||||
)
|
||||
|
||||
# Max iterations / timeout / budget — ask for summary
|
||||
messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Please summarize your findings so far based on the information gathered.",
|
||||
}
|
||||
)
|
||||
try:
|
||||
response = self.llm.gen(
|
||||
model=self.model_id, messages=messages, tools=None
|
||||
)
|
||||
self._track_tokens(self._snapshot_llm_tokens())
|
||||
text = self._extract_text(response)
|
||||
return text or "Research step completed."
|
||||
except Exception:
|
||||
return "Research step completed."
|
||||
|
||||
def _execute_step_tools_with_refinement(
|
||||
self,
|
||||
tool_calls,
|
||||
tools_dict: Dict,
|
||||
messages: List[Dict],
|
||||
executor: ToolExecutor,
|
||||
last_search_empty: bool,
|
||||
) -> tuple[List[Dict], bool]:
|
||||
"""Execute tool calls with query refinement on empty results.
|
||||
|
||||
Returns (updated_messages, was_last_search_empty).
|
||||
"""
|
||||
search_returned_empty = False
|
||||
|
||||
for call in tool_calls:
|
||||
gen = executor.execute(
|
||||
tools_dict, call, self.llm.__class__.__name__
|
||||
)
|
||||
result = None
|
||||
call_id = None
|
||||
while True:
|
||||
try:
|
||||
event = next(gen)
|
||||
# Log tool_call status events instead of discarding them
|
||||
if isinstance(event, dict) and event.get("type") == "tool_call":
|
||||
logger.debug(
|
||||
"Tool %s status: %s",
|
||||
event.get("data", {}).get("action_name", ""),
|
||||
event.get("data", {}).get("status", ""),
|
||||
)
|
||||
except StopIteration as e:
|
||||
result, call_id = e.value
|
||||
break
|
||||
|
||||
# Detect empty search results for refinement
|
||||
is_search = "search" in (call.name or "").lower()
|
||||
result_str = str(result) if result else ""
|
||||
if is_search and "No documents found" in result_str:
|
||||
search_returned_empty = True
|
||||
if last_search_empty:
|
||||
# Two consecutive empty searches — inject refinement hint
|
||||
result_str += (
|
||||
"\n\nHint: Previous search also returned no results. "
|
||||
"Try a very different query with different keywords, "
|
||||
"or broaden your search terms."
|
||||
)
|
||||
result = result_str
|
||||
|
||||
function_call_content = {
|
||||
"function_call": {
|
||||
"name": call.name,
|
||||
"args": call.arguments,
|
||||
"call_id": call_id,
|
||||
}
|
||||
}
|
||||
messages.append(
|
||||
{"role": "assistant", "content": [function_call_content]}
|
||||
)
|
||||
tool_message = self.llm_handler.create_tool_message(call, result)
|
||||
messages.append(tool_message)
|
||||
|
||||
return messages, search_returned_empty
|
||||
|
||||
def _collect_step_sources(self):
|
||||
"""Collect sources from InternalSearchTool and register with CitationManager."""
|
||||
cache_key = f"internal_search:{INTERNAL_TOOL_ID}:{self.user or ''}"
|
||||
tool = self.tool_executor._loaded_tools.get(cache_key)
|
||||
if tool and hasattr(tool, "retrieved_docs"):
|
||||
for doc in tool.retrieved_docs:
|
||||
self.citations.add(doc)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Phase 3: Synthesis
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _synthesis_phase(
|
||||
self,
|
||||
question: str,
|
||||
plan: List[Dict],
|
||||
intermediate_reports: List[Dict],
|
||||
tools_dict: Dict,
|
||||
log_context: LogContext,
|
||||
) -> Generator[Dict, None, None]:
|
||||
"""Compile all findings into a final cited report (streaming)."""
|
||||
plan_lines = []
|
||||
for i, step in enumerate(plan, 1):
|
||||
plan_lines.append(
|
||||
f"{i}. {step.get('query', 'Unknown')} — {step.get('rationale', '')}"
|
||||
)
|
||||
plan_summary = "\n".join(plan_lines)
|
||||
|
||||
findings_parts = []
|
||||
for i, report in enumerate(intermediate_reports, 1):
|
||||
step_query = report["step"].get("query", "Unknown")
|
||||
content = report["content"]
|
||||
findings_parts.append(
|
||||
f"--- Step {i}: {step_query} ---\n{content}"
|
||||
)
|
||||
findings = "\n\n".join(findings_parts)
|
||||
|
||||
references = self.citations.format_references()
|
||||
|
||||
synthesis_prompt = SYNTHESIS_PROMPT.replace("{question}", question)
|
||||
synthesis_prompt = synthesis_prompt.replace("{plan_summary}", plan_summary)
|
||||
synthesis_prompt = synthesis_prompt.replace("{findings}", findings)
|
||||
synthesis_prompt = synthesis_prompt.replace("{references}", references)
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": synthesis_prompt},
|
||||
{"role": "user", "content": f"Please write the research report for: {question}"},
|
||||
]
|
||||
|
||||
llm_response = self.llm.gen_stream(
|
||||
model=self.model_id, messages=messages, tools=None
|
||||
)
|
||||
|
||||
if log_context:
|
||||
from application.logging import build_stack_data
|
||||
|
||||
log_context.stacks.append(
|
||||
{"component": "synthesis_llm", "data": build_stack_data(self.llm)}
|
||||
)
|
||||
|
||||
yield from self._handle_response(
|
||||
llm_response, tools_dict, messages, log_context
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _extract_text(self, response) -> str:
|
||||
"""Extract text content from a non-streaming LLM response."""
|
||||
if isinstance(response, str):
|
||||
return response
|
||||
if hasattr(response, "message") and hasattr(response.message, "content"):
|
||||
return response.message.content or ""
|
||||
if hasattr(response, "choices") and response.choices:
|
||||
choice = response.choices[0]
|
||||
if hasattr(choice, "message") and hasattr(choice.message, "content"):
|
||||
return choice.message.content or ""
|
||||
if hasattr(response, "content") and isinstance(response.content, list):
|
||||
if response.content and hasattr(response.content[0], "text"):
|
||||
return response.content[0].text or ""
|
||||
return str(response) if response else ""
|
||||
@@ -1,313 +0,0 @@
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from bson.objectid import ObjectId
|
||||
|
||||
from application.agents.tools.tool_action_parser import ToolActionParser
|
||||
from application.agents.tools.tool_manager import ToolManager
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.core.settings import settings
|
||||
from application.security.encryption import decrypt_credentials
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ToolExecutor:
|
||||
"""Handles tool discovery, preparation, and execution.
|
||||
|
||||
Extracted from BaseAgent to separate concerns and enable tool caching.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
user_api_key: Optional[str] = None,
|
||||
user: Optional[str] = None,
|
||||
decoded_token: Optional[Dict] = None,
|
||||
):
|
||||
self.user_api_key = user_api_key
|
||||
self.user = user
|
||||
self.decoded_token = decoded_token
|
||||
self.tool_calls: List[Dict] = []
|
||||
self._loaded_tools: Dict[str, object] = {}
|
||||
self.conversation_id: Optional[str] = None
|
||||
|
||||
def get_tools(self) -> Dict[str, Dict]:
|
||||
"""Load tool configs from DB based on user context."""
|
||||
if self.user_api_key:
|
||||
return self._get_tools_by_api_key(self.user_api_key)
|
||||
return self._get_user_tools(self.user or "local")
|
||||
|
||||
def _get_tools_by_api_key(self, api_key: str) -> Dict[str, Dict]:
|
||||
mongo = MongoDB.get_client()
|
||||
db = mongo[settings.MONGO_DB_NAME]
|
||||
agents_collection = db["agents"]
|
||||
tools_collection = db["user_tools"]
|
||||
|
||||
agent_data = agents_collection.find_one({"key": api_key})
|
||||
tool_ids = agent_data.get("tools", []) if agent_data else []
|
||||
|
||||
tools = (
|
||||
tools_collection.find(
|
||||
{"_id": {"$in": [ObjectId(tool_id) for tool_id in tool_ids]}}
|
||||
)
|
||||
if tool_ids
|
||||
else []
|
||||
)
|
||||
tools = list(tools)
|
||||
return {str(tool["_id"]): tool for tool in tools} if tools else {}
|
||||
|
||||
def _get_user_tools(self, user: str = "local") -> Dict[str, Dict]:
|
||||
mongo = MongoDB.get_client()
|
||||
db = mongo[settings.MONGO_DB_NAME]
|
||||
user_tools_collection = db["user_tools"]
|
||||
user_tools = user_tools_collection.find({"user": user, "status": True})
|
||||
user_tools = list(user_tools)
|
||||
return {str(i): tool for i, tool in enumerate(user_tools)}
|
||||
|
||||
def prepare_tools_for_llm(self, tools_dict: Dict) -> List[Dict]:
|
||||
"""Convert tool configs to LLM function schemas."""
|
||||
return [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": f"{action['name']}_{tool_id}",
|
||||
"description": action["description"],
|
||||
"parameters": self._build_tool_parameters(action),
|
||||
},
|
||||
}
|
||||
for tool_id, tool in tools_dict.items()
|
||||
if (
|
||||
(tool["name"] == "api_tool" and "actions" in tool.get("config", {}))
|
||||
or (tool["name"] != "api_tool" and "actions" in tool)
|
||||
)
|
||||
for action in (
|
||||
tool["config"]["actions"].values()
|
||||
if tool["name"] == "api_tool"
|
||||
else tool["actions"]
|
||||
)
|
||||
if action.get("active", True)
|
||||
]
|
||||
|
||||
def _build_tool_parameters(self, action: Dict) -> Dict:
|
||||
params = {"type": "object", "properties": {}, "required": []}
|
||||
for param_type in ["query_params", "headers", "body", "parameters"]:
|
||||
if param_type in action and action[param_type].get("properties"):
|
||||
for k, v in action[param_type]["properties"].items():
|
||||
if v.get("filled_by_llm", True):
|
||||
params["properties"][k] = {
|
||||
key: value
|
||||
for key, value in v.items()
|
||||
if key not in ("filled_by_llm", "value", "required")
|
||||
}
|
||||
if v.get("required", False):
|
||||
params["required"].append(k)
|
||||
return params
|
||||
|
||||
def execute(self, tools_dict: Dict, call, llm_class_name: str):
|
||||
"""Execute a tool call. Yields status events, returns (result, call_id)."""
|
||||
parser = ToolActionParser(llm_class_name)
|
||||
tool_id, action_name, call_args = parser.parse_args(call)
|
||||
|
||||
call_id = getattr(call, "id", None) or str(uuid.uuid4())
|
||||
|
||||
if tool_id is None or action_name is None:
|
||||
error_message = f"Error: Failed to parse LLM tool call. Tool name: {getattr(call, 'name', 'unknown')}"
|
||||
logger.error(error_message)
|
||||
|
||||
tool_call_data = {
|
||||
"tool_name": "unknown",
|
||||
"call_id": call_id,
|
||||
"action_name": getattr(call, "name", "unknown"),
|
||||
"arguments": call_args or {},
|
||||
"result": f"Failed to parse tool call. Invalid tool name format: {getattr(call, 'name', 'unknown')}",
|
||||
}
|
||||
yield {"type": "tool_call", "data": {**tool_call_data, "status": "error"}}
|
||||
self.tool_calls.append(tool_call_data)
|
||||
return "Failed to parse tool call.", call_id
|
||||
|
||||
if tool_id not in tools_dict:
|
||||
error_message = f"Error: Tool ID '{tool_id}' extracted from LLM call not found in available tools_dict. Available IDs: {list(tools_dict.keys())}"
|
||||
logger.error(error_message)
|
||||
|
||||
tool_call_data = {
|
||||
"tool_name": "unknown",
|
||||
"call_id": call_id,
|
||||
"action_name": f"{action_name}_{tool_id}",
|
||||
"arguments": call_args,
|
||||
"result": f"Tool with ID {tool_id} not found. Available tools: {list(tools_dict.keys())}",
|
||||
}
|
||||
yield {"type": "tool_call", "data": {**tool_call_data, "status": "error"}}
|
||||
self.tool_calls.append(tool_call_data)
|
||||
return f"Tool with ID {tool_id} not found.", call_id
|
||||
|
||||
tool_call_data = {
|
||||
"tool_name": tools_dict[tool_id]["name"],
|
||||
"call_id": call_id,
|
||||
"action_name": f"{action_name}_{tool_id}",
|
||||
"arguments": call_args,
|
||||
}
|
||||
yield {"type": "tool_call", "data": {**tool_call_data, "status": "pending"}}
|
||||
|
||||
tool_data = tools_dict[tool_id]
|
||||
action_data = (
|
||||
tool_data["config"]["actions"][action_name]
|
||||
if tool_data["name"] == "api_tool"
|
||||
else next(
|
||||
action
|
||||
for action in tool_data["actions"]
|
||||
if action["name"] == action_name
|
||||
)
|
||||
)
|
||||
|
||||
query_params, headers, body, parameters = {}, {}, {}, {}
|
||||
param_types = {
|
||||
"query_params": query_params,
|
||||
"headers": headers,
|
||||
"body": body,
|
||||
"parameters": parameters,
|
||||
}
|
||||
|
||||
for param_type, target_dict in param_types.items():
|
||||
if param_type in action_data and action_data[param_type].get("properties"):
|
||||
for param, details in action_data[param_type]["properties"].items():
|
||||
if (
|
||||
param not in call_args
|
||||
and "value" in details
|
||||
and details["value"]
|
||||
):
|
||||
target_dict[param] = details["value"]
|
||||
for param, value in call_args.items():
|
||||
for param_type, target_dict in param_types.items():
|
||||
if param_type in action_data and param in action_data[param_type].get(
|
||||
"properties", {}
|
||||
):
|
||||
target_dict[param] = value
|
||||
|
||||
# Load tool (with caching)
|
||||
tool = self._get_or_load_tool(
|
||||
tool_data, tool_id, action_name,
|
||||
headers=headers, query_params=query_params,
|
||||
)
|
||||
|
||||
resolved_arguments = (
|
||||
{"query_params": query_params, "headers": headers, "body": body}
|
||||
if tool_data["name"] == "api_tool"
|
||||
else parameters
|
||||
)
|
||||
if tool_data["name"] == "api_tool":
|
||||
logger.debug(
|
||||
f"Executing api: {action_name} with query_params: {query_params}, headers: {headers}, body: {body}"
|
||||
)
|
||||
result = tool.execute_action(action_name, **body)
|
||||
else:
|
||||
logger.debug(f"Executing tool: {action_name} with args: {call_args}")
|
||||
result = tool.execute_action(action_name, **parameters)
|
||||
|
||||
get_artifact_id = (
|
||||
getattr(tool, "get_artifact_id", None)
|
||||
if tool_data["name"] != "api_tool"
|
||||
else None
|
||||
)
|
||||
|
||||
artifact_id = None
|
||||
if callable(get_artifact_id):
|
||||
try:
|
||||
artifact_id = get_artifact_id(action_name, **parameters)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to extract artifact_id from tool %s for action %s",
|
||||
tool_data["name"],
|
||||
action_name,
|
||||
)
|
||||
|
||||
artifact_id = str(artifact_id).strip() if artifact_id is not None else ""
|
||||
if artifact_id:
|
||||
tool_call_data["artifact_id"] = artifact_id
|
||||
result_full = str(result)
|
||||
tool_call_data["resolved_arguments"] = resolved_arguments
|
||||
tool_call_data["result_full"] = result_full
|
||||
tool_call_data["result"] = (
|
||||
f"{result_full[:50]}..." if len(result_full) > 50 else result_full
|
||||
)
|
||||
|
||||
stream_tool_call_data = {
|
||||
key: value
|
||||
for key, value in tool_call_data.items()
|
||||
if key not in {"result_full", "resolved_arguments"}
|
||||
}
|
||||
yield {"type": "tool_call", "data": {**stream_tool_call_data, "status": "completed"}}
|
||||
self.tool_calls.append(tool_call_data)
|
||||
|
||||
return result, call_id
|
||||
|
||||
def _get_or_load_tool(
|
||||
self, tool_data: Dict, tool_id: str, action_name: str,
|
||||
headers: Optional[Dict] = None, query_params: Optional[Dict] = None,
|
||||
):
|
||||
"""Load a tool, using cache when possible."""
|
||||
cache_key = f"{tool_data['name']}:{tool_id}:{self.user or ''}"
|
||||
if cache_key in self._loaded_tools:
|
||||
return self._loaded_tools[cache_key]
|
||||
|
||||
tm = ToolManager(config={})
|
||||
|
||||
if tool_data["name"] == "api_tool":
|
||||
action_config = tool_data["config"]["actions"][action_name]
|
||||
tool_config = {
|
||||
"url": action_config["url"],
|
||||
"method": action_config["method"],
|
||||
"headers": headers or {},
|
||||
"query_params": query_params or {},
|
||||
}
|
||||
if "body_content_type" in action_config:
|
||||
tool_config["body_content_type"] = action_config.get(
|
||||
"body_content_type", "application/json"
|
||||
)
|
||||
tool_config["body_encoding_rules"] = action_config.get(
|
||||
"body_encoding_rules", {}
|
||||
)
|
||||
else:
|
||||
tool_config = tool_data["config"].copy() if tool_data["config"] else {}
|
||||
if tool_config.get("encrypted_credentials") and self.user:
|
||||
decrypted = decrypt_credentials(
|
||||
tool_config["encrypted_credentials"], self.user
|
||||
)
|
||||
tool_config.update(decrypted)
|
||||
tool_config["auth_credentials"] = decrypted
|
||||
tool_config.pop("encrypted_credentials", None)
|
||||
tool_config["tool_id"] = str(tool_data.get("_id", tool_id))
|
||||
if self.conversation_id:
|
||||
tool_config["conversation_id"] = self.conversation_id
|
||||
if tool_data["name"] == "mcp_tool":
|
||||
tool_config["query_mode"] = True
|
||||
|
||||
tool = tm.load_tool(
|
||||
tool_data["name"],
|
||||
tool_config=tool_config,
|
||||
user_id=self.user,
|
||||
)
|
||||
|
||||
# Don't cache api_tool since config varies by action
|
||||
if tool_data["name"] != "api_tool":
|
||||
self._loaded_tools[cache_key] = tool
|
||||
|
||||
return tool
|
||||
|
||||
def get_truncated_tool_calls(self) -> List[Dict]:
|
||||
return [
|
||||
{
|
||||
"tool_name": tool_call.get("tool_name"),
|
||||
"call_id": tool_call.get("call_id"),
|
||||
"action_name": tool_call.get("action_name"),
|
||||
"arguments": tool_call.get("arguments"),
|
||||
"artifact_id": tool_call.get("artifact_id"),
|
||||
"result": (
|
||||
f"{str(tool_call['result'])[:50]}..."
|
||||
if len(str(tool_call["result"])) > 50
|
||||
else tool_call["result"]
|
||||
),
|
||||
"status": "completed",
|
||||
}
|
||||
for tool_call in self.tool_calls
|
||||
]
|
||||
@@ -1,323 +0,0 @@
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, Optional, Union
|
||||
from urllib.parse import quote, urlencode
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ContentType(str, Enum):
|
||||
"""Supported content types for request bodies."""
|
||||
|
||||
JSON = "application/json"
|
||||
FORM_URLENCODED = "application/x-www-form-urlencoded"
|
||||
MULTIPART_FORM_DATA = "multipart/form-data"
|
||||
TEXT_PLAIN = "text/plain"
|
||||
XML = "application/xml"
|
||||
OCTET_STREAM = "application/octet-stream"
|
||||
|
||||
|
||||
class RequestBodySerializer:
|
||||
"""Serializes request bodies according to content-type and OpenAPI 3.1 spec."""
|
||||
|
||||
@staticmethod
|
||||
def serialize(
|
||||
body_data: Dict[str, Any],
|
||||
content_type: str = ContentType.JSON,
|
||||
encoding_rules: Optional[Dict[str, Dict[str, Any]]] = None,
|
||||
) -> tuple[Union[str, bytes], Dict[str, str]]:
|
||||
"""
|
||||
Serialize body data to appropriate format.
|
||||
|
||||
Args:
|
||||
body_data: Dictionary of body parameters
|
||||
content_type: Content-Type header value
|
||||
encoding_rules: OpenAPI Encoding Object rules per field
|
||||
|
||||
Returns:
|
||||
Tuple of (serialized_body, updated_headers_dict)
|
||||
|
||||
Raises:
|
||||
ValueError: If serialization fails
|
||||
"""
|
||||
if not body_data:
|
||||
return None, {}
|
||||
|
||||
try:
|
||||
content_type_lower = content_type.lower().split(";")[0].strip()
|
||||
|
||||
if content_type_lower == ContentType.JSON:
|
||||
return RequestBodySerializer._serialize_json(body_data)
|
||||
|
||||
elif content_type_lower == ContentType.FORM_URLENCODED:
|
||||
return RequestBodySerializer._serialize_form_urlencoded(
|
||||
body_data, encoding_rules
|
||||
)
|
||||
|
||||
elif content_type_lower == ContentType.MULTIPART_FORM_DATA:
|
||||
return RequestBodySerializer._serialize_multipart_form_data(
|
||||
body_data, encoding_rules
|
||||
)
|
||||
|
||||
elif content_type_lower == ContentType.TEXT_PLAIN:
|
||||
return RequestBodySerializer._serialize_text_plain(body_data)
|
||||
|
||||
elif content_type_lower == ContentType.XML:
|
||||
return RequestBodySerializer._serialize_xml(body_data)
|
||||
|
||||
elif content_type_lower == ContentType.OCTET_STREAM:
|
||||
return RequestBodySerializer._serialize_octet_stream(body_data)
|
||||
|
||||
else:
|
||||
logger.warning(
|
||||
f"Unknown content type: {content_type}, treating as JSON"
|
||||
)
|
||||
return RequestBodySerializer._serialize_json(body_data)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error serializing body: {str(e)}", exc_info=True)
|
||||
raise ValueError(f"Failed to serialize request body: {str(e)}")
|
||||
|
||||
@staticmethod
|
||||
def _serialize_json(body_data: Dict[str, Any]) -> tuple[str, Dict[str, str]]:
|
||||
"""Serialize body as JSON per OpenAPI spec."""
|
||||
try:
|
||||
serialized = json.dumps(
|
||||
body_data, separators=(",", ":"), ensure_ascii=False
|
||||
)
|
||||
headers = {"Content-Type": ContentType.JSON.value}
|
||||
return serialized, headers
|
||||
except (TypeError, ValueError) as e:
|
||||
raise ValueError(f"Failed to serialize JSON body: {str(e)}")
|
||||
|
||||
@staticmethod
|
||||
def _serialize_form_urlencoded(
|
||||
body_data: Dict[str, Any],
|
||||
encoding_rules: Optional[Dict[str, Dict[str, Any]]] = None,
|
||||
) -> tuple[str, Dict[str, str]]:
|
||||
"""Serialize body as application/x-www-form-urlencoded per RFC1866/RFC3986."""
|
||||
encoding_rules = encoding_rules or {}
|
||||
params = []
|
||||
|
||||
for key, value in body_data.items():
|
||||
if value is None:
|
||||
continue
|
||||
|
||||
rule = encoding_rules.get(key, {})
|
||||
style = rule.get("style", "form")
|
||||
explode = rule.get("explode", style == "form")
|
||||
content_type = rule.get("contentType", "text/plain")
|
||||
|
||||
serialized_value = RequestBodySerializer._serialize_form_value(
|
||||
value, style, explode, content_type, key
|
||||
)
|
||||
|
||||
if isinstance(serialized_value, list):
|
||||
for sv in serialized_value:
|
||||
params.append((key, sv))
|
||||
else:
|
||||
params.append((key, serialized_value))
|
||||
|
||||
# Use standard urlencode (replaces space with +)
|
||||
serialized = urlencode(params, safe="")
|
||||
headers = {"Content-Type": ContentType.FORM_URLENCODED.value}
|
||||
return serialized, headers
|
||||
|
||||
@staticmethod
|
||||
def _serialize_form_value(
|
||||
value: Any, style: str, explode: bool, content_type: str, key: str
|
||||
) -> Union[str, list]:
|
||||
"""Serialize individual form value with encoding rules."""
|
||||
if isinstance(value, dict):
|
||||
if content_type == "application/json":
|
||||
return json.dumps(value, separators=(",", ":"))
|
||||
elif content_type == "application/xml":
|
||||
return RequestBodySerializer._dict_to_xml(value)
|
||||
else:
|
||||
if style == "deepObject" and explode:
|
||||
return [
|
||||
f"{RequestBodySerializer._percent_encode(str(v))}"
|
||||
for v in value.values()
|
||||
]
|
||||
elif explode:
|
||||
return [
|
||||
f"{RequestBodySerializer._percent_encode(str(v))}"
|
||||
for v in value.values()
|
||||
]
|
||||
else:
|
||||
pairs = [f"{k},{v}" for k, v in value.items()]
|
||||
return RequestBodySerializer._percent_encode(",".join(pairs))
|
||||
|
||||
elif isinstance(value, (list, tuple)):
|
||||
if explode:
|
||||
return [
|
||||
RequestBodySerializer._percent_encode(str(item)) for item in value
|
||||
]
|
||||
else:
|
||||
return RequestBodySerializer._percent_encode(
|
||||
",".join(str(v) for v in value)
|
||||
)
|
||||
|
||||
else:
|
||||
return RequestBodySerializer._percent_encode(str(value))
|
||||
|
||||
@staticmethod
|
||||
def _serialize_multipart_form_data(
|
||||
body_data: Dict[str, Any],
|
||||
encoding_rules: Optional[Dict[str, Dict[str, Any]]] = None,
|
||||
) -> tuple[bytes, Dict[str, str]]:
|
||||
"""
|
||||
Serialize body as multipart/form-data per RFC7578.
|
||||
|
||||
Supports file uploads and encoding rules.
|
||||
"""
|
||||
import secrets
|
||||
|
||||
encoding_rules = encoding_rules or {}
|
||||
boundary = f"----DocsGPT{secrets.token_hex(16)}"
|
||||
parts = []
|
||||
|
||||
for key, value in body_data.items():
|
||||
if value is None:
|
||||
continue
|
||||
|
||||
rule = encoding_rules.get(key, {})
|
||||
content_type = rule.get("contentType", "text/plain")
|
||||
headers_rule = rule.get("headers", {})
|
||||
|
||||
part = RequestBodySerializer._create_multipart_part(
|
||||
key, value, content_type, headers_rule
|
||||
)
|
||||
parts.append(part)
|
||||
|
||||
body_bytes = f"--{boundary}\r\n".encode("utf-8")
|
||||
body_bytes += f"--{boundary}\r\n".join(parts).encode("utf-8")
|
||||
body_bytes += f"\r\n--{boundary}--\r\n".encode("utf-8")
|
||||
|
||||
headers = {
|
||||
"Content-Type": f"multipart/form-data; boundary={boundary}",
|
||||
}
|
||||
return body_bytes, headers
|
||||
|
||||
@staticmethod
|
||||
def _create_multipart_part(
|
||||
name: str, value: Any, content_type: str, headers_rule: Dict[str, Any]
|
||||
) -> str:
|
||||
"""Create a single multipart/form-data part."""
|
||||
headers = [
|
||||
f'Content-Disposition: form-data; name="{RequestBodySerializer._percent_encode(name)}"'
|
||||
]
|
||||
|
||||
if isinstance(value, bytes):
|
||||
if content_type == "application/octet-stream":
|
||||
value_encoded = base64.b64encode(value).decode("utf-8")
|
||||
else:
|
||||
value_encoded = value.decode("utf-8", errors="replace")
|
||||
headers.append(f"Content-Type: {content_type}")
|
||||
headers.append("Content-Transfer-Encoding: base64")
|
||||
elif isinstance(value, dict):
|
||||
if content_type == "application/json":
|
||||
value_encoded = json.dumps(value, separators=(",", ":"))
|
||||
elif content_type == "application/xml":
|
||||
value_encoded = RequestBodySerializer._dict_to_xml(value)
|
||||
else:
|
||||
value_encoded = str(value)
|
||||
headers.append(f"Content-Type: {content_type}")
|
||||
elif isinstance(value, str) and content_type != "text/plain":
|
||||
try:
|
||||
if content_type == "application/json":
|
||||
json.loads(value)
|
||||
value_encoded = value
|
||||
elif content_type == "application/xml":
|
||||
value_encoded = value
|
||||
else:
|
||||
value_encoded = str(value)
|
||||
except json.JSONDecodeError:
|
||||
value_encoded = str(value)
|
||||
headers.append(f"Content-Type: {content_type}")
|
||||
else:
|
||||
value_encoded = str(value)
|
||||
if content_type != "text/plain":
|
||||
headers.append(f"Content-Type: {content_type}")
|
||||
|
||||
part = "\r\n".join(headers) + "\r\n\r\n" + value_encoded + "\r\n"
|
||||
return part
|
||||
|
||||
@staticmethod
|
||||
def _serialize_text_plain(body_data: Dict[str, Any]) -> tuple[str, Dict[str, str]]:
|
||||
"""Serialize body as plain text."""
|
||||
if len(body_data) == 1:
|
||||
value = list(body_data.values())[0]
|
||||
return str(value), {"Content-Type": ContentType.TEXT_PLAIN.value}
|
||||
else:
|
||||
text = "\n".join(f"{k}: {v}" for k, v in body_data.items())
|
||||
return text, {"Content-Type": ContentType.TEXT_PLAIN.value}
|
||||
|
||||
@staticmethod
|
||||
def _serialize_xml(body_data: Dict[str, Any]) -> tuple[str, Dict[str, str]]:
|
||||
"""Serialize body as XML."""
|
||||
xml_str = RequestBodySerializer._dict_to_xml(body_data)
|
||||
return xml_str, {"Content-Type": ContentType.XML.value}
|
||||
|
||||
@staticmethod
|
||||
def _serialize_octet_stream(
|
||||
body_data: Dict[str, Any],
|
||||
) -> tuple[bytes, Dict[str, str]]:
|
||||
"""Serialize body as binary octet stream."""
|
||||
if isinstance(body_data, bytes):
|
||||
return body_data, {"Content-Type": ContentType.OCTET_STREAM.value}
|
||||
elif isinstance(body_data, str):
|
||||
return body_data.encode("utf-8"), {
|
||||
"Content-Type": ContentType.OCTET_STREAM.value
|
||||
}
|
||||
else:
|
||||
serialized = json.dumps(body_data)
|
||||
return serialized.encode("utf-8"), {
|
||||
"Content-Type": ContentType.OCTET_STREAM.value
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _percent_encode(value: str, safe_chars: str = "") -> str:
|
||||
"""
|
||||
Percent-encode per RFC3986.
|
||||
|
||||
Args:
|
||||
value: String to encode
|
||||
safe_chars: Additional characters to not encode
|
||||
"""
|
||||
return quote(value, safe=safe_chars)
|
||||
|
||||
@staticmethod
|
||||
def _dict_to_xml(data: Dict[str, Any], root_name: str = "root") -> str:
|
||||
"""
|
||||
Convert dict to simple XML format.
|
||||
"""
|
||||
|
||||
def build_xml(obj: Any, name: str) -> str:
|
||||
if isinstance(obj, dict):
|
||||
inner = "".join(build_xml(v, k) for k, v in obj.items())
|
||||
return f"<{name}>{inner}</{name}>"
|
||||
elif isinstance(obj, (list, tuple)):
|
||||
items = "".join(
|
||||
build_xml(item, f"{name[:-1] if name.endswith('s') else name}")
|
||||
for item in obj
|
||||
)
|
||||
return items
|
||||
else:
|
||||
return f"<{name}>{RequestBodySerializer._escape_xml(str(obj))}</{name}>"
|
||||
|
||||
root = build_xml(data, root_name)
|
||||
return f'<?xml version="1.0" encoding="UTF-8"?>{root}'
|
||||
|
||||
@staticmethod
|
||||
def _escape_xml(value: str) -> str:
|
||||
"""Escape XML special characters."""
|
||||
return (
|
||||
value.replace("&", "&")
|
||||
.replace("<", "<")
|
||||
.replace(">", ">")
|
||||
.replace('"', """)
|
||||
.replace("'", "'")
|
||||
)
|
||||
@@ -1,280 +1,72 @@
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from typing import Any, Dict, Optional
|
||||
from urllib.parse import urlencode
|
||||
|
||||
import requests
|
||||
|
||||
from application.agents.tools.api_body_serializer import (
|
||||
ContentType,
|
||||
RequestBodySerializer,
|
||||
)
|
||||
from application.agents.tools.base import Tool
|
||||
from application.core.url_validation import validate_url, SSRFError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_TIMEOUT = 90 # seconds
|
||||
|
||||
|
||||
class APITool(Tool):
|
||||
"""
|
||||
API Tool
|
||||
A flexible tool for performing various API actions (e.g., sending messages, retrieving data) via custom user-specified APIs.
|
||||
A flexible tool for performing various API actions (e.g., sending messages, retrieving data) via custom user-specified APIs
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
self.url = config.get("url", "")
|
||||
self.method = config.get("method", "GET")
|
||||
self.headers = config.get("headers", {})
|
||||
self.headers = config.get("headers", {"Content-Type": "application/json"})
|
||||
self.query_params = config.get("query_params", {})
|
||||
self.body_content_type = config.get("body_content_type", ContentType.JSON)
|
||||
self.body_encoding_rules = config.get("body_encoding_rules", {})
|
||||
|
||||
def execute_action(self, action_name, **kwargs):
|
||||
"""Execute an API action with the given arguments."""
|
||||
return self._make_api_call(
|
||||
self.url,
|
||||
self.method,
|
||||
self.headers,
|
||||
self.query_params,
|
||||
kwargs,
|
||||
self.body_content_type,
|
||||
self.body_encoding_rules,
|
||||
self.url, self.method, self.headers, self.query_params, kwargs
|
||||
)
|
||||
|
||||
def _make_api_call(
|
||||
self,
|
||||
url: str,
|
||||
method: str,
|
||||
headers: Dict[str, str],
|
||||
query_params: Dict[str, Any],
|
||||
body: Dict[str, Any],
|
||||
content_type: str = ContentType.JSON,
|
||||
encoding_rules: Optional[Dict[str, Dict[str, Any]]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Make an API call with proper body serialization and error handling.
|
||||
|
||||
Args:
|
||||
url: API endpoint URL
|
||||
method: HTTP method (GET, POST, PUT, DELETE, PATCH, HEAD, OPTIONS)
|
||||
headers: Request headers dict
|
||||
query_params: URL query parameters
|
||||
body: Request body as dict
|
||||
content_type: Content-Type for serialization
|
||||
encoding_rules: OpenAPI encoding rules
|
||||
|
||||
Returns:
|
||||
Dict with status_code, data, and message
|
||||
"""
|
||||
request_url = url
|
||||
request_headers = headers.copy() if headers else {}
|
||||
response = None
|
||||
|
||||
# Validate URL to prevent SSRF attacks
|
||||
def _make_api_call(self, url, method, headers, query_params, body):
|
||||
if query_params:
|
||||
url = f"{url}?{requests.compat.urlencode(query_params)}"
|
||||
# if isinstance(body, dict):
|
||||
# body = json.dumps(body)
|
||||
try:
|
||||
validate_url(request_url)
|
||||
except SSRFError as e:
|
||||
logger.error(f"URL validation failed: {e}")
|
||||
return {
|
||||
"status_code": None,
|
||||
"message": f"URL validation error: {e}",
|
||||
"data": None,
|
||||
}
|
||||
|
||||
try:
|
||||
path_params_used = set()
|
||||
if query_params:
|
||||
for match in re.finditer(r"\{([^}]+)\}", request_url):
|
||||
param_name = match.group(1)
|
||||
if param_name in query_params:
|
||||
request_url = request_url.replace(
|
||||
f"{{{param_name}}}", str(query_params[param_name])
|
||||
)
|
||||
path_params_used.add(param_name)
|
||||
remaining_params = {
|
||||
k: v for k, v in query_params.items() if k not in path_params_used
|
||||
}
|
||||
if remaining_params:
|
||||
query_string = urlencode(remaining_params)
|
||||
separator = "&" if "?" in request_url else "?"
|
||||
request_url = f"{request_url}{separator}{query_string}"
|
||||
|
||||
# Re-validate URL after parameter substitution to prevent SSRF via path params
|
||||
try:
|
||||
validate_url(request_url)
|
||||
except SSRFError as e:
|
||||
logger.error(f"URL validation failed after parameter substitution: {e}")
|
||||
return {
|
||||
"status_code": None,
|
||||
"message": f"URL validation error: {e}",
|
||||
"data": None,
|
||||
}
|
||||
|
||||
# Serialize body based on content type
|
||||
|
||||
if body and body != {}:
|
||||
try:
|
||||
serialized_body, body_headers = RequestBodySerializer.serialize(
|
||||
body, content_type, encoding_rules
|
||||
)
|
||||
request_headers.update(body_headers)
|
||||
except ValueError as e:
|
||||
logger.error(f"Body serialization failed: {str(e)}")
|
||||
return {
|
||||
"status_code": None,
|
||||
"message": f"Body serialization error: {str(e)}",
|
||||
"data": None,
|
||||
}
|
||||
else:
|
||||
serialized_body = None
|
||||
if "Content-Type" not in request_headers and method not in [
|
||||
"GET",
|
||||
"HEAD",
|
||||
"DELETE",
|
||||
]:
|
||||
request_headers["Content-Type"] = ContentType.JSON
|
||||
logger.debug(
|
||||
f"API Call: {method} {request_url} | Content-Type: {request_headers.get('Content-Type', 'N/A')}"
|
||||
)
|
||||
|
||||
if method.upper() == "GET":
|
||||
response = requests.get(
|
||||
request_url, headers=request_headers, timeout=DEFAULT_TIMEOUT
|
||||
)
|
||||
elif method.upper() == "POST":
|
||||
response = requests.post(
|
||||
request_url,
|
||||
data=serialized_body,
|
||||
headers=request_headers,
|
||||
timeout=DEFAULT_TIMEOUT,
|
||||
)
|
||||
elif method.upper() == "PUT":
|
||||
response = requests.put(
|
||||
request_url,
|
||||
data=serialized_body,
|
||||
headers=request_headers,
|
||||
timeout=DEFAULT_TIMEOUT,
|
||||
)
|
||||
elif method.upper() == "DELETE":
|
||||
response = requests.delete(
|
||||
request_url, headers=request_headers, timeout=DEFAULT_TIMEOUT
|
||||
)
|
||||
elif method.upper() == "PATCH":
|
||||
response = requests.patch(
|
||||
request_url,
|
||||
data=serialized_body,
|
||||
headers=request_headers,
|
||||
timeout=DEFAULT_TIMEOUT,
|
||||
)
|
||||
elif method.upper() == "HEAD":
|
||||
response = requests.head(
|
||||
request_url, headers=request_headers, timeout=DEFAULT_TIMEOUT
|
||||
)
|
||||
elif method.upper() == "OPTIONS":
|
||||
response = requests.options(
|
||||
request_url, headers=request_headers, timeout=DEFAULT_TIMEOUT
|
||||
)
|
||||
else:
|
||||
return {
|
||||
"status_code": None,
|
||||
"message": f"Unsupported HTTP method: {method}",
|
||||
"data": None,
|
||||
}
|
||||
print(f"Making API call: {method} {url} with body: {body}")
|
||||
if body == "{}":
|
||||
body = None
|
||||
response = requests.request(method, url, headers=headers, data=body)
|
||||
response.raise_for_status()
|
||||
|
||||
data = self._parse_response(response)
|
||||
content_type = response.headers.get(
|
||||
"Content-Type", "application/json"
|
||||
).lower()
|
||||
if "application/json" in content_type:
|
||||
try:
|
||||
data = response.json()
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"Error decoding JSON: {e}. Raw response: {response.text}")
|
||||
return {
|
||||
"status_code": response.status_code,
|
||||
"message": f"API call returned invalid JSON. Error: {e}",
|
||||
"data": response.text,
|
||||
}
|
||||
elif "text/" in content_type or "application/xml" in content_type:
|
||||
data = response.text
|
||||
elif not response.content:
|
||||
data = None
|
||||
else:
|
||||
print(f"Unsupported content type: {content_type}")
|
||||
data = response.content
|
||||
|
||||
return {
|
||||
"status_code": response.status_code,
|
||||
"data": data,
|
||||
"message": "API call successful.",
|
||||
}
|
||||
except requests.exceptions.Timeout:
|
||||
logger.error(f"Request timeout for {request_url}")
|
||||
return {
|
||||
"status_code": None,
|
||||
"message": f"Request timeout ({DEFAULT_TIMEOUT}s exceeded)",
|
||||
"data": None,
|
||||
}
|
||||
except requests.exceptions.ConnectionError as e:
|
||||
logger.error(f"Connection error: {str(e)}")
|
||||
return {
|
||||
"status_code": None,
|
||||
"message": f"Connection error: {str(e)}",
|
||||
"data": None,
|
||||
}
|
||||
except requests.exceptions.HTTPError as e:
|
||||
logger.error(f"HTTP error {response.status_code}: {str(e)}")
|
||||
try:
|
||||
error_data = response.json()
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
error_data = response.text
|
||||
return {
|
||||
"status_code": response.status_code,
|
||||
"message": f"HTTP Error {response.status_code}",
|
||||
"data": error_data,
|
||||
}
|
||||
except requests.exceptions.RequestException as e:
|
||||
logger.error(f"Request failed: {str(e)}")
|
||||
return {
|
||||
"status_code": response.status_code if response else None,
|
||||
"message": f"API call failed: {str(e)}",
|
||||
"data": None,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error in API call: {str(e)}", exc_info=True)
|
||||
return {
|
||||
"status_code": None,
|
||||
"message": f"Unexpected error: {str(e)}",
|
||||
"data": None,
|
||||
}
|
||||
|
||||
def _parse_response(self, response: requests.Response) -> Any:
|
||||
"""
|
||||
Parse response based on Content-Type header.
|
||||
|
||||
Supports: JSON, XML, plain text, binary data.
|
||||
"""
|
||||
content_type = response.headers.get("Content-Type", "").lower()
|
||||
|
||||
if not response.content:
|
||||
return None
|
||||
# JSON response
|
||||
|
||||
if "application/json" in content_type:
|
||||
try:
|
||||
return response.json()
|
||||
except json.JSONDecodeError as e:
|
||||
logger.warning(f"Failed to parse JSON response: {str(e)}")
|
||||
return response.text
|
||||
# XML response
|
||||
|
||||
elif "application/xml" in content_type or "text/xml" in content_type:
|
||||
return response.text
|
||||
# Plain text response
|
||||
|
||||
elif "text/plain" in content_type or "text/html" in content_type:
|
||||
return response.text
|
||||
# Binary/unknown response
|
||||
|
||||
else:
|
||||
# Try to decode as text first, fall back to base64
|
||||
|
||||
try:
|
||||
return response.text
|
||||
except (UnicodeDecodeError, AttributeError):
|
||||
import base64
|
||||
|
||||
return base64.b64encode(response.content).decode("utf-8")
|
||||
|
||||
def get_actions_metadata(self):
|
||||
"""Return metadata for available actions (none for API Tool - actions are user-defined)."""
|
||||
return []
|
||||
|
||||
def get_config_requirements(self):
|
||||
"""Return configuration requirements for the tool."""
|
||||
return {}
|
||||
|
||||
@@ -1,11 +1,6 @@
|
||||
import logging
|
||||
|
||||
import requests
|
||||
|
||||
from application.agents.tools.base import Tool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BraveSearchTool(Tool):
|
||||
"""
|
||||
@@ -46,7 +41,7 @@ class BraveSearchTool(Tool):
|
||||
"""
|
||||
Performs a web search using the Brave Search API.
|
||||
"""
|
||||
logger.debug("Performing Brave web search for: %s", query)
|
||||
print(f"Performing Brave web search for: {query}")
|
||||
|
||||
url = f"{self.base_url}/web/search"
|
||||
|
||||
@@ -99,7 +94,7 @@ class BraveSearchTool(Tool):
|
||||
"""
|
||||
Performs an image search using the Brave Search API.
|
||||
"""
|
||||
logger.debug("Performing Brave image search for: %s", query)
|
||||
print(f"Performing Brave image search for: {query}")
|
||||
|
||||
url = f"{self.base_url}/images/search"
|
||||
|
||||
@@ -182,10 +177,6 @@ class BraveSearchTool(Tool):
|
||||
return {
|
||||
"token": {
|
||||
"type": "string",
|
||||
"label": "API Key",
|
||||
"description": "Brave Search API key for authentication",
|
||||
"required": True,
|
||||
"secret": True,
|
||||
"order": 1,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -1,14 +1,5 @@
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from application.agents.tools.base import Tool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MAX_RETRIES = 3
|
||||
RETRY_DELAY = 2.0
|
||||
DEFAULT_TIMEOUT = 15
|
||||
from duckduckgo_search import DDGS
|
||||
|
||||
|
||||
class DuckDuckGoSearchTool(Tool):
|
||||
@@ -19,123 +10,71 @@ class DuckDuckGoSearchTool(Tool):
|
||||
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
self.timeout = config.get("timeout", DEFAULT_TIMEOUT)
|
||||
|
||||
def _get_ddgs_client(self):
|
||||
from ddgs import DDGS
|
||||
|
||||
return DDGS(timeout=self.timeout)
|
||||
|
||||
def _execute_with_retry(self, operation, operation_name: str) -> Dict[str, Any]:
|
||||
last_error = None
|
||||
for attempt in range(1, MAX_RETRIES + 1):
|
||||
try:
|
||||
results = operation()
|
||||
return {
|
||||
"status_code": 200,
|
||||
"results": list(results) if results else [],
|
||||
"message": f"{operation_name} completed successfully.",
|
||||
}
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
error_str = str(e).lower()
|
||||
if "ratelimit" in error_str or "429" in error_str:
|
||||
if attempt < MAX_RETRIES:
|
||||
delay = RETRY_DELAY * attempt
|
||||
logger.warning(
|
||||
f"{operation_name} rate limited, retrying in {delay}s (attempt {attempt}/{MAX_RETRIES})"
|
||||
)
|
||||
time.sleep(delay)
|
||||
continue
|
||||
logger.error(f"{operation_name} failed: {e}")
|
||||
break
|
||||
return {
|
||||
"status_code": 500,
|
||||
"results": [],
|
||||
"message": f"{operation_name} failed: {str(last_error)}",
|
||||
}
|
||||
|
||||
def execute_action(self, action_name, **kwargs):
|
||||
actions = {
|
||||
"ddg_web_search": self._web_search,
|
||||
"ddg_image_search": self._image_search,
|
||||
"ddg_news_search": self._news_search,
|
||||
}
|
||||
if action_name not in actions:
|
||||
|
||||
if action_name in actions:
|
||||
return actions[action_name](**kwargs)
|
||||
else:
|
||||
raise ValueError(f"Unknown action: {action_name}")
|
||||
return actions[action_name](**kwargs)
|
||||
|
||||
def _web_search(
|
||||
self,
|
||||
query: str,
|
||||
max_results: int = 5,
|
||||
region: str = "wt-wt",
|
||||
safesearch: str = "moderate",
|
||||
timelimit: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
logger.info(f"DuckDuckGo web search: {query}")
|
||||
query,
|
||||
max_results=5,
|
||||
):
|
||||
print(f"Performing DuckDuckGo web search for: {query}")
|
||||
|
||||
def operation():
|
||||
client = self._get_ddgs_client()
|
||||
return client.text(
|
||||
try:
|
||||
results = DDGS().text(
|
||||
query,
|
||||
region=region,
|
||||
safesearch=safesearch,
|
||||
timelimit=timelimit,
|
||||
max_results=min(max_results, 20),
|
||||
max_results=max_results,
|
||||
)
|
||||
|
||||
return self._execute_with_retry(operation, "Web search")
|
||||
return {
|
||||
"status_code": 200,
|
||||
"results": results,
|
||||
"message": "Web search completed successfully.",
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"status_code": 500,
|
||||
"message": f"Web search failed: {str(e)}",
|
||||
}
|
||||
|
||||
def _image_search(
|
||||
self,
|
||||
query: str,
|
||||
max_results: int = 5,
|
||||
region: str = "wt-wt",
|
||||
safesearch: str = "moderate",
|
||||
timelimit: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
logger.info(f"DuckDuckGo image search: {query}")
|
||||
query,
|
||||
max_results=5,
|
||||
):
|
||||
print(f"Performing DuckDuckGo image search for: {query}")
|
||||
|
||||
def operation():
|
||||
client = self._get_ddgs_client()
|
||||
return client.images(
|
||||
query,
|
||||
region=region,
|
||||
safesearch=safesearch,
|
||||
timelimit=timelimit,
|
||||
max_results=min(max_results, 50),
|
||||
try:
|
||||
results = DDGS().images(
|
||||
keywords=query,
|
||||
max_results=max_results,
|
||||
)
|
||||
|
||||
return self._execute_with_retry(operation, "Image search")
|
||||
|
||||
def _news_search(
|
||||
self,
|
||||
query: str,
|
||||
max_results: int = 5,
|
||||
region: str = "wt-wt",
|
||||
safesearch: str = "moderate",
|
||||
timelimit: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
logger.info(f"DuckDuckGo news search: {query}")
|
||||
|
||||
def operation():
|
||||
client = self._get_ddgs_client()
|
||||
return client.news(
|
||||
query,
|
||||
region=region,
|
||||
safesearch=safesearch,
|
||||
timelimit=timelimit,
|
||||
max_results=min(max_results, 20),
|
||||
)
|
||||
|
||||
return self._execute_with_retry(operation, "News search")
|
||||
return {
|
||||
"status_code": 200,
|
||||
"results": results,
|
||||
"message": "Image search completed successfully.",
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"status_code": 500,
|
||||
"message": f"Image search failed: {str(e)}",
|
||||
}
|
||||
|
||||
def get_actions_metadata(self):
|
||||
return [
|
||||
{
|
||||
"name": "ddg_web_search",
|
||||
"description": "Search the web using DuckDuckGo. Returns titles, URLs, and snippets.",
|
||||
"description": "Perform a web search using DuckDuckGo.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@@ -145,15 +84,7 @@ class DuckDuckGoSearchTool(Tool):
|
||||
},
|
||||
"max_results": {
|
||||
"type": "integer",
|
||||
"description": "Number of results (default: 5, max: 20)",
|
||||
},
|
||||
"region": {
|
||||
"type": "string",
|
||||
"description": "Region code (default: wt-wt for worldwide, us-en for US)",
|
||||
},
|
||||
"timelimit": {
|
||||
"type": "string",
|
||||
"description": "Time filter: d (day), w (week), m (month), y (year)",
|
||||
"description": "Number of results to return (default: 5)",
|
||||
},
|
||||
},
|
||||
"required": ["query"],
|
||||
@@ -161,43 +92,17 @@ class DuckDuckGoSearchTool(Tool):
|
||||
},
|
||||
{
|
||||
"name": "ddg_image_search",
|
||||
"description": "Search for images using DuckDuckGo. Returns image URLs and metadata.",
|
||||
"description": "Perform an image search using DuckDuckGo.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "Image search query",
|
||||
"description": "Search query",
|
||||
},
|
||||
"max_results": {
|
||||
"type": "integer",
|
||||
"description": "Number of results (default: 5, max: 50)",
|
||||
},
|
||||
"region": {
|
||||
"type": "string",
|
||||
"description": "Region code (default: wt-wt for worldwide)",
|
||||
},
|
||||
},
|
||||
"required": ["query"],
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "ddg_news_search",
|
||||
"description": "Search for news articles using DuckDuckGo. Returns recent news.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "News search query",
|
||||
},
|
||||
"max_results": {
|
||||
"type": "integer",
|
||||
"description": "Number of results (default: 5, max: 20)",
|
||||
},
|
||||
"timelimit": {
|
||||
"type": "string",
|
||||
"description": "Time filter: d (day), w (week), m (month)",
|
||||
"description": "Number of results to return (default: 5, max: 50)",
|
||||
},
|
||||
},
|
||||
"required": ["query"],
|
||||
|
||||
@@ -1,436 +0,0 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from application.agents.tools.base import Tool
|
||||
from application.core.settings import settings
|
||||
from application.retriever.retriever_creator import RetrieverCreator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class InternalSearchTool(Tool):
|
||||
"""Wraps the ClassicRAG retriever as an LLM-callable tool.
|
||||
|
||||
Instead of pre-fetching docs into the prompt, the LLM decides
|
||||
when and what to search. Supports multiple searches per session.
|
||||
|
||||
Optional capabilities (enabled when sources have directory_structure):
|
||||
- path_filter on search: restrict results to a specific file/folder
|
||||
- list_files action: browse the file/folder structure
|
||||
"""
|
||||
|
||||
def __init__(self, config: Dict):
|
||||
self.config = config
|
||||
self.retrieved_docs: List[Dict] = []
|
||||
self._retriever = None
|
||||
self._directory_structure: Optional[Dict] = None
|
||||
self._dir_structure_loaded = False
|
||||
|
||||
def _get_retriever(self):
|
||||
if self._retriever is None:
|
||||
self._retriever = RetrieverCreator.create_retriever(
|
||||
self.config.get("retriever_name", "classic"),
|
||||
source=self.config.get("source", {}),
|
||||
chat_history=[],
|
||||
prompt="",
|
||||
chunks=int(self.config.get("chunks", 2)),
|
||||
doc_token_limit=int(self.config.get("doc_token_limit", 50000)),
|
||||
model_id=self.config.get("model_id", "docsgpt-local"),
|
||||
user_api_key=self.config.get("user_api_key"),
|
||||
agent_id=self.config.get("agent_id"),
|
||||
llm_name=self.config.get("llm_name", settings.LLM_PROVIDER),
|
||||
api_key=self.config.get("api_key", settings.API_KEY),
|
||||
decoded_token=self.config.get("decoded_token"),
|
||||
)
|
||||
return self._retriever
|
||||
|
||||
def _get_directory_structure(self) -> Optional[Dict]:
|
||||
"""Load directory structure from MongoDB for the configured sources."""
|
||||
if self._dir_structure_loaded:
|
||||
return self._directory_structure
|
||||
|
||||
self._dir_structure_loaded = True
|
||||
source = self.config.get("source", {})
|
||||
active_docs = source.get("active_docs", [])
|
||||
if not active_docs:
|
||||
return None
|
||||
|
||||
try:
|
||||
from bson.objectid import ObjectId
|
||||
from application.core.mongo_db import MongoDB
|
||||
|
||||
mongo = MongoDB.get_client()
|
||||
db = mongo[settings.MONGO_DB_NAME]
|
||||
sources_collection = db["sources"]
|
||||
|
||||
if isinstance(active_docs, str):
|
||||
active_docs = [active_docs]
|
||||
|
||||
merged_structure = {}
|
||||
for doc_id in active_docs:
|
||||
try:
|
||||
source_doc = sources_collection.find_one(
|
||||
{"_id": ObjectId(doc_id)}
|
||||
)
|
||||
if not source_doc:
|
||||
continue
|
||||
dir_str = source_doc.get("directory_structure")
|
||||
if dir_str:
|
||||
if isinstance(dir_str, str):
|
||||
dir_str = json.loads(dir_str)
|
||||
source_name = source_doc.get("name", doc_id)
|
||||
if len(active_docs) > 1:
|
||||
merged_structure[source_name] = dir_str
|
||||
else:
|
||||
merged_structure = dir_str
|
||||
except Exception as e:
|
||||
logger.debug(f"Could not load dir structure for {doc_id}: {e}")
|
||||
|
||||
self._directory_structure = merged_structure if merged_structure else None
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to load directory structures: {e}")
|
||||
|
||||
return self._directory_structure
|
||||
|
||||
def execute_action(self, action_name: str, **kwargs):
|
||||
if action_name == "search":
|
||||
return self._execute_search(**kwargs)
|
||||
elif action_name == "list_files":
|
||||
return self._execute_list_files(**kwargs)
|
||||
return f"Unknown action: {action_name}"
|
||||
|
||||
def _execute_search(self, **kwargs) -> str:
|
||||
query = kwargs.get("query", "")
|
||||
path_filter = kwargs.get("path_filter", "")
|
||||
|
||||
if not query:
|
||||
return "Error: 'query' parameter is required."
|
||||
|
||||
try:
|
||||
retriever = self._get_retriever()
|
||||
docs = retriever.search(query)
|
||||
except Exception as e:
|
||||
logger.error(f"Internal search failed: {e}", exc_info=True)
|
||||
return "Search failed: an internal error occurred."
|
||||
|
||||
if not docs:
|
||||
return "No documents found matching your query."
|
||||
|
||||
# Apply path filter if specified
|
||||
if path_filter:
|
||||
path_lower = path_filter.lower()
|
||||
docs = [
|
||||
d
|
||||
for d in docs
|
||||
if path_lower in d.get("source", "").lower()
|
||||
or path_lower in d.get("filename", "").lower()
|
||||
or path_lower in d.get("title", "").lower()
|
||||
]
|
||||
if not docs:
|
||||
return f"No documents found matching query '{query}' in path '{path_filter}'."
|
||||
|
||||
# Accumulate for source tracking
|
||||
for doc in docs:
|
||||
if doc not in self.retrieved_docs:
|
||||
self.retrieved_docs.append(doc)
|
||||
|
||||
# Format results for the LLM
|
||||
formatted = []
|
||||
for i, doc in enumerate(docs, 1):
|
||||
title = doc.get("title", "Untitled")
|
||||
text = doc.get("text", "")
|
||||
source = doc.get("source", "Unknown")
|
||||
filename = doc.get("filename", "")
|
||||
header = filename or title
|
||||
formatted.append(f"[{i}] {header} (source: {source})\n{text}")
|
||||
|
||||
return "\n\n---\n\n".join(formatted)
|
||||
|
||||
def _execute_list_files(self, **kwargs) -> str:
|
||||
path = kwargs.get("path", "")
|
||||
dir_structure = self._get_directory_structure()
|
||||
|
||||
if not dir_structure:
|
||||
return "No file structure available for the current sources."
|
||||
|
||||
# Navigate to the requested path
|
||||
current = dir_structure
|
||||
if path:
|
||||
for part in path.strip("/").split("/"):
|
||||
if not part:
|
||||
continue
|
||||
if isinstance(current, dict) and part in current:
|
||||
current = current[part]
|
||||
else:
|
||||
return f"Path '{path}' not found in the file structure."
|
||||
|
||||
# Format the structure for the LLM
|
||||
return self._format_structure(current, path or "/")
|
||||
|
||||
def _format_structure(self, node: Dict, current_path: str) -> str:
|
||||
if not isinstance(node, dict):
|
||||
return f"'{current_path}' is a file, not a directory."
|
||||
|
||||
lines = [f"File structure at '{current_path}':\n"]
|
||||
folders = []
|
||||
files = []
|
||||
|
||||
for name, value in sorted(node.items()):
|
||||
if isinstance(value, dict):
|
||||
# Check if it's a file metadata dict or a folder
|
||||
if "type" in value or "size_bytes" in value or "token_count" in value:
|
||||
# It's a file with metadata
|
||||
size = value.get("token_count", "")
|
||||
ftype = value.get("type", "")
|
||||
info_parts = []
|
||||
if ftype:
|
||||
info_parts.append(ftype)
|
||||
if size:
|
||||
info_parts.append(f"{size} tokens")
|
||||
info = f" ({', '.join(info_parts)})" if info_parts else ""
|
||||
files.append(f" {name}{info}")
|
||||
else:
|
||||
# It's a folder
|
||||
count = self._count_files(value)
|
||||
folders.append(f" {name}/ ({count} items)")
|
||||
else:
|
||||
files.append(f" {name}")
|
||||
|
||||
if folders:
|
||||
lines.append("Folders:")
|
||||
lines.extend(folders)
|
||||
if files:
|
||||
lines.append("Files:")
|
||||
lines.extend(files)
|
||||
if not folders and not files:
|
||||
lines.append(" (empty)")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
def _count_files(self, node: Dict) -> int:
|
||||
count = 0
|
||||
for value in node.values():
|
||||
if isinstance(value, dict):
|
||||
if "type" in value or "size_bytes" in value or "token_count" in value:
|
||||
count += 1
|
||||
else:
|
||||
count += self._count_files(value)
|
||||
else:
|
||||
count += 1
|
||||
return count
|
||||
|
||||
def get_actions_metadata(self):
|
||||
actions = [
|
||||
{
|
||||
"name": "search",
|
||||
"description": (
|
||||
"Search the user's uploaded documents and knowledge base. "
|
||||
"Use this to find relevant information before answering questions. "
|
||||
"You can call this multiple times with different queries."
|
||||
),
|
||||
"parameters": {
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "The search query. Be specific and focused.",
|
||||
"filled_by_llm": True,
|
||||
"required": True,
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
# Add path_filter and list_files only if directory structure exists
|
||||
has_structure = self.config.get("has_directory_structure", False)
|
||||
if has_structure:
|
||||
actions[0]["parameters"]["properties"]["path_filter"] = {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Optional: filter results to a specific file or folder path. "
|
||||
"Use list_files first to see available paths."
|
||||
),
|
||||
"filled_by_llm": True,
|
||||
"required": False,
|
||||
}
|
||||
actions.append(
|
||||
{
|
||||
"name": "list_files",
|
||||
"description": (
|
||||
"Browse the file and folder structure of the knowledge base. "
|
||||
"Use this to see what files are available before searching. "
|
||||
"Optionally provide a path to browse a specific folder."
|
||||
),
|
||||
"parameters": {
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Optional: folder path to browse. Leave empty for root.",
|
||||
"filled_by_llm": True,
|
||||
"required": False,
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
return actions
|
||||
|
||||
def get_config_requirements(self):
|
||||
return {}
|
||||
|
||||
|
||||
# Constants for building synthetic tools_dict entries
|
||||
INTERNAL_TOOL_ID = "internal"
|
||||
|
||||
|
||||
def build_internal_tool_entry(has_directory_structure: bool = False) -> Dict:
|
||||
"""Build the tools_dict entry for InternalSearchTool.
|
||||
|
||||
Dynamically includes list_files and path_filter based on
|
||||
whether the sources have directory structure.
|
||||
"""
|
||||
search_params = {
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "The search query. Be specific and focused.",
|
||||
"filled_by_llm": True,
|
||||
"required": True,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
actions = [
|
||||
{
|
||||
"name": "search",
|
||||
"description": (
|
||||
"Search the user's uploaded documents and knowledge base. "
|
||||
"Use this to find relevant information before answering questions. "
|
||||
"You can call this multiple times with different queries."
|
||||
),
|
||||
"active": True,
|
||||
"parameters": search_params,
|
||||
}
|
||||
]
|
||||
|
||||
if has_directory_structure:
|
||||
search_params["properties"]["path_filter"] = {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Optional: filter results to a specific file or folder path. "
|
||||
"Use list_files first to see available paths."
|
||||
),
|
||||
"filled_by_llm": True,
|
||||
"required": False,
|
||||
}
|
||||
actions.append(
|
||||
{
|
||||
"name": "list_files",
|
||||
"description": (
|
||||
"Browse the file and folder structure of the knowledge base. "
|
||||
"Use this to see what files are available before searching. "
|
||||
"Optionally provide a path to browse a specific folder."
|
||||
),
|
||||
"active": True,
|
||||
"parameters": {
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Optional: folder path to browse. Leave empty for root.",
|
||||
"filled_by_llm": True,
|
||||
"required": False,
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
return {"name": "internal_search", "actions": actions}
|
||||
|
||||
|
||||
# Keep backward compat
|
||||
INTERNAL_TOOL_ENTRY = build_internal_tool_entry(has_directory_structure=False)
|
||||
|
||||
|
||||
def sources_have_directory_structure(source: Dict) -> bool:
|
||||
"""Check if any of the active sources have directory_structure in MongoDB."""
|
||||
active_docs = source.get("active_docs", [])
|
||||
if not active_docs:
|
||||
return False
|
||||
|
||||
try:
|
||||
from bson.objectid import ObjectId
|
||||
from application.core.mongo_db import MongoDB
|
||||
|
||||
mongo = MongoDB.get_client()
|
||||
db = mongo[settings.MONGO_DB_NAME]
|
||||
sources_collection = db["sources"]
|
||||
|
||||
if isinstance(active_docs, str):
|
||||
active_docs = [active_docs]
|
||||
|
||||
for doc_id in active_docs:
|
||||
try:
|
||||
source_doc = sources_collection.find_one(
|
||||
{"_id": ObjectId(doc_id)},
|
||||
{"directory_structure": 1},
|
||||
)
|
||||
if source_doc and source_doc.get("directory_structure"):
|
||||
return True
|
||||
except Exception:
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.debug(f"Could not check directory structure: {e}")
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def add_internal_search_tool(tools_dict: Dict, retriever_config: Dict) -> None:
|
||||
"""Add the internal search tool to tools_dict if sources are configured.
|
||||
|
||||
Shared by AgenticAgent and ResearchAgent to avoid duplicate setup logic.
|
||||
Mutates tools_dict in place.
|
||||
"""
|
||||
source = retriever_config.get("source", {})
|
||||
has_sources = bool(source.get("active_docs"))
|
||||
if not retriever_config or not has_sources:
|
||||
return
|
||||
|
||||
has_dir = sources_have_directory_structure(source)
|
||||
internal_entry = build_internal_tool_entry(has_directory_structure=has_dir)
|
||||
internal_entry["config"] = build_internal_tool_config(
|
||||
**retriever_config,
|
||||
has_directory_structure=has_dir,
|
||||
)
|
||||
tools_dict[INTERNAL_TOOL_ID] = internal_entry
|
||||
|
||||
|
||||
def build_internal_tool_config(
|
||||
source: Dict,
|
||||
retriever_name: str = "classic",
|
||||
chunks: int = 2,
|
||||
doc_token_limit: int = 50000,
|
||||
model_id: str = "docsgpt-local",
|
||||
user_api_key: Optional[str] = None,
|
||||
agent_id: Optional[str] = None,
|
||||
llm_name: str = None,
|
||||
api_key: str = None,
|
||||
decoded_token: Optional[Dict] = None,
|
||||
has_directory_structure: bool = False,
|
||||
) -> Dict:
|
||||
"""Build the config dict for InternalSearchTool."""
|
||||
return {
|
||||
"source": source,
|
||||
"retriever_name": retriever_name,
|
||||
"chunks": chunks,
|
||||
"doc_token_limit": doc_token_limit,
|
||||
"model_id": model_id,
|
||||
"user_api_key": user_api_key,
|
||||
"agent_id": agent_id,
|
||||
"llm_name": llm_name or settings.LLM_PROVIDER,
|
||||
"api_key": api_key or settings.API_KEY,
|
||||
"decoded_token": decoded_token,
|
||||
"has_directory_structure": has_directory_structure,
|
||||
}
|
||||
@@ -1,981 +0,0 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import concurrent.futures
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
|
||||
from fastmcp import Client
|
||||
from fastmcp.client.auth import BearerAuth
|
||||
from fastmcp.client.transports import (
|
||||
SSETransport,
|
||||
StdioTransport,
|
||||
StreamableHttpTransport,
|
||||
)
|
||||
from mcp.client.auth import OAuthClientProvider, TokenStorage
|
||||
from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken
|
||||
from pydantic import AnyHttpUrl, ValidationError
|
||||
from redis import Redis
|
||||
|
||||
from application.agents.tools.base import Tool
|
||||
from application.api.user.tasks import mcp_oauth_status_task, mcp_oauth_task
|
||||
from application.cache import get_redis_instance
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.core.settings import settings
|
||||
from application.security.encryption import decrypt_credentials
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
mongo = MongoDB.get_client()
|
||||
db = mongo[settings.MONGO_DB_NAME]
|
||||
|
||||
_mcp_clients_cache = {}
|
||||
|
||||
|
||||
class MCPTool(Tool):
|
||||
"""
|
||||
MCP Tool
|
||||
Connect to remote Model Context Protocol (MCP) servers to access dynamic tools and resources.
|
||||
"""
|
||||
|
||||
def __init__(self, config: Dict[str, Any], user_id: Optional[str] = None):
|
||||
"""
|
||||
Initialize the MCP Tool with configuration.
|
||||
|
||||
Args:
|
||||
config: Dictionary containing MCP server configuration:
|
||||
- server_url: URL of the remote MCP server
|
||||
- transport_type: Transport type (auto, sse, http, stdio)
|
||||
- auth_type: Type of authentication (bearer, oauth, api_key, basic, none)
|
||||
- encrypted_credentials: Encrypted credentials (if available)
|
||||
- timeout: Request timeout in seconds (default: 30)
|
||||
- headers: Custom headers for requests
|
||||
- command: Command for STDIO transport
|
||||
- args: Arguments for STDIO transport
|
||||
- oauth_scopes: OAuth scopes for oauth auth type
|
||||
- oauth_client_name: OAuth client name for oauth auth type
|
||||
- query_mode: If True, use non-interactive OAuth (fail-fast on 401)
|
||||
user_id: User ID for decrypting credentials (required if encrypted_credentials exist)
|
||||
"""
|
||||
self.config = config
|
||||
self.user_id = user_id
|
||||
self.server_url = config.get("server_url", "")
|
||||
self.transport_type = config.get("transport_type", "auto")
|
||||
self.auth_type = config.get("auth_type", "none")
|
||||
self.timeout = config.get("timeout", 30)
|
||||
self.custom_headers = config.get("headers", {})
|
||||
|
||||
self.auth_credentials = {}
|
||||
if config.get("encrypted_credentials") and user_id:
|
||||
self.auth_credentials = decrypt_credentials(
|
||||
config["encrypted_credentials"], user_id
|
||||
)
|
||||
else:
|
||||
self.auth_credentials = config.get("auth_credentials", {})
|
||||
self.oauth_scopes = config.get("oauth_scopes", [])
|
||||
self.oauth_task_id = config.get("oauth_task_id", None)
|
||||
self.oauth_client_name = config.get("oauth_client_name", "DocsGPT-MCP")
|
||||
self.redirect_uri = self._resolve_redirect_uri(config.get("redirect_uri"))
|
||||
|
||||
self.available_tools = []
|
||||
self._cache_key = self._generate_cache_key()
|
||||
self._client = None
|
||||
self.query_mode = config.get("query_mode", False)
|
||||
|
||||
if self.server_url and self.auth_type != "oauth":
|
||||
self._setup_client()
|
||||
|
||||
def _resolve_redirect_uri(self, configured_redirect_uri: Optional[str]) -> str:
|
||||
if configured_redirect_uri:
|
||||
return configured_redirect_uri.rstrip("/")
|
||||
|
||||
explicit = getattr(settings, "MCP_OAUTH_REDIRECT_URI", None)
|
||||
if explicit:
|
||||
return explicit.rstrip("/")
|
||||
|
||||
connector_base = getattr(settings, "CONNECTOR_REDIRECT_BASE_URI", None)
|
||||
if connector_base:
|
||||
parsed = urlparse(connector_base)
|
||||
if parsed.scheme and parsed.netloc:
|
||||
return f"{parsed.scheme}://{parsed.netloc}/api/mcp_server/callback"
|
||||
|
||||
return f"{settings.API_URL.rstrip('/')}/api/mcp_server/callback"
|
||||
|
||||
def _generate_cache_key(self) -> str:
|
||||
"""Generate a unique cache key for this MCP server configuration."""
|
||||
auth_key = ""
|
||||
if self.auth_type == "oauth":
|
||||
scopes_str = ",".join(self.oauth_scopes) if self.oauth_scopes else "none"
|
||||
auth_key = (
|
||||
f"oauth:{self.oauth_client_name}:{scopes_str}:{self.redirect_uri}"
|
||||
)
|
||||
elif self.auth_type in ["bearer"]:
|
||||
token = self.auth_credentials.get(
|
||||
"bearer_token", ""
|
||||
) or self.auth_credentials.get("access_token", "")
|
||||
auth_key = f"bearer:{token[:10]}..." if token else "bearer:none"
|
||||
elif self.auth_type == "api_key":
|
||||
api_key = self.auth_credentials.get("api_key", "")
|
||||
auth_key = f"apikey:{api_key[:10]}..." if api_key else "apikey:none"
|
||||
elif self.auth_type == "basic":
|
||||
username = self.auth_credentials.get("username", "")
|
||||
auth_key = f"basic:{username}"
|
||||
else:
|
||||
auth_key = "none"
|
||||
return f"{self.server_url}#{self.transport_type}#{auth_key}"
|
||||
|
||||
def _setup_client(self):
|
||||
global _mcp_clients_cache
|
||||
if self._cache_key in _mcp_clients_cache:
|
||||
cached_data = _mcp_clients_cache[self._cache_key]
|
||||
if time.time() - cached_data["created_at"] < 300:
|
||||
self._client = cached_data["client"]
|
||||
return
|
||||
else:
|
||||
del _mcp_clients_cache[self._cache_key]
|
||||
transport = self._create_transport()
|
||||
auth = None
|
||||
|
||||
if self.auth_type == "oauth":
|
||||
redis_client = get_redis_instance()
|
||||
if self.query_mode:
|
||||
auth = NonInteractiveOAuth(
|
||||
mcp_url=self.server_url,
|
||||
scopes=self.oauth_scopes,
|
||||
redis_client=redis_client,
|
||||
redirect_uri=self.redirect_uri,
|
||||
db=db,
|
||||
user_id=self.user_id,
|
||||
)
|
||||
else:
|
||||
auth = DocsGPTOAuth(
|
||||
mcp_url=self.server_url,
|
||||
scopes=self.oauth_scopes,
|
||||
redis_client=redis_client,
|
||||
redirect_uri=self.redirect_uri,
|
||||
task_id=self.oauth_task_id,
|
||||
db=db,
|
||||
user_id=self.user_id,
|
||||
)
|
||||
elif self.auth_type == "bearer":
|
||||
token = self.auth_credentials.get(
|
||||
"bearer_token", ""
|
||||
) or self.auth_credentials.get("access_token", "")
|
||||
if token:
|
||||
auth = BearerAuth(token)
|
||||
self._client = Client(transport, auth=auth)
|
||||
_mcp_clients_cache[self._cache_key] = {
|
||||
"client": self._client,
|
||||
"created_at": time.time(),
|
||||
}
|
||||
|
||||
def _create_transport(self):
|
||||
"""Create appropriate transport based on configuration."""
|
||||
headers = {"Content-Type": "application/json", "User-Agent": "DocsGPT-MCP/1.0"}
|
||||
headers.update(self.custom_headers)
|
||||
|
||||
if self.auth_type == "api_key":
|
||||
api_key = self.auth_credentials.get("api_key", "")
|
||||
header_name = self.auth_credentials.get("api_key_header", "X-API-Key")
|
||||
if api_key:
|
||||
headers[header_name] = api_key
|
||||
elif self.auth_type == "basic":
|
||||
username = self.auth_credentials.get("username", "")
|
||||
password = self.auth_credentials.get("password", "")
|
||||
if username and password:
|
||||
credentials = base64.b64encode(
|
||||
f"{username}:{password}".encode()
|
||||
).decode()
|
||||
headers["Authorization"] = f"Basic {credentials}"
|
||||
if self.transport_type == "auto":
|
||||
if "sse" in self.server_url.lower() or self.server_url.endswith("/sse"):
|
||||
transport_type = "sse"
|
||||
else:
|
||||
transport_type = "http"
|
||||
else:
|
||||
transport_type = self.transport_type
|
||||
if transport_type == "stdio":
|
||||
raise ValueError("STDIO transport is disabled")
|
||||
if transport_type == "sse":
|
||||
headers.update({"Accept": "text/event-stream", "Cache-Control": "no-cache"})
|
||||
return SSETransport(url=self.server_url, headers=headers)
|
||||
elif transport_type == "http":
|
||||
return StreamableHttpTransport(url=self.server_url, headers=headers)
|
||||
elif transport_type == "stdio":
|
||||
command = self.config.get("command", "python")
|
||||
args = self.config.get("args", [])
|
||||
env = self.auth_credentials if self.auth_credentials else None
|
||||
return StdioTransport(command=command, args=args, env=env)
|
||||
else:
|
||||
return StreamableHttpTransport(url=self.server_url, headers=headers)
|
||||
|
||||
def _format_tools(self, tools_response) -> List[Dict]:
|
||||
"""Format tools response to match expected format."""
|
||||
if hasattr(tools_response, "tools"):
|
||||
tools = tools_response.tools
|
||||
elif isinstance(tools_response, list):
|
||||
tools = tools_response
|
||||
else:
|
||||
tools = []
|
||||
tools_dict = []
|
||||
for tool in tools:
|
||||
if hasattr(tool, "name"):
|
||||
tool_dict = {
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
}
|
||||
if hasattr(tool, "inputSchema"):
|
||||
tool_dict["inputSchema"] = tool.inputSchema
|
||||
tools_dict.append(tool_dict)
|
||||
elif isinstance(tool, dict):
|
||||
tools_dict.append(tool)
|
||||
else:
|
||||
if hasattr(tool, "model_dump"):
|
||||
tools_dict.append(tool.model_dump())
|
||||
else:
|
||||
tools_dict.append({"name": str(tool), "description": ""})
|
||||
return tools_dict
|
||||
|
||||
async def _execute_with_client(self, operation: str, *args, **kwargs):
|
||||
"""Execute operation with FastMCP client."""
|
||||
if not self._client:
|
||||
raise Exception("FastMCP client not initialized")
|
||||
async with self._client:
|
||||
if operation == "ping":
|
||||
return await self._client.ping()
|
||||
elif operation == "list_tools":
|
||||
tools_response = await self._client.list_tools()
|
||||
self.available_tools = self._format_tools(tools_response)
|
||||
return self.available_tools
|
||||
elif operation == "call_tool":
|
||||
tool_name = args[0]
|
||||
tool_args = kwargs
|
||||
return await self._client.call_tool(tool_name, tool_args)
|
||||
elif operation == "list_resources":
|
||||
return await self._client.list_resources()
|
||||
elif operation == "list_prompts":
|
||||
return await self._client.list_prompts()
|
||||
else:
|
||||
raise Exception(f"Unknown operation: {operation}")
|
||||
|
||||
_ERROR_MAP = [
|
||||
(concurrent.futures.TimeoutError, lambda op, t, _: f"Timed out after {t}s"),
|
||||
(ConnectionRefusedError, lambda *_: "Connection refused"),
|
||||
]
|
||||
|
||||
_ERROR_PATTERNS = {
|
||||
("403", "Forbidden"): "Access denied (403 Forbidden)",
|
||||
("401", "Unauthorized"): "Authentication failed (401 Unauthorized)",
|
||||
("ECONNREFUSED",): "Connection refused",
|
||||
("SSL", "certificate"): "SSL/TLS error",
|
||||
}
|
||||
|
||||
def _run_async_operation(self, operation: str, *args, **kwargs):
|
||||
try:
|
||||
try:
|
||||
asyncio.get_running_loop()
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
future = executor.submit(
|
||||
self._run_in_new_loop, operation, *args, **kwargs
|
||||
)
|
||||
return future.result(timeout=self.timeout)
|
||||
except RuntimeError:
|
||||
return self._run_in_new_loop(operation, *args, **kwargs)
|
||||
except Exception as e:
|
||||
raise self._map_error(operation, e) from e
|
||||
raise self._map_error(operation, e) from e
|
||||
|
||||
def _run_in_new_loop(self, operation, *args, **kwargs):
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
return loop.run_until_complete(
|
||||
self._execute_with_client(operation, *args, **kwargs)
|
||||
)
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
def _map_error(self, operation: str, exc: Exception) -> Exception:
|
||||
for exc_type, msg_fn in self._ERROR_MAP:
|
||||
if isinstance(exc, exc_type):
|
||||
return Exception(msg_fn(operation, self.timeout, exc))
|
||||
error_msg = str(exc)
|
||||
for patterns, friendly in self._ERROR_PATTERNS.items():
|
||||
if any(p.lower() in error_msg.lower() for p in patterns):
|
||||
return Exception(friendly)
|
||||
logger.error("MCP %s failed: %s", operation, exc)
|
||||
return exc
|
||||
|
||||
def discover_tools(self) -> List[Dict]:
|
||||
"""
|
||||
Discover available tools from the MCP server using FastMCP.
|
||||
|
||||
Returns:
|
||||
List of tool definitions from the server
|
||||
"""
|
||||
if not self.server_url:
|
||||
return []
|
||||
if not self._client:
|
||||
self._setup_client()
|
||||
try:
|
||||
tools = self._run_async_operation("list_tools")
|
||||
self.available_tools = tools
|
||||
return self.available_tools
|
||||
except Exception as e:
|
||||
raise Exception(f"Failed to discover tools from MCP server: {str(e)}")
|
||||
|
||||
def execute_action(self, action_name: str, **kwargs) -> Any:
|
||||
if not self.server_url:
|
||||
raise Exception("No MCP server configured")
|
||||
if not self._client:
|
||||
self._setup_client()
|
||||
cleaned_kwargs = {}
|
||||
for key, value in kwargs.items():
|
||||
if value == "" or value is None:
|
||||
continue
|
||||
cleaned_kwargs[key] = value
|
||||
try:
|
||||
result = self._run_async_operation(
|
||||
"call_tool", action_name, **cleaned_kwargs
|
||||
)
|
||||
return self._format_result(result)
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
lower_msg = error_msg.lower()
|
||||
is_auth_error = (
|
||||
"401" in error_msg
|
||||
or "unauthorized" in lower_msg
|
||||
or "session expired" in lower_msg
|
||||
or "re-authorize" in lower_msg
|
||||
)
|
||||
if is_auth_error:
|
||||
if self.auth_type == "oauth":
|
||||
raise Exception(
|
||||
f"Action '{action_name}' failed: OAuth session expired. "
|
||||
"Please re-authorize this MCP server in tool settings."
|
||||
) from e
|
||||
global _mcp_clients_cache
|
||||
_mcp_clients_cache.pop(self._cache_key, None)
|
||||
self._client = None
|
||||
self._setup_client()
|
||||
try:
|
||||
result = self._run_async_operation(
|
||||
"call_tool", action_name, **cleaned_kwargs
|
||||
)
|
||||
return self._format_result(result)
|
||||
except Exception as retry_e:
|
||||
raise Exception(
|
||||
f"Action '{action_name}' failed after re-auth attempt: {retry_e}. "
|
||||
"Your credentials may have expired — please re-authorize in tool settings."
|
||||
) from retry_e
|
||||
raise Exception(
|
||||
f"Failed to execute action '{action_name}': {error_msg}"
|
||||
) from e
|
||||
|
||||
def _format_result(self, result) -> Dict:
|
||||
"""Format FastMCP result to match expected format."""
|
||||
if hasattr(result, "content"):
|
||||
content_list = []
|
||||
for content_item in result.content:
|
||||
if hasattr(content_item, "text"):
|
||||
content_list.append({"type": "text", "text": content_item.text})
|
||||
elif hasattr(content_item, "data"):
|
||||
content_list.append({"type": "data", "data": content_item.data})
|
||||
else:
|
||||
content_list.append(
|
||||
{"type": "unknown", "content": str(content_item)}
|
||||
)
|
||||
return {
|
||||
"content": content_list,
|
||||
"isError": getattr(result, "isError", False),
|
||||
}
|
||||
else:
|
||||
return result
|
||||
|
||||
def test_connection(self) -> Dict:
|
||||
if not self.server_url:
|
||||
return {
|
||||
"success": False,
|
||||
"message": "No server URL configured",
|
||||
"tools_count": 0,
|
||||
}
|
||||
try:
|
||||
parsed = urlparse(self.server_url)
|
||||
if parsed.scheme not in ("http", "https"):
|
||||
return {
|
||||
"success": False,
|
||||
"message": f"Invalid URL scheme '{parsed.scheme}' — use http:// or https://",
|
||||
"tools_count": 0,
|
||||
}
|
||||
except Exception:
|
||||
return {
|
||||
"success": False,
|
||||
"message": "Invalid URL format",
|
||||
"tools_count": 0,
|
||||
}
|
||||
if not self._client:
|
||||
try:
|
||||
self._setup_client()
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"message": f"Client init failed: {str(e)}",
|
||||
"tools_count": 0,
|
||||
}
|
||||
try:
|
||||
if self.auth_type == "oauth":
|
||||
return self._test_oauth_connection()
|
||||
else:
|
||||
return self._test_regular_connection()
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"message": f"Connection failed: {str(e)}",
|
||||
"tools_count": 0,
|
||||
}
|
||||
|
||||
def _test_regular_connection(self) -> Dict:
|
||||
ping_ok = False
|
||||
ping_error = None
|
||||
try:
|
||||
self._run_async_operation("ping")
|
||||
ping_ok = True
|
||||
except Exception as e:
|
||||
ping_error = str(e)
|
||||
|
||||
try:
|
||||
tools = self.discover_tools()
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"message": f"Connection failed: {ping_error or str(e)}",
|
||||
"tools_count": 0,
|
||||
}
|
||||
|
||||
if not tools and not ping_ok:
|
||||
return {
|
||||
"success": False,
|
||||
"message": f"Connection failed: {ping_error or 'No tools found'}",
|
||||
"tools_count": 0,
|
||||
}
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"Connected — found {len(tools)} tool{'s' if len(tools) != 1 else ''}.",
|
||||
"tools_count": len(tools),
|
||||
"tools": [
|
||||
{
|
||||
"name": tool.get("name", "unknown"),
|
||||
"description": tool.get("description", ""),
|
||||
}
|
||||
for tool in tools
|
||||
],
|
||||
}
|
||||
|
||||
def _test_oauth_connection(self) -> Dict:
|
||||
storage = DBTokenStorage(
|
||||
server_url=self.server_url, user_id=self.user_id, db_client=db
|
||||
)
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
tokens = loop.run_until_complete(storage.get_tokens())
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
if tokens and tokens.access_token:
|
||||
self.query_mode = True
|
||||
_mcp_clients_cache.pop(self._cache_key, None)
|
||||
self._client = None
|
||||
self._setup_client()
|
||||
try:
|
||||
tools = self.discover_tools()
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"Connected — found {len(tools)} tool{'s' if len(tools) != 1 else ''}.",
|
||||
"tools_count": len(tools),
|
||||
"tools": [
|
||||
{
|
||||
"name": t.get("name", "unknown"),
|
||||
"description": t.get("description", ""),
|
||||
}
|
||||
for t in tools
|
||||
],
|
||||
}
|
||||
except Exception as e:
|
||||
logger.warning("OAuth token validation failed: %s", e)
|
||||
_mcp_clients_cache.pop(self._cache_key, None)
|
||||
self._client = None
|
||||
|
||||
return self._start_oauth_task()
|
||||
|
||||
def _start_oauth_task(self) -> Dict:
|
||||
task_config = self.config.copy()
|
||||
task_config.pop("query_mode", None)
|
||||
result = mcp_oauth_task.delay(task_config, self.user_id)
|
||||
return {
|
||||
"success": False,
|
||||
"requires_oauth": True,
|
||||
"task_id": result.id,
|
||||
"message": "OAuth authorization required.",
|
||||
"tools_count": 0,
|
||||
}
|
||||
|
||||
def get_actions_metadata(self) -> List[Dict]:
|
||||
"""
|
||||
Get metadata for all available actions.
|
||||
|
||||
Returns:
|
||||
List of action metadata dictionaries
|
||||
"""
|
||||
actions = []
|
||||
for tool in self.available_tools:
|
||||
input_schema = (
|
||||
tool.get("inputSchema")
|
||||
or tool.get("input_schema")
|
||||
or tool.get("schema")
|
||||
or tool.get("parameters")
|
||||
)
|
||||
|
||||
parameters_schema = {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": [],
|
||||
}
|
||||
|
||||
if input_schema:
|
||||
if isinstance(input_schema, dict):
|
||||
if "properties" in input_schema:
|
||||
parameters_schema = {
|
||||
"type": input_schema.get("type", "object"),
|
||||
"properties": input_schema.get("properties", {}),
|
||||
"required": input_schema.get("required", []),
|
||||
}
|
||||
|
||||
for key in ["additionalProperties", "description"]:
|
||||
if key in input_schema:
|
||||
parameters_schema[key] = input_schema[key]
|
||||
else:
|
||||
parameters_schema["properties"] = input_schema
|
||||
action = {
|
||||
"name": tool.get("name", ""),
|
||||
"description": tool.get("description", ""),
|
||||
"parameters": parameters_schema,
|
||||
}
|
||||
actions.append(action)
|
||||
return actions
|
||||
|
||||
def get_config_requirements(self) -> Dict:
|
||||
return {
|
||||
"server_url": {
|
||||
"type": "string",
|
||||
"label": "Server URL",
|
||||
"description": "URL of the remote MCP server",
|
||||
"required": True,
|
||||
"secret": False,
|
||||
"order": 1,
|
||||
},
|
||||
"auth_type": {
|
||||
"type": "string",
|
||||
"label": "Authentication Type",
|
||||
"description": "Authentication method for the MCP server",
|
||||
"enum": ["none", "bearer", "oauth", "api_key", "basic"],
|
||||
"default": "none",
|
||||
"required": True,
|
||||
"secret": False,
|
||||
"order": 2,
|
||||
},
|
||||
"api_key": {
|
||||
"type": "string",
|
||||
"label": "API Key",
|
||||
"description": "API key for authentication",
|
||||
"required": False,
|
||||
"secret": True,
|
||||
"order": 3,
|
||||
"depends_on": {"auth_type": "api_key"},
|
||||
},
|
||||
"api_key_header": {
|
||||
"type": "string",
|
||||
"label": "API Key Header",
|
||||
"description": "Header name for API key (default: X-API-Key)",
|
||||
"default": "X-API-Key",
|
||||
"required": False,
|
||||
"secret": False,
|
||||
"order": 4,
|
||||
"depends_on": {"auth_type": "api_key"},
|
||||
},
|
||||
"bearer_token": {
|
||||
"type": "string",
|
||||
"label": "Bearer Token",
|
||||
"description": "Bearer token for authentication",
|
||||
"required": False,
|
||||
"secret": True,
|
||||
"order": 3,
|
||||
"depends_on": {"auth_type": "bearer"},
|
||||
},
|
||||
"username": {
|
||||
"type": "string",
|
||||
"label": "Username",
|
||||
"description": "Username for basic authentication",
|
||||
"required": False,
|
||||
"secret": False,
|
||||
"order": 3,
|
||||
"depends_on": {"auth_type": "basic"},
|
||||
},
|
||||
"password": {
|
||||
"type": "string",
|
||||
"label": "Password",
|
||||
"description": "Password for basic authentication",
|
||||
"required": False,
|
||||
"secret": True,
|
||||
"order": 4,
|
||||
"depends_on": {"auth_type": "basic"},
|
||||
},
|
||||
"oauth_scopes": {
|
||||
"type": "string",
|
||||
"label": "OAuth Scopes",
|
||||
"description": "Comma-separated OAuth scopes to request",
|
||||
"required": False,
|
||||
"secret": False,
|
||||
"order": 3,
|
||||
"depends_on": {"auth_type": "oauth"},
|
||||
},
|
||||
"timeout": {
|
||||
"type": "number",
|
||||
"label": "Timeout (seconds)",
|
||||
"description": "Request timeout in seconds (1-300)",
|
||||
"default": 30,
|
||||
"required": False,
|
||||
"secret": False,
|
||||
"order": 10,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class DocsGPTOAuth(OAuthClientProvider):
|
||||
"""
|
||||
Custom OAuth handler for DocsGPT that uses frontend redirect instead of browser.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
mcp_url: str,
|
||||
redirect_uri: str,
|
||||
redis_client: Redis | None = None,
|
||||
redis_prefix: str = "mcp_oauth:",
|
||||
task_id: str = None,
|
||||
scopes: str | list[str] | None = None,
|
||||
client_name: str = "DocsGPT-MCP",
|
||||
user_id=None,
|
||||
db=None,
|
||||
additional_client_metadata: dict[str, Any] | None = None,
|
||||
skip_redirect_validation: bool = False,
|
||||
):
|
||||
self.redirect_uri = redirect_uri
|
||||
self.redis_client = redis_client
|
||||
self.redis_prefix = redis_prefix
|
||||
self.task_id = task_id
|
||||
self.user_id = user_id
|
||||
self.db = db
|
||||
|
||||
parsed_url = urlparse(mcp_url)
|
||||
self.server_base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
|
||||
|
||||
if isinstance(scopes, list):
|
||||
scopes = " ".join(scopes)
|
||||
client_metadata = OAuthClientMetadata(
|
||||
client_name=client_name,
|
||||
redirect_uris=[AnyHttpUrl(redirect_uri)],
|
||||
grant_types=["authorization_code", "refresh_token"],
|
||||
response_types=["code"],
|
||||
scope=scopes,
|
||||
**(additional_client_metadata or {}),
|
||||
)
|
||||
|
||||
storage = DBTokenStorage(
|
||||
server_url=self.server_base_url,
|
||||
user_id=self.user_id,
|
||||
db_client=self.db,
|
||||
expected_redirect_uri=None if skip_redirect_validation else redirect_uri,
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
server_url=self.server_base_url,
|
||||
client_metadata=client_metadata,
|
||||
storage=storage,
|
||||
redirect_handler=self.redirect_handler,
|
||||
callback_handler=self.callback_handler,
|
||||
)
|
||||
|
||||
self.auth_url = None
|
||||
self.extracted_state = None
|
||||
|
||||
def _process_auth_url(self, authorization_url: str) -> tuple[str, str]:
|
||||
"""Process authorization URL to extract state"""
|
||||
try:
|
||||
parsed_url = urlparse(authorization_url)
|
||||
query_params = parse_qs(parsed_url.query)
|
||||
|
||||
state_params = query_params.get("state", [])
|
||||
if state_params:
|
||||
state = state_params[0]
|
||||
else:
|
||||
raise ValueError("No state in auth URL")
|
||||
return authorization_url, state
|
||||
except Exception as e:
|
||||
raise Exception(f"Failed to process auth URL: {e}")
|
||||
|
||||
async def redirect_handler(self, authorization_url: str) -> None:
|
||||
"""Store auth URL and state in Redis for frontend to use."""
|
||||
auth_url, state = self._process_auth_url(authorization_url)
|
||||
logger.info("Processed auth_url: %s, state: %s", auth_url, state)
|
||||
self.auth_url = auth_url
|
||||
self.extracted_state = state
|
||||
|
||||
if self.redis_client and self.extracted_state:
|
||||
key = f"{self.redis_prefix}auth_url:{self.extracted_state}"
|
||||
self.redis_client.setex(key, 600, auth_url)
|
||||
logger.info("Stored auth_url in Redis: %s", key)
|
||||
|
||||
if self.task_id:
|
||||
status_key = f"mcp_oauth_status:{self.task_id}"
|
||||
status_data = {
|
||||
"status": "requires_redirect",
|
||||
"message": "Authorization required",
|
||||
"authorization_url": self.auth_url,
|
||||
"state": self.extracted_state,
|
||||
"requires_oauth": True,
|
||||
"task_id": self.task_id,
|
||||
}
|
||||
self.redis_client.setex(status_key, 600, json.dumps(status_data))
|
||||
|
||||
async def callback_handler(self) -> tuple[str, str | None]:
|
||||
"""Wait for auth code from Redis using the state value."""
|
||||
if not self.redis_client or not self.extracted_state:
|
||||
raise Exception("Redis client or state not configured for OAuth")
|
||||
poll_interval = 1
|
||||
max_wait_time = 300
|
||||
code_key = f"{self.redis_prefix}code:{self.extracted_state}"
|
||||
|
||||
if self.task_id:
|
||||
status_key = f"mcp_oauth_status:{self.task_id}"
|
||||
status_data = {
|
||||
"status": "awaiting_callback",
|
||||
"message": "Waiting for authorization...",
|
||||
"authorization_url": self.auth_url,
|
||||
"state": self.extracted_state,
|
||||
"requires_oauth": True,
|
||||
"task_id": self.task_id,
|
||||
}
|
||||
self.redis_client.setex(status_key, 600, json.dumps(status_data))
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < max_wait_time:
|
||||
code_data = self.redis_client.get(code_key)
|
||||
if code_data:
|
||||
code = code_data.decode()
|
||||
returned_state = self.extracted_state
|
||||
|
||||
self.redis_client.delete(code_key)
|
||||
self.redis_client.delete(
|
||||
f"{self.redis_prefix}auth_url:{self.extracted_state}"
|
||||
)
|
||||
self.redis_client.delete(
|
||||
f"{self.redis_prefix}state:{self.extracted_state}"
|
||||
)
|
||||
|
||||
if self.task_id:
|
||||
status_data = {
|
||||
"status": "callback_received",
|
||||
"message": "Completing authentication...",
|
||||
"task_id": self.task_id,
|
||||
}
|
||||
self.redis_client.setex(status_key, 600, json.dumps(status_data))
|
||||
return code, returned_state
|
||||
error_key = f"{self.redis_prefix}error:{self.extracted_state}"
|
||||
error_data = self.redis_client.get(error_key)
|
||||
if error_data:
|
||||
error_msg = error_data.decode()
|
||||
self.redis_client.delete(error_key)
|
||||
self.redis_client.delete(
|
||||
f"{self.redis_prefix}auth_url:{self.extracted_state}"
|
||||
)
|
||||
self.redis_client.delete(
|
||||
f"{self.redis_prefix}state:{self.extracted_state}"
|
||||
)
|
||||
raise Exception(f"OAuth error: {error_msg}")
|
||||
await asyncio.sleep(poll_interval)
|
||||
self.redis_client.delete(f"{self.redis_prefix}auth_url:{self.extracted_state}")
|
||||
self.redis_client.delete(f"{self.redis_prefix}state:{self.extracted_state}")
|
||||
raise Exception("OAuth timeout: no code received within 5 minutes")
|
||||
|
||||
|
||||
class NonInteractiveOAuth(DocsGPTOAuth):
|
||||
"""OAuth provider that fails fast on 401 instead of starting interactive auth.
|
||||
|
||||
Used during query execution to prevent the streaming response from blocking
|
||||
while waiting for user authorization that will never come.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
kwargs.setdefault("task_id", None)
|
||||
kwargs["skip_redirect_validation"] = True
|
||||
super().__init__(**kwargs)
|
||||
|
||||
async def redirect_handler(self, authorization_url: str) -> None:
|
||||
raise Exception(
|
||||
"OAuth session expired — please re-authorize this MCP server in tool settings."
|
||||
)
|
||||
|
||||
async def callback_handler(self) -> tuple[str, str | None]:
|
||||
raise Exception(
|
||||
"OAuth session expired — please re-authorize this MCP server in tool settings."
|
||||
)
|
||||
|
||||
|
||||
class DBTokenStorage(TokenStorage):
|
||||
def __init__(
|
||||
self,
|
||||
server_url: str,
|
||||
user_id: str,
|
||||
db_client,
|
||||
expected_redirect_uri: Optional[str] = None,
|
||||
):
|
||||
self.server_url = server_url
|
||||
self.user_id = user_id
|
||||
self.db_client = db_client
|
||||
self.expected_redirect_uri = expected_redirect_uri
|
||||
self.collection = db_client["connector_sessions"]
|
||||
|
||||
@staticmethod
|
||||
def get_base_url(url: str) -> str:
|
||||
parsed = urlparse(url)
|
||||
return f"{parsed.scheme}://{parsed.netloc}"
|
||||
|
||||
def get_db_key(self) -> dict:
|
||||
return {
|
||||
"server_url": self.get_base_url(self.server_url),
|
||||
"user_id": self.user_id,
|
||||
}
|
||||
|
||||
async def get_tokens(self) -> OAuthToken | None:
|
||||
doc = await asyncio.to_thread(self.collection.find_one, self.get_db_key())
|
||||
if not doc or "tokens" not in doc:
|
||||
return None
|
||||
try:
|
||||
return OAuthToken.model_validate(doc["tokens"])
|
||||
except ValidationError as e:
|
||||
logger.error("Could not load tokens: %s", e)
|
||||
return None
|
||||
|
||||
async def set_tokens(self, tokens: OAuthToken) -> None:
|
||||
await asyncio.to_thread(
|
||||
self.collection.update_one,
|
||||
self.get_db_key(),
|
||||
{"$set": {"tokens": tokens.model_dump()}},
|
||||
True,
|
||||
)
|
||||
logger.info("Saved tokens for %s", self.get_base_url(self.server_url))
|
||||
|
||||
async def get_client_info(self) -> OAuthClientInformationFull | None:
|
||||
doc = await asyncio.to_thread(self.collection.find_one, self.get_db_key())
|
||||
if not doc or "client_info" not in doc:
|
||||
logger.debug(
|
||||
"No client_info in DB for %s", self.get_base_url(self.server_url)
|
||||
)
|
||||
return None
|
||||
try:
|
||||
client_info = OAuthClientInformationFull.model_validate(doc["client_info"])
|
||||
if self.expected_redirect_uri:
|
||||
stored_uris = [
|
||||
str(uri).rstrip("/") for uri in client_info.redirect_uris
|
||||
]
|
||||
expected_uri = self.expected_redirect_uri.rstrip("/")
|
||||
if expected_uri not in stored_uris:
|
||||
logger.warning(
|
||||
"Redirect URI mismatch for %s: expected=%s stored=%s — clearing.",
|
||||
self.get_base_url(self.server_url),
|
||||
expected_uri,
|
||||
stored_uris,
|
||||
)
|
||||
await asyncio.to_thread(
|
||||
self.collection.update_one,
|
||||
self.get_db_key(),
|
||||
{"$unset": {"client_info": "", "tokens": ""}},
|
||||
)
|
||||
return None
|
||||
return client_info
|
||||
except ValidationError as e:
|
||||
logger.error("Could not load client info: %s", e)
|
||||
return None
|
||||
|
||||
def _serialize_client_info(self, info: dict) -> dict:
|
||||
if "redirect_uris" in info and isinstance(info["redirect_uris"], list):
|
||||
info["redirect_uris"] = [str(u) for u in info["redirect_uris"]]
|
||||
return info
|
||||
|
||||
async def set_client_info(self, client_info: OAuthClientInformationFull) -> None:
|
||||
serialized_info = self._serialize_client_info(client_info.model_dump())
|
||||
await asyncio.to_thread(
|
||||
self.collection.update_one,
|
||||
self.get_db_key(),
|
||||
{"$set": {"client_info": serialized_info}},
|
||||
True,
|
||||
)
|
||||
logger.info("Saved client info for %s", self.get_base_url(self.server_url))
|
||||
|
||||
async def clear(self) -> None:
|
||||
await asyncio.to_thread(self.collection.delete_one, self.get_db_key())
|
||||
logger.info("Cleared OAuth cache for %s", self.get_base_url(self.server_url))
|
||||
|
||||
@classmethod
|
||||
async def clear_all(cls, db_client) -> None:
|
||||
collection = db_client["connector_sessions"]
|
||||
await asyncio.to_thread(collection.delete_many, {})
|
||||
logger.info("Cleared all OAuth client cache data.")
|
||||
|
||||
|
||||
class MCPOAuthManager:
|
||||
"""Manager for handling MCP OAuth callbacks."""
|
||||
|
||||
def __init__(self, redis_client: Redis | None, redis_prefix: str = "mcp_oauth:"):
|
||||
self.redis_client = redis_client
|
||||
self.redis_prefix = redis_prefix
|
||||
|
||||
def handle_oauth_callback(
|
||||
self, state: str, code: str, error: Optional[str] = None
|
||||
) -> bool:
|
||||
"""
|
||||
Handle OAuth callback from provider.
|
||||
|
||||
Args:
|
||||
state: The state parameter from OAuth callback
|
||||
code: The authorization code from OAuth callback
|
||||
error: Error message if OAuth failed
|
||||
|
||||
Returns:
|
||||
True if successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
if not self.redis_client or not state:
|
||||
raise Exception("Redis client or state not provided")
|
||||
if error:
|
||||
error_key = f"{self.redis_prefix}error:{state}"
|
||||
self.redis_client.setex(error_key, 300, error)
|
||||
raise Exception(f"OAuth error received: {error}")
|
||||
code_key = f"{self.redis_prefix}code:{state}"
|
||||
self.redis_client.setex(code_key, 300, code)
|
||||
|
||||
state_key = f"{self.redis_prefix}state:{state}"
|
||||
self.redis_client.setex(state_key, 300, "completed")
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error("Error handling OAuth callback: %s", e)
|
||||
return False
|
||||
|
||||
def get_oauth_status(self, task_id: str) -> Dict[str, Any]:
|
||||
"""Get current status of OAuth flow using provided task_id."""
|
||||
if not task_id:
|
||||
return {"status": "not_started", "message": "OAuth flow not started"}
|
||||
return mcp_oauth_status_task(task_id)
|
||||
@@ -1,546 +0,0 @@
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
import re
|
||||
import uuid
|
||||
|
||||
from .base import Tool
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.core.settings import settings
|
||||
|
||||
|
||||
class MemoryTool(Tool):
|
||||
"""Memory
|
||||
|
||||
Stores and retrieves information across conversations through a memory file directory.
|
||||
"""
|
||||
|
||||
def __init__(self, tool_config: Optional[Dict[str, Any]] = None, user_id: Optional[str] = None) -> None:
|
||||
"""Initialize the tool.
|
||||
|
||||
Args:
|
||||
tool_config: Optional tool configuration. Should include:
|
||||
- tool_id: Unique identifier for this memory tool instance (from user_tools._id)
|
||||
This ensures each user's tool configuration has isolated memories
|
||||
user_id: The authenticated user's id (should come from decoded_token["sub"]).
|
||||
"""
|
||||
self.user_id: Optional[str] = user_id
|
||||
|
||||
# Get tool_id from configuration (passed from user_tools._id in production)
|
||||
# In production, tool_id is the MongoDB ObjectId string from user_tools collection
|
||||
if tool_config and "tool_id" in tool_config:
|
||||
self.tool_id = tool_config["tool_id"]
|
||||
elif user_id:
|
||||
# Fallback for backward compatibility or testing
|
||||
self.tool_id = f"default_{user_id}"
|
||||
else:
|
||||
# Last resort fallback (shouldn't happen in normal use)
|
||||
self.tool_id = str(uuid.uuid4())
|
||||
|
||||
db = MongoDB.get_client()[settings.MONGO_DB_NAME]
|
||||
self.collection = db["memories"]
|
||||
|
||||
# -----------------------------
|
||||
# Action implementations
|
||||
# -----------------------------
|
||||
def execute_action(self, action_name: str, **kwargs: Any) -> str:
|
||||
"""Execute an action by name.
|
||||
|
||||
Args:
|
||||
action_name: One of view, create, str_replace, insert, delete, rename.
|
||||
**kwargs: Parameters for the action.
|
||||
|
||||
Returns:
|
||||
A human-readable string result.
|
||||
"""
|
||||
if not self.user_id:
|
||||
return "Error: MemoryTool requires a valid user_id."
|
||||
|
||||
if action_name == "view":
|
||||
return self._view(
|
||||
kwargs.get("path", "/"),
|
||||
kwargs.get("view_range")
|
||||
)
|
||||
|
||||
if action_name == "create":
|
||||
return self._create(
|
||||
kwargs.get("path", ""),
|
||||
kwargs.get("file_text", "")
|
||||
)
|
||||
|
||||
if action_name == "str_replace":
|
||||
return self._str_replace(
|
||||
kwargs.get("path", ""),
|
||||
kwargs.get("old_str", ""),
|
||||
kwargs.get("new_str", "")
|
||||
)
|
||||
|
||||
if action_name == "insert":
|
||||
return self._insert(
|
||||
kwargs.get("path", ""),
|
||||
kwargs.get("insert_line", 1),
|
||||
kwargs.get("insert_text", "")
|
||||
)
|
||||
|
||||
if action_name == "delete":
|
||||
return self._delete(kwargs.get("path", ""))
|
||||
|
||||
if action_name == "rename":
|
||||
return self._rename(
|
||||
kwargs.get("old_path", ""),
|
||||
kwargs.get("new_path", "")
|
||||
)
|
||||
|
||||
return f"Unknown action: {action_name}"
|
||||
|
||||
def get_actions_metadata(self) -> List[Dict[str, Any]]:
|
||||
"""Return JSON metadata describing supported actions for tool schemas."""
|
||||
return [
|
||||
{
|
||||
"name": "view",
|
||||
"description": "Shows directory contents or file contents with optional line ranges.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Path to file or directory (e.g., /notes.txt or /project/ or /)."
|
||||
},
|
||||
"view_range": {
|
||||
"type": "array",
|
||||
"items": {"type": "integer"},
|
||||
"description": "Optional [start_line, end_line] to view specific lines (1-indexed)."
|
||||
}
|
||||
},
|
||||
"required": ["path"]
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "create",
|
||||
"description": "Create or overwrite a file.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "File path to create (e.g., /notes.txt or /project/task.txt)."
|
||||
},
|
||||
"file_text": {
|
||||
"type": "string",
|
||||
"description": "Content to write to the file."
|
||||
}
|
||||
},
|
||||
"required": ["path", "file_text"]
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "str_replace",
|
||||
"description": "Replace text in a file.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "File path (e.g., /notes.txt)."
|
||||
},
|
||||
"old_str": {
|
||||
"type": "string",
|
||||
"description": "String to find."
|
||||
},
|
||||
"new_str": {
|
||||
"type": "string",
|
||||
"description": "String to replace with."
|
||||
}
|
||||
},
|
||||
"required": ["path", "old_str", "new_str"]
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "insert",
|
||||
"description": "Insert text at a specific line in a file.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "File path (e.g., /notes.txt)."
|
||||
},
|
||||
"insert_line": {
|
||||
"type": "integer",
|
||||
"description": "Line number to insert at (1-indexed)."
|
||||
},
|
||||
"insert_text": {
|
||||
"type": "string",
|
||||
"description": "Text to insert."
|
||||
}
|
||||
},
|
||||
"required": ["path", "insert_line", "insert_text"]
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "delete",
|
||||
"description": "Delete a file or directory.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Path to delete (e.g., /notes.txt or /project/)."
|
||||
}
|
||||
},
|
||||
"required": ["path"]
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "rename",
|
||||
"description": "Rename or move a file/directory.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"old_path": {
|
||||
"type": "string",
|
||||
"description": "Current path (e.g., /old.txt)."
|
||||
},
|
||||
"new_path": {
|
||||
"type": "string",
|
||||
"description": "New path (e.g., /new.txt)."
|
||||
}
|
||||
},
|
||||
"required": ["old_path", "new_path"]
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
def get_config_requirements(self) -> Dict[str, Any]:
|
||||
"""Return configuration requirements."""
|
||||
return {}
|
||||
|
||||
# -----------------------------
|
||||
# Path validation
|
||||
# -----------------------------
|
||||
def _validate_path(self, path: str) -> Optional[str]:
|
||||
"""Validate and normalize path.
|
||||
|
||||
Args:
|
||||
path: User-provided path.
|
||||
|
||||
Returns:
|
||||
Normalized path or None if invalid.
|
||||
"""
|
||||
if not path:
|
||||
return None
|
||||
|
||||
# Remove any leading/trailing whitespace
|
||||
path = path.strip()
|
||||
|
||||
# Preserve whether path ends with / (indicates directory)
|
||||
is_directory = path.endswith("/")
|
||||
|
||||
# Ensure path starts with / for consistency
|
||||
if not path.startswith("/"):
|
||||
path = "/" + path
|
||||
|
||||
# Check for directory traversal patterns
|
||||
if ".." in path or path.count("//") > 0:
|
||||
return None
|
||||
|
||||
# Normalize the path
|
||||
try:
|
||||
# Convert to Path object and resolve to canonical form
|
||||
normalized = str(Path(path).as_posix())
|
||||
|
||||
# Ensure it still starts with /
|
||||
if not normalized.startswith("/"):
|
||||
return None
|
||||
|
||||
# Preserve trailing slash for directories
|
||||
if is_directory and not normalized.endswith("/") and normalized != "/":
|
||||
normalized = normalized + "/"
|
||||
|
||||
return normalized
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
# -----------------------------
|
||||
# Internal helpers
|
||||
# -----------------------------
|
||||
def _view(self, path: str, view_range: Optional[List[int]] = None) -> str:
|
||||
"""View directory contents or file contents."""
|
||||
validated_path = self._validate_path(path)
|
||||
if not validated_path:
|
||||
return "Error: Invalid path."
|
||||
|
||||
# Check if viewing directory (ends with / or is root)
|
||||
if validated_path == "/" or validated_path.endswith("/"):
|
||||
return self._view_directory(validated_path)
|
||||
|
||||
# Otherwise view file
|
||||
return self._view_file(validated_path, view_range)
|
||||
|
||||
def _view_directory(self, path: str) -> str:
|
||||
"""List files in a directory."""
|
||||
# Ensure path ends with / for proper prefix matching
|
||||
search_path = path if path.endswith("/") else path + "/"
|
||||
|
||||
# Find all files that start with this directory path
|
||||
query = {
|
||||
"user_id": self.user_id,
|
||||
"tool_id": self.tool_id,
|
||||
"path": {"$regex": f"^{re.escape(search_path)}"}
|
||||
}
|
||||
|
||||
docs = list(self.collection.find(query, {"path": 1}))
|
||||
|
||||
if not docs:
|
||||
return f"Directory: {path}\n(empty)"
|
||||
|
||||
# Extract filenames relative to the directory
|
||||
files = []
|
||||
for doc in docs:
|
||||
file_path = doc["path"]
|
||||
# Remove the directory prefix
|
||||
if file_path.startswith(search_path):
|
||||
relative = file_path[len(search_path):]
|
||||
if relative:
|
||||
files.append(relative)
|
||||
|
||||
files.sort()
|
||||
file_list = "\n".join(f"- {f}" for f in files)
|
||||
return f"Directory: {path}\n{file_list}"
|
||||
|
||||
def _view_file(self, path: str, view_range: Optional[List[int]] = None) -> str:
|
||||
"""View file contents with optional line range."""
|
||||
doc = self.collection.find_one({"user_id": self.user_id, "tool_id": self.tool_id, "path": path})
|
||||
|
||||
if not doc or not doc.get("content"):
|
||||
return f"Error: File not found: {path}"
|
||||
|
||||
content = str(doc["content"])
|
||||
|
||||
# Apply view_range if specified
|
||||
if view_range and len(view_range) == 2:
|
||||
lines = content.split("\n")
|
||||
start, end = view_range
|
||||
# Convert to 0-indexed
|
||||
start_idx = max(0, start - 1)
|
||||
end_idx = min(len(lines), end)
|
||||
|
||||
if start_idx >= len(lines):
|
||||
return f"Error: Line range out of bounds. File has {len(lines)} lines."
|
||||
|
||||
selected_lines = lines[start_idx:end_idx]
|
||||
# Add line numbers (enumerate with 1-based start)
|
||||
numbered_lines = [f"{i}: {line}" for i, line in enumerate(selected_lines, start=start)]
|
||||
return "\n".join(numbered_lines)
|
||||
|
||||
return content
|
||||
|
||||
def _create(self, path: str, file_text: str) -> str:
|
||||
"""Create or overwrite a file."""
|
||||
validated_path = self._validate_path(path)
|
||||
if not validated_path:
|
||||
return "Error: Invalid path."
|
||||
|
||||
if validated_path == "/" or validated_path.endswith("/"):
|
||||
return "Error: Cannot create a file at directory path."
|
||||
|
||||
self.collection.update_one(
|
||||
{"user_id": self.user_id, "tool_id": self.tool_id, "path": validated_path},
|
||||
{
|
||||
"$set": {
|
||||
"content": file_text,
|
||||
"updated_at": datetime.now()
|
||||
}
|
||||
},
|
||||
upsert=True
|
||||
)
|
||||
|
||||
return f"File created: {validated_path}"
|
||||
|
||||
def _str_replace(self, path: str, old_str: str, new_str: str) -> str:
|
||||
"""Replace text in a file."""
|
||||
validated_path = self._validate_path(path)
|
||||
if not validated_path:
|
||||
return "Error: Invalid path."
|
||||
|
||||
if not old_str:
|
||||
return "Error: old_str is required."
|
||||
|
||||
doc = self.collection.find_one({"user_id": self.user_id, "tool_id": self.tool_id, "path": validated_path})
|
||||
|
||||
if not doc or not doc.get("content"):
|
||||
return f"Error: File not found: {validated_path}"
|
||||
|
||||
current_content = str(doc["content"])
|
||||
|
||||
# Check if old_str exists (case-insensitive)
|
||||
if old_str.lower() not in current_content.lower():
|
||||
return f"Error: String '{old_str}' not found in file."
|
||||
|
||||
# Replace the string (case-insensitive)
|
||||
import re as regex_module
|
||||
updated_content = regex_module.sub(regex_module.escape(old_str), new_str, current_content, flags=regex_module.IGNORECASE)
|
||||
|
||||
self.collection.update_one(
|
||||
{"user_id": self.user_id, "tool_id": self.tool_id, "path": validated_path},
|
||||
{
|
||||
"$set": {
|
||||
"content": updated_content,
|
||||
"updated_at": datetime.now()
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
return f"File updated: {validated_path}"
|
||||
|
||||
def _insert(self, path: str, insert_line: int, insert_text: str) -> str:
|
||||
"""Insert text at a specific line."""
|
||||
validated_path = self._validate_path(path)
|
||||
if not validated_path:
|
||||
return "Error: Invalid path."
|
||||
|
||||
if not insert_text:
|
||||
return "Error: insert_text is required."
|
||||
|
||||
doc = self.collection.find_one({"user_id": self.user_id, "tool_id": self.tool_id, "path": validated_path})
|
||||
|
||||
if not doc or not doc.get("content"):
|
||||
return f"Error: File not found: {validated_path}"
|
||||
|
||||
current_content = str(doc["content"])
|
||||
lines = current_content.split("\n")
|
||||
|
||||
# Convert to 0-indexed
|
||||
index = insert_line - 1
|
||||
if index < 0 or index > len(lines):
|
||||
return f"Error: Invalid line number. File has {len(lines)} lines."
|
||||
|
||||
lines.insert(index, insert_text)
|
||||
updated_content = "\n".join(lines)
|
||||
|
||||
self.collection.update_one(
|
||||
{"user_id": self.user_id, "tool_id": self.tool_id, "path": validated_path},
|
||||
{
|
||||
"$set": {
|
||||
"content": updated_content,
|
||||
"updated_at": datetime.now()
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
return f"Text inserted at line {insert_line} in {validated_path}"
|
||||
|
||||
def _delete(self, path: str) -> str:
|
||||
"""Delete a file or directory."""
|
||||
validated_path = self._validate_path(path)
|
||||
if not validated_path:
|
||||
return "Error: Invalid path."
|
||||
|
||||
if validated_path == "/":
|
||||
# Delete all files for this user and tool
|
||||
result = self.collection.delete_many({"user_id": self.user_id, "tool_id": self.tool_id})
|
||||
return f"Deleted {result.deleted_count} file(s) from memory."
|
||||
|
||||
# Check if it's a directory (ends with /)
|
||||
if validated_path.endswith("/"):
|
||||
# Delete all files in directory
|
||||
result = self.collection.delete_many({
|
||||
"user_id": self.user_id,
|
||||
"tool_id": self.tool_id,
|
||||
"path": {"$regex": f"^{re.escape(validated_path)}"}
|
||||
})
|
||||
return f"Deleted directory and {result.deleted_count} file(s)."
|
||||
|
||||
# Try to delete as directory first (without trailing slash)
|
||||
# Check if any files start with this path + /
|
||||
search_path = validated_path + "/"
|
||||
directory_result = self.collection.delete_many({
|
||||
"user_id": self.user_id,
|
||||
"tool_id": self.tool_id,
|
||||
"path": {"$regex": f"^{re.escape(search_path)}"}
|
||||
})
|
||||
|
||||
if directory_result.deleted_count > 0:
|
||||
return f"Deleted directory and {directory_result.deleted_count} file(s)."
|
||||
|
||||
# Delete single file
|
||||
result = self.collection.delete_one({
|
||||
"user_id": self.user_id,
|
||||
"tool_id": self.tool_id,
|
||||
"path": validated_path
|
||||
})
|
||||
|
||||
if result.deleted_count:
|
||||
return f"Deleted: {validated_path}"
|
||||
return f"Error: File not found: {validated_path}"
|
||||
|
||||
def _rename(self, old_path: str, new_path: str) -> str:
|
||||
"""Rename or move a file/directory."""
|
||||
validated_old = self._validate_path(old_path)
|
||||
validated_new = self._validate_path(new_path)
|
||||
|
||||
if not validated_old or not validated_new:
|
||||
return "Error: Invalid path."
|
||||
|
||||
if validated_old == "/" or validated_new == "/":
|
||||
return "Error: Cannot rename root directory."
|
||||
|
||||
# Check if renaming a directory
|
||||
if validated_old.endswith("/"):
|
||||
# Ensure validated_new also ends with / for proper path replacement
|
||||
if not validated_new.endswith("/"):
|
||||
validated_new = validated_new + "/"
|
||||
|
||||
# Find all files in the old directory
|
||||
docs = list(self.collection.find({
|
||||
"user_id": self.user_id,
|
||||
"tool_id": self.tool_id,
|
||||
"path": {"$regex": f"^{re.escape(validated_old)}"}
|
||||
}))
|
||||
|
||||
if not docs:
|
||||
return f"Error: Directory not found: {validated_old}"
|
||||
|
||||
# Update paths for all files
|
||||
for doc in docs:
|
||||
old_file_path = doc["path"]
|
||||
new_file_path = old_file_path.replace(validated_old, validated_new, 1)
|
||||
|
||||
self.collection.update_one(
|
||||
{"_id": doc["_id"]},
|
||||
{"$set": {"path": new_file_path, "updated_at": datetime.now()}}
|
||||
)
|
||||
|
||||
return f"Renamed directory: {validated_old} -> {validated_new} ({len(docs)} files)"
|
||||
|
||||
# Rename single file
|
||||
doc = self.collection.find_one({
|
||||
"user_id": self.user_id,
|
||||
"tool_id": self.tool_id,
|
||||
"path": validated_old
|
||||
})
|
||||
|
||||
if not doc:
|
||||
return f"Error: File not found: {validated_old}"
|
||||
|
||||
# Check if new path already exists
|
||||
existing = self.collection.find_one({
|
||||
"user_id": self.user_id,
|
||||
"tool_id": self.tool_id,
|
||||
"path": validated_new
|
||||
})
|
||||
|
||||
if existing:
|
||||
return f"Error: File already exists at {validated_new}"
|
||||
|
||||
# Delete the old document and create a new one with the new path
|
||||
self.collection.delete_one({"user_id": self.user_id, "tool_id": self.tool_id, "path": validated_old})
|
||||
self.collection.insert_one({
|
||||
"user_id": self.user_id,
|
||||
"tool_id": self.tool_id,
|
||||
"path": validated_new,
|
||||
"content": doc.get("content", ""),
|
||||
"updated_at": datetime.now()
|
||||
})
|
||||
|
||||
return f"Renamed: {validated_old} -> {validated_new}"
|
||||
@@ -1,223 +0,0 @@
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
import uuid
|
||||
|
||||
from .base import Tool
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.core.settings import settings
|
||||
|
||||
|
||||
class NotesTool(Tool):
|
||||
"""Notepad
|
||||
|
||||
Single note. Supports viewing, overwriting, string replacement.
|
||||
"""
|
||||
|
||||
def __init__(self, tool_config: Optional[Dict[str, Any]] = None, user_id: Optional[str] = None) -> None:
|
||||
"""Initialize the tool.
|
||||
|
||||
Args:
|
||||
tool_config: Optional tool configuration. Should include:
|
||||
- tool_id: Unique identifier for this notes tool instance (from user_tools._id)
|
||||
This ensures each user's tool configuration has isolated notes
|
||||
user_id: The authenticated user's id (should come from decoded_token["sub"]).
|
||||
"""
|
||||
self.user_id: Optional[str] = user_id
|
||||
|
||||
# Get tool_id from configuration (passed from user_tools._id in production)
|
||||
# In production, tool_id is the MongoDB ObjectId string from user_tools collection
|
||||
if tool_config and "tool_id" in tool_config:
|
||||
self.tool_id = tool_config["tool_id"]
|
||||
elif user_id:
|
||||
# Fallback for backward compatibility or testing
|
||||
self.tool_id = f"default_{user_id}"
|
||||
else:
|
||||
# Last resort fallback (shouldn't happen in normal use)
|
||||
self.tool_id = str(uuid.uuid4())
|
||||
|
||||
db = MongoDB.get_client()[settings.MONGO_DB_NAME]
|
||||
self.collection = db["notes"]
|
||||
|
||||
self._last_artifact_id: Optional[str] = None
|
||||
|
||||
# -----------------------------
|
||||
# Action implementations
|
||||
# -----------------------------
|
||||
def execute_action(self, action_name: str, **kwargs: Any) -> str:
|
||||
"""Execute an action by name.
|
||||
|
||||
Args:
|
||||
action_name: One of view, overwrite, str_replace, insert, delete.
|
||||
**kwargs: Parameters for the action.
|
||||
|
||||
Returns:
|
||||
A human-readable string result.
|
||||
"""
|
||||
if not self.user_id:
|
||||
return "Error: NotesTool requires a valid user_id."
|
||||
|
||||
self._last_artifact_id = None
|
||||
|
||||
if action_name == "view":
|
||||
return self._get_note()
|
||||
|
||||
if action_name == "overwrite":
|
||||
return self._overwrite_note(kwargs.get("text", ""))
|
||||
|
||||
if action_name == "str_replace":
|
||||
return self._str_replace(kwargs.get("old_str", ""), kwargs.get("new_str", ""))
|
||||
|
||||
if action_name == "insert":
|
||||
return self._insert(kwargs.get("line_number", 1), kwargs.get("text", ""))
|
||||
|
||||
if action_name == "delete":
|
||||
return self._delete_note()
|
||||
|
||||
return f"Unknown action: {action_name}"
|
||||
|
||||
def get_actions_metadata(self) -> List[Dict[str, Any]]:
|
||||
"""Return JSON metadata describing supported actions for tool schemas."""
|
||||
return [
|
||||
{
|
||||
"name": "view",
|
||||
"description": "Retrieve the user's note.",
|
||||
"parameters": {"type": "object", "properties": {}},
|
||||
},
|
||||
{
|
||||
"name": "overwrite",
|
||||
"description": "Replace the entire note content (creates if doesn't exist).",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"text": {"type": "string", "description": "New note content."}
|
||||
},
|
||||
"required": ["text"],
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "str_replace",
|
||||
"description": "Replace occurrences of old_str with new_str in the note.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"old_str": {"type": "string", "description": "String to find."},
|
||||
"new_str": {"type": "string", "description": "String to replace with."}
|
||||
},
|
||||
"required": ["old_str", "new_str"],
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "insert",
|
||||
"description": "Insert text at the specified line number (1-indexed).",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"line_number": {"type": "integer", "description": "Line number to insert at (1-indexed)."},
|
||||
"text": {"type": "string", "description": "Text to insert."}
|
||||
},
|
||||
"required": ["line_number", "text"],
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "delete",
|
||||
"description": "Delete the user's note.",
|
||||
"parameters": {"type": "object", "properties": {}},
|
||||
},
|
||||
]
|
||||
|
||||
def get_config_requirements(self) -> Dict[str, Any]:
|
||||
"""Return configuration requirements (none for now)."""
|
||||
return {}
|
||||
|
||||
def get_artifact_id(self, action_name: str, **kwargs: Any) -> Optional[str]:
|
||||
return self._last_artifact_id
|
||||
|
||||
# -----------------------------
|
||||
# Internal helpers (single-note)
|
||||
# -----------------------------
|
||||
def _get_note(self) -> str:
|
||||
doc = self.collection.find_one({"user_id": self.user_id, "tool_id": self.tool_id})
|
||||
if not doc or not doc.get("note"):
|
||||
return "No note found."
|
||||
if doc.get("_id") is not None:
|
||||
self._last_artifact_id = str(doc.get("_id"))
|
||||
return str(doc["note"])
|
||||
|
||||
def _overwrite_note(self, content: str) -> str:
|
||||
content = (content or "").strip()
|
||||
if not content:
|
||||
return "Note content required."
|
||||
result = self.collection.find_one_and_update(
|
||||
{"user_id": self.user_id, "tool_id": self.tool_id},
|
||||
{"$set": {"note": content, "updated_at": datetime.utcnow()}},
|
||||
upsert=True,
|
||||
return_document=True,
|
||||
)
|
||||
if result and result.get("_id") is not None:
|
||||
self._last_artifact_id = str(result.get("_id"))
|
||||
return "Note saved."
|
||||
|
||||
def _str_replace(self, old_str: str, new_str: str) -> str:
|
||||
if not old_str:
|
||||
return "old_str is required."
|
||||
|
||||
doc = self.collection.find_one({"user_id": self.user_id, "tool_id": self.tool_id})
|
||||
if not doc or not doc.get("note"):
|
||||
return "No note found."
|
||||
|
||||
current_note = str(doc["note"])
|
||||
|
||||
# Case-insensitive search
|
||||
if old_str.lower() not in current_note.lower():
|
||||
return f"String '{old_str}' not found in note."
|
||||
|
||||
# Case-insensitive replacement
|
||||
import re
|
||||
updated_note = re.sub(re.escape(old_str), new_str, current_note, flags=re.IGNORECASE)
|
||||
|
||||
result = self.collection.find_one_and_update(
|
||||
{"user_id": self.user_id, "tool_id": self.tool_id},
|
||||
{"$set": {"note": updated_note, "updated_at": datetime.utcnow()}},
|
||||
return_document=True,
|
||||
)
|
||||
if result and result.get("_id") is not None:
|
||||
self._last_artifact_id = str(result.get("_id"))
|
||||
return "Note updated."
|
||||
|
||||
def _insert(self, line_number: int, text: str) -> str:
|
||||
if not text:
|
||||
return "Text is required."
|
||||
|
||||
doc = self.collection.find_one({"user_id": self.user_id, "tool_id": self.tool_id})
|
||||
if not doc or not doc.get("note"):
|
||||
return "No note found."
|
||||
|
||||
current_note = str(doc["note"])
|
||||
lines = current_note.split("\n")
|
||||
|
||||
# Convert to 0-indexed and validate
|
||||
index = line_number - 1
|
||||
if index < 0 or index > len(lines):
|
||||
return f"Invalid line number. Note has {len(lines)} lines."
|
||||
|
||||
lines.insert(index, text)
|
||||
updated_note = "\n".join(lines)
|
||||
|
||||
result = self.collection.find_one_and_update(
|
||||
{"user_id": self.user_id, "tool_id": self.tool_id},
|
||||
{"$set": {"note": updated_note, "updated_at": datetime.utcnow()}},
|
||||
return_document=True,
|
||||
)
|
||||
if result and result.get("_id") is not None:
|
||||
self._last_artifact_id = str(result.get("_id"))
|
||||
return "Text inserted."
|
||||
|
||||
def _delete_note(self) -> str:
|
||||
doc = self.collection.find_one_and_delete(
|
||||
{"user_id": self.user_id, "tool_id": self.tool_id}
|
||||
)
|
||||
if not doc:
|
||||
return "No note found to delete."
|
||||
if doc.get("_id") is not None:
|
||||
self._last_artifact_id = str(doc.get("_id"))
|
||||
return "Note deleted."
|
||||
@@ -116,13 +116,12 @@ class NtfyTool(Tool):
|
||||
]
|
||||
|
||||
def get_config_requirements(self):
|
||||
"""
|
||||
Specify the configuration requirements.
|
||||
|
||||
Returns:
|
||||
dict: Dictionary describing required config parameters.
|
||||
"""
|
||||
return {
|
||||
"token": {
|
||||
"type": "string",
|
||||
"label": "Access Token",
|
||||
"description": "Ntfy access token for authentication",
|
||||
"required": True,
|
||||
"secret": True,
|
||||
"order": 1,
|
||||
},
|
||||
"token": {"type": "string", "description": "Access token for authentication"},
|
||||
}
|
||||
@@ -1,12 +1,6 @@
|
||||
import logging
|
||||
|
||||
import psycopg2
|
||||
|
||||
from application.agents.tools.base import Tool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PostgresTool(Tool):
|
||||
"""
|
||||
PostgreSQL Database Tool
|
||||
@@ -23,15 +17,17 @@ class PostgresTool(Tool):
|
||||
"postgres_execute_sql": self._execute_sql,
|
||||
"postgres_get_schema": self._get_schema,
|
||||
}
|
||||
if action_name not in actions:
|
||||
|
||||
if action_name in actions:
|
||||
return actions[action_name](**kwargs)
|
||||
else:
|
||||
raise ValueError(f"Unknown action: {action_name}")
|
||||
return actions[action_name](**kwargs)
|
||||
|
||||
def _execute_sql(self, sql_query):
|
||||
"""
|
||||
Executes an SQL query against the PostgreSQL database using a connection string.
|
||||
"""
|
||||
conn = None
|
||||
conn = None # Initialize conn to None for error handling
|
||||
try:
|
||||
conn = psycopg2.connect(self.connection_string)
|
||||
cur = conn.cursor()
|
||||
@@ -39,9 +35,7 @@ class PostgresTool(Tool):
|
||||
conn.commit()
|
||||
|
||||
if sql_query.strip().lower().startswith("select"):
|
||||
column_names = (
|
||||
[desc[0] for desc in cur.description] if cur.description else []
|
||||
)
|
||||
column_names = [desc[0] for desc in cur.description] if cur.description else []
|
||||
results = []
|
||||
rows = cur.fetchall()
|
||||
for row in rows:
|
||||
@@ -49,9 +43,7 @@ class PostgresTool(Tool):
|
||||
response_data = {"data": results, "column_names": column_names}
|
||||
else:
|
||||
row_count = cur.rowcount
|
||||
response_data = {
|
||||
"message": f"Query executed successfully, {row_count} rows affected."
|
||||
}
|
||||
response_data = {"message": f"Query executed successfully, {row_count} rows affected."}
|
||||
|
||||
cur.close()
|
||||
return {
|
||||
@@ -62,27 +54,26 @@ class PostgresTool(Tool):
|
||||
|
||||
except psycopg2.Error as e:
|
||||
error_message = f"Database error: {e}"
|
||||
logger.error("PostgreSQL execute_sql error: %s", e)
|
||||
print(f"Database error: {e}")
|
||||
return {
|
||||
"status_code": 500,
|
||||
"message": "Failed to execute SQL query.",
|
||||
"error": error_message,
|
||||
}
|
||||
finally:
|
||||
if conn:
|
||||
if conn: # Ensure connection is closed even if errors occur
|
||||
conn.close()
|
||||
|
||||
def _get_schema(self, db_name):
|
||||
"""
|
||||
Retrieves the schema of the PostgreSQL database using a connection string.
|
||||
"""
|
||||
conn = None
|
||||
conn = None # Initialize conn to None for error handling
|
||||
try:
|
||||
conn = psycopg2.connect(self.connection_string)
|
||||
cur = conn.cursor()
|
||||
|
||||
cur.execute(
|
||||
"""
|
||||
cur.execute("""
|
||||
SELECT
|
||||
table_name,
|
||||
column_name,
|
||||
@@ -96,22 +87,19 @@ class PostgresTool(Tool):
|
||||
ORDER BY
|
||||
table_name,
|
||||
ordinal_position;
|
||||
"""
|
||||
)
|
||||
""")
|
||||
|
||||
schema_data = {}
|
||||
for row in cur.fetchall():
|
||||
table_name, column_name, data_type, column_default, is_nullable = row
|
||||
if table_name not in schema_data:
|
||||
schema_data[table_name] = []
|
||||
schema_data[table_name].append(
|
||||
{
|
||||
"column_name": column_name,
|
||||
"data_type": data_type,
|
||||
"column_default": column_default,
|
||||
"is_nullable": is_nullable,
|
||||
}
|
||||
)
|
||||
schema_data[table_name].append({
|
||||
"column_name": column_name,
|
||||
"data_type": data_type,
|
||||
"column_default": column_default,
|
||||
"is_nullable": is_nullable
|
||||
})
|
||||
|
||||
cur.close()
|
||||
return {
|
||||
@@ -122,14 +110,14 @@ class PostgresTool(Tool):
|
||||
|
||||
except psycopg2.Error as e:
|
||||
error_message = f"Database error: {e}"
|
||||
logger.error("PostgreSQL get_schema error: %s", e)
|
||||
print(f"Database error: {e}")
|
||||
return {
|
||||
"status_code": 500,
|
||||
"message": "Failed to retrieve database schema.",
|
||||
"error": error_message,
|
||||
}
|
||||
finally:
|
||||
if conn:
|
||||
if conn: # Ensure connection is closed even if errors occur
|
||||
conn.close()
|
||||
|
||||
def get_actions_metadata(self):
|
||||
@@ -170,10 +158,6 @@ class PostgresTool(Tool):
|
||||
return {
|
||||
"token": {
|
||||
"type": "string",
|
||||
"label": "Connection String",
|
||||
"description": "PostgreSQL database connection string",
|
||||
"required": True,
|
||||
"secret": True,
|
||||
"order": 1,
|
||||
"description": "PostgreSQL database connection string (e.g., 'postgresql://user:password@host:port/dbname')",
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -1,7 +1,7 @@
|
||||
import requests
|
||||
from markdownify import markdownify
|
||||
from application.agents.tools.base import Tool
|
||||
from application.core.url_validation import validate_url, SSRFError
|
||||
from urllib.parse import urlparse
|
||||
|
||||
class ReadWebpageTool(Tool):
|
||||
"""
|
||||
@@ -31,12 +31,11 @@ class ReadWebpageTool(Tool):
|
||||
if not url:
|
||||
return "Error: URL parameter is missing."
|
||||
|
||||
# Validate URL to prevent SSRF attacks
|
||||
try:
|
||||
url = validate_url(url)
|
||||
except SSRFError as e:
|
||||
return f"Error: URL validation failed - {e}"
|
||||
|
||||
# Ensure the URL has a scheme (if not, default to http)
|
||||
parsed_url = urlparse(url)
|
||||
if not parsed_url.scheme:
|
||||
url = "http://" + url
|
||||
|
||||
try:
|
||||
response = requests.get(url, timeout=10, headers={'User-Agent': 'DocsGPT-Agent/1.0'})
|
||||
response.raise_for_status() # Raise an exception for HTTP errors (4xx or 5xx)
|
||||
|
||||
@@ -1,342 +0,0 @@
|
||||
"""
|
||||
API Specification Parser
|
||||
|
||||
Parses OpenAPI 3.x and Swagger 2.0 specifications and converts them
|
||||
to API Tool action definitions for use in DocsGPT.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import yaml
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SUPPORTED_METHODS = frozenset(
|
||||
{"get", "post", "put", "delete", "patch", "head", "options"}
|
||||
)
|
||||
|
||||
|
||||
def parse_spec(spec_content: str) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]:
|
||||
"""
|
||||
Parse an API specification and convert operations to action definitions.
|
||||
|
||||
Supports OpenAPI 3.x and Swagger 2.0 formats in JSON or YAML.
|
||||
|
||||
Args:
|
||||
spec_content: Raw specification content as string
|
||||
|
||||
Returns:
|
||||
Tuple of (metadata dict, list of action dicts)
|
||||
|
||||
Raises:
|
||||
ValueError: If the spec is invalid or uses an unsupported format
|
||||
"""
|
||||
spec = _load_spec(spec_content)
|
||||
_validate_spec(spec)
|
||||
|
||||
is_swagger = "swagger" in spec
|
||||
metadata = _extract_metadata(spec, is_swagger)
|
||||
actions = _extract_actions(spec, is_swagger)
|
||||
|
||||
return metadata, actions
|
||||
|
||||
|
||||
def _load_spec(content: str) -> Dict[str, Any]:
|
||||
"""Parse spec content from JSON or YAML string."""
|
||||
content = content.strip()
|
||||
if not content:
|
||||
raise ValueError("Empty specification content")
|
||||
try:
|
||||
if content.startswith("{"):
|
||||
return json.loads(content)
|
||||
return yaml.safe_load(content)
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValueError(f"Invalid JSON format: {e.msg}")
|
||||
except yaml.YAMLError as e:
|
||||
raise ValueError(f"Invalid YAML format: {e}")
|
||||
|
||||
|
||||
def _validate_spec(spec: Dict[str, Any]) -> None:
|
||||
"""Validate spec version and required fields."""
|
||||
if not isinstance(spec, dict):
|
||||
raise ValueError("Specification must be a valid object")
|
||||
openapi_version = spec.get("openapi", "")
|
||||
swagger_version = spec.get("swagger", "")
|
||||
|
||||
if not (openapi_version.startswith("3.") or swagger_version == "2.0"):
|
||||
raise ValueError(
|
||||
"Unsupported specification version. Expected OpenAPI 3.x or Swagger 2.0"
|
||||
)
|
||||
if "paths" not in spec or not spec["paths"]:
|
||||
raise ValueError("No API paths defined in the specification")
|
||||
|
||||
|
||||
def _extract_metadata(spec: Dict[str, Any], is_swagger: bool) -> Dict[str, Any]:
|
||||
"""Extract API metadata from specification."""
|
||||
info = spec.get("info", {})
|
||||
base_url = _get_base_url(spec, is_swagger)
|
||||
|
||||
return {
|
||||
"title": info.get("title", "Untitled API"),
|
||||
"description": (info.get("description", "") or "")[:500],
|
||||
"version": info.get("version", ""),
|
||||
"base_url": base_url,
|
||||
}
|
||||
|
||||
|
||||
def _get_base_url(spec: Dict[str, Any], is_swagger: bool) -> str:
|
||||
"""Extract base URL from spec (handles both OpenAPI 3.x and Swagger 2.0)."""
|
||||
if is_swagger:
|
||||
schemes = spec.get("schemes", ["https"])
|
||||
host = spec.get("host", "")
|
||||
base_path = spec.get("basePath", "")
|
||||
if host:
|
||||
scheme = schemes[0] if schemes else "https"
|
||||
return f"{scheme}://{host}{base_path}".rstrip("/")
|
||||
return ""
|
||||
servers = spec.get("servers", [])
|
||||
if servers and isinstance(servers, list) and servers[0].get("url"):
|
||||
return servers[0]["url"].rstrip("/")
|
||||
return ""
|
||||
|
||||
|
||||
def _extract_actions(spec: Dict[str, Any], is_swagger: bool) -> List[Dict[str, Any]]:
|
||||
"""Extract all API operations as action definitions."""
|
||||
actions = []
|
||||
paths = spec.get("paths", {})
|
||||
base_url = _get_base_url(spec, is_swagger)
|
||||
|
||||
components = spec.get("components", {})
|
||||
definitions = spec.get("definitions", {})
|
||||
|
||||
for path, path_item in paths.items():
|
||||
if not isinstance(path_item, dict):
|
||||
continue
|
||||
path_params = path_item.get("parameters", [])
|
||||
|
||||
for method in SUPPORTED_METHODS:
|
||||
operation = path_item.get(method)
|
||||
if not isinstance(operation, dict):
|
||||
continue
|
||||
try:
|
||||
action = _build_action(
|
||||
path=path,
|
||||
method=method,
|
||||
operation=operation,
|
||||
path_params=path_params,
|
||||
base_url=base_url,
|
||||
components=components,
|
||||
definitions=definitions,
|
||||
is_swagger=is_swagger,
|
||||
)
|
||||
actions.append(action)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to parse operation {method.upper()} {path}: {e}"
|
||||
)
|
||||
continue
|
||||
return actions
|
||||
|
||||
|
||||
def _build_action(
|
||||
path: str,
|
||||
method: str,
|
||||
operation: Dict[str, Any],
|
||||
path_params: List[Dict],
|
||||
base_url: str,
|
||||
components: Dict[str, Any],
|
||||
definitions: Dict[str, Any],
|
||||
is_swagger: bool,
|
||||
) -> Dict[str, Any]:
|
||||
"""Build a single action from an API operation."""
|
||||
action_name = _generate_action_name(operation, method, path)
|
||||
full_url = f"{base_url}{path}" if base_url else path
|
||||
|
||||
all_params = path_params + operation.get("parameters", [])
|
||||
query_params, headers = _categorize_parameters(all_params, components, definitions)
|
||||
|
||||
body, body_content_type = _extract_request_body(
|
||||
operation, components, definitions, is_swagger
|
||||
)
|
||||
|
||||
description = operation.get("summary", "") or operation.get("description", "")
|
||||
|
||||
return {
|
||||
"name": action_name,
|
||||
"url": full_url,
|
||||
"method": method.upper(),
|
||||
"description": (description or "")[:500],
|
||||
"query_params": {"type": "object", "properties": query_params},
|
||||
"headers": {"type": "object", "properties": headers},
|
||||
"body": {"type": "object", "properties": body},
|
||||
"body_content_type": body_content_type,
|
||||
"active": True,
|
||||
}
|
||||
|
||||
|
||||
def _generate_action_name(operation: Dict[str, Any], method: str, path: str) -> str:
|
||||
"""Generate a valid action name from operationId or method+path."""
|
||||
if operation.get("operationId"):
|
||||
name = operation["operationId"]
|
||||
else:
|
||||
path_slug = re.sub(r"[{}]", "", path)
|
||||
path_slug = re.sub(r"[^a-zA-Z0-9]", "_", path_slug)
|
||||
path_slug = re.sub(r"_+", "_", path_slug).strip("_")
|
||||
name = f"{method}_{path_slug}"
|
||||
name = re.sub(r"[^a-zA-Z0-9_-]", "_", name)
|
||||
return name[:64]
|
||||
|
||||
|
||||
def _categorize_parameters(
|
||||
parameters: List[Dict],
|
||||
components: Dict[str, Any],
|
||||
definitions: Dict[str, Any],
|
||||
) -> Tuple[Dict, Dict]:
|
||||
"""Categorize parameters into query params and headers."""
|
||||
query_params = {}
|
||||
headers = {}
|
||||
|
||||
for param in parameters:
|
||||
resolved = _resolve_ref(param, components, definitions)
|
||||
if not resolved or "name" not in resolved:
|
||||
continue
|
||||
location = resolved.get("in", "query")
|
||||
prop = _param_to_property(resolved)
|
||||
|
||||
if location in ("query", "path"):
|
||||
query_params[resolved["name"]] = prop
|
||||
elif location == "header":
|
||||
headers[resolved["name"]] = prop
|
||||
return query_params, headers
|
||||
|
||||
|
||||
def _param_to_property(param: Dict) -> Dict[str, Any]:
|
||||
"""Convert an API parameter to an action property definition."""
|
||||
schema = param.get("schema", {})
|
||||
param_type = schema.get("type", param.get("type", "string"))
|
||||
|
||||
mapped_type = "integer" if param_type in ("integer", "number") else "string"
|
||||
|
||||
return {
|
||||
"type": mapped_type,
|
||||
"description": (param.get("description", "") or "")[:200],
|
||||
"value": "",
|
||||
"filled_by_llm": param.get("required", False),
|
||||
"required": param.get("required", False),
|
||||
}
|
||||
|
||||
|
||||
def _extract_request_body(
|
||||
operation: Dict[str, Any],
|
||||
components: Dict[str, Any],
|
||||
definitions: Dict[str, Any],
|
||||
is_swagger: bool,
|
||||
) -> Tuple[Dict, str]:
|
||||
"""Extract request body schema and content type."""
|
||||
content_types = [
|
||||
"application/json",
|
||||
"application/x-www-form-urlencoded",
|
||||
"multipart/form-data",
|
||||
"text/plain",
|
||||
"application/xml",
|
||||
]
|
||||
|
||||
if is_swagger:
|
||||
consumes = operation.get("consumes", [])
|
||||
body_param = next(
|
||||
(p for p in operation.get("parameters", []) if p.get("in") == "body"), None
|
||||
)
|
||||
if not body_param:
|
||||
return {}, "application/json"
|
||||
selected_type = consumes[0] if consumes else "application/json"
|
||||
schema = body_param.get("schema", {})
|
||||
else:
|
||||
request_body = operation.get("requestBody", {})
|
||||
if not request_body:
|
||||
return {}, "application/json"
|
||||
request_body = _resolve_ref(request_body, components, definitions)
|
||||
content = request_body.get("content", {})
|
||||
|
||||
selected_type = "application/json"
|
||||
schema = {}
|
||||
|
||||
for ct in content_types:
|
||||
if ct in content:
|
||||
selected_type = ct
|
||||
schema = content[ct].get("schema", {})
|
||||
break
|
||||
if not schema and content:
|
||||
first_type = next(iter(content))
|
||||
selected_type = first_type
|
||||
schema = content[first_type].get("schema", {})
|
||||
properties = _schema_to_properties(schema, components, definitions)
|
||||
return properties, selected_type
|
||||
|
||||
|
||||
def _schema_to_properties(
|
||||
schema: Dict,
|
||||
components: Dict[str, Any],
|
||||
definitions: Dict[str, Any],
|
||||
depth: int = 0,
|
||||
) -> Dict[str, Any]:
|
||||
"""Convert schema to action body properties (limited depth to prevent recursion)."""
|
||||
if depth > 3:
|
||||
return {}
|
||||
schema = _resolve_ref(schema, components, definitions)
|
||||
if not schema or not isinstance(schema, dict):
|
||||
return {}
|
||||
properties = {}
|
||||
schema_type = schema.get("type", "object")
|
||||
|
||||
if schema_type == "object":
|
||||
required_fields = set(schema.get("required", []))
|
||||
for prop_name, prop_schema in schema.get("properties", {}).items():
|
||||
resolved = _resolve_ref(prop_schema, components, definitions)
|
||||
if not isinstance(resolved, dict):
|
||||
continue
|
||||
prop_type = resolved.get("type", "string")
|
||||
mapped_type = "integer" if prop_type in ("integer", "number") else "string"
|
||||
|
||||
properties[prop_name] = {
|
||||
"type": mapped_type,
|
||||
"description": (resolved.get("description", "") or "")[:200],
|
||||
"value": "",
|
||||
"filled_by_llm": prop_name in required_fields,
|
||||
"required": prop_name in required_fields,
|
||||
}
|
||||
return properties
|
||||
|
||||
|
||||
def _resolve_ref(
|
||||
obj: Any,
|
||||
components: Dict[str, Any],
|
||||
definitions: Dict[str, Any],
|
||||
) -> Optional[Dict]:
|
||||
"""Resolve $ref references in the specification."""
|
||||
if not isinstance(obj, dict):
|
||||
return obj if isinstance(obj, dict) else None
|
||||
if "$ref" not in obj:
|
||||
return obj
|
||||
ref_path = obj["$ref"]
|
||||
|
||||
if ref_path.startswith("#/components/"):
|
||||
parts = ref_path.replace("#/components/", "").split("/")
|
||||
return _traverse_path(components, parts)
|
||||
elif ref_path.startswith("#/definitions/"):
|
||||
parts = ref_path.replace("#/definitions/", "").split("/")
|
||||
return _traverse_path(definitions, parts)
|
||||
logger.debug(f"Unsupported ref path: {ref_path}")
|
||||
return None
|
||||
|
||||
|
||||
def _traverse_path(obj: Dict, parts: List[str]) -> Optional[Dict]:
|
||||
"""Traverse a nested dictionary using path parts."""
|
||||
try:
|
||||
for part in parts:
|
||||
obj = obj[part]
|
||||
return obj if isinstance(obj, dict) else None
|
||||
except (KeyError, TypeError):
|
||||
return None
|
||||
@@ -1,11 +1,6 @@
|
||||
import logging
|
||||
|
||||
import requests
|
||||
|
||||
from application.agents.tools.base import Tool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TelegramTool(Tool):
|
||||
"""
|
||||
@@ -23,19 +18,21 @@ class TelegramTool(Tool):
|
||||
"telegram_send_message": self._send_message,
|
||||
"telegram_send_image": self._send_image,
|
||||
}
|
||||
if action_name not in actions:
|
||||
|
||||
if action_name in actions:
|
||||
return actions[action_name](**kwargs)
|
||||
else:
|
||||
raise ValueError(f"Unknown action: {action_name}")
|
||||
return actions[action_name](**kwargs)
|
||||
|
||||
def _send_message(self, text, chat_id):
|
||||
logger.debug("Sending Telegram message to chat_id=%s", chat_id)
|
||||
print(f"Sending message: {text}")
|
||||
url = f"https://api.telegram.org/bot{self.token}/sendMessage"
|
||||
payload = {"chat_id": chat_id, "text": text}
|
||||
response = requests.post(url, data=payload)
|
||||
return {"status_code": response.status_code, "message": "Message sent"}
|
||||
|
||||
def _send_image(self, image_url, chat_id):
|
||||
logger.debug("Sending Telegram image to chat_id=%s", chat_id)
|
||||
print(f"Sending image: {image_url}")
|
||||
url = f"https://api.telegram.org/bot{self.token}/sendPhoto"
|
||||
payload = {"chat_id": chat_id, "photo": image_url}
|
||||
response = requests.post(url, data=payload)
|
||||
@@ -85,12 +82,5 @@ class TelegramTool(Tool):
|
||||
|
||||
def get_config_requirements(self):
|
||||
return {
|
||||
"token": {
|
||||
"type": "string",
|
||||
"label": "Bot Token",
|
||||
"description": "Telegram bot token for authentication",
|
||||
"required": True,
|
||||
"secret": True,
|
||||
"order": 1,
|
||||
},
|
||||
"token": {"type": "string", "description": "Bot token for authentication"},
|
||||
}
|
||||
|
||||
@@ -1,68 +0,0 @@
|
||||
from application.agents.tools.base import Tool
|
||||
|
||||
|
||||
THINK_TOOL_ID = "think"
|
||||
|
||||
THINK_TOOL_ENTRY = {
|
||||
"name": "think",
|
||||
"actions": [
|
||||
{
|
||||
"name": "reason",
|
||||
"description": (
|
||||
"Use this tool to think through your reasoning step by step "
|
||||
"before deciding on your next action. Always reason before "
|
||||
"searching or answering."
|
||||
),
|
||||
"active": True,
|
||||
"parameters": {
|
||||
"properties": {
|
||||
"reasoning": {
|
||||
"type": "string",
|
||||
"description": "Your step-by-step reasoning and analysis",
|
||||
"filled_by_llm": True,
|
||||
"required": True,
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
class ThinkTool(Tool):
|
||||
"""Pseudo-tool that captures chain-of-thought reasoning.
|
||||
|
||||
Returns a short acknowledgment so the LLM can continue.
|
||||
The reasoning content is captured in tool_call data for transparency.
|
||||
"""
|
||||
|
||||
def __init__(self, config=None):
|
||||
pass
|
||||
|
||||
def execute_action(self, action_name: str, **kwargs):
|
||||
return "Continue."
|
||||
|
||||
def get_actions_metadata(self):
|
||||
return [
|
||||
{
|
||||
"name": "reason",
|
||||
"description": (
|
||||
"Use this tool to think through your reasoning step by step "
|
||||
"before deciding on your next action. Always reason before "
|
||||
"searching or answering."
|
||||
),
|
||||
"parameters": {
|
||||
"properties": {
|
||||
"reasoning": {
|
||||
"type": "string",
|
||||
"description": "Your step-by-step reasoning and analysis",
|
||||
"filled_by_llm": True,
|
||||
"required": True,
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
def get_config_requirements(self):
|
||||
return {}
|
||||
@@ -1,333 +0,0 @@
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
import uuid
|
||||
|
||||
from .base import Tool
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.core.settings import settings
|
||||
|
||||
|
||||
class TodoListTool(Tool):
|
||||
"""Todo List
|
||||
|
||||
Manages todo items for users. Supports creating, viewing, updating, and deleting todos.
|
||||
"""
|
||||
|
||||
def __init__(self, tool_config: Optional[Dict[str, Any]] = None, user_id: Optional[str] = None) -> None:
|
||||
"""Initialize the tool.
|
||||
|
||||
Args:
|
||||
tool_config: Optional tool configuration. Should include:
|
||||
- tool_id: Unique identifier for this todo list tool instance (from user_tools._id)
|
||||
This ensures each user's tool configuration has isolated todos
|
||||
user_id: The authenticated user's id (should come from decoded_token["sub"]).
|
||||
"""
|
||||
self.user_id: Optional[str] = user_id
|
||||
|
||||
# Get tool_id from configuration (passed from user_tools._id in production)
|
||||
# In production, tool_id is the MongoDB ObjectId string from user_tools collection
|
||||
if tool_config and "tool_id" in tool_config:
|
||||
self.tool_id = tool_config["tool_id"]
|
||||
elif user_id:
|
||||
# Fallback for backward compatibility or testing
|
||||
self.tool_id = f"default_{user_id}"
|
||||
else:
|
||||
# Last resort fallback (shouldn't happen in normal use)
|
||||
self.tool_id = str(uuid.uuid4())
|
||||
|
||||
db = MongoDB.get_client()[settings.MONGO_DB_NAME]
|
||||
self.collection = db["todos"]
|
||||
|
||||
self._last_artifact_id: Optional[str] = None
|
||||
|
||||
# -----------------------------
|
||||
# Action implementations
|
||||
# -----------------------------
|
||||
def execute_action(self, action_name: str, **kwargs: Any) -> str:
|
||||
"""Execute an action by name.
|
||||
|
||||
Args:
|
||||
action_name: One of list, create, get, update, complete, delete.
|
||||
**kwargs: Parameters for the action.
|
||||
|
||||
Returns:
|
||||
A human-readable string result.
|
||||
"""
|
||||
if not self.user_id:
|
||||
return "Error: TodoListTool requires a valid user_id."
|
||||
|
||||
self._last_artifact_id = None
|
||||
|
||||
if action_name == "list":
|
||||
return self._list()
|
||||
|
||||
if action_name == "create":
|
||||
return self._create(kwargs.get("title", ""))
|
||||
|
||||
if action_name == "get":
|
||||
return self._get(kwargs.get("todo_id"))
|
||||
|
||||
if action_name == "update":
|
||||
return self._update(
|
||||
kwargs.get("todo_id"),
|
||||
kwargs.get("title", "")
|
||||
)
|
||||
|
||||
if action_name == "complete":
|
||||
return self._complete(kwargs.get("todo_id"))
|
||||
|
||||
if action_name == "delete":
|
||||
return self._delete(kwargs.get("todo_id"))
|
||||
|
||||
return f"Unknown action: {action_name}"
|
||||
|
||||
def get_actions_metadata(self) -> List[Dict[str, Any]]:
|
||||
"""Return JSON metadata describing supported actions for tool schemas."""
|
||||
return [
|
||||
{
|
||||
"name": "list",
|
||||
"description": "List all todos for the user.",
|
||||
"parameters": {"type": "object", "properties": {}},
|
||||
},
|
||||
{
|
||||
"name": "create",
|
||||
"description": "Create a new todo item.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"title": {
|
||||
"type": "string",
|
||||
"description": "Title of the todo item."
|
||||
}
|
||||
},
|
||||
"required": ["title"],
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "get",
|
||||
"description": "Get a specific todo by ID.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"todo_id": {
|
||||
"type": "integer",
|
||||
"description": "The ID of the todo to retrieve."
|
||||
}
|
||||
},
|
||||
"required": ["todo_id"],
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "update",
|
||||
"description": "Update a todo's title by ID.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"todo_id": {
|
||||
"type": "integer",
|
||||
"description": "The ID of the todo to update."
|
||||
},
|
||||
"title": {
|
||||
"type": "string",
|
||||
"description": "The new title for the todo."
|
||||
}
|
||||
},
|
||||
"required": ["todo_id", "title"],
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "complete",
|
||||
"description": "Mark a todo as completed.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"todo_id": {
|
||||
"type": "integer",
|
||||
"description": "The ID of the todo to mark as completed."
|
||||
}
|
||||
},
|
||||
"required": ["todo_id"],
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "delete",
|
||||
"description": "Delete a specific todo by ID.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"todo_id": {
|
||||
"type": "integer",
|
||||
"description": "The ID of the todo to delete."
|
||||
}
|
||||
},
|
||||
"required": ["todo_id"],
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
def get_config_requirements(self) -> Dict[str, Any]:
|
||||
"""Return configuration requirements."""
|
||||
return {}
|
||||
|
||||
def get_artifact_id(self, action_name: str, **kwargs: Any) -> Optional[str]:
|
||||
return self._last_artifact_id
|
||||
|
||||
# -----------------------------
|
||||
# Internal helpers
|
||||
# -----------------------------
|
||||
def _coerce_todo_id(self, value: Optional[Any]) -> Optional[int]:
|
||||
"""Convert todo identifiers to sequential integers."""
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
if isinstance(value, int):
|
||||
return value if value > 0 else None
|
||||
|
||||
if isinstance(value, str):
|
||||
stripped = value.strip()
|
||||
if stripped.isdigit():
|
||||
numeric_value = int(stripped)
|
||||
return numeric_value if numeric_value > 0 else None
|
||||
|
||||
return None
|
||||
|
||||
def _get_next_todo_id(self) -> int:
|
||||
"""Get the next sequential todo_id for this user and tool.
|
||||
|
||||
Returns a simple integer (1, 2, 3, ...) scoped to this user/tool.
|
||||
With 5-10 todos max, scanning is negligible.
|
||||
"""
|
||||
query = {"user_id": self.user_id, "tool_id": self.tool_id}
|
||||
todos = list(self.collection.find(query, {"todo_id": 1}))
|
||||
|
||||
# Find the maximum todo_id
|
||||
max_id = 0
|
||||
for todo in todos:
|
||||
todo_id = self._coerce_todo_id(todo.get("todo_id"))
|
||||
if todo_id is not None:
|
||||
max_id = max(max_id, todo_id)
|
||||
|
||||
return max_id + 1
|
||||
|
||||
def _list(self) -> str:
|
||||
"""List all todos for the user."""
|
||||
query = {"user_id": self.user_id, "tool_id": self.tool_id}
|
||||
todos = list(self.collection.find(query))
|
||||
|
||||
if not todos:
|
||||
return "No todos found."
|
||||
|
||||
result_lines = ["Todos:"]
|
||||
for doc in todos:
|
||||
todo_id = doc.get("todo_id")
|
||||
title = doc.get("title", "Untitled")
|
||||
status = doc.get("status", "open")
|
||||
|
||||
line = f"[{todo_id}] {title} ({status})"
|
||||
result_lines.append(line)
|
||||
|
||||
return "\n".join(result_lines)
|
||||
|
||||
def _create(self, title: str) -> str:
|
||||
"""Create a new todo item."""
|
||||
title = (title or "").strip()
|
||||
if not title:
|
||||
return "Error: Title is required."
|
||||
|
||||
now = datetime.now()
|
||||
todo_id = self._get_next_todo_id()
|
||||
|
||||
doc = {
|
||||
"todo_id": todo_id,
|
||||
"user_id": self.user_id,
|
||||
"tool_id": self.tool_id,
|
||||
"title": title,
|
||||
"status": "open",
|
||||
"created_at": now,
|
||||
"updated_at": now,
|
||||
}
|
||||
insert_result = self.collection.insert_one(doc)
|
||||
inserted_id = getattr(insert_result, "inserted_id", None) or doc.get("_id")
|
||||
if inserted_id is not None:
|
||||
self._last_artifact_id = str(inserted_id)
|
||||
return f"Todo created with ID {todo_id}: {title}"
|
||||
|
||||
def _get(self, todo_id: Optional[Any]) -> str:
|
||||
"""Get a specific todo by ID."""
|
||||
parsed_todo_id = self._coerce_todo_id(todo_id)
|
||||
if parsed_todo_id is None:
|
||||
return "Error: todo_id must be a positive integer."
|
||||
|
||||
query = {"user_id": self.user_id, "tool_id": self.tool_id, "todo_id": parsed_todo_id}
|
||||
doc = self.collection.find_one(query)
|
||||
|
||||
if not doc:
|
||||
return f"Error: Todo with ID {parsed_todo_id} not found."
|
||||
|
||||
if doc.get("_id") is not None:
|
||||
self._last_artifact_id = str(doc.get("_id"))
|
||||
|
||||
title = doc.get("title", "Untitled")
|
||||
status = doc.get("status", "open")
|
||||
|
||||
result = f"Todo [{parsed_todo_id}]:\nTitle: {title}\nStatus: {status}"
|
||||
|
||||
return result
|
||||
|
||||
def _update(self, todo_id: Optional[Any], title: str) -> str:
|
||||
"""Update a todo's title by ID."""
|
||||
parsed_todo_id = self._coerce_todo_id(todo_id)
|
||||
if parsed_todo_id is None:
|
||||
return "Error: todo_id must be a positive integer."
|
||||
|
||||
title = (title or "").strip()
|
||||
if not title:
|
||||
return "Error: Title is required."
|
||||
|
||||
query = {"user_id": self.user_id, "tool_id": self.tool_id, "todo_id": parsed_todo_id}
|
||||
doc = self.collection.find_one_and_update(
|
||||
query,
|
||||
{"$set": {"title": title, "updated_at": datetime.now()}},
|
||||
)
|
||||
if not doc:
|
||||
return f"Error: Todo with ID {parsed_todo_id} not found."
|
||||
|
||||
if doc.get("_id") is not None:
|
||||
self._last_artifact_id = str(doc.get("_id"))
|
||||
|
||||
return f"Todo {parsed_todo_id} updated to: {title}"
|
||||
|
||||
def _complete(self, todo_id: Optional[Any]) -> str:
|
||||
"""Mark a todo as completed."""
|
||||
parsed_todo_id = self._coerce_todo_id(todo_id)
|
||||
if parsed_todo_id is None:
|
||||
return "Error: todo_id must be a positive integer."
|
||||
|
||||
query = {"user_id": self.user_id, "tool_id": self.tool_id, "todo_id": parsed_todo_id}
|
||||
doc = self.collection.find_one_and_update(
|
||||
query,
|
||||
{"$set": {"status": "completed", "updated_at": datetime.now()}},
|
||||
)
|
||||
if not doc:
|
||||
return f"Error: Todo with ID {parsed_todo_id} not found."
|
||||
|
||||
if doc.get("_id") is not None:
|
||||
self._last_artifact_id = str(doc.get("_id"))
|
||||
|
||||
return f"Todo {parsed_todo_id} marked as completed."
|
||||
|
||||
def _delete(self, todo_id: Optional[Any]) -> str:
|
||||
"""Delete a specific todo by ID."""
|
||||
parsed_todo_id = self._coerce_todo_id(todo_id)
|
||||
if parsed_todo_id is None:
|
||||
return "Error: todo_id must be a positive integer."
|
||||
|
||||
query = {"user_id": self.user_id, "tool_id": self.tool_id, "todo_id": parsed_todo_id}
|
||||
doc = self.collection.find_one_and_delete(query)
|
||||
if not doc:
|
||||
return f"Error: Todo with ID {parsed_todo_id} not found."
|
||||
|
||||
if doc.get("_id") is not None:
|
||||
self._last_artifact_id = str(doc.get("_id"))
|
||||
|
||||
return f"Todo {parsed_todo_id} deleted."
|
||||
@@ -20,24 +20,20 @@ class ToolActionParser:
|
||||
try:
|
||||
call_args = json.loads(call.arguments)
|
||||
tool_parts = call.name.split("_")
|
||||
|
||||
|
||||
# If the tool name doesn't contain an underscore, it's likely a hallucinated tool
|
||||
if len(tool_parts) < 2:
|
||||
logger.warning(
|
||||
f"Invalid tool name format: {call.name}. Expected format: action_name_tool_id"
|
||||
)
|
||||
logger.warning(f"Invalid tool name format: {call.name}. Expected format: action_name_tool_id")
|
||||
return None, None, None
|
||||
|
||||
|
||||
tool_id = tool_parts[-1]
|
||||
action_name = "_".join(tool_parts[:-1])
|
||||
|
||||
|
||||
# Validate that tool_id looks like a numerical ID
|
||||
if not tool_id.isdigit():
|
||||
logger.warning(
|
||||
f"Tool ID '{tool_id}' is not numerical. This might be a hallucinated tool call."
|
||||
)
|
||||
|
||||
except (AttributeError, TypeError, json.JSONDecodeError) as e:
|
||||
logger.warning(f"Tool ID '{tool_id}' is not numerical. This might be a hallucinated tool call.")
|
||||
|
||||
except (AttributeError, TypeError) as e:
|
||||
logger.error(f"Error parsing OpenAI LLM call: {e}")
|
||||
return None, None, None
|
||||
return tool_id, action_name, call_args
|
||||
@@ -46,23 +42,19 @@ class ToolActionParser:
|
||||
try:
|
||||
call_args = call.arguments
|
||||
tool_parts = call.name.split("_")
|
||||
|
||||
|
||||
# If the tool name doesn't contain an underscore, it's likely a hallucinated tool
|
||||
if len(tool_parts) < 2:
|
||||
logger.warning(
|
||||
f"Invalid tool name format: {call.name}. Expected format: action_name_tool_id"
|
||||
)
|
||||
logger.warning(f"Invalid tool name format: {call.name}. Expected format: action_name_tool_id")
|
||||
return None, None, None
|
||||
|
||||
|
||||
tool_id = tool_parts[-1]
|
||||
action_name = "_".join(tool_parts[:-1])
|
||||
|
||||
|
||||
# Validate that tool_id looks like a numerical ID
|
||||
if not tool_id.isdigit():
|
||||
logger.warning(
|
||||
f"Tool ID '{tool_id}' is not numerical. This might be a hallucinated tool call."
|
||||
)
|
||||
|
||||
logger.warning(f"Tool ID '{tool_id}' is not numerical. This might be a hallucinated tool call.")
|
||||
|
||||
except (AttributeError, TypeError) as e:
|
||||
logger.error(f"Error parsing Google LLM call: {e}")
|
||||
return None, None, None
|
||||
|
||||
@@ -23,23 +23,16 @@ class ToolManager:
|
||||
tool_config = self.config.get(name, {})
|
||||
self.tools[name] = obj(tool_config)
|
||||
|
||||
def load_tool(self, tool_name, tool_config, user_id=None):
|
||||
def load_tool(self, tool_name, tool_config):
|
||||
self.config[tool_name] = tool_config
|
||||
module = importlib.import_module(f"application.agents.tools.{tool_name}")
|
||||
for member_name, obj in inspect.getmembers(module, inspect.isclass):
|
||||
if issubclass(obj, Tool) and obj is not Tool:
|
||||
if tool_name in {"mcp_tool", "notes", "memory", "todo_list"} and user_id:
|
||||
return obj(tool_config, user_id)
|
||||
else:
|
||||
return obj(tool_config)
|
||||
return obj(tool_config)
|
||||
|
||||
def execute_action(self, tool_name, action_name, user_id=None, **kwargs):
|
||||
def execute_action(self, tool_name, action_name, **kwargs):
|
||||
if tool_name not in self.tools:
|
||||
raise ValueError(f"Tool '{tool_name}' not loaded")
|
||||
if tool_name in {"mcp_tool", "memory", "todo_list", "notes"} and user_id:
|
||||
tool_config = self.config.get(tool_name, {})
|
||||
tool = self.load_tool(tool_name, tool_config, user_id)
|
||||
return tool.execute_action(action_name, **kwargs)
|
||||
return self.tools[tool_name].execute_action(action_name, **kwargs)
|
||||
|
||||
def get_all_actions_metadata(self):
|
||||
|
||||
@@ -1,231 +0,0 @@
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Dict, Generator, Optional
|
||||
|
||||
from application.agents.base import BaseAgent
|
||||
from application.agents.workflows.schemas import (
|
||||
ExecutionStatus,
|
||||
Workflow,
|
||||
WorkflowEdge,
|
||||
WorkflowGraph,
|
||||
WorkflowNode,
|
||||
WorkflowRun,
|
||||
)
|
||||
from application.agents.workflows.workflow_engine import WorkflowEngine
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.core.settings import settings
|
||||
from application.logging import log_activity, LogContext
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WorkflowAgent(BaseAgent):
|
||||
"""A specialized agent that executes predefined workflows."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
workflow_id: Optional[str] = None,
|
||||
workflow: Optional[Dict[str, Any]] = None,
|
||||
workflow_owner: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.workflow_id = workflow_id
|
||||
self.workflow_owner = workflow_owner
|
||||
self._workflow_data = workflow
|
||||
self._engine: Optional[WorkflowEngine] = None
|
||||
|
||||
@log_activity()
|
||||
def gen(
|
||||
self, query: str, log_context: LogContext = None
|
||||
) -> Generator[Dict[str, str], None, None]:
|
||||
yield from self._gen_inner(query, log_context)
|
||||
|
||||
def _gen_inner(
|
||||
self, query: str, log_context: LogContext
|
||||
) -> Generator[Dict[str, str], None, None]:
|
||||
graph = self._load_workflow_graph()
|
||||
if not graph:
|
||||
yield {"type": "error", "error": "Failed to load workflow configuration."}
|
||||
return
|
||||
self._engine = WorkflowEngine(graph, self)
|
||||
yield from self._engine.execute({}, query)
|
||||
self._save_workflow_run(query)
|
||||
|
||||
def _load_workflow_graph(self) -> Optional[WorkflowGraph]:
|
||||
if self._workflow_data:
|
||||
return self._parse_embedded_workflow()
|
||||
if self.workflow_id:
|
||||
return self._load_from_database()
|
||||
return None
|
||||
|
||||
def _parse_embedded_workflow(self) -> Optional[WorkflowGraph]:
|
||||
try:
|
||||
nodes_data = self._workflow_data.get("nodes", [])
|
||||
edges_data = self._workflow_data.get("edges", [])
|
||||
|
||||
workflow = Workflow(
|
||||
name=self._workflow_data.get("name", "Embedded Workflow"),
|
||||
description=self._workflow_data.get("description"),
|
||||
)
|
||||
|
||||
nodes = []
|
||||
for n in nodes_data:
|
||||
node_config = n.get("data", {})
|
||||
nodes.append(
|
||||
WorkflowNode(
|
||||
id=n["id"],
|
||||
workflow_id=self.workflow_id or "embedded",
|
||||
type=n["type"],
|
||||
title=n.get("title", "Node"),
|
||||
description=n.get("description"),
|
||||
position=n.get("position", {"x": 0, "y": 0}),
|
||||
config=node_config,
|
||||
)
|
||||
)
|
||||
edges = []
|
||||
for e in edges_data:
|
||||
edges.append(
|
||||
WorkflowEdge(
|
||||
id=e["id"],
|
||||
workflow_id=self.workflow_id or "embedded",
|
||||
source=e.get("source") or e.get("source_id"),
|
||||
target=e.get("target") or e.get("target_id"),
|
||||
sourceHandle=e.get("sourceHandle") or e.get("source_handle"),
|
||||
targetHandle=e.get("targetHandle") or e.get("target_handle"),
|
||||
)
|
||||
)
|
||||
return WorkflowGraph(workflow=workflow, nodes=nodes, edges=edges)
|
||||
except Exception as e:
|
||||
logger.error(f"Invalid embedded workflow: {e}")
|
||||
return None
|
||||
|
||||
def _load_from_database(self) -> Optional[WorkflowGraph]:
|
||||
try:
|
||||
from bson.objectid import ObjectId
|
||||
|
||||
if not self.workflow_id or not ObjectId.is_valid(self.workflow_id):
|
||||
logger.error(f"Invalid workflow ID: {self.workflow_id}")
|
||||
return None
|
||||
owner_id = self.workflow_owner
|
||||
if not owner_id and isinstance(self.decoded_token, dict):
|
||||
owner_id = self.decoded_token.get("sub")
|
||||
if not owner_id:
|
||||
logger.error(
|
||||
f"Workflow owner not available for workflow load: {self.workflow_id}"
|
||||
)
|
||||
return None
|
||||
|
||||
mongo = MongoDB.get_client()
|
||||
db = mongo[settings.MONGO_DB_NAME]
|
||||
|
||||
workflows_coll = db["workflows"]
|
||||
workflow_nodes_coll = db["workflow_nodes"]
|
||||
workflow_edges_coll = db["workflow_edges"]
|
||||
|
||||
workflow_doc = workflows_coll.find_one(
|
||||
{"_id": ObjectId(self.workflow_id), "user": owner_id}
|
||||
)
|
||||
if not workflow_doc:
|
||||
logger.error(
|
||||
f"Workflow {self.workflow_id} not found or inaccessible for user {owner_id}"
|
||||
)
|
||||
return None
|
||||
workflow = Workflow(**workflow_doc)
|
||||
graph_version = workflow_doc.get("current_graph_version", 1)
|
||||
try:
|
||||
graph_version = int(graph_version)
|
||||
if graph_version <= 0:
|
||||
graph_version = 1
|
||||
except (ValueError, TypeError):
|
||||
graph_version = 1
|
||||
|
||||
nodes_docs = list(
|
||||
workflow_nodes_coll.find(
|
||||
{"workflow_id": self.workflow_id, "graph_version": graph_version}
|
||||
)
|
||||
)
|
||||
if not nodes_docs and graph_version == 1:
|
||||
nodes_docs = list(
|
||||
workflow_nodes_coll.find(
|
||||
{
|
||||
"workflow_id": self.workflow_id,
|
||||
"graph_version": {"$exists": False},
|
||||
}
|
||||
)
|
||||
)
|
||||
nodes = [WorkflowNode(**doc) for doc in nodes_docs]
|
||||
|
||||
edges_docs = list(
|
||||
workflow_edges_coll.find(
|
||||
{"workflow_id": self.workflow_id, "graph_version": graph_version}
|
||||
)
|
||||
)
|
||||
if not edges_docs and graph_version == 1:
|
||||
edges_docs = list(
|
||||
workflow_edges_coll.find(
|
||||
{
|
||||
"workflow_id": self.workflow_id,
|
||||
"graph_version": {"$exists": False},
|
||||
}
|
||||
)
|
||||
)
|
||||
edges = [WorkflowEdge(**doc) for doc in edges_docs]
|
||||
|
||||
return WorkflowGraph(workflow=workflow, nodes=nodes, edges=edges)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load workflow from database: {e}")
|
||||
return None
|
||||
|
||||
def _save_workflow_run(self, query: str) -> None:
|
||||
if not self._engine:
|
||||
return
|
||||
try:
|
||||
mongo = MongoDB.get_client()
|
||||
db = mongo[settings.MONGO_DB_NAME]
|
||||
workflow_runs_coll = db["workflow_runs"]
|
||||
|
||||
run = WorkflowRun(
|
||||
workflow_id=self.workflow_id or "unknown",
|
||||
status=self._determine_run_status(),
|
||||
inputs={"query": query},
|
||||
outputs=self._serialize_state(self._engine.state),
|
||||
steps=self._engine.get_execution_summary(),
|
||||
created_at=datetime.now(timezone.utc),
|
||||
completed_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
workflow_runs_coll.insert_one(run.to_mongo_doc())
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save workflow run: {e}")
|
||||
|
||||
def _determine_run_status(self) -> ExecutionStatus:
|
||||
if not self._engine or not self._engine.execution_log:
|
||||
return ExecutionStatus.COMPLETED
|
||||
for log in self._engine.execution_log:
|
||||
if log.get("status") == ExecutionStatus.FAILED.value:
|
||||
return ExecutionStatus.FAILED
|
||||
return ExecutionStatus.COMPLETED
|
||||
|
||||
def _serialize_state(self, state: Dict[str, Any]) -> Dict[str, Any]:
|
||||
serialized: Dict[str, Any] = {}
|
||||
for key, value in state.items():
|
||||
serialized[key] = self._serialize_state_value(value)
|
||||
return serialized
|
||||
|
||||
def _serialize_state_value(self, value: Any) -> Any:
|
||||
if isinstance(value, dict):
|
||||
return {
|
||||
str(dict_key): self._serialize_state_value(dict_value)
|
||||
for dict_key, dict_value in value.items()
|
||||
}
|
||||
if isinstance(value, list):
|
||||
return [self._serialize_state_value(item) for item in value]
|
||||
if isinstance(value, tuple):
|
||||
return [self._serialize_state_value(item) for item in value]
|
||||
if isinstance(value, datetime):
|
||||
return value.isoformat()
|
||||
if isinstance(value, (str, int, float, bool, type(None))):
|
||||
return value
|
||||
return str(value)
|
||||
@@ -1,64 +0,0 @@
|
||||
from typing import Any, Dict
|
||||
|
||||
import celpy
|
||||
import celpy.celtypes
|
||||
|
||||
|
||||
class CelEvaluationError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def _convert_value(value: Any) -> Any:
|
||||
if isinstance(value, bool):
|
||||
return celpy.celtypes.BoolType(value)
|
||||
if isinstance(value, int):
|
||||
return celpy.celtypes.IntType(value)
|
||||
if isinstance(value, float):
|
||||
return celpy.celtypes.DoubleType(value)
|
||||
if isinstance(value, str):
|
||||
return celpy.celtypes.StringType(value)
|
||||
if isinstance(value, list):
|
||||
return celpy.celtypes.ListType([_convert_value(item) for item in value])
|
||||
if isinstance(value, dict):
|
||||
return celpy.celtypes.MapType(
|
||||
{celpy.celtypes.StringType(k): _convert_value(v) for k, v in value.items()}
|
||||
)
|
||||
if value is None:
|
||||
return celpy.celtypes.BoolType(False)
|
||||
return celpy.celtypes.StringType(str(value))
|
||||
|
||||
|
||||
def build_activation(state: Dict[str, Any]) -> Dict[str, Any]:
|
||||
return {k: _convert_value(v) for k, v in state.items()}
|
||||
|
||||
|
||||
def evaluate_cel(expression: str, state: Dict[str, Any]) -> Any:
|
||||
if not expression or not expression.strip():
|
||||
raise CelEvaluationError("Empty expression")
|
||||
try:
|
||||
env = celpy.Environment()
|
||||
ast = env.compile(expression)
|
||||
program = env.program(ast)
|
||||
activation = build_activation(state)
|
||||
result = program.evaluate(activation)
|
||||
except celpy.CELEvalError as exc:
|
||||
raise CelEvaluationError(f"CEL evaluation error: {exc}") from exc
|
||||
except Exception as exc:
|
||||
raise CelEvaluationError(f"CEL error: {exc}") from exc
|
||||
return cel_to_python(result)
|
||||
|
||||
|
||||
def cel_to_python(value: Any) -> Any:
|
||||
if isinstance(value, celpy.celtypes.BoolType):
|
||||
return bool(value)
|
||||
if isinstance(value, celpy.celtypes.IntType):
|
||||
return int(value)
|
||||
if isinstance(value, celpy.celtypes.DoubleType):
|
||||
return float(value)
|
||||
if isinstance(value, celpy.celtypes.StringType):
|
||||
return str(value)
|
||||
if isinstance(value, celpy.celtypes.ListType):
|
||||
return [cel_to_python(item) for item in value]
|
||||
if isinstance(value, celpy.celtypes.MapType):
|
||||
return {str(k): cel_to_python(v) for k, v in value.items()}
|
||||
return value
|
||||
@@ -1,104 +0,0 @@
|
||||
"""Workflow Node Agents - defines specialized agents for workflow nodes."""
|
||||
|
||||
from typing import Any, Dict, List, Optional, Type
|
||||
|
||||
from application.agents.agentic_agent import AgenticAgent
|
||||
from application.agents.base import BaseAgent
|
||||
from application.agents.classic_agent import ClassicAgent
|
||||
from application.agents.research_agent import ResearchAgent
|
||||
from application.agents.workflows.schemas import AgentType
|
||||
|
||||
|
||||
class ToolFilterMixin:
|
||||
"""Mixin that filters fetched tools to only those specified in tool_ids."""
|
||||
|
||||
_allowed_tool_ids: List[str]
|
||||
|
||||
def _get_user_tools(self, user: str = "local") -> Dict[str, Dict[str, Any]]:
|
||||
all_tools = super()._get_user_tools(user)
|
||||
if not self._allowed_tool_ids:
|
||||
return {}
|
||||
filtered_tools = {
|
||||
tool_id: tool
|
||||
for tool_id, tool in all_tools.items()
|
||||
if str(tool.get("_id", "")) in self._allowed_tool_ids
|
||||
}
|
||||
return filtered_tools
|
||||
|
||||
def _get_tools(self, api_key: str = None) -> Dict[str, Dict[str, Any]]:
|
||||
all_tools = super()._get_tools(api_key)
|
||||
if not self._allowed_tool_ids:
|
||||
return {}
|
||||
filtered_tools = {
|
||||
tool_id: tool
|
||||
for tool_id, tool in all_tools.items()
|
||||
if str(tool.get("_id", "")) in self._allowed_tool_ids
|
||||
}
|
||||
return filtered_tools
|
||||
|
||||
|
||||
class _WorkflowNodeMixin:
|
||||
"""Common __init__ for all workflow node agents."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
endpoint: str,
|
||||
llm_name: str,
|
||||
model_id: str,
|
||||
api_key: str,
|
||||
tool_ids: Optional[List[str]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
endpoint=endpoint,
|
||||
llm_name=llm_name,
|
||||
model_id=model_id,
|
||||
api_key=api_key,
|
||||
**kwargs,
|
||||
)
|
||||
self._allowed_tool_ids = tool_ids or []
|
||||
|
||||
|
||||
class WorkflowNodeClassicAgent(ToolFilterMixin, _WorkflowNodeMixin, ClassicAgent):
|
||||
pass
|
||||
|
||||
|
||||
class WorkflowNodeAgenticAgent(ToolFilterMixin, _WorkflowNodeMixin, AgenticAgent):
|
||||
pass
|
||||
|
||||
|
||||
class WorkflowNodeResearchAgent(ToolFilterMixin, _WorkflowNodeMixin, ResearchAgent):
|
||||
pass
|
||||
|
||||
|
||||
class WorkflowNodeAgentFactory:
|
||||
|
||||
_agents: Dict[AgentType, Type[BaseAgent]] = {
|
||||
AgentType.CLASSIC: WorkflowNodeClassicAgent,
|
||||
AgentType.REACT: WorkflowNodeClassicAgent, # backwards compat
|
||||
AgentType.AGENTIC: WorkflowNodeAgenticAgent,
|
||||
AgentType.RESEARCH: WorkflowNodeResearchAgent,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
agent_type: AgentType,
|
||||
endpoint: str,
|
||||
llm_name: str,
|
||||
model_id: str,
|
||||
api_key: str,
|
||||
tool_ids: Optional[List[str]] = None,
|
||||
**kwargs,
|
||||
) -> BaseAgent:
|
||||
agent_class = cls._agents.get(agent_type)
|
||||
if not agent_class:
|
||||
raise ValueError(f"Unsupported agent type: {agent_type}")
|
||||
return agent_class(
|
||||
endpoint=endpoint,
|
||||
llm_name=llm_name,
|
||||
model_id=model_id,
|
||||
api_key=api_key,
|
||||
tool_ids=tool_ids,
|
||||
**kwargs,
|
||||
)
|
||||
@@ -1,237 +0,0 @@
|
||||
from datetime import datetime, timezone
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Literal, Optional, Union
|
||||
|
||||
from bson import ObjectId
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
|
||||
class NodeType(str, Enum):
|
||||
START = "start"
|
||||
END = "end"
|
||||
AGENT = "agent"
|
||||
NOTE = "note"
|
||||
STATE = "state"
|
||||
CONDITION = "condition"
|
||||
|
||||
|
||||
class AgentType(str, Enum):
|
||||
CLASSIC = "classic"
|
||||
REACT = "react"
|
||||
AGENTIC = "agentic"
|
||||
RESEARCH = "research"
|
||||
|
||||
|
||||
class ExecutionStatus(str, Enum):
|
||||
PENDING = "pending"
|
||||
RUNNING = "running"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
|
||||
|
||||
class Position(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
x: float = 0.0
|
||||
y: float = 0.0
|
||||
|
||||
|
||||
class AgentNodeConfig(BaseModel):
|
||||
model_config = ConfigDict(extra="allow")
|
||||
agent_type: AgentType = AgentType.CLASSIC
|
||||
llm_name: Optional[str] = None
|
||||
system_prompt: str = "You are a helpful assistant."
|
||||
prompt_template: str = ""
|
||||
output_variable: Optional[str] = None
|
||||
stream_to_user: bool = True
|
||||
tools: List[str] = Field(default_factory=list)
|
||||
sources: List[str] = Field(default_factory=list)
|
||||
chunks: str = "2"
|
||||
retriever: str = ""
|
||||
model_id: Optional[str] = None
|
||||
json_schema: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class ConditionCase(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid", populate_by_name=True)
|
||||
name: Optional[str] = None
|
||||
expression: str = ""
|
||||
source_handle: str = Field(..., alias="sourceHandle")
|
||||
|
||||
|
||||
class ConditionNodeConfig(BaseModel):
|
||||
model_config = ConfigDict(extra="allow")
|
||||
mode: Literal["simple", "advanced"] = "simple"
|
||||
cases: List[ConditionCase] = Field(default_factory=list)
|
||||
|
||||
|
||||
class StateOperation(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
expression: str = ""
|
||||
target_variable: str = ""
|
||||
|
||||
|
||||
class WorkflowEdgeCreate(BaseModel):
|
||||
model_config = ConfigDict(populate_by_name=True)
|
||||
id: str
|
||||
workflow_id: str
|
||||
source_id: str = Field(..., alias="source")
|
||||
target_id: str = Field(..., alias="target")
|
||||
source_handle: Optional[str] = Field(None, alias="sourceHandle")
|
||||
target_handle: Optional[str] = Field(None, alias="targetHandle")
|
||||
|
||||
|
||||
class WorkflowEdge(WorkflowEdgeCreate):
|
||||
mongo_id: Optional[str] = Field(None, alias="_id")
|
||||
|
||||
@field_validator("mongo_id", mode="before")
|
||||
@classmethod
|
||||
def convert_objectid(cls, v: Any) -> Optional[str]:
|
||||
if isinstance(v, ObjectId):
|
||||
return str(v)
|
||||
return v
|
||||
|
||||
def to_mongo_doc(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"id": self.id,
|
||||
"workflow_id": self.workflow_id,
|
||||
"source_id": self.source_id,
|
||||
"target_id": self.target_id,
|
||||
"source_handle": self.source_handle,
|
||||
"target_handle": self.target_handle,
|
||||
}
|
||||
|
||||
|
||||
class WorkflowNodeCreate(BaseModel):
|
||||
model_config = ConfigDict(extra="allow")
|
||||
id: str
|
||||
workflow_id: str
|
||||
type: NodeType
|
||||
title: str = "Node"
|
||||
description: Optional[str] = None
|
||||
position: Position = Field(default_factory=Position)
|
||||
config: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
@field_validator("position", mode="before")
|
||||
@classmethod
|
||||
def parse_position(cls, v: Union[Dict[str, float], Position]) -> Position:
|
||||
if isinstance(v, dict):
|
||||
return Position(**v)
|
||||
return v
|
||||
|
||||
|
||||
class WorkflowNode(WorkflowNodeCreate):
|
||||
mongo_id: Optional[str] = Field(None, alias="_id")
|
||||
|
||||
@field_validator("mongo_id", mode="before")
|
||||
@classmethod
|
||||
def convert_objectid(cls, v: Any) -> Optional[str]:
|
||||
if isinstance(v, ObjectId):
|
||||
return str(v)
|
||||
return v
|
||||
|
||||
def to_mongo_doc(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"id": self.id,
|
||||
"workflow_id": self.workflow_id,
|
||||
"type": self.type.value,
|
||||
"title": self.title,
|
||||
"description": self.description,
|
||||
"position": self.position.model_dump(),
|
||||
"config": self.config,
|
||||
}
|
||||
|
||||
|
||||
class WorkflowCreate(BaseModel):
|
||||
model_config = ConfigDict(extra="allow")
|
||||
name: str = "New Workflow"
|
||||
description: Optional[str] = None
|
||||
user: Optional[str] = None
|
||||
|
||||
|
||||
class Workflow(WorkflowCreate):
|
||||
id: Optional[str] = Field(None, alias="_id")
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
updated_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
|
||||
@field_validator("id", mode="before")
|
||||
@classmethod
|
||||
def convert_objectid(cls, v: Any) -> Optional[str]:
|
||||
if isinstance(v, ObjectId):
|
||||
return str(v)
|
||||
return v
|
||||
|
||||
def to_mongo_doc(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"name": self.name,
|
||||
"description": self.description,
|
||||
"user": self.user,
|
||||
"created_at": self.created_at,
|
||||
"updated_at": self.updated_at,
|
||||
}
|
||||
|
||||
|
||||
class WorkflowGraph(BaseModel):
|
||||
workflow: Workflow
|
||||
nodes: List[WorkflowNode] = Field(default_factory=list)
|
||||
edges: List[WorkflowEdge] = Field(default_factory=list)
|
||||
|
||||
def get_node_by_id(self, node_id: str) -> Optional[WorkflowNode]:
|
||||
for node in self.nodes:
|
||||
if node.id == node_id:
|
||||
return node
|
||||
return None
|
||||
|
||||
def get_start_node(self) -> Optional[WorkflowNode]:
|
||||
for node in self.nodes:
|
||||
if node.type == NodeType.START:
|
||||
return node
|
||||
return None
|
||||
|
||||
def get_outgoing_edges(self, node_id: str) -> List[WorkflowEdge]:
|
||||
return [edge for edge in self.edges if edge.source_id == node_id]
|
||||
|
||||
|
||||
class NodeExecutionLog(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
node_id: str
|
||||
node_type: str
|
||||
status: ExecutionStatus
|
||||
started_at: datetime
|
||||
completed_at: Optional[datetime] = None
|
||||
error: Optional[str] = None
|
||||
state_snapshot: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class WorkflowRunCreate(BaseModel):
|
||||
workflow_id: str
|
||||
inputs: Dict[str, str] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class WorkflowRun(BaseModel):
|
||||
model_config = ConfigDict(extra="allow")
|
||||
id: Optional[str] = Field(None, alias="_id")
|
||||
workflow_id: str
|
||||
status: ExecutionStatus = ExecutionStatus.PENDING
|
||||
inputs: Dict[str, str] = Field(default_factory=dict)
|
||||
outputs: Dict[str, Any] = Field(default_factory=dict)
|
||||
steps: List[NodeExecutionLog] = Field(default_factory=list)
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
completed_at: Optional[datetime] = None
|
||||
|
||||
@field_validator("id", mode="before")
|
||||
@classmethod
|
||||
def convert_objectid(cls, v: Any) -> Optional[str]:
|
||||
if isinstance(v, ObjectId):
|
||||
return str(v)
|
||||
return v
|
||||
|
||||
def to_mongo_doc(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"workflow_id": self.workflow_id,
|
||||
"status": self.status.value,
|
||||
"inputs": self.inputs,
|
||||
"outputs": self.outputs,
|
||||
"steps": [step.model_dump() for step in self.steps],
|
||||
"created_at": self.created_at,
|
||||
"completed_at": self.completed_at,
|
||||
}
|
||||
@@ -1,470 +0,0 @@
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Dict, Generator, List, Optional, TYPE_CHECKING
|
||||
|
||||
from application.agents.workflows.cel_evaluator import CelEvaluationError, evaluate_cel
|
||||
from application.agents.workflows.node_agent import WorkflowNodeAgentFactory
|
||||
from application.agents.workflows.schemas import (
|
||||
AgentNodeConfig,
|
||||
AgentType,
|
||||
ConditionNodeConfig,
|
||||
ExecutionStatus,
|
||||
NodeExecutionLog,
|
||||
NodeType,
|
||||
WorkflowGraph,
|
||||
WorkflowNode,
|
||||
)
|
||||
from application.core.json_schema_utils import (
|
||||
JsonSchemaValidationError,
|
||||
normalize_json_schema_payload,
|
||||
)
|
||||
from application.error import sanitize_api_error
|
||||
from application.templates.namespaces import NamespaceManager
|
||||
from application.templates.template_engine import TemplateEngine, TemplateRenderError
|
||||
|
||||
try:
|
||||
import jsonschema
|
||||
except ImportError: # pragma: no cover - optional dependency in some deployments.
|
||||
jsonschema = None
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from application.agents.base import BaseAgent
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
StateValue = Any
|
||||
WorkflowState = Dict[str, StateValue]
|
||||
TEMPLATE_RESERVED_NAMESPACES = {"agent", "system", "source", "tools", "passthrough"}
|
||||
|
||||
|
||||
class WorkflowEngine:
|
||||
MAX_EXECUTION_STEPS = 50
|
||||
|
||||
def __init__(self, graph: WorkflowGraph, agent: "BaseAgent"):
|
||||
self.graph = graph
|
||||
self.agent = agent
|
||||
self.state: WorkflowState = {}
|
||||
self.execution_log: List[Dict[str, Any]] = []
|
||||
self._condition_result: Optional[str] = None
|
||||
self._template_engine = TemplateEngine()
|
||||
self._namespace_manager = NamespaceManager()
|
||||
|
||||
def execute(
|
||||
self, initial_inputs: WorkflowState, query: str
|
||||
) -> Generator[Dict[str, str], None, None]:
|
||||
self._initialize_state(initial_inputs, query)
|
||||
|
||||
start_node = self.graph.get_start_node()
|
||||
if not start_node:
|
||||
yield {"type": "error", "error": "No start node found in workflow."}
|
||||
return
|
||||
current_node_id: Optional[str] = start_node.id
|
||||
steps = 0
|
||||
|
||||
while current_node_id and steps < self.MAX_EXECUTION_STEPS:
|
||||
node = self.graph.get_node_by_id(current_node_id)
|
||||
if not node:
|
||||
yield {"type": "error", "error": f"Node {current_node_id} not found."}
|
||||
break
|
||||
log_entry = self._create_log_entry(node)
|
||||
|
||||
yield {
|
||||
"type": "workflow_step",
|
||||
"node_id": node.id,
|
||||
"node_type": node.type.value,
|
||||
"node_title": node.title,
|
||||
"status": "running",
|
||||
}
|
||||
|
||||
try:
|
||||
yield from self._execute_node(node)
|
||||
log_entry["status"] = ExecutionStatus.COMPLETED.value
|
||||
log_entry["completed_at"] = datetime.now(timezone.utc)
|
||||
|
||||
output_key = f"node_{node.id}_output"
|
||||
node_output = self.state.get(output_key)
|
||||
|
||||
yield {
|
||||
"type": "workflow_step",
|
||||
"node_id": node.id,
|
||||
"node_type": node.type.value,
|
||||
"node_title": node.title,
|
||||
"status": "completed",
|
||||
"state_snapshot": dict(self.state),
|
||||
"output": node_output,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing node {node.id}: {e}", exc_info=True)
|
||||
log_entry["status"] = ExecutionStatus.FAILED.value
|
||||
log_entry["error"] = str(e)
|
||||
log_entry["completed_at"] = datetime.now(timezone.utc)
|
||||
log_entry["state_snapshot"] = dict(self.state)
|
||||
self.execution_log.append(log_entry)
|
||||
|
||||
user_friendly_error = sanitize_api_error(e)
|
||||
yield {
|
||||
"type": "workflow_step",
|
||||
"node_id": node.id,
|
||||
"node_type": node.type.value,
|
||||
"node_title": node.title,
|
||||
"status": "failed",
|
||||
"state_snapshot": dict(self.state),
|
||||
"error": user_friendly_error,
|
||||
}
|
||||
yield {"type": "error", "error": user_friendly_error}
|
||||
break
|
||||
log_entry["state_snapshot"] = dict(self.state)
|
||||
self.execution_log.append(log_entry)
|
||||
|
||||
if node.type == NodeType.END:
|
||||
break
|
||||
current_node_id = self._get_next_node_id(current_node_id)
|
||||
if current_node_id is None and node.type != NodeType.END:
|
||||
logger.warning(
|
||||
f"Branch ended at node '{node.title}' ({node.id}) without reaching an end node"
|
||||
)
|
||||
steps += 1
|
||||
if steps >= self.MAX_EXECUTION_STEPS:
|
||||
logger.warning(
|
||||
f"Workflow reached max steps limit ({self.MAX_EXECUTION_STEPS})"
|
||||
)
|
||||
|
||||
def _initialize_state(self, initial_inputs: WorkflowState, query: str) -> None:
|
||||
self.state.update(initial_inputs)
|
||||
self.state["query"] = query
|
||||
self.state["chat_history"] = str(self.agent.chat_history)
|
||||
|
||||
def _create_log_entry(self, node: WorkflowNode) -> Dict[str, Any]:
|
||||
return {
|
||||
"node_id": node.id,
|
||||
"node_type": node.type.value,
|
||||
"started_at": datetime.now(timezone.utc),
|
||||
"completed_at": None,
|
||||
"status": ExecutionStatus.RUNNING.value,
|
||||
"error": None,
|
||||
"state_snapshot": {},
|
||||
}
|
||||
|
||||
def _get_next_node_id(self, current_node_id: str) -> Optional[str]:
|
||||
node = self.graph.get_node_by_id(current_node_id)
|
||||
edges = self.graph.get_outgoing_edges(current_node_id)
|
||||
if not edges:
|
||||
return None
|
||||
|
||||
if node and node.type == NodeType.CONDITION and self._condition_result:
|
||||
target_handle = self._condition_result
|
||||
self._condition_result = None
|
||||
for edge in edges:
|
||||
if edge.source_handle == target_handle:
|
||||
return edge.target_id
|
||||
return None
|
||||
|
||||
return edges[0].target_id
|
||||
|
||||
def _execute_node(
|
||||
self, node: WorkflowNode
|
||||
) -> Generator[Dict[str, str], None, None]:
|
||||
logger.info(f"Executing node {node.id} ({node.type.value})")
|
||||
|
||||
node_handlers = {
|
||||
NodeType.START: self._execute_start_node,
|
||||
NodeType.NOTE: self._execute_note_node,
|
||||
NodeType.AGENT: self._execute_agent_node,
|
||||
NodeType.STATE: self._execute_state_node,
|
||||
NodeType.CONDITION: self._execute_condition_node,
|
||||
NodeType.END: self._execute_end_node,
|
||||
}
|
||||
|
||||
handler = node_handlers.get(node.type)
|
||||
if handler:
|
||||
yield from handler(node)
|
||||
|
||||
def _execute_start_node(
|
||||
self, node: WorkflowNode
|
||||
) -> Generator[Dict[str, str], None, None]:
|
||||
yield from ()
|
||||
|
||||
def _execute_note_node(
|
||||
self, node: WorkflowNode
|
||||
) -> Generator[Dict[str, str], None, None]:
|
||||
yield from ()
|
||||
|
||||
def _execute_agent_node(
|
||||
self, node: WorkflowNode
|
||||
) -> Generator[Dict[str, str], None, None]:
|
||||
from application.core.model_utils import (
|
||||
get_api_key_for_provider,
|
||||
get_model_capabilities,
|
||||
get_provider_from_model_id,
|
||||
)
|
||||
|
||||
node_config = AgentNodeConfig(**node.config.get("config", node.config))
|
||||
|
||||
if node_config.prompt_template:
|
||||
formatted_prompt = self._format_template(node_config.prompt_template)
|
||||
else:
|
||||
formatted_prompt = self.state.get("query", "")
|
||||
node_json_schema = self._normalize_node_json_schema(
|
||||
node_config.json_schema, node.title
|
||||
)
|
||||
node_model_id = node_config.model_id or self.agent.model_id
|
||||
node_llm_name = (
|
||||
node_config.llm_name
|
||||
or get_provider_from_model_id(node_model_id or "")
|
||||
or self.agent.llm_name
|
||||
)
|
||||
node_api_key = get_api_key_for_provider(node_llm_name) or self.agent.api_key
|
||||
|
||||
if node_json_schema and node_model_id:
|
||||
model_capabilities = get_model_capabilities(node_model_id)
|
||||
if model_capabilities and not model_capabilities.get(
|
||||
"supports_structured_output", False
|
||||
):
|
||||
raise ValueError(
|
||||
f'Model "{node_model_id}" does not support structured output for node "{node.title}"'
|
||||
)
|
||||
|
||||
factory_kwargs = {
|
||||
"agent_type": node_config.agent_type,
|
||||
"endpoint": self.agent.endpoint,
|
||||
"llm_name": node_llm_name,
|
||||
"model_id": node_model_id,
|
||||
"api_key": node_api_key,
|
||||
"tool_ids": node_config.tools,
|
||||
"prompt": node_config.system_prompt,
|
||||
"chat_history": self.agent.chat_history,
|
||||
"decoded_token": self.agent.decoded_token,
|
||||
"json_schema": node_json_schema,
|
||||
}
|
||||
|
||||
# Agentic/research agents need retriever_config for on-demand search
|
||||
if node_config.agent_type in (AgentType.AGENTIC, AgentType.RESEARCH):
|
||||
factory_kwargs["retriever_config"] = {
|
||||
"source": {"active_docs": node_config.sources} if node_config.sources else {},
|
||||
"retriever_name": node_config.retriever or "classic",
|
||||
"chunks": int(node_config.chunks) if node_config.chunks else 2,
|
||||
"model_id": node_model_id,
|
||||
"llm_name": node_llm_name,
|
||||
"api_key": node_api_key,
|
||||
"decoded_token": self.agent.decoded_token,
|
||||
}
|
||||
|
||||
node_agent = WorkflowNodeAgentFactory.create(**factory_kwargs)
|
||||
|
||||
full_response_parts: List[str] = []
|
||||
structured_response_parts: List[str] = []
|
||||
has_structured_response = False
|
||||
first_chunk = True
|
||||
for event in node_agent.gen(formatted_prompt):
|
||||
if "answer" in event:
|
||||
chunk = str(event["answer"])
|
||||
full_response_parts.append(chunk)
|
||||
if event.get("structured"):
|
||||
has_structured_response = True
|
||||
structured_response_parts.append(chunk)
|
||||
if node_config.stream_to_user:
|
||||
if first_chunk and hasattr(self, "_has_streamed"):
|
||||
yield {"answer": "\n\n"}
|
||||
first_chunk = False
|
||||
yield event
|
||||
|
||||
if node_config.stream_to_user:
|
||||
self._has_streamed = True
|
||||
|
||||
full_response = "".join(full_response_parts).strip()
|
||||
output_value: Any = full_response
|
||||
if has_structured_response:
|
||||
structured_response = "".join(structured_response_parts).strip()
|
||||
response_to_parse = structured_response or full_response
|
||||
parsed_success, parsed_structured = self._parse_structured_output(
|
||||
response_to_parse
|
||||
)
|
||||
output_value = parsed_structured if parsed_success else response_to_parse
|
||||
if node_json_schema:
|
||||
self._validate_structured_output(node_json_schema, output_value)
|
||||
elif node_json_schema:
|
||||
parsed_success, parsed_structured = self._parse_structured_output(
|
||||
full_response
|
||||
)
|
||||
if not parsed_success:
|
||||
raise ValueError(
|
||||
"Structured output was expected but response was not valid JSON"
|
||||
)
|
||||
output_value = parsed_structured
|
||||
self._validate_structured_output(node_json_schema, output_value)
|
||||
|
||||
default_output_key = f"node_{node.id}_output"
|
||||
self.state[default_output_key] = output_value
|
||||
|
||||
if node_config.output_variable:
|
||||
self.state[node_config.output_variable] = output_value
|
||||
|
||||
def _execute_state_node(
|
||||
self, node: WorkflowNode
|
||||
) -> Generator[Dict[str, str], None, None]:
|
||||
config = node.config.get("config", node.config)
|
||||
for op in config.get("operations", []):
|
||||
expression = op.get("expression", "")
|
||||
target_variable = op.get("target_variable", "")
|
||||
if expression and target_variable:
|
||||
self.state[target_variable] = evaluate_cel(expression, self.state)
|
||||
yield from ()
|
||||
|
||||
def _execute_condition_node(
|
||||
self, node: WorkflowNode
|
||||
) -> Generator[Dict[str, str], None, None]:
|
||||
config = ConditionNodeConfig(**node.config.get("config", node.config))
|
||||
matched_handle = None
|
||||
|
||||
for case in config.cases:
|
||||
if not case.expression.strip():
|
||||
continue
|
||||
try:
|
||||
if evaluate_cel(case.expression, self.state):
|
||||
matched_handle = case.source_handle
|
||||
break
|
||||
except CelEvaluationError:
|
||||
continue
|
||||
|
||||
self._condition_result = matched_handle or "else"
|
||||
yield from ()
|
||||
|
||||
def _execute_end_node(
|
||||
self, node: WorkflowNode
|
||||
) -> Generator[Dict[str, str], None, None]:
|
||||
config = node.config.get("config", node.config)
|
||||
output_template = str(config.get("output_template", ""))
|
||||
if output_template:
|
||||
formatted_output = self._format_template(output_template)
|
||||
yield {"answer": formatted_output}
|
||||
|
||||
def _parse_structured_output(self, raw_response: str) -> tuple[bool, Optional[Any]]:
|
||||
normalized_response = raw_response.strip()
|
||||
if not normalized_response:
|
||||
return False, None
|
||||
|
||||
try:
|
||||
return True, json.loads(normalized_response)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(
|
||||
"Workflow agent returned structured output that was not valid JSON"
|
||||
)
|
||||
return False, None
|
||||
|
||||
def _normalize_node_json_schema(
|
||||
self, schema: Optional[Dict[str, Any]], node_title: str
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
if schema is None:
|
||||
return None
|
||||
try:
|
||||
return normalize_json_schema_payload(schema)
|
||||
except JsonSchemaValidationError as exc:
|
||||
raise ValueError(
|
||||
f'Invalid JSON schema for node "{node_title}": {exc}'
|
||||
) from exc
|
||||
|
||||
def _validate_structured_output(self, schema: Dict[str, Any], output_value: Any) -> None:
|
||||
if jsonschema is None:
|
||||
logger.warning(
|
||||
"jsonschema package is not available, skipping structured output validation"
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
normalized_schema = normalize_json_schema_payload(schema)
|
||||
except JsonSchemaValidationError as exc:
|
||||
raise ValueError(f"Invalid JSON schema: {exc}") from exc
|
||||
|
||||
try:
|
||||
jsonschema.validate(instance=output_value, schema=normalized_schema)
|
||||
except jsonschema.exceptions.ValidationError as exc:
|
||||
raise ValueError(f"Structured output did not match schema: {exc.message}") from exc
|
||||
except jsonschema.exceptions.SchemaError as exc:
|
||||
raise ValueError(f"Invalid JSON schema: {exc.message}") from exc
|
||||
|
||||
def _format_template(self, template: str) -> str:
|
||||
context = self._build_template_context()
|
||||
try:
|
||||
return self._template_engine.render(template, context)
|
||||
except TemplateRenderError as e:
|
||||
logger.warning(
|
||||
"Workflow template rendering failed, using raw template: %s", str(e)
|
||||
)
|
||||
return template
|
||||
|
||||
def _build_template_context(self) -> Dict[str, Any]:
|
||||
docs, docs_together = self._get_source_template_data()
|
||||
passthrough_data = (
|
||||
self.state.get("passthrough")
|
||||
if isinstance(self.state.get("passthrough"), dict)
|
||||
else None
|
||||
)
|
||||
tools_data = (
|
||||
self.state.get("tools") if isinstance(self.state.get("tools"), dict) else None
|
||||
)
|
||||
|
||||
context = self._namespace_manager.build_context(
|
||||
user_id=getattr(self.agent, "user", None),
|
||||
request_id=getattr(self.agent, "request_id", None),
|
||||
passthrough_data=passthrough_data,
|
||||
docs=docs,
|
||||
docs_together=docs_together,
|
||||
tools_data=tools_data,
|
||||
)
|
||||
|
||||
agent_context: Dict[str, Any] = {}
|
||||
for key, value in self.state.items():
|
||||
if not isinstance(key, str):
|
||||
continue
|
||||
normalized_key = key.strip()
|
||||
if not normalized_key:
|
||||
continue
|
||||
agent_context[normalized_key] = value
|
||||
|
||||
context["agent"] = agent_context
|
||||
|
||||
# Keep legacy top-level variables working while namespaced variables are adopted.
|
||||
for key, value in agent_context.items():
|
||||
if key in TEMPLATE_RESERVED_NAMESPACES:
|
||||
context[f"agent_{key}"] = value
|
||||
continue
|
||||
if key not in context:
|
||||
context[key] = value
|
||||
|
||||
return context
|
||||
|
||||
def _get_source_template_data(self) -> tuple[Optional[List[Dict[str, Any]]], Optional[str]]:
|
||||
docs = getattr(self.agent, "retrieved_docs", None)
|
||||
if not isinstance(docs, list) or len(docs) == 0:
|
||||
return None, None
|
||||
|
||||
docs_together_parts: List[str] = []
|
||||
for doc in docs:
|
||||
if not isinstance(doc, dict):
|
||||
continue
|
||||
text = doc.get("text")
|
||||
if not isinstance(text, str):
|
||||
continue
|
||||
|
||||
filename = doc.get("filename") or doc.get("title") or doc.get("source")
|
||||
if isinstance(filename, str) and filename.strip():
|
||||
docs_together_parts.append(f"{filename}\n{text}")
|
||||
else:
|
||||
docs_together_parts.append(text)
|
||||
|
||||
docs_together = "\n\n".join(docs_together_parts) if docs_together_parts else None
|
||||
return docs, docs_together
|
||||
|
||||
def get_execution_summary(self) -> List[NodeExecutionLog]:
|
||||
return [
|
||||
NodeExecutionLog(
|
||||
node_id=log["node_id"],
|
||||
node_type=log["node_type"],
|
||||
status=ExecutionStatus(log["status"]),
|
||||
started_at=log["started_at"],
|
||||
completed_at=log.get("completed_at"),
|
||||
error=log.get("error"),
|
||||
state_snapshot=log.get("state_snapshot", {}),
|
||||
)
|
||||
for log in self.execution_log
|
||||
]
|
||||
@@ -3,7 +3,6 @@ from flask import Blueprint
|
||||
from application.api import api
|
||||
from application.api.answer.routes.answer import AnswerResource
|
||||
from application.api.answer.routes.base import answer_ns
|
||||
from application.api.answer.routes.search import SearchResource
|
||||
from application.api.answer.routes.stream import StreamResource
|
||||
|
||||
|
||||
@@ -15,7 +14,6 @@ api.add_namespace(answer_ns)
|
||||
def init_answer_routes():
|
||||
api.add_resource(StreamResource, "/stream")
|
||||
api.add_resource(AnswerResource, "/api/answer")
|
||||
api.add_resource(SearchResource, "/api/search")
|
||||
|
||||
|
||||
init_answer_routes()
|
||||
|
||||
@@ -40,9 +40,9 @@ class AnswerResource(Resource, BaseAnswerResource):
|
||||
"chunks": fields.Integer(
|
||||
required=False, default=2, description="Number of chunks"
|
||||
),
|
||||
"token_limit": fields.Integer(required=False, description="Token limit"),
|
||||
"retriever": fields.String(required=False, description="Retriever type"),
|
||||
"api_key": fields.String(required=False, description="API key"),
|
||||
"agent_id": fields.String(required=False, description="Agent ID"),
|
||||
"active_docs": fields.String(
|
||||
required=False, description="Active documents"
|
||||
),
|
||||
@@ -54,14 +54,6 @@ class AnswerResource(Resource, BaseAnswerResource):
|
||||
default=True,
|
||||
description="Whether to save the conversation",
|
||||
),
|
||||
"model_id": fields.String(
|
||||
required=False,
|
||||
description="Model ID to use for this request",
|
||||
),
|
||||
"passthrough": fields.Raw(
|
||||
required=False,
|
||||
description="Dynamic parameters to inject into prompt template",
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
@@ -74,26 +66,22 @@ class AnswerResource(Resource, BaseAnswerResource):
|
||||
decoded_token = getattr(request, "decoded_token", None)
|
||||
processor = StreamProcessor(data, decoded_token)
|
||||
try:
|
||||
agent = processor.build_agent(data.get("question", ""))
|
||||
processor.initialize()
|
||||
if not processor.decoded_token:
|
||||
return make_response({"error": "Unauthorized"}, 401)
|
||||
|
||||
if error := self.check_usage(processor.agent_config):
|
||||
return error
|
||||
agent = processor.create_agent()
|
||||
retriever = processor.create_retriever()
|
||||
|
||||
stream = self.complete_stream(
|
||||
question=data["question"],
|
||||
agent=agent,
|
||||
retriever=retriever,
|
||||
conversation_id=processor.conversation_id,
|
||||
user_api_key=processor.agent_config.get("user_api_key"),
|
||||
decoded_token=processor.decoded_token,
|
||||
isNoneDoc=data.get("isNoneDoc"),
|
||||
index=None,
|
||||
should_save_conversation=data.get("save_conversation", True),
|
||||
agent_id=processor.agent_id,
|
||||
is_shared_usage=processor.is_shared_usage,
|
||||
shared_token=processor.shared_token,
|
||||
model_id=processor.model_id,
|
||||
)
|
||||
stream_result = self.process_response_stream(stream)
|
||||
|
||||
@@ -130,5 +118,5 @@ class AnswerResource(Resource, BaseAnswerResource):
|
||||
f"/api/answer - error: {str(e)} - traceback: {traceback.format_exc()}",
|
||||
extra={"error": str(e), "traceback": traceback.format_exc()},
|
||||
)
|
||||
return make_response({"error": "An error occurred processing your request"}, 500)
|
||||
return make_response({"error": str(e)}, 500)
|
||||
return make_response(result, 200)
|
||||
|
||||
@@ -3,21 +3,15 @@ import json
|
||||
import logging
|
||||
from typing import Any, Dict, Generator, List, Optional
|
||||
|
||||
from flask import jsonify, make_response, Response
|
||||
from flask import Response
|
||||
from flask_restx import Namespace
|
||||
|
||||
from application.api.answer.services.conversation_service import ConversationService
|
||||
from application.core.model_utils import (
|
||||
get_api_key_for_provider,
|
||||
get_default_model_id,
|
||||
get_provider_from_model_id,
|
||||
)
|
||||
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.core.settings import settings
|
||||
from application.error import sanitize_api_error
|
||||
from application.llm.llm_creator import LLMCreator
|
||||
from application.utils import check_required_fields
|
||||
from application.utils import check_required_fields, get_gpt_model
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -31,9 +25,8 @@ class BaseAnswerResource:
|
||||
def __init__(self):
|
||||
mongo = MongoDB.get_client()
|
||||
db = mongo[settings.MONGO_DB_NAME]
|
||||
self.db = db
|
||||
self.user_logs_collection = db["user_logs"]
|
||||
self.default_model_id = get_default_model_id()
|
||||
self.gpt_model = get_gpt_model()
|
||||
self.conversation_service = ConversationService()
|
||||
|
||||
def validate_request(
|
||||
@@ -47,125 +40,11 @@ class BaseAnswerResource:
|
||||
return missing_fields
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _prepare_tool_calls_for_logging(
|
||||
tool_calls: Optional[List[Dict[str, Any]]], max_chars: int = 10000
|
||||
) -> List[Dict[str, Any]]:
|
||||
if not tool_calls:
|
||||
return []
|
||||
|
||||
prepared = []
|
||||
for tool_call in tool_calls:
|
||||
if not isinstance(tool_call, dict):
|
||||
prepared.append({"result": str(tool_call)[:max_chars]})
|
||||
continue
|
||||
|
||||
item = dict(tool_call)
|
||||
for key in ("result", "result_full"):
|
||||
value = item.get(key)
|
||||
if isinstance(value, str) and len(value) > max_chars:
|
||||
item[key] = value[:max_chars]
|
||||
prepared.append(item)
|
||||
return prepared
|
||||
|
||||
def check_usage(self, agent_config: Dict) -> Optional[Response]:
|
||||
"""Check if there is a usage limit and if it is exceeded
|
||||
|
||||
Args:
|
||||
agent_config: The config dict of agent instance
|
||||
|
||||
Returns:
|
||||
None or Response if either of limits exceeded.
|
||||
|
||||
"""
|
||||
api_key = agent_config.get("user_api_key")
|
||||
if not api_key:
|
||||
return None
|
||||
agents_collection = self.db["agents"]
|
||||
agent = agents_collection.find_one({"key": api_key})
|
||||
|
||||
if not agent:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Invalid API key."}), 401
|
||||
)
|
||||
limited_token_mode_raw = agent.get("limited_token_mode", False)
|
||||
limited_request_mode_raw = agent.get("limited_request_mode", False)
|
||||
|
||||
limited_token_mode = (
|
||||
limited_token_mode_raw
|
||||
if isinstance(limited_token_mode_raw, bool)
|
||||
else limited_token_mode_raw == "True"
|
||||
)
|
||||
limited_request_mode = (
|
||||
limited_request_mode_raw
|
||||
if isinstance(limited_request_mode_raw, bool)
|
||||
else limited_request_mode_raw == "True"
|
||||
)
|
||||
|
||||
token_limit = int(
|
||||
agent.get("token_limit", settings.DEFAULT_AGENT_LIMITS["token_limit"])
|
||||
)
|
||||
request_limit = int(
|
||||
agent.get("request_limit", settings.DEFAULT_AGENT_LIMITS["request_limit"])
|
||||
)
|
||||
|
||||
token_usage_collection = self.db["token_usage"]
|
||||
|
||||
end_date = datetime.datetime.now()
|
||||
start_date = end_date - datetime.timedelta(hours=24)
|
||||
|
||||
match_query = {
|
||||
"timestamp": {"$gte": start_date, "$lte": end_date},
|
||||
"api_key": api_key,
|
||||
}
|
||||
|
||||
if limited_token_mode:
|
||||
token_pipeline = [
|
||||
{"$match": match_query},
|
||||
{
|
||||
"$group": {
|
||||
"_id": None,
|
||||
"total_tokens": {
|
||||
"$sum": {"$add": ["$prompt_tokens", "$generated_tokens"]}
|
||||
},
|
||||
}
|
||||
},
|
||||
]
|
||||
token_result = list(token_usage_collection.aggregate(token_pipeline))
|
||||
daily_token_usage = token_result[0]["total_tokens"] if token_result else 0
|
||||
else:
|
||||
daily_token_usage = 0
|
||||
if limited_request_mode:
|
||||
daily_request_usage = token_usage_collection.count_documents(match_query)
|
||||
else:
|
||||
daily_request_usage = 0
|
||||
if not limited_token_mode and not limited_request_mode:
|
||||
return None
|
||||
token_exceeded = (
|
||||
limited_token_mode and token_limit > 0 and daily_token_usage >= token_limit
|
||||
)
|
||||
request_exceeded = (
|
||||
limited_request_mode
|
||||
and request_limit > 0
|
||||
and daily_request_usage >= request_limit
|
||||
)
|
||||
|
||||
if token_exceeded or request_exceeded:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": "Exceeding usage limit, please try again later.",
|
||||
}
|
||||
),
|
||||
429,
|
||||
)
|
||||
return None
|
||||
|
||||
def complete_stream(
|
||||
self,
|
||||
question: str,
|
||||
agent: Any,
|
||||
retriever: Any,
|
||||
conversation_id: Optional[str],
|
||||
user_api_key: Optional[str],
|
||||
decoded_token: Dict[str, Any],
|
||||
@@ -176,7 +55,6 @@ class BaseAnswerResource:
|
||||
agent_id: Optional[str] = None,
|
||||
is_shared_usage: bool = False,
|
||||
shared_token: Optional[str] = None,
|
||||
model_id: Optional[str] = None,
|
||||
) -> Generator[str, None, None]:
|
||||
"""
|
||||
Generator function that streams the complete conversation response.
|
||||
@@ -195,8 +73,6 @@ class BaseAnswerResource:
|
||||
agent_id: ID of agent used
|
||||
is_shared_usage: Flag for shared agent usage
|
||||
shared_token: Token for shared agent
|
||||
model_id: Model ID used for the request
|
||||
retrieved_docs: Pre-fetched documents for sources (optional)
|
||||
|
||||
Yields:
|
||||
Server-sent event strings
|
||||
@@ -206,12 +82,9 @@ class BaseAnswerResource:
|
||||
is_structured = False
|
||||
schema_info = None
|
||||
structured_chunks = []
|
||||
query_metadata = {}
|
||||
|
||||
for line in agent.gen(query=question):
|
||||
if "metadata" in line:
|
||||
query_metadata.update(line["metadata"])
|
||||
elif "answer" in line:
|
||||
for line in agent.gen(query=question, retriever=retriever):
|
||||
if "answer" in line:
|
||||
response_full += str(line["answer"])
|
||||
if line.get("structured"):
|
||||
is_structured = True
|
||||
@@ -237,22 +110,14 @@ class BaseAnswerResource:
|
||||
yield f"data: {data}\n\n"
|
||||
elif "tool_calls" in line:
|
||||
tool_calls = line["tool_calls"]
|
||||
data = json.dumps({"type": "tool_calls", "tool_calls": tool_calls})
|
||||
yield f"data: {data}\n\n"
|
||||
elif "thought" in line:
|
||||
thought += line["thought"]
|
||||
data = json.dumps({"type": "thought", "thought": line["thought"]})
|
||||
yield f"data: {data}\n\n"
|
||||
elif "type" in line:
|
||||
if line.get("type") == "error":
|
||||
sanitized_error = {
|
||||
"type": "error",
|
||||
"error": sanitize_api_error(line.get("error", "An error occurred"))
|
||||
}
|
||||
data = json.dumps(sanitized_error)
|
||||
else:
|
||||
data = json.dumps(line)
|
||||
data = json.dumps(line)
|
||||
yield f"data: {data}\n\n"
|
||||
|
||||
if is_structured and structured_chunks:
|
||||
structured_data = {
|
||||
"type": "structured_answer",
|
||||
@@ -262,23 +127,15 @@ class BaseAnswerResource:
|
||||
}
|
||||
data = json.dumps(structured_data)
|
||||
yield f"data: {data}\n\n"
|
||||
|
||||
if isNoneDoc:
|
||||
for doc in source_log_docs:
|
||||
doc["source"] = "None"
|
||||
provider = (
|
||||
get_provider_from_model_id(model_id)
|
||||
if model_id
|
||||
else settings.LLM_PROVIDER
|
||||
)
|
||||
system_api_key = get_api_key_for_provider(provider or settings.LLM_PROVIDER)
|
||||
|
||||
llm = LLMCreator.create_llm(
|
||||
provider or settings.LLM_PROVIDER,
|
||||
api_key=system_api_key,
|
||||
settings.LLM_PROVIDER,
|
||||
api_key=settings.API_KEY,
|
||||
user_api_key=user_api_key,
|
||||
decoded_token=decoded_token,
|
||||
model_id=model_id,
|
||||
agent_id=agent_id,
|
||||
)
|
||||
|
||||
if should_save_conversation:
|
||||
@@ -290,7 +147,7 @@ class BaseAnswerResource:
|
||||
source_log_docs,
|
||||
tool_calls,
|
||||
llm,
|
||||
model_id or self.default_model_id,
|
||||
self.gpt_model,
|
||||
decoded_token,
|
||||
index=index,
|
||||
api_key=user_api_key,
|
||||
@@ -298,48 +155,23 @@ class BaseAnswerResource:
|
||||
is_shared_usage=is_shared_usage,
|
||||
shared_token=shared_token,
|
||||
attachment_ids=attachment_ids,
|
||||
metadata=query_metadata if query_metadata else None,
|
||||
)
|
||||
# Persist compression metadata/summary if it exists and wasn't saved mid-execution
|
||||
compression_meta = getattr(agent, "compression_metadata", None)
|
||||
compression_saved = getattr(agent, "compression_saved", False)
|
||||
if conversation_id and compression_meta and not compression_saved:
|
||||
try:
|
||||
self.conversation_service.update_compression_metadata(
|
||||
conversation_id, compression_meta
|
||||
)
|
||||
self.conversation_service.append_compression_message(
|
||||
conversation_id, compression_meta
|
||||
)
|
||||
agent.compression_saved = True
|
||||
logger.info(
|
||||
f"Persisted compression metadata for conversation {conversation_id}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to persist compression metadata: {str(e)}",
|
||||
exc_info=True,
|
||||
)
|
||||
else:
|
||||
conversation_id = None
|
||||
id_data = {"type": "id", "id": str(conversation_id)}
|
||||
data = json.dumps(id_data)
|
||||
yield f"data: {data}\n\n"
|
||||
|
||||
tool_calls_for_logging = self._prepare_tool_calls_for_logging(
|
||||
getattr(agent, "tool_calls", tool_calls) or tool_calls
|
||||
)
|
||||
|
||||
retriever_params = retriever.get_params()
|
||||
log_data = {
|
||||
"action": "stream_answer",
|
||||
"level": "info",
|
||||
"user": decoded_token.get("sub"),
|
||||
"api_key": user_api_key,
|
||||
"agent_id": agent_id,
|
||||
"question": question,
|
||||
"response": response_full,
|
||||
"sources": source_log_docs,
|
||||
"tool_calls": tool_calls_for_logging,
|
||||
"retriever_params": retriever_params,
|
||||
"attachments": attachment_ids,
|
||||
"timestamp": datetime.datetime.now(datetime.timezone.utc),
|
||||
}
|
||||
@@ -347,73 +179,18 @@ class BaseAnswerResource:
|
||||
log_data["structured_output"] = True
|
||||
if schema_info:
|
||||
log_data["schema"] = schema_info
|
||||
# Clean up text fields to be no longer than 10000 characters
|
||||
|
||||
|
||||
# clean up text fields to be no longer than 10000 characters
|
||||
for key, value in log_data.items():
|
||||
if isinstance(value, str) and len(value) > 10000:
|
||||
log_data[key] = value[:10000]
|
||||
|
||||
self.user_logs_collection.insert_one(log_data)
|
||||
|
||||
# End of stream
|
||||
|
||||
data = json.dumps({"type": "end"})
|
||||
yield f"data: {data}\n\n"
|
||||
except GeneratorExit:
|
||||
logger.info(f"Stream aborted by client for question: {question[:50]}... ")
|
||||
# Save partial response
|
||||
|
||||
if should_save_conversation and response_full:
|
||||
try:
|
||||
if isNoneDoc:
|
||||
for doc in source_log_docs:
|
||||
doc["source"] = "None"
|
||||
llm = LLMCreator.create_llm(
|
||||
settings.LLM_PROVIDER,
|
||||
api_key=settings.API_KEY,
|
||||
user_api_key=user_api_key,
|
||||
decoded_token=decoded_token,
|
||||
agent_id=agent_id,
|
||||
)
|
||||
self.conversation_service.save_conversation(
|
||||
conversation_id,
|
||||
question,
|
||||
response_full,
|
||||
thought,
|
||||
source_log_docs,
|
||||
tool_calls,
|
||||
llm,
|
||||
model_id or self.default_model_id,
|
||||
decoded_token,
|
||||
index=index,
|
||||
api_key=user_api_key,
|
||||
agent_id=agent_id,
|
||||
is_shared_usage=is_shared_usage,
|
||||
shared_token=shared_token,
|
||||
attachment_ids=attachment_ids,
|
||||
metadata=query_metadata if query_metadata else None,
|
||||
)
|
||||
compression_meta = getattr(agent, "compression_metadata", None)
|
||||
compression_saved = getattr(agent, "compression_saved", False)
|
||||
if conversation_id and compression_meta and not compression_saved:
|
||||
try:
|
||||
self.conversation_service.update_compression_metadata(
|
||||
conversation_id, compression_meta
|
||||
)
|
||||
self.conversation_service.append_compression_message(
|
||||
conversation_id, compression_meta
|
||||
)
|
||||
agent.compression_saved = True
|
||||
logger.info(
|
||||
f"Persisted compression metadata for conversation {conversation_id} (partial stream)"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to persist compression metadata (partial stream): {str(e)}",
|
||||
exc_info=True,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error saving partial response: {str(e)}", exc_info=True
|
||||
)
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error in stream: {str(e)}", exc_info=True)
|
||||
data = json.dumps(
|
||||
@@ -457,7 +234,7 @@ class BaseAnswerResource:
|
||||
thought = event["thought"]
|
||||
elif event["type"] == "error":
|
||||
logger.error(f"Error from stream: {event['error']}")
|
||||
return None, None, None, None, event["error"], None
|
||||
return None, None, None, None, event["error"]
|
||||
elif event["type"] == "end":
|
||||
stream_ended = True
|
||||
except (json.JSONDecodeError, KeyError) as e:
|
||||
@@ -465,7 +242,8 @@ class BaseAnswerResource:
|
||||
continue
|
||||
if not stream_ended:
|
||||
logger.error("Stream ended unexpectedly without an 'end' event.")
|
||||
return None, None, None, None, "Stream ended unexpectedly", None
|
||||
return None, None, None, None, "Stream ended unexpectedly"
|
||||
|
||||
result = (
|
||||
conversation_id,
|
||||
response_full,
|
||||
@@ -477,6 +255,7 @@ class BaseAnswerResource:
|
||||
|
||||
if is_structured:
|
||||
result = result + ({"structured": True, "schema": schema_info},)
|
||||
|
||||
return result
|
||||
|
||||
def error_stream_generate(self, err_response):
|
||||
|
||||
@@ -1,186 +0,0 @@
|
||||
import logging
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from flask import make_response, request
|
||||
from flask_restx import fields, Resource
|
||||
|
||||
from bson.dbref import DBRef
|
||||
|
||||
from application.api.answer.routes.base import answer_ns
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.core.settings import settings
|
||||
from application.vectorstore.vector_creator import VectorCreator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@answer_ns.route("/api/search")
|
||||
class SearchResource(Resource):
|
||||
"""Fast search endpoint for retrieving relevant documents"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
mongo = MongoDB.get_client()
|
||||
self.db = mongo[settings.MONGO_DB_NAME]
|
||||
self.agents_collection = self.db["agents"]
|
||||
|
||||
search_model = answer_ns.model(
|
||||
"SearchModel",
|
||||
{
|
||||
"question": fields.String(
|
||||
required=True, description="Search query"
|
||||
),
|
||||
"api_key": fields.String(
|
||||
required=True, description="API key for authentication"
|
||||
),
|
||||
"chunks": fields.Integer(
|
||||
required=False, default=5, description="Number of results to return"
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
def _get_sources_from_api_key(self, api_key: str) -> List[str]:
|
||||
"""Get source IDs connected to the API key/agent.
|
||||
|
||||
"""
|
||||
agent_data = self.agents_collection.find_one({"key": api_key})
|
||||
if not agent_data:
|
||||
return []
|
||||
|
||||
source_ids = []
|
||||
|
||||
# Handle multiple sources (only if non-empty)
|
||||
sources = agent_data.get("sources", [])
|
||||
if sources and isinstance(sources, list) and len(sources) > 0:
|
||||
for source_ref in sources:
|
||||
# Skip "default" - it's a placeholder, not an actual vectorstore
|
||||
if source_ref == "default":
|
||||
continue
|
||||
elif isinstance(source_ref, DBRef):
|
||||
source_doc = self.db.dereference(source_ref)
|
||||
if source_doc:
|
||||
source_ids.append(str(source_doc["_id"]))
|
||||
|
||||
# Handle single source (legacy) - check if sources was empty or didn't yield results
|
||||
if not source_ids:
|
||||
source = agent_data.get("source")
|
||||
if isinstance(source, DBRef):
|
||||
source_doc = self.db.dereference(source)
|
||||
if source_doc:
|
||||
source_ids.append(str(source_doc["_id"]))
|
||||
# Skip "default" - it's a placeholder, not an actual vectorstore
|
||||
elif source and source != "default":
|
||||
source_ids.append(source)
|
||||
|
||||
return source_ids
|
||||
|
||||
def _search_vectorstores(
|
||||
self, query: str, source_ids: List[str], chunks: int
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Search across vectorstores and return results"""
|
||||
if not source_ids:
|
||||
return []
|
||||
|
||||
results = []
|
||||
chunks_per_source = max(1, chunks // len(source_ids))
|
||||
seen_texts = set()
|
||||
|
||||
for source_id in source_ids:
|
||||
if not source_id or not source_id.strip():
|
||||
continue
|
||||
|
||||
try:
|
||||
docsearch = VectorCreator.create_vectorstore(
|
||||
settings.VECTOR_STORE, source_id, settings.EMBEDDINGS_KEY
|
||||
)
|
||||
docs = docsearch.search(query, k=chunks_per_source * 2)
|
||||
|
||||
for doc in docs:
|
||||
if len(results) >= chunks:
|
||||
break
|
||||
|
||||
if hasattr(doc, "page_content") and hasattr(doc, "metadata"):
|
||||
page_content = doc.page_content
|
||||
metadata = doc.metadata
|
||||
else:
|
||||
page_content = doc.get("text", doc.get("page_content", ""))
|
||||
metadata = doc.get("metadata", {})
|
||||
|
||||
# Skip duplicates
|
||||
text_hash = hash(page_content[:200])
|
||||
if text_hash in seen_texts:
|
||||
continue
|
||||
seen_texts.add(text_hash)
|
||||
|
||||
title = metadata.get(
|
||||
"title", metadata.get("post_title", "")
|
||||
)
|
||||
if not isinstance(title, str):
|
||||
title = str(title) if title else ""
|
||||
|
||||
# Clean up title
|
||||
if title:
|
||||
title = title.split("/")[-1]
|
||||
else:
|
||||
# Use filename or first part of content as title
|
||||
title = metadata.get("filename", page_content[:50] + "...")
|
||||
|
||||
source = metadata.get("source", source_id)
|
||||
|
||||
results.append({
|
||||
"text": page_content,
|
||||
"title": title,
|
||||
"source": source,
|
||||
})
|
||||
|
||||
if len(results) >= chunks:
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error searching vectorstore {source_id}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
continue
|
||||
|
||||
return results[:chunks]
|
||||
|
||||
@answer_ns.expect(search_model)
|
||||
@answer_ns.doc(description="Search for relevant documents based on query")
|
||||
def post(self):
|
||||
data = request.get_json()
|
||||
|
||||
question = data.get("question")
|
||||
api_key = data.get("api_key")
|
||||
chunks = data.get("chunks", 5)
|
||||
|
||||
if not question:
|
||||
return make_response({"error": "question is required"}, 400)
|
||||
|
||||
if not api_key:
|
||||
return make_response({"error": "api_key is required"}, 400)
|
||||
|
||||
# Validate API key
|
||||
agent = self.agents_collection.find_one({"key": api_key})
|
||||
if not agent:
|
||||
return make_response({"error": "Invalid API key"}, 401)
|
||||
|
||||
try:
|
||||
# Get sources connected to this API key
|
||||
source_ids = self._get_sources_from_api_key(api_key)
|
||||
|
||||
if not source_ids:
|
||||
return make_response([], 200)
|
||||
|
||||
# Perform search
|
||||
results = self._search_vectorstores(question, source_ids, chunks)
|
||||
|
||||
return make_response(results, 200)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"/api/search - error: {str(e)}",
|
||||
extra={"error": str(e)},
|
||||
exc_info=True,
|
||||
)
|
||||
return make_response({"error": "Search failed"}, 500)
|
||||
@@ -40,9 +40,9 @@ class StreamResource(Resource, BaseAnswerResource):
|
||||
"chunks": fields.Integer(
|
||||
required=False, default=2, description="Number of chunks"
|
||||
),
|
||||
"token_limit": fields.Integer(required=False, description="Token limit"),
|
||||
"retriever": fields.String(required=False, description="Retriever type"),
|
||||
"api_key": fields.String(required=False, description="API key"),
|
||||
"agent_id": fields.String(required=False, description="Agent ID"),
|
||||
"active_docs": fields.String(
|
||||
required=False, description="Active documents"
|
||||
),
|
||||
@@ -57,17 +57,9 @@ class StreamResource(Resource, BaseAnswerResource):
|
||||
default=True,
|
||||
description="Whether to save the conversation",
|
||||
),
|
||||
"model_id": fields.String(
|
||||
required=False,
|
||||
description="Model ID to use for this request",
|
||||
),
|
||||
"attachments": fields.List(
|
||||
fields.String, required=False, description="List of attachment IDs"
|
||||
),
|
||||
"passthrough": fields.Raw(
|
||||
required=False,
|
||||
description="Dynamic parameters to inject into prompt template",
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
@@ -80,20 +72,15 @@ class StreamResource(Resource, BaseAnswerResource):
|
||||
decoded_token = getattr(request, "decoded_token", None)
|
||||
processor = StreamProcessor(data, decoded_token)
|
||||
try:
|
||||
agent = processor.build_agent(data["question"])
|
||||
if not processor.decoded_token:
|
||||
return Response(
|
||||
self.error_stream_generate("Unauthorized"),
|
||||
status=401,
|
||||
mimetype="text/event-stream",
|
||||
)
|
||||
processor.initialize()
|
||||
agent = processor.create_agent()
|
||||
retriever = processor.create_retriever()
|
||||
|
||||
if error := self.check_usage(processor.agent_config):
|
||||
return error
|
||||
return Response(
|
||||
self.complete_stream(
|
||||
question=data["question"],
|
||||
agent=agent,
|
||||
retriever=retriever,
|
||||
conversation_id=processor.conversation_id,
|
||||
user_api_key=processor.agent_config.get("user_api_key"),
|
||||
decoded_token=processor.decoded_token,
|
||||
@@ -101,10 +88,9 @@ class StreamResource(Resource, BaseAnswerResource):
|
||||
index=data.get("index"),
|
||||
should_save_conversation=data.get("save_conversation", True),
|
||||
attachment_ids=data.get("attachments", []),
|
||||
agent_id=processor.agent_id,
|
||||
agent_id=data.get("agent_id"),
|
||||
is_shared_usage=processor.is_shared_usage,
|
||||
shared_token=processor.shared_token,
|
||||
model_id=processor.model_id,
|
||||
),
|
||||
mimetype="text/event-stream",
|
||||
)
|
||||
|
||||
@@ -1,20 +0,0 @@
|
||||
"""
|
||||
Compression module for managing conversation context compression.
|
||||
|
||||
"""
|
||||
|
||||
from application.api.answer.services.compression.orchestrator import (
|
||||
CompressionOrchestrator,
|
||||
)
|
||||
from application.api.answer.services.compression.service import CompressionService
|
||||
from application.api.answer.services.compression.types import (
|
||||
CompressionResult,
|
||||
CompressionMetadata,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"CompressionOrchestrator",
|
||||
"CompressionService",
|
||||
"CompressionResult",
|
||||
"CompressionMetadata",
|
||||
]
|
||||
@@ -1,234 +0,0 @@
|
||||
"""Message reconstruction utilities for compression."""
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MessageBuilder:
|
||||
"""Builds message arrays from compressed context."""
|
||||
|
||||
@staticmethod
|
||||
def build_from_compressed_context(
|
||||
system_prompt: str,
|
||||
compressed_summary: Optional[str],
|
||||
recent_queries: List[Dict],
|
||||
include_tool_calls: bool = False,
|
||||
context_type: str = "pre_request",
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
Build messages from compressed context.
|
||||
|
||||
Args:
|
||||
system_prompt: Original system prompt
|
||||
compressed_summary: Compressed summary (if any)
|
||||
recent_queries: Recent uncompressed queries
|
||||
include_tool_calls: Whether to include tool calls from history
|
||||
context_type: Type of context ('pre_request' or 'mid_execution')
|
||||
|
||||
Returns:
|
||||
List of message dicts ready for LLM
|
||||
"""
|
||||
# Append compression summary to system prompt if present
|
||||
if compressed_summary:
|
||||
system_prompt = MessageBuilder._append_compression_context(
|
||||
system_prompt, compressed_summary, context_type
|
||||
)
|
||||
|
||||
messages = [{"role": "system", "content": system_prompt}]
|
||||
|
||||
# Add recent history
|
||||
for query in recent_queries:
|
||||
if "prompt" in query and "response" in query:
|
||||
messages.append({"role": "user", "content": query["prompt"]})
|
||||
messages.append({"role": "assistant", "content": query["response"]})
|
||||
|
||||
# Add tool calls from history if present
|
||||
if include_tool_calls and "tool_calls" in query:
|
||||
for tool_call in query["tool_calls"]:
|
||||
call_id = tool_call.get("call_id") or str(uuid.uuid4())
|
||||
|
||||
function_call_dict = {
|
||||
"function_call": {
|
||||
"name": tool_call.get("action_name"),
|
||||
"args": tool_call.get("arguments"),
|
||||
"call_id": call_id,
|
||||
}
|
||||
}
|
||||
function_response_dict = {
|
||||
"function_response": {
|
||||
"name": tool_call.get("action_name"),
|
||||
"response": {"result": tool_call.get("result")},
|
||||
"call_id": call_id,
|
||||
}
|
||||
}
|
||||
|
||||
messages.append(
|
||||
{"role": "assistant", "content": [function_call_dict]}
|
||||
)
|
||||
messages.append(
|
||||
{"role": "tool", "content": [function_response_dict]}
|
||||
)
|
||||
|
||||
# If no recent queries (everything was compressed), add a continuation user message
|
||||
if len(recent_queries) == 0 and compressed_summary:
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": "Please continue with the remaining tasks based on the context above."
|
||||
})
|
||||
logger.info("Added continuation user message to maintain proper turn-taking after full compression")
|
||||
|
||||
return messages
|
||||
|
||||
@staticmethod
|
||||
def _append_compression_context(
|
||||
system_prompt: str, compressed_summary: str, context_type: str = "pre_request"
|
||||
) -> str:
|
||||
"""
|
||||
Append compression context to system prompt.
|
||||
|
||||
Args:
|
||||
system_prompt: Original system prompt
|
||||
compressed_summary: Summary to append
|
||||
context_type: Type of compression context
|
||||
|
||||
Returns:
|
||||
Updated system prompt
|
||||
"""
|
||||
# Remove existing compression context if present
|
||||
if "This session is being continued" in system_prompt or "Context window limit reached" in system_prompt:
|
||||
parts = system_prompt.split("\n\n---\n\n")
|
||||
system_prompt = parts[0]
|
||||
|
||||
# Build appropriate context message based on type
|
||||
if context_type == "mid_execution":
|
||||
context_message = (
|
||||
"\n\n---\n\n"
|
||||
"Context window limit reached during execution. "
|
||||
"Previous conversation has been compressed to fit within limits. "
|
||||
"The conversation is summarized below:\n\n"
|
||||
f"{compressed_summary}"
|
||||
)
|
||||
else: # pre_request
|
||||
context_message = (
|
||||
"\n\n---\n\n"
|
||||
"This session is being continued from a previous conversation that "
|
||||
"has been compressed to fit within context limits. "
|
||||
"The conversation is summarized below:\n\n"
|
||||
f"{compressed_summary}"
|
||||
)
|
||||
|
||||
return system_prompt + context_message
|
||||
|
||||
@staticmethod
|
||||
def rebuild_messages_after_compression(
|
||||
messages: List[Dict],
|
||||
compressed_summary: Optional[str],
|
||||
recent_queries: List[Dict],
|
||||
include_current_execution: bool = False,
|
||||
include_tool_calls: bool = False,
|
||||
) -> Optional[List[Dict]]:
|
||||
"""
|
||||
Rebuild the message list after compression so tool execution can continue.
|
||||
|
||||
Args:
|
||||
messages: Original message list
|
||||
compressed_summary: Compressed summary
|
||||
recent_queries: Recent uncompressed queries
|
||||
include_current_execution: Whether to preserve current execution messages
|
||||
include_tool_calls: Whether to include tool calls from history
|
||||
|
||||
Returns:
|
||||
Rebuilt message list or None if failed
|
||||
"""
|
||||
# Find the system message
|
||||
system_message = next(
|
||||
(msg for msg in messages if msg.get("role") == "system"), None
|
||||
)
|
||||
if not system_message:
|
||||
logger.warning("No system message found in messages list")
|
||||
return None
|
||||
|
||||
# Update system message with compressed summary
|
||||
if compressed_summary:
|
||||
content = system_message.get("content", "")
|
||||
system_message["content"] = MessageBuilder._append_compression_context(
|
||||
content, compressed_summary, "mid_execution"
|
||||
)
|
||||
logger.info(
|
||||
"Appended compression summary to system prompt (truncated): %s",
|
||||
(
|
||||
compressed_summary[:500] + "..."
|
||||
if len(compressed_summary) > 500
|
||||
else compressed_summary
|
||||
),
|
||||
)
|
||||
|
||||
rebuilt_messages = [system_message]
|
||||
|
||||
# Add recent history from compressed context
|
||||
for query in recent_queries:
|
||||
if "prompt" in query and "response" in query:
|
||||
rebuilt_messages.append({"role": "user", "content": query["prompt"]})
|
||||
rebuilt_messages.append(
|
||||
{"role": "assistant", "content": query["response"]}
|
||||
)
|
||||
|
||||
# Add tool calls from history if present
|
||||
if include_tool_calls and "tool_calls" in query:
|
||||
for tool_call in query["tool_calls"]:
|
||||
call_id = tool_call.get("call_id") or str(uuid.uuid4())
|
||||
|
||||
function_call_dict = {
|
||||
"function_call": {
|
||||
"name": tool_call.get("action_name"),
|
||||
"args": tool_call.get("arguments"),
|
||||
"call_id": call_id,
|
||||
}
|
||||
}
|
||||
function_response_dict = {
|
||||
"function_response": {
|
||||
"name": tool_call.get("action_name"),
|
||||
"response": {"result": tool_call.get("result")},
|
||||
"call_id": call_id,
|
||||
}
|
||||
}
|
||||
|
||||
rebuilt_messages.append(
|
||||
{"role": "assistant", "content": [function_call_dict]}
|
||||
)
|
||||
rebuilt_messages.append(
|
||||
{"role": "tool", "content": [function_response_dict]}
|
||||
)
|
||||
|
||||
# If no recent queries (everything was compressed), add a continuation user message
|
||||
if len(recent_queries) == 0 and compressed_summary:
|
||||
rebuilt_messages.append({
|
||||
"role": "user",
|
||||
"content": "Please continue with the remaining tasks based on the context above."
|
||||
})
|
||||
logger.info("Added continuation user message to maintain proper turn-taking after full compression")
|
||||
|
||||
if include_current_execution:
|
||||
# Preserve any messages that were added during the current execution cycle
|
||||
recent_msg_count = 1 # system message
|
||||
for query in recent_queries:
|
||||
if "prompt" in query and "response" in query:
|
||||
recent_msg_count += 2
|
||||
if "tool_calls" in query:
|
||||
recent_msg_count += len(query["tool_calls"]) * 2
|
||||
|
||||
if len(messages) > recent_msg_count:
|
||||
current_execution_messages = messages[recent_msg_count:]
|
||||
rebuilt_messages.extend(current_execution_messages)
|
||||
logger.info(
|
||||
f"Preserved {len(current_execution_messages)} messages from current execution cycle"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Messages rebuilt: {len(messages)} → {len(rebuilt_messages)} messages. "
|
||||
f"Ready to continue tool execution."
|
||||
)
|
||||
return rebuilt_messages
|
||||
@@ -1,233 +0,0 @@
|
||||
"""High-level compression orchestration."""
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from application.api.answer.services.compression.service import CompressionService
|
||||
from application.api.answer.services.compression.threshold_checker import (
|
||||
CompressionThresholdChecker,
|
||||
)
|
||||
from application.api.answer.services.compression.types import CompressionResult
|
||||
from application.api.answer.services.conversation_service import ConversationService
|
||||
from application.core.model_utils import (
|
||||
get_api_key_for_provider,
|
||||
get_provider_from_model_id,
|
||||
)
|
||||
from application.core.settings import settings
|
||||
from application.llm.llm_creator import LLMCreator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CompressionOrchestrator:
|
||||
"""
|
||||
Facade for compression operations.
|
||||
|
||||
Coordinates between all compression components and provides
|
||||
a simple interface for callers.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
conversation_service: ConversationService,
|
||||
threshold_checker: Optional[CompressionThresholdChecker] = None,
|
||||
):
|
||||
"""
|
||||
Initialize orchestrator.
|
||||
|
||||
Args:
|
||||
conversation_service: Service for DB operations
|
||||
threshold_checker: Custom threshold checker (optional)
|
||||
"""
|
||||
self.conversation_service = conversation_service
|
||||
self.threshold_checker = threshold_checker or CompressionThresholdChecker()
|
||||
|
||||
def compress_if_needed(
|
||||
self,
|
||||
conversation_id: str,
|
||||
user_id: str,
|
||||
model_id: str,
|
||||
decoded_token: Dict[str, Any],
|
||||
current_query_tokens: int = 500,
|
||||
) -> CompressionResult:
|
||||
"""
|
||||
Check if compression is needed and perform it if so.
|
||||
|
||||
This is the main entry point for compression operations.
|
||||
|
||||
Args:
|
||||
conversation_id: Conversation ID
|
||||
user_id: User ID
|
||||
model_id: Model being used for conversation
|
||||
decoded_token: User's decoded JWT token
|
||||
current_query_tokens: Estimated tokens for current query
|
||||
|
||||
Returns:
|
||||
CompressionResult with summary and recent queries
|
||||
"""
|
||||
try:
|
||||
# Load conversation
|
||||
conversation = self.conversation_service.get_conversation(
|
||||
conversation_id, user_id
|
||||
)
|
||||
|
||||
if not conversation:
|
||||
logger.warning(
|
||||
f"Conversation {conversation_id} not found for user {user_id}"
|
||||
)
|
||||
return CompressionResult.failure("Conversation not found")
|
||||
|
||||
# Check if compression is needed
|
||||
if not self.threshold_checker.should_compress(
|
||||
conversation, model_id, current_query_tokens
|
||||
):
|
||||
# No compression needed, return full history
|
||||
queries = conversation.get("queries", [])
|
||||
return CompressionResult.success_no_compression(queries)
|
||||
|
||||
# Perform compression
|
||||
return self._perform_compression(
|
||||
conversation_id, conversation, model_id, decoded_token
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error in compress_if_needed: {str(e)}", exc_info=True
|
||||
)
|
||||
return CompressionResult.failure(str(e))
|
||||
|
||||
def _perform_compression(
|
||||
self,
|
||||
conversation_id: str,
|
||||
conversation: Dict[str, Any],
|
||||
model_id: str,
|
||||
decoded_token: Dict[str, Any],
|
||||
) -> CompressionResult:
|
||||
"""
|
||||
Perform the actual compression operation.
|
||||
|
||||
Args:
|
||||
conversation_id: Conversation ID
|
||||
conversation: Conversation document
|
||||
model_id: Model ID for conversation
|
||||
decoded_token: User token
|
||||
|
||||
Returns:
|
||||
CompressionResult
|
||||
"""
|
||||
try:
|
||||
# Determine which model to use for compression
|
||||
compression_model = (
|
||||
settings.COMPRESSION_MODEL_OVERRIDE
|
||||
if settings.COMPRESSION_MODEL_OVERRIDE
|
||||
else model_id
|
||||
)
|
||||
|
||||
# Get provider and API key for compression model
|
||||
provider = get_provider_from_model_id(compression_model)
|
||||
api_key = get_api_key_for_provider(provider)
|
||||
|
||||
# Create compression LLM
|
||||
compression_llm = LLMCreator.create_llm(
|
||||
provider,
|
||||
api_key=api_key,
|
||||
user_api_key=None,
|
||||
decoded_token=decoded_token,
|
||||
model_id=compression_model,
|
||||
agent_id=conversation.get("agent_id"),
|
||||
)
|
||||
|
||||
# Create compression service with DB update capability
|
||||
compression_service = CompressionService(
|
||||
llm=compression_llm,
|
||||
model_id=compression_model,
|
||||
conversation_service=self.conversation_service,
|
||||
)
|
||||
|
||||
# Compress all queries up to the latest
|
||||
queries_count = len(conversation.get("queries", []))
|
||||
compress_up_to = queries_count - 1
|
||||
|
||||
if compress_up_to < 0:
|
||||
logger.warning("No queries to compress")
|
||||
return CompressionResult.success_no_compression([])
|
||||
|
||||
logger.info(
|
||||
f"Initiating compression for conversation {conversation_id}: "
|
||||
f"compressing all {queries_count} queries (0-{compress_up_to})"
|
||||
)
|
||||
|
||||
# Perform compression and save to DB
|
||||
metadata = compression_service.compress_and_save(
|
||||
conversation_id, conversation, compress_up_to
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Compression successful - ratio: {metadata.compression_ratio:.1f}x, "
|
||||
f"saved {metadata.original_token_count - metadata.compressed_token_count} tokens"
|
||||
)
|
||||
|
||||
# Reload conversation with updated metadata
|
||||
conversation = self.conversation_service.get_conversation(
|
||||
conversation_id, user_id=decoded_token.get("sub")
|
||||
)
|
||||
|
||||
# Get compressed context
|
||||
compressed_summary, recent_queries = (
|
||||
compression_service.get_compressed_context(conversation)
|
||||
)
|
||||
|
||||
return CompressionResult.success_with_compression(
|
||||
compressed_summary, recent_queries, metadata
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error performing compression: {str(e)}", exc_info=True)
|
||||
return CompressionResult.failure(str(e))
|
||||
|
||||
def compress_mid_execution(
|
||||
self,
|
||||
conversation_id: str,
|
||||
user_id: str,
|
||||
model_id: str,
|
||||
decoded_token: Dict[str, Any],
|
||||
current_conversation: Optional[Dict[str, Any]] = None,
|
||||
) -> CompressionResult:
|
||||
"""
|
||||
Perform compression during tool execution.
|
||||
|
||||
Args:
|
||||
conversation_id: Conversation ID
|
||||
user_id: User ID
|
||||
model_id: Model ID
|
||||
decoded_token: User token
|
||||
current_conversation: Pre-loaded conversation (optional)
|
||||
|
||||
Returns:
|
||||
CompressionResult
|
||||
"""
|
||||
try:
|
||||
# Load conversation if not provided
|
||||
if current_conversation:
|
||||
conversation = current_conversation
|
||||
else:
|
||||
conversation = self.conversation_service.get_conversation(
|
||||
conversation_id, user_id
|
||||
)
|
||||
|
||||
if not conversation:
|
||||
logger.warning(
|
||||
f"Could not load conversation {conversation_id} for mid-execution compression"
|
||||
)
|
||||
return CompressionResult.failure("Conversation not found")
|
||||
|
||||
# Perform compression
|
||||
return self._perform_compression(
|
||||
conversation_id, conversation, model_id, decoded_token
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error in mid-execution compression: {str(e)}", exc_info=True
|
||||
)
|
||||
return CompressionResult.failure(str(e))
|
||||
@@ -1,149 +0,0 @@
|
||||
"""Compression prompt building logic."""
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CompressionPromptBuilder:
|
||||
"""Builds prompts for LLM compression calls."""
|
||||
|
||||
def __init__(self, version: str = "v1.0"):
|
||||
"""
|
||||
Initialize prompt builder.
|
||||
|
||||
Args:
|
||||
version: Prompt template version to use
|
||||
"""
|
||||
self.version = version
|
||||
self.system_prompt = self._load_prompt(version)
|
||||
|
||||
def _load_prompt(self, version: str) -> str:
|
||||
"""
|
||||
Load prompt template from file.
|
||||
|
||||
Args:
|
||||
version: Version string (e.g., 'v1.0')
|
||||
|
||||
Returns:
|
||||
Prompt template content
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If prompt template file doesn't exist
|
||||
"""
|
||||
current_dir = Path(__file__).resolve().parents[4]
|
||||
prompt_path = current_dir / "prompts" / "compression" / f"{version}.txt"
|
||||
|
||||
try:
|
||||
with open(prompt_path, "r") as f:
|
||||
return f.read()
|
||||
except FileNotFoundError:
|
||||
logger.error(f"Compression prompt template not found: {prompt_path}")
|
||||
raise FileNotFoundError(
|
||||
f"Compression prompt template '{version}' not found at {prompt_path}. "
|
||||
f"Please ensure the template file exists."
|
||||
)
|
||||
|
||||
def build_prompt(
|
||||
self,
|
||||
queries: List[Dict[str, Any]],
|
||||
existing_compressions: Optional[List[Dict[str, Any]]] = None,
|
||||
) -> List[Dict[str, str]]:
|
||||
"""
|
||||
Build messages for compression LLM call.
|
||||
|
||||
Args:
|
||||
queries: List of query objects to compress
|
||||
existing_compressions: List of previous compression points
|
||||
|
||||
Returns:
|
||||
List of message dicts for LLM
|
||||
"""
|
||||
# Build conversation text
|
||||
conversation_text = self._format_conversation(queries)
|
||||
|
||||
# Add existing compression context if present
|
||||
existing_compression_context = ""
|
||||
if existing_compressions and len(existing_compressions) > 0:
|
||||
existing_compression_context = (
|
||||
"\n\nIMPORTANT: This conversation has been compressed before. "
|
||||
"Previous compression summaries:\n\n"
|
||||
)
|
||||
for i, comp in enumerate(existing_compressions):
|
||||
existing_compression_context += (
|
||||
f"--- Compression {i + 1} (up to message {comp.get('query_index', 'unknown')}) ---\n"
|
||||
f"{comp.get('compressed_summary', '')}\n\n"
|
||||
)
|
||||
existing_compression_context += (
|
||||
"Your task is to create a NEW summary that incorporates the context from "
|
||||
"previous compressions AND the new messages below. The final summary should "
|
||||
"be comprehensive and include all important information from both previous "
|
||||
"compressions and new messages.\n\n"
|
||||
)
|
||||
|
||||
user_prompt = (
|
||||
f"{existing_compression_context}"
|
||||
f"Here is the conversation to summarize:\n\n"
|
||||
f"{conversation_text}"
|
||||
)
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": self.system_prompt},
|
||||
{"role": "user", "content": user_prompt},
|
||||
]
|
||||
|
||||
return messages
|
||||
|
||||
def _format_conversation(self, queries: List[Dict[str, Any]]) -> str:
|
||||
"""
|
||||
Format conversation queries into readable text for compression.
|
||||
|
||||
Args:
|
||||
queries: List of query objects
|
||||
|
||||
Returns:
|
||||
Formatted conversation text
|
||||
"""
|
||||
conversation_lines = []
|
||||
|
||||
for i, query in enumerate(queries):
|
||||
conversation_lines.append(f"--- Message {i + 1} ---")
|
||||
conversation_lines.append(f"User: {query.get('prompt', '')}")
|
||||
|
||||
# Add tool calls if present
|
||||
tool_calls = query.get("tool_calls", [])
|
||||
if tool_calls:
|
||||
conversation_lines.append("\nTool Calls:")
|
||||
for tc in tool_calls:
|
||||
tool_name = tc.get("tool_name", "unknown")
|
||||
action_name = tc.get("action_name", "unknown")
|
||||
arguments = tc.get("arguments", {})
|
||||
result = tc.get("result", "")
|
||||
if result is None:
|
||||
result = ""
|
||||
status = tc.get("status", "unknown")
|
||||
|
||||
# Include full tool result for complete compression context
|
||||
conversation_lines.append(
|
||||
f" - {tool_name}.{action_name}({arguments}) "
|
||||
f"[{status}] → {result}"
|
||||
)
|
||||
|
||||
# Add agent thought if present
|
||||
thought = query.get("thought", "")
|
||||
if thought:
|
||||
conversation_lines.append(f"\nAgent Thought: {thought}")
|
||||
|
||||
# Add assistant response
|
||||
conversation_lines.append(f"\nAssistant: {query.get('response', '')}")
|
||||
|
||||
# Add sources if present
|
||||
sources = query.get("sources", [])
|
||||
if sources:
|
||||
conversation_lines.append(f"\nSources Used: {len(sources)} documents")
|
||||
|
||||
conversation_lines.append("") # Empty line between messages
|
||||
|
||||
return "\n".join(conversation_lines)
|
||||
@@ -1,306 +0,0 @@
|
||||
"""Core compression service with simplified responsibilities."""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from application.api.answer.services.compression.prompt_builder import (
|
||||
CompressionPromptBuilder,
|
||||
)
|
||||
from application.api.answer.services.compression.token_counter import TokenCounter
|
||||
from application.api.answer.services.compression.types import (
|
||||
CompressionMetadata,
|
||||
)
|
||||
from application.core.settings import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CompressionService:
|
||||
"""
|
||||
Service for compressing conversation history.
|
||||
|
||||
Handles DB updates.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm,
|
||||
model_id: str,
|
||||
conversation_service=None,
|
||||
prompt_builder: Optional[CompressionPromptBuilder] = None,
|
||||
):
|
||||
"""
|
||||
Initialize compression service.
|
||||
|
||||
Args:
|
||||
llm: LLM instance to use for compression
|
||||
model_id: Model ID for compression
|
||||
conversation_service: Service for DB operations (optional, for DB updates)
|
||||
prompt_builder: Custom prompt builder (optional)
|
||||
"""
|
||||
self.llm = llm
|
||||
self.model_id = model_id
|
||||
self.conversation_service = conversation_service
|
||||
self.prompt_builder = prompt_builder or CompressionPromptBuilder(
|
||||
version=settings.COMPRESSION_PROMPT_VERSION
|
||||
)
|
||||
|
||||
def compress_conversation(
|
||||
self,
|
||||
conversation: Dict[str, Any],
|
||||
compress_up_to_index: int,
|
||||
) -> CompressionMetadata:
|
||||
"""
|
||||
Compress conversation history up to specified index.
|
||||
|
||||
Args:
|
||||
conversation: Full conversation document
|
||||
compress_up_to_index: Last query index to include in compression
|
||||
|
||||
Returns:
|
||||
CompressionMetadata with compression details
|
||||
|
||||
Raises:
|
||||
ValueError: If compress_up_to_index is invalid
|
||||
"""
|
||||
try:
|
||||
queries = conversation.get("queries", [])
|
||||
|
||||
if compress_up_to_index < 0 or compress_up_to_index >= len(queries):
|
||||
raise ValueError(
|
||||
f"Invalid compress_up_to_index: {compress_up_to_index} "
|
||||
f"(conversation has {len(queries)} queries)"
|
||||
)
|
||||
|
||||
# Get queries to compress
|
||||
queries_to_compress = queries[: compress_up_to_index + 1]
|
||||
|
||||
# Check if there are existing compressions
|
||||
existing_compressions = conversation.get("compression_metadata", {}).get(
|
||||
"compression_points", []
|
||||
)
|
||||
|
||||
if existing_compressions:
|
||||
logger.info(
|
||||
f"Found {len(existing_compressions)} previous compression(s) - "
|
||||
f"will incorporate into new summary"
|
||||
)
|
||||
|
||||
# Calculate original token count
|
||||
original_tokens = TokenCounter.count_query_tokens(queries_to_compress)
|
||||
|
||||
# Log tool call stats
|
||||
self._log_tool_call_stats(queries_to_compress)
|
||||
|
||||
# Build compression prompt
|
||||
messages = self.prompt_builder.build_prompt(
|
||||
queries_to_compress, existing_compressions
|
||||
)
|
||||
|
||||
# Call LLM to generate compression
|
||||
logger.info(
|
||||
f"Starting compression: {len(queries_to_compress)} queries "
|
||||
f"(messages 0-{compress_up_to_index}, {original_tokens} tokens) "
|
||||
f"using model {self.model_id}"
|
||||
)
|
||||
|
||||
response = self.llm.gen(
|
||||
model=self.model_id, messages=messages, max_tokens=4000
|
||||
)
|
||||
|
||||
# Extract summary from response
|
||||
compressed_summary = self._extract_summary(response)
|
||||
|
||||
# Calculate compressed token count
|
||||
compressed_tokens = TokenCounter.count_message_tokens(
|
||||
[{"content": compressed_summary}]
|
||||
)
|
||||
|
||||
# Calculate compression ratio
|
||||
compression_ratio = (
|
||||
original_tokens / compressed_tokens if compressed_tokens > 0 else 0
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Compression complete: {original_tokens} → {compressed_tokens} tokens "
|
||||
f"({compression_ratio:.1f}x compression)"
|
||||
)
|
||||
|
||||
# Build compression metadata
|
||||
compression_metadata = CompressionMetadata(
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
query_index=compress_up_to_index,
|
||||
compressed_summary=compressed_summary,
|
||||
original_token_count=original_tokens,
|
||||
compressed_token_count=compressed_tokens,
|
||||
compression_ratio=compression_ratio,
|
||||
model_used=self.model_id,
|
||||
compression_prompt_version=self.prompt_builder.version,
|
||||
)
|
||||
|
||||
return compression_metadata
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error compressing conversation: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
def compress_and_save(
|
||||
self,
|
||||
conversation_id: str,
|
||||
conversation: Dict[str, Any],
|
||||
compress_up_to_index: int,
|
||||
) -> CompressionMetadata:
|
||||
"""
|
||||
Compress conversation and save to database.
|
||||
|
||||
Args:
|
||||
conversation_id: Conversation ID
|
||||
conversation: Full conversation document
|
||||
compress_up_to_index: Last query index to include
|
||||
|
||||
Returns:
|
||||
CompressionMetadata
|
||||
|
||||
Raises:
|
||||
ValueError: If conversation_service not provided or invalid index
|
||||
"""
|
||||
if not self.conversation_service:
|
||||
raise ValueError(
|
||||
"conversation_service required for compress_and_save operation"
|
||||
)
|
||||
|
||||
# Perform compression
|
||||
metadata = self.compress_conversation(conversation, compress_up_to_index)
|
||||
|
||||
# Save to database
|
||||
self.conversation_service.update_compression_metadata(
|
||||
conversation_id, metadata.to_dict()
|
||||
)
|
||||
|
||||
logger.info(f"Compression metadata saved to database for {conversation_id}")
|
||||
|
||||
return metadata
|
||||
|
||||
def get_compressed_context(
|
||||
self, conversation: Dict[str, Any]
|
||||
) -> tuple[Optional[str], List[Dict[str, Any]]]:
|
||||
"""
|
||||
Get compressed summary + recent uncompressed messages.
|
||||
|
||||
Args:
|
||||
conversation: Full conversation document
|
||||
|
||||
Returns:
|
||||
(compressed_summary, recent_messages)
|
||||
"""
|
||||
try:
|
||||
compression_metadata = conversation.get("compression_metadata", {})
|
||||
|
||||
if not compression_metadata.get("is_compressed"):
|
||||
logger.debug("No compression metadata found - using full history")
|
||||
queries = conversation.get("queries", [])
|
||||
if queries is None:
|
||||
logger.error("Conversation queries is None - returning empty list")
|
||||
return None, []
|
||||
return None, queries
|
||||
|
||||
compression_points = compression_metadata.get("compression_points", [])
|
||||
|
||||
if not compression_points:
|
||||
logger.debug("No compression points found - using full history")
|
||||
queries = conversation.get("queries", [])
|
||||
if queries is None:
|
||||
logger.error("Conversation queries is None - returning empty list")
|
||||
return None, []
|
||||
return None, queries
|
||||
|
||||
# Get the most recent compression point
|
||||
latest_compression = compression_points[-1]
|
||||
compressed_summary = latest_compression.get("compressed_summary")
|
||||
last_compressed_index = latest_compression.get("query_index")
|
||||
compressed_tokens = latest_compression.get("compressed_token_count", 0)
|
||||
original_tokens = latest_compression.get("original_token_count", 0)
|
||||
|
||||
# Get only messages after compression point
|
||||
queries = conversation.get("queries", [])
|
||||
total_queries = len(queries)
|
||||
recent_queries = queries[last_compressed_index + 1 :]
|
||||
|
||||
logger.info(
|
||||
f"Using compressed context: summary ({compressed_tokens} tokens, "
|
||||
f"compressed from {original_tokens}) + {len(recent_queries)} recent messages "
|
||||
f"(messages {last_compressed_index + 1}-{total_queries - 1})"
|
||||
)
|
||||
|
||||
return compressed_summary, recent_queries
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error getting compressed context: {str(e)}", exc_info=True
|
||||
)
|
||||
queries = conversation.get("queries", [])
|
||||
if queries is None:
|
||||
return None, []
|
||||
return None, queries
|
||||
|
||||
def _extract_summary(self, llm_response: str) -> str:
|
||||
"""
|
||||
Extract clean summary from LLM response.
|
||||
|
||||
Args:
|
||||
llm_response: Raw LLM response
|
||||
|
||||
Returns:
|
||||
Cleaned summary text
|
||||
"""
|
||||
try:
|
||||
# Try to extract content within <summary> tags
|
||||
summary_match = re.search(
|
||||
r"<summary>(.*?)</summary>", llm_response, re.DOTALL
|
||||
)
|
||||
|
||||
if summary_match:
|
||||
summary = summary_match.group(1).strip()
|
||||
else:
|
||||
# If no summary tags, remove analysis tags and use the rest
|
||||
summary = re.sub(
|
||||
r"<analysis>.*?</analysis>", "", llm_response, flags=re.DOTALL
|
||||
).strip()
|
||||
|
||||
return summary
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error extracting summary: {str(e)}, using full response")
|
||||
return llm_response
|
||||
|
||||
def _log_tool_call_stats(self, queries: List[Dict[str, Any]]) -> None:
|
||||
"""Log statistics about tool calls in queries."""
|
||||
total_tool_calls = 0
|
||||
total_tool_result_chars = 0
|
||||
tool_call_breakdown = {}
|
||||
|
||||
for q in queries:
|
||||
for tc in q.get("tool_calls", []):
|
||||
total_tool_calls += 1
|
||||
tool_name = tc.get("tool_name", "unknown")
|
||||
action_name = tc.get("action_name", "unknown")
|
||||
key = f"{tool_name}.{action_name}"
|
||||
tool_call_breakdown[key] = tool_call_breakdown.get(key, 0) + 1
|
||||
|
||||
# Track total tool result size
|
||||
result = tc.get("result", "")
|
||||
if result:
|
||||
total_tool_result_chars += len(str(result))
|
||||
|
||||
if total_tool_calls > 0:
|
||||
tool_breakdown_str = ", ".join(
|
||||
f"{tool}({count})"
|
||||
for tool, count in sorted(tool_call_breakdown.items())
|
||||
)
|
||||
tool_result_kb = total_tool_result_chars / 1024
|
||||
logger.info(
|
||||
f"Tool call breakdown: {tool_breakdown_str} "
|
||||
f"(total result size: {tool_result_kb:.1f} KB, {total_tool_result_chars:,} chars)"
|
||||
)
|
||||
@@ -1,103 +0,0 @@
|
||||
"""Compression threshold checking logic."""
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict
|
||||
|
||||
from application.core.model_utils import get_token_limit
|
||||
from application.core.settings import settings
|
||||
from application.api.answer.services.compression.token_counter import TokenCounter
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CompressionThresholdChecker:
|
||||
"""Determines if compression is needed based on token thresholds."""
|
||||
|
||||
def __init__(self, threshold_percentage: float = None):
|
||||
"""
|
||||
Initialize threshold checker.
|
||||
|
||||
Args:
|
||||
threshold_percentage: Percentage of context to use as threshold
|
||||
(defaults to settings.COMPRESSION_THRESHOLD_PERCENTAGE)
|
||||
"""
|
||||
self.threshold_percentage = (
|
||||
threshold_percentage or settings.COMPRESSION_THRESHOLD_PERCENTAGE
|
||||
)
|
||||
|
||||
def should_compress(
|
||||
self,
|
||||
conversation: Dict[str, Any],
|
||||
model_id: str,
|
||||
current_query_tokens: int = 500,
|
||||
) -> bool:
|
||||
"""
|
||||
Determine if compression is needed.
|
||||
|
||||
Args:
|
||||
conversation: Full conversation document
|
||||
model_id: Target model for this request
|
||||
current_query_tokens: Estimated tokens for current query
|
||||
|
||||
Returns:
|
||||
True if tokens >= threshold% of context window
|
||||
"""
|
||||
try:
|
||||
# Calculate total tokens in conversation
|
||||
total_tokens = TokenCounter.count_conversation_tokens(conversation)
|
||||
total_tokens += current_query_tokens
|
||||
|
||||
# Get context window limit for model
|
||||
context_limit = get_token_limit(model_id)
|
||||
|
||||
# Calculate threshold
|
||||
threshold = int(context_limit * self.threshold_percentage)
|
||||
|
||||
compression_needed = total_tokens >= threshold
|
||||
percentage_used = (total_tokens / context_limit) * 100
|
||||
|
||||
if compression_needed:
|
||||
logger.warning(
|
||||
f"COMPRESSION TRIGGERED: {total_tokens} tokens / {context_limit} limit "
|
||||
f"({percentage_used:.1f}% used, threshold: {self.threshold_percentage * 100:.0f}%)"
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f"Compression check: {total_tokens}/{context_limit} tokens "
|
||||
f"({percentage_used:.1f}% used, threshold: {self.threshold_percentage * 100:.0f}%) - No compression needed"
|
||||
)
|
||||
|
||||
return compression_needed
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking compression need: {str(e)}", exc_info=True)
|
||||
return False
|
||||
|
||||
def check_message_tokens(self, messages: list, model_id: str) -> bool:
|
||||
"""
|
||||
Check if message list exceeds threshold.
|
||||
|
||||
Args:
|
||||
messages: List of message dicts
|
||||
model_id: Target model
|
||||
|
||||
Returns:
|
||||
True if at or above threshold
|
||||
"""
|
||||
try:
|
||||
current_tokens = TokenCounter.count_message_tokens(messages)
|
||||
context_limit = get_token_limit(model_id)
|
||||
threshold = int(context_limit * self.threshold_percentage)
|
||||
|
||||
if current_tokens >= threshold:
|
||||
logger.warning(
|
||||
f"Message context limit approaching: {current_tokens}/{context_limit} tokens "
|
||||
f"({(current_tokens/context_limit)*100:.1f}%)"
|
||||
)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking message tokens: {str(e)}", exc_info=True)
|
||||
return False
|
||||
@@ -1,103 +0,0 @@
|
||||
"""Token counting utilities for compression."""
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from application.utils import num_tokens_from_string
|
||||
from application.core.settings import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TokenCounter:
|
||||
"""Centralized token counting for conversations and messages."""
|
||||
|
||||
@staticmethod
|
||||
def count_message_tokens(messages: List[Dict]) -> int:
|
||||
"""
|
||||
Calculate total tokens in a list of messages.
|
||||
|
||||
Args:
|
||||
messages: List of message dicts with 'content' field
|
||||
|
||||
Returns:
|
||||
Total token count
|
||||
"""
|
||||
total_tokens = 0
|
||||
for message in messages:
|
||||
content = message.get("content", "")
|
||||
if isinstance(content, str):
|
||||
total_tokens += num_tokens_from_string(content)
|
||||
elif isinstance(content, list):
|
||||
# Handle structured content (tool calls, etc.)
|
||||
for item in content:
|
||||
if isinstance(item, dict):
|
||||
total_tokens += num_tokens_from_string(str(item))
|
||||
return total_tokens
|
||||
|
||||
@staticmethod
|
||||
def count_query_tokens(
|
||||
queries: List[Dict[str, Any]], include_tool_calls: bool = True
|
||||
) -> int:
|
||||
"""
|
||||
Count tokens across multiple query objects.
|
||||
|
||||
Args:
|
||||
queries: List of query objects from conversation
|
||||
include_tool_calls: Whether to count tool call tokens
|
||||
|
||||
Returns:
|
||||
Total token count
|
||||
"""
|
||||
total_tokens = 0
|
||||
|
||||
for query in queries:
|
||||
# Count prompt and response tokens
|
||||
if "prompt" in query:
|
||||
total_tokens += num_tokens_from_string(query["prompt"])
|
||||
if "response" in query:
|
||||
total_tokens += num_tokens_from_string(query["response"])
|
||||
if "thought" in query:
|
||||
total_tokens += num_tokens_from_string(query.get("thought", ""))
|
||||
|
||||
# Count tool call tokens
|
||||
if include_tool_calls and "tool_calls" in query:
|
||||
for tool_call in query["tool_calls"]:
|
||||
tool_call_string = (
|
||||
f"Tool: {tool_call.get('tool_name')} | "
|
||||
f"Action: {tool_call.get('action_name')} | "
|
||||
f"Args: {tool_call.get('arguments')} | "
|
||||
f"Response: {tool_call.get('result')}"
|
||||
)
|
||||
total_tokens += num_tokens_from_string(tool_call_string)
|
||||
|
||||
return total_tokens
|
||||
|
||||
@staticmethod
|
||||
def count_conversation_tokens(
|
||||
conversation: Dict[str, Any], include_system_prompt: bool = False
|
||||
) -> int:
|
||||
"""
|
||||
Calculate total tokens in a conversation.
|
||||
|
||||
Args:
|
||||
conversation: Conversation document
|
||||
include_system_prompt: Whether to include system prompt in count
|
||||
|
||||
Returns:
|
||||
Total token count
|
||||
"""
|
||||
try:
|
||||
queries = conversation.get("queries", [])
|
||||
total_tokens = TokenCounter.count_query_tokens(queries)
|
||||
|
||||
# Add system prompt tokens if requested
|
||||
if include_system_prompt:
|
||||
# Rough estimate for system prompt
|
||||
total_tokens += settings.RESERVED_TOKENS.get("system_prompt", 500)
|
||||
|
||||
return total_tokens
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating conversation tokens: {str(e)}")
|
||||
return 0
|
||||
@@ -1,83 +0,0 @@
|
||||
"""Type definitions for compression module."""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class CompressionMetadata:
|
||||
"""Metadata about a compression operation."""
|
||||
|
||||
timestamp: datetime
|
||||
query_index: int
|
||||
compressed_summary: str
|
||||
original_token_count: int
|
||||
compressed_token_count: int
|
||||
compression_ratio: float
|
||||
model_used: str
|
||||
compression_prompt_version: str
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary for DB storage."""
|
||||
return {
|
||||
"timestamp": self.timestamp,
|
||||
"query_index": self.query_index,
|
||||
"compressed_summary": self.compressed_summary,
|
||||
"original_token_count": self.original_token_count,
|
||||
"compressed_token_count": self.compressed_token_count,
|
||||
"compression_ratio": self.compression_ratio,
|
||||
"model_used": self.model_used,
|
||||
"compression_prompt_version": self.compression_prompt_version,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class CompressionResult:
|
||||
"""Result of a compression operation."""
|
||||
|
||||
success: bool
|
||||
compressed_summary: Optional[str] = None
|
||||
recent_queries: List[Dict[str, Any]] = field(default_factory=list)
|
||||
metadata: Optional[CompressionMetadata] = None
|
||||
error: Optional[str] = None
|
||||
compression_performed: bool = False
|
||||
|
||||
@classmethod
|
||||
def success_with_compression(
|
||||
cls, summary: str, queries: List[Dict], metadata: CompressionMetadata
|
||||
) -> "CompressionResult":
|
||||
"""Create a successful result with compression."""
|
||||
return cls(
|
||||
success=True,
|
||||
compressed_summary=summary,
|
||||
recent_queries=queries,
|
||||
metadata=metadata,
|
||||
compression_performed=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def success_no_compression(cls, queries: List[Dict]) -> "CompressionResult":
|
||||
"""Create a successful result without compression needed."""
|
||||
return cls(
|
||||
success=True,
|
||||
recent_queries=queries,
|
||||
compression_performed=False,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def failure(cls, error: str) -> "CompressionResult":
|
||||
"""Create a failure result."""
|
||||
return cls(success=False, error=error, compression_performed=False)
|
||||
|
||||
def as_history(self) -> List[Dict[str, str]]:
|
||||
"""
|
||||
Convert recent queries to history format.
|
||||
|
||||
Returns:
|
||||
List of prompt/response dicts
|
||||
"""
|
||||
return [
|
||||
{"prompt": q["prompt"], "response": q["response"]}
|
||||
for q in self.recent_queries
|
||||
]
|
||||
@@ -52,7 +52,7 @@ class ConversationService:
|
||||
sources: List[Dict[str, Any]],
|
||||
tool_calls: List[Dict[str, Any]],
|
||||
llm: Any,
|
||||
model_id: str,
|
||||
gpt_model: str,
|
||||
decoded_token: Dict[str, Any],
|
||||
index: Optional[int] = None,
|
||||
api_key: Optional[str] = None,
|
||||
@@ -60,16 +60,13 @@ class ConversationService:
|
||||
is_shared_usage: bool = False,
|
||||
shared_token: Optional[str] = None,
|
||||
attachment_ids: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> str:
|
||||
"""Save or update a conversation in the database"""
|
||||
if decoded_token is None:
|
||||
raise ValueError("Invalid or missing authentication token")
|
||||
user_id = decoded_token.get("sub")
|
||||
if not user_id:
|
||||
raise ValueError("User ID not found in token")
|
||||
current_time = datetime.now(timezone.utc)
|
||||
|
||||
|
||||
# clean up in sources array such that we save max 1k characters for text part
|
||||
for source in sources:
|
||||
if "text" in source and isinstance(source["text"], str):
|
||||
@@ -93,12 +90,6 @@ class ConversationService:
|
||||
f"queries.{index}.tool_calls": tool_calls,
|
||||
f"queries.{index}.timestamp": current_time,
|
||||
f"queries.{index}.attachments": attachment_ids,
|
||||
f"queries.{index}.model_id": model_id,
|
||||
**(
|
||||
{f"queries.{index}.metadata": metadata}
|
||||
if metadata
|
||||
else {}
|
||||
),
|
||||
}
|
||||
},
|
||||
)
|
||||
@@ -129,8 +120,6 @@ class ConversationService:
|
||||
"tool_calls": tool_calls,
|
||||
"timestamp": current_time,
|
||||
"attachments": attachment_ids,
|
||||
"model_id": model_id,
|
||||
**({"metadata": metadata} if metadata else {}),
|
||||
}
|
||||
}
|
||||
},
|
||||
@@ -144,9 +133,10 @@ class ConversationService:
|
||||
|
||||
messages_summary = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant that creates concise conversation titles. "
|
||||
"Summarize conversations in 3 words or less using the same language as the user.",
|
||||
"role": "assistant",
|
||||
"content": "Summarise following conversation in no more than 3 "
|
||||
"words, respond ONLY with the summary, use the same "
|
||||
"language as the user query",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
@@ -157,30 +147,24 @@ class ConversationService:
|
||||
]
|
||||
|
||||
completion = llm.gen(
|
||||
model=model_id, messages=messages_summary, max_tokens=500
|
||||
model=gpt_model, messages=messages_summary, max_tokens=30
|
||||
)
|
||||
|
||||
if not completion or not completion.strip():
|
||||
completion = question[:50] if question else "New Conversation"
|
||||
|
||||
query_doc = {
|
||||
"prompt": question,
|
||||
"response": response,
|
||||
"thought": thought,
|
||||
"sources": sources,
|
||||
"tool_calls": tool_calls,
|
||||
"timestamp": current_time,
|
||||
"attachments": attachment_ids,
|
||||
"model_id": model_id,
|
||||
}
|
||||
if metadata:
|
||||
query_doc["metadata"] = metadata
|
||||
|
||||
conversation_data = {
|
||||
"user": user_id,
|
||||
"date": current_time,
|
||||
"name": completion,
|
||||
"queries": [query_doc],
|
||||
"queries": [
|
||||
{
|
||||
"prompt": question,
|
||||
"response": response,
|
||||
"thought": thought,
|
||||
"sources": sources,
|
||||
"tool_calls": tool_calls,
|
||||
"timestamp": current_time,
|
||||
"attachments": attachment_ids,
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
if api_key:
|
||||
@@ -194,103 +178,3 @@ class ConversationService:
|
||||
conversation_data["api_key"] = agent["key"]
|
||||
result = self.conversations_collection.insert_one(conversation_data)
|
||||
return str(result.inserted_id)
|
||||
|
||||
def update_compression_metadata(
|
||||
self, conversation_id: str, compression_metadata: Dict[str, Any]
|
||||
) -> None:
|
||||
"""
|
||||
Update conversation with compression metadata.
|
||||
|
||||
Uses $push with $slice to keep only the most recent compression points,
|
||||
preventing unbounded array growth. Since each compression incorporates
|
||||
previous compressions, older points become redundant.
|
||||
|
||||
Args:
|
||||
conversation_id: Conversation ID
|
||||
compression_metadata: Compression point data
|
||||
"""
|
||||
try:
|
||||
self.conversations_collection.update_one(
|
||||
{"_id": ObjectId(conversation_id)},
|
||||
{
|
||||
"$set": {
|
||||
"compression_metadata.is_compressed": True,
|
||||
"compression_metadata.last_compression_at": compression_metadata.get(
|
||||
"timestamp"
|
||||
),
|
||||
},
|
||||
"$push": {
|
||||
"compression_metadata.compression_points": {
|
||||
"$each": [compression_metadata],
|
||||
"$slice": -settings.COMPRESSION_MAX_HISTORY_POINTS,
|
||||
}
|
||||
},
|
||||
},
|
||||
)
|
||||
logger.info(
|
||||
f"Updated compression metadata for conversation {conversation_id}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error updating compression metadata: {str(e)}", exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
def append_compression_message(
|
||||
self, conversation_id: str, compression_metadata: Dict[str, Any]
|
||||
) -> None:
|
||||
"""
|
||||
Append a synthetic compression summary entry into the conversation history.
|
||||
This makes the summary visible in the DB alongside normal queries.
|
||||
"""
|
||||
try:
|
||||
summary = compression_metadata.get("compressed_summary", "")
|
||||
if not summary:
|
||||
return
|
||||
timestamp = compression_metadata.get("timestamp", datetime.now(timezone.utc))
|
||||
|
||||
self.conversations_collection.update_one(
|
||||
{"_id": ObjectId(conversation_id)},
|
||||
{
|
||||
"$push": {
|
||||
"queries": {
|
||||
"prompt": "[Context Compression Summary]",
|
||||
"response": summary,
|
||||
"thought": "",
|
||||
"sources": [],
|
||||
"tool_calls": [],
|
||||
"timestamp": timestamp,
|
||||
"attachments": [],
|
||||
"model_id": compression_metadata.get("model_used"),
|
||||
}
|
||||
}
|
||||
},
|
||||
)
|
||||
logger.info(f"Appended compression summary to conversation {conversation_id}")
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error appending compression summary: {str(e)}", exc_info=True
|
||||
)
|
||||
|
||||
def get_compression_metadata(
|
||||
self, conversation_id: str
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Get compression metadata for a conversation.
|
||||
|
||||
Args:
|
||||
conversation_id: Conversation ID
|
||||
|
||||
Returns:
|
||||
Compression metadata dict or None
|
||||
"""
|
||||
try:
|
||||
conversation = self.conversations_collection.find_one(
|
||||
{"_id": ObjectId(conversation_id)}, {"compression_metadata": 1}
|
||||
)
|
||||
return conversation.get("compression_metadata") if conversation else None
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error getting compression metadata: {str(e)}", exc_info=True
|
||||
)
|
||||
return None
|
||||
|
||||
@@ -1,97 +0,0 @@
|
||||
import logging
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from application.templates.namespaces import NamespaceManager
|
||||
|
||||
from application.templates.template_engine import TemplateEngine, TemplateRenderError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PromptRenderer:
|
||||
"""Service for rendering prompts with dynamic context using namespaces"""
|
||||
|
||||
def __init__(self):
|
||||
self.template_engine = TemplateEngine()
|
||||
self.namespace_manager = NamespaceManager()
|
||||
|
||||
def render_prompt(
|
||||
self,
|
||||
prompt_content: str,
|
||||
user_id: Optional[str] = None,
|
||||
request_id: Optional[str] = None,
|
||||
passthrough_data: Optional[Dict[str, Any]] = None,
|
||||
docs: Optional[list] = None,
|
||||
docs_together: Optional[str] = None,
|
||||
tools_data: Optional[Dict[str, Any]] = None,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
"""
|
||||
Render prompt with full context from all namespaces.
|
||||
|
||||
Args:
|
||||
prompt_content: Raw prompt template string
|
||||
user_id: Current user identifier
|
||||
request_id: Unique request identifier
|
||||
passthrough_data: Parameters from web request
|
||||
docs: RAG retrieved documents
|
||||
docs_together: Concatenated document content
|
||||
tools_data: Pre-fetched tool results organized by tool name
|
||||
**kwargs: Additional parameters for namespace builders
|
||||
|
||||
Returns:
|
||||
Rendered prompt string with all variables substituted
|
||||
|
||||
Raises:
|
||||
TemplateRenderError: If template rendering fails
|
||||
"""
|
||||
if not prompt_content:
|
||||
return ""
|
||||
|
||||
uses_template = self._uses_template_syntax(prompt_content)
|
||||
|
||||
if not uses_template:
|
||||
return self._apply_legacy_substitutions(prompt_content, docs_together)
|
||||
|
||||
try:
|
||||
context = self.namespace_manager.build_context(
|
||||
user_id=user_id,
|
||||
request_id=request_id,
|
||||
passthrough_data=passthrough_data,
|
||||
docs=docs,
|
||||
docs_together=docs_together,
|
||||
tools_data=tools_data,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return self.template_engine.render(prompt_content, context)
|
||||
except TemplateRenderError:
|
||||
raise
|
||||
except Exception as e:
|
||||
error_msg = f"Prompt rendering failed: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
raise TemplateRenderError(error_msg) from e
|
||||
|
||||
def _uses_template_syntax(self, prompt_content: str) -> bool:
|
||||
"""Check if prompt uses Jinja2 template syntax"""
|
||||
return "{{" in prompt_content and "}}" in prompt_content
|
||||
|
||||
def _apply_legacy_substitutions(
|
||||
self, prompt_content: str, docs_together: Optional[str] = None
|
||||
) -> str:
|
||||
"""
|
||||
Apply backward-compatible substitutions for old prompt format.
|
||||
|
||||
Handles legacy {summaries} and {query} placeholders during transition period.
|
||||
"""
|
||||
if docs_together:
|
||||
prompt_content = prompt_content.replace("{summaries}", docs_together)
|
||||
return prompt_content
|
||||
|
||||
def validate_template(self, prompt_content: str) -> bool:
|
||||
"""Validate prompt template syntax"""
|
||||
return self.template_engine.validate_template(prompt_content)
|
||||
|
||||
def extract_variables(self, prompt_content: str) -> set[str]:
|
||||
"""Extract all variable names from prompt template"""
|
||||
return self.template_engine.extract_variables(prompt_content)
|
||||
@@ -3,30 +3,18 @@ import json
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional, Set
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from bson.dbref import DBRef
|
||||
|
||||
from bson.objectid import ObjectId
|
||||
|
||||
from application.agents.agent_creator import AgentCreator
|
||||
from application.api.answer.services.compression import CompressionOrchestrator
|
||||
from application.api.answer.services.compression.token_counter import TokenCounter
|
||||
from application.api.answer.services.conversation_service import ConversationService
|
||||
from application.api.answer.services.prompt_renderer import PromptRenderer
|
||||
from application.core.model_utils import (
|
||||
get_api_key_for_provider,
|
||||
get_default_model_id,
|
||||
get_provider_from_model_id,
|
||||
validate_model_id,
|
||||
)
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.core.settings import settings
|
||||
from application.retriever.retriever_creator import RetrieverCreator
|
||||
from application.utils import (
|
||||
calculate_doc_token_budget,
|
||||
limit_chat_history,
|
||||
)
|
||||
from application.utils import get_gpt_model, limit_chat_history
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -38,23 +26,13 @@ def get_prompt(prompt_id: str, prompts_collection=None) -> str:
|
||||
current_dir = Path(__file__).resolve().parents[3]
|
||||
prompts_dir = current_dir / "prompts"
|
||||
|
||||
# Maps for classic agent types
|
||||
CLASSIC_PRESETS = {
|
||||
preset_mapping = {
|
||||
"default": "chat_combine_default.txt",
|
||||
"creative": "chat_combine_creative.txt",
|
||||
"strict": "chat_combine_strict.txt",
|
||||
"reduce": "chat_reduce_prompt.txt",
|
||||
}
|
||||
|
||||
# Agentic counterparts — same styles, but with search tool instructions
|
||||
AGENTIC_PRESETS = {
|
||||
"default": "agentic/default.txt",
|
||||
"creative": "agentic/creative.txt",
|
||||
"strict": "agentic/strict.txt",
|
||||
}
|
||||
|
||||
preset_mapping = {**CLASSIC_PRESETS, **{f"agentic_{k}": v for k, v in AGENTIC_PRESETS.items()}}
|
||||
|
||||
if prompt_id in preset_mapping:
|
||||
file_path = os.path.join(prompts_dir, preset_mapping[prompt_id])
|
||||
try:
|
||||
@@ -91,60 +69,27 @@ class StreamProcessor:
|
||||
self.decoded_token.get("sub") if self.decoded_token is not None else None
|
||||
)
|
||||
self.conversation_id = self.data.get("conversation_id")
|
||||
self.source = {}
|
||||
self.all_sources = []
|
||||
self.source = (
|
||||
{"active_docs": self.data["active_docs"]}
|
||||
if "active_docs" in self.data
|
||||
else {}
|
||||
)
|
||||
self.attachments = []
|
||||
self.history = []
|
||||
self.retrieved_docs = []
|
||||
self.agent_config = {}
|
||||
self.retriever_config = {}
|
||||
self.is_shared_usage = False
|
||||
self.shared_token = None
|
||||
self.agent_id = self.data.get("agent_id")
|
||||
self.agent_key = None
|
||||
self.model_id: Optional[str] = None
|
||||
self.gpt_model = get_gpt_model()
|
||||
self.conversation_service = ConversationService()
|
||||
self.compression_orchestrator = CompressionOrchestrator(
|
||||
self.conversation_service
|
||||
)
|
||||
self.prompt_renderer = PromptRenderer()
|
||||
self._prompt_content: Optional[str] = None
|
||||
self._required_tool_actions: Optional[Dict[str, Set[Optional[str]]]] = None
|
||||
self.compressed_summary: Optional[str] = None
|
||||
self.compressed_summary_tokens: int = 0
|
||||
|
||||
def initialize(self):
|
||||
"""Initialize all required components for processing"""
|
||||
self._configure_agent()
|
||||
self._validate_and_set_model()
|
||||
self._configure_source()
|
||||
self._configure_retriever()
|
||||
self._configure_agent()
|
||||
self._load_conversation_history()
|
||||
self._process_attachments()
|
||||
|
||||
def build_agent(self, question: str):
|
||||
"""One call to go from request data to a ready-to-run agent.
|
||||
|
||||
Combines initialize(), pre_fetch_docs(), pre_fetch_tools(), and
|
||||
create_agent() into a single convenience method.
|
||||
"""
|
||||
self.initialize()
|
||||
|
||||
agent_type = self.agent_config.get("agent_type", "classic")
|
||||
|
||||
# Agentic/research agents skip pre-fetch — the LLM searches on-demand via tools
|
||||
if agent_type in ("agentic", "research"):
|
||||
tools_data = self.pre_fetch_tools()
|
||||
return self.create_agent(tools_data=tools_data)
|
||||
|
||||
docs_together, docs_list = self.pre_fetch_docs(question)
|
||||
tools_data = self.pre_fetch_tools()
|
||||
return self.create_agent(
|
||||
docs_together=docs_together,
|
||||
docs=docs_list,
|
||||
tools_data=tools_data,
|
||||
)
|
||||
|
||||
def _load_conversation_history(self):
|
||||
"""Load conversation history either from DB or request"""
|
||||
if self.conversation_id and self.initial_user_id:
|
||||
@@ -153,84 +98,14 @@ class StreamProcessor:
|
||||
)
|
||||
if not conversation:
|
||||
raise ValueError("Conversation not found or unauthorized")
|
||||
|
||||
# Check if compression is enabled and needed
|
||||
if settings.ENABLE_CONVERSATION_COMPRESSION:
|
||||
self._handle_compression(conversation)
|
||||
else:
|
||||
# Original behavior - load all history (include metadata if present)
|
||||
self.history = [
|
||||
{
|
||||
"prompt": query["prompt"],
|
||||
"response": query["response"],
|
||||
**(
|
||||
{"metadata": query["metadata"]}
|
||||
if "metadata" in query
|
||||
else {}
|
||||
),
|
||||
}
|
||||
for query in conversation.get("queries", [])
|
||||
]
|
||||
else:
|
||||
self.history = limit_chat_history(
|
||||
json.loads(self.data.get("history", "[]")), model_id=self.model_id
|
||||
)
|
||||
|
||||
def _handle_compression(self, conversation: Dict[str, Any]):
|
||||
"""Handle conversation compression logic using orchestrator."""
|
||||
try:
|
||||
result = self.compression_orchestrator.compress_if_needed(
|
||||
conversation_id=self.conversation_id,
|
||||
user_id=self.initial_user_id,
|
||||
model_id=self.model_id,
|
||||
decoded_token=self.decoded_token,
|
||||
)
|
||||
|
||||
if not result.success:
|
||||
logger.error(f"Compression failed: {result.error}, using full history")
|
||||
self.history = [
|
||||
{
|
||||
"prompt": query["prompt"],
|
||||
"response": query["response"],
|
||||
**({"metadata": query["metadata"]} if "metadata" in query else {}),
|
||||
}
|
||||
for query in conversation.get("queries", [])
|
||||
]
|
||||
return
|
||||
|
||||
if result.compression_performed and result.compressed_summary:
|
||||
self.compressed_summary = result.compressed_summary
|
||||
self.compressed_summary_tokens = TokenCounter.count_message_tokens(
|
||||
[{"content": result.compressed_summary}]
|
||||
)
|
||||
logger.info(
|
||||
f"Using compressed summary ({self.compressed_summary_tokens} tokens) "
|
||||
f"+ {len(result.recent_queries)} recent messages"
|
||||
)
|
||||
|
||||
self.history = result.as_history()
|
||||
# Preserve metadata from recent queries (as_history only has prompt/response)
|
||||
recent = result.recent_queries if result.recent_queries else conversation.get("queries", [])
|
||||
for i, entry in enumerate(self.history):
|
||||
# Match by index from the end of recent queries
|
||||
offset = len(recent) - len(self.history)
|
||||
qi = offset + i
|
||||
if 0 <= qi < len(recent) and "metadata" in recent[qi]:
|
||||
entry["metadata"] = recent[qi]["metadata"]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error handling compression, falling back to standard history: {str(e)}",
|
||||
exc_info=True,
|
||||
)
|
||||
self.history = [
|
||||
{
|
||||
"prompt": query["prompt"],
|
||||
"response": query["response"],
|
||||
**({"metadata": query["metadata"]} if "metadata" in query else {}),
|
||||
}
|
||||
{"prompt": query["prompt"], "response": query["response"]}
|
||||
for query in conversation.get("queries", [])
|
||||
]
|
||||
else:
|
||||
self.history = limit_chat_history(
|
||||
json.loads(self.data.get("history", "[]")), gpt_model=self.gpt_model
|
||||
)
|
||||
|
||||
def _process_attachments(self):
|
||||
"""Process any attachments in the request"""
|
||||
@@ -240,6 +115,9 @@ class StreamProcessor:
|
||||
)
|
||||
|
||||
def _get_attachments_content(self, attachment_ids, user_id):
|
||||
"""
|
||||
Retrieve content from attachment documents based on their IDs.
|
||||
"""
|
||||
if not attachment_ids:
|
||||
return []
|
||||
attachments = []
|
||||
@@ -248,6 +126,7 @@ class StreamProcessor:
|
||||
attachment_doc = self.attachments_collection.find_one(
|
||||
{"_id": ObjectId(attachment_id), "user": user_id}
|
||||
)
|
||||
|
||||
if attachment_doc:
|
||||
attachments.append(attachment_doc)
|
||||
except Exception as e:
|
||||
@@ -256,33 +135,6 @@ class StreamProcessor:
|
||||
)
|
||||
return attachments
|
||||
|
||||
def _validate_and_set_model(self):
|
||||
"""Validate and set model_id from request"""
|
||||
from application.core.model_settings import ModelRegistry
|
||||
|
||||
requested_model = self.data.get("model_id")
|
||||
|
||||
if requested_model:
|
||||
if not validate_model_id(requested_model):
|
||||
registry = ModelRegistry.get_instance()
|
||||
available_models = [m.id for m in registry.get_enabled_models()]
|
||||
raise ValueError(
|
||||
f"Invalid model_id '{requested_model}'. "
|
||||
f"Available models: {', '.join(available_models[:5])}"
|
||||
+ (
|
||||
f" and {len(available_models) - 5} more"
|
||||
if len(available_models) > 5
|
||||
else ""
|
||||
)
|
||||
)
|
||||
self.model_id = requested_model
|
||||
else:
|
||||
agent_default_model = self.agent_config.get("default_model_id", "")
|
||||
if agent_default_model and validate_model_id(agent_default_model):
|
||||
self.model_id = agent_default_model
|
||||
else:
|
||||
self.model_id = get_default_model_id()
|
||||
|
||||
def _get_agent_key(self, agent_id: Optional[str], user_id: Optional[str]) -> tuple:
|
||||
"""Get API key for agent with access control"""
|
||||
if not agent_id:
|
||||
@@ -319,581 +171,107 @@ class StreamProcessor:
|
||||
source = data.get("source")
|
||||
if isinstance(source, DBRef):
|
||||
source_doc = self.db.dereference(source)
|
||||
if source_doc:
|
||||
data["source"] = str(source_doc["_id"])
|
||||
data["retriever"] = source_doc.get("retriever", data.get("retriever"))
|
||||
data["chunks"] = source_doc.get("chunks", data.get("chunks"))
|
||||
else:
|
||||
data["source"] = None
|
||||
elif source == "default":
|
||||
data["source"] = "default"
|
||||
data["source"] = str(source_doc["_id"])
|
||||
data["retriever"] = source_doc.get("retriever", data.get("retriever"))
|
||||
data["chunks"] = source_doc.get("chunks", data.get("chunks"))
|
||||
else:
|
||||
data["source"] = None
|
||||
|
||||
sources = data.get("sources", [])
|
||||
if sources and isinstance(sources, list):
|
||||
sources_list = []
|
||||
for i, source_ref in enumerate(sources):
|
||||
if source_ref == "default":
|
||||
processed_source = {
|
||||
"id": "default",
|
||||
"retriever": "classic",
|
||||
"chunks": data.get("chunks", "2"),
|
||||
}
|
||||
sources_list.append(processed_source)
|
||||
elif isinstance(source_ref, DBRef):
|
||||
source_doc = self.db.dereference(source_ref)
|
||||
if source_doc:
|
||||
processed_source = {
|
||||
"id": str(source_doc["_id"]),
|
||||
"retriever": source_doc.get("retriever", "classic"),
|
||||
"chunks": source_doc.get("chunks", data.get("chunks", "2")),
|
||||
}
|
||||
sources_list.append(processed_source)
|
||||
data["sources"] = sources_list
|
||||
else:
|
||||
data["sources"] = []
|
||||
|
||||
data["default_model_id"] = data.get("default_model_id", "")
|
||||
|
||||
return data
|
||||
|
||||
def _configure_source(self):
|
||||
"""Configure the source based on agent data"""
|
||||
api_key = self.data.get("api_key") or self.agent_key
|
||||
|
||||
if api_key:
|
||||
agent_data = self._get_data_from_api_key(api_key)
|
||||
|
||||
if agent_data.get("sources") and len(agent_data["sources"]) > 0:
|
||||
source_ids = [
|
||||
source["id"] for source in agent_data["sources"] if source.get("id")
|
||||
]
|
||||
if source_ids:
|
||||
self.source = {"active_docs": source_ids}
|
||||
else:
|
||||
self.source = {}
|
||||
self.all_sources = agent_data["sources"]
|
||||
elif agent_data.get("source"):
|
||||
self.source = {"active_docs": agent_data["source"]}
|
||||
self.all_sources = [
|
||||
{
|
||||
"id": agent_data["source"],
|
||||
"retriever": agent_data.get("retriever", "classic"),
|
||||
}
|
||||
]
|
||||
else:
|
||||
self.source = {}
|
||||
self.all_sources = []
|
||||
return
|
||||
if "active_docs" in self.data:
|
||||
self.source = {"active_docs": self.data["active_docs"]}
|
||||
return
|
||||
self.source = {}
|
||||
self.all_sources = []
|
||||
|
||||
def _resolve_agent_id(self) -> Optional[str]:
|
||||
"""Resolve agent_id from request, then fall back to conversation context."""
|
||||
request_agent_id = self.data.get("agent_id")
|
||||
if request_agent_id:
|
||||
return str(request_agent_id)
|
||||
|
||||
if not self.conversation_id or not self.initial_user_id:
|
||||
return None
|
||||
|
||||
try:
|
||||
conversation = self.conversation_service.get_conversation(
|
||||
self.conversation_id, self.initial_user_id
|
||||
)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
if not conversation:
|
||||
return None
|
||||
|
||||
conversation_agent_id = conversation.get("agent_id")
|
||||
if conversation_agent_id:
|
||||
return str(conversation_agent_id)
|
||||
|
||||
return None
|
||||
|
||||
def _configure_agent(self):
|
||||
"""Configure the agent based on request data.
|
||||
|
||||
Unified flow: resolve the effective API key, then extract config once.
|
||||
"""
|
||||
agent_id = self._resolve_agent_id()
|
||||
|
||||
"""Configure the agent based on request data"""
|
||||
agent_id = self.data.get("agent_id")
|
||||
self.agent_key, self.is_shared_usage, self.shared_token = self._get_agent_key(
|
||||
agent_id, self.initial_user_id
|
||||
)
|
||||
self.agent_id = str(agent_id) if agent_id else None
|
||||
|
||||
# Determine the effective API key (explicit > agent-derived)
|
||||
effective_key = self.data.get("api_key") or self.agent_key
|
||||
|
||||
if effective_key:
|
||||
data_key = self._get_data_from_api_key(effective_key)
|
||||
if data_key.get("_id"):
|
||||
self.agent_id = str(data_key.get("_id"))
|
||||
|
||||
api_key = self.data.get("api_key")
|
||||
if api_key:
|
||||
data_key = self._get_data_from_api_key(api_key)
|
||||
self.agent_config.update(
|
||||
{
|
||||
"prompt_id": data_key.get("prompt_id", "default"),
|
||||
"agent_type": data_key.get("agent_type", settings.AGENT_NAME),
|
||||
"user_api_key": effective_key,
|
||||
"user_api_key": api_key,
|
||||
"json_schema": data_key.get("json_schema"),
|
||||
"default_model_id": data_key.get("default_model_id", ""),
|
||||
"models": data_key.get("models", []),
|
||||
}
|
||||
)
|
||||
|
||||
# Set identity context
|
||||
if self.data.get("api_key"):
|
||||
# External API key: use the key owner's identity
|
||||
self.initial_user_id = data_key.get("user")
|
||||
self.decoded_token = {"sub": data_key.get("user")}
|
||||
elif self.is_shared_usage:
|
||||
# Shared agent: keep the caller's identity
|
||||
pass
|
||||
else:
|
||||
# Owner using their own agent
|
||||
self.decoded_token = {"sub": data_key.get("user")}
|
||||
|
||||
self.initial_user_id = data_key.get("user")
|
||||
self.decoded_token = {"sub": data_key.get("user")}
|
||||
if data_key.get("source"):
|
||||
self.source = {"active_docs": data_key["source"]}
|
||||
if data_key.get("workflow"):
|
||||
self.agent_config["workflow"] = data_key["workflow"]
|
||||
self.agent_config["workflow_owner"] = data_key.get("user")
|
||||
if data_key.get("retriever"):
|
||||
self.retriever_config["retriever_name"] = data_key["retriever"]
|
||||
if data_key.get("chunks") is not None:
|
||||
try:
|
||||
self.retriever_config["chunks"] = int(data_key["chunks"])
|
||||
except (ValueError, TypeError):
|
||||
logger.warning(
|
||||
f"Invalid chunks value: {data_key['chunks']}, using default value 2"
|
||||
)
|
||||
self.retriever_config["chunks"] = 2
|
||||
self.retriever_config["chunks"] = data_key["chunks"]
|
||||
elif self.agent_key:
|
||||
data_key = self._get_data_from_api_key(self.agent_key)
|
||||
self.agent_config.update(
|
||||
{
|
||||
"prompt_id": data_key.get("prompt_id", "default"),
|
||||
"agent_type": data_key.get("agent_type", settings.AGENT_NAME),
|
||||
"user_api_key": self.agent_key,
|
||||
"json_schema": data_key.get("json_schema"),
|
||||
}
|
||||
)
|
||||
self.decoded_token = (
|
||||
self.decoded_token
|
||||
if self.is_shared_usage
|
||||
else {"sub": data_key.get("user")}
|
||||
)
|
||||
if data_key.get("source"):
|
||||
self.source = {"active_docs": data_key["source"]}
|
||||
if data_key.get("retriever"):
|
||||
self.retriever_config["retriever_name"] = data_key["retriever"]
|
||||
if data_key.get("chunks") is not None:
|
||||
self.retriever_config["chunks"] = data_key["chunks"]
|
||||
else:
|
||||
# No API key — default/workflow configuration
|
||||
agent_type = settings.AGENT_NAME
|
||||
if self.data.get("workflow") and isinstance(
|
||||
self.data.get("workflow"), dict
|
||||
):
|
||||
agent_type = "workflow"
|
||||
self.agent_config["workflow"] = self.data["workflow"]
|
||||
if isinstance(self.decoded_token, dict):
|
||||
self.agent_config["workflow_owner"] = self.decoded_token.get("sub")
|
||||
|
||||
self.agent_config.update(
|
||||
{
|
||||
"prompt_id": self.data.get("prompt_id", "default"),
|
||||
"agent_type": agent_type,
|
||||
"agent_type": settings.AGENT_NAME,
|
||||
"user_api_key": None,
|
||||
"json_schema": None,
|
||||
"default_model_id": "",
|
||||
}
|
||||
)
|
||||
|
||||
def _configure_retriever(self):
|
||||
doc_token_limit = calculate_doc_token_budget(model_id=self.model_id)
|
||||
|
||||
"""Configure the retriever based on request data"""
|
||||
self.retriever_config = {
|
||||
"retriever_name": self.data.get("retriever", "classic"),
|
||||
"chunks": int(self.data.get("chunks", 2)),
|
||||
"doc_token_limit": doc_token_limit,
|
||||
"token_limit": self.data.get("token_limit", settings.DEFAULT_MAX_HISTORY),
|
||||
}
|
||||
|
||||
api_key = self.data.get("api_key") or self.agent_key
|
||||
if not api_key and "isNoneDoc" in self.data and self.data["isNoneDoc"]:
|
||||
if "isNoneDoc" in self.data and self.data["isNoneDoc"]:
|
||||
self.retriever_config["chunks"] = 0
|
||||
|
||||
def create_agent(self):
|
||||
"""Create and return the configured agent"""
|
||||
return AgentCreator.create_agent(
|
||||
self.agent_config["agent_type"],
|
||||
endpoint="stream",
|
||||
llm_name=settings.LLM_PROVIDER,
|
||||
gpt_model=self.gpt_model,
|
||||
api_key=settings.API_KEY,
|
||||
user_api_key=self.agent_config["user_api_key"],
|
||||
prompt=get_prompt(self.agent_config["prompt_id"], self.prompts_collection),
|
||||
chat_history=self.history,
|
||||
decoded_token=self.decoded_token,
|
||||
attachments=self.attachments,
|
||||
json_schema=self.agent_config.get("json_schema"),
|
||||
)
|
||||
|
||||
def create_retriever(self):
|
||||
"""Create and return the configured retriever"""
|
||||
return RetrieverCreator.create_retriever(
|
||||
self.retriever_config["retriever_name"],
|
||||
source=self.source,
|
||||
chat_history=self.history,
|
||||
prompt=get_prompt(self.agent_config["prompt_id"], self.prompts_collection),
|
||||
chunks=self.retriever_config["chunks"],
|
||||
doc_token_limit=self.retriever_config.get("doc_token_limit", 50000),
|
||||
model_id=self.model_id,
|
||||
user_api_key=self.agent_config["user_api_key"],
|
||||
agent_id=self.agent_id,
|
||||
decoded_token=self.decoded_token,
|
||||
)
|
||||
|
||||
def pre_fetch_docs(self, question: str) -> tuple[Optional[str], Optional[list]]:
|
||||
"""Pre-fetch documents for template rendering before agent creation"""
|
||||
if self.data.get("isNoneDoc", False) and not self.agent_id:
|
||||
logger.info("Pre-fetch skipped: isNoneDoc=True")
|
||||
return None, None
|
||||
try:
|
||||
retriever = self.create_retriever()
|
||||
logger.info(
|
||||
f"Pre-fetching docs with chunks={retriever.chunks}, doc_token_limit={retriever.doc_token_limit}"
|
||||
)
|
||||
docs = retriever.search(question)
|
||||
logger.info(f"Pre-fetch retrieved {len(docs) if docs else 0} documents")
|
||||
|
||||
if not docs:
|
||||
logger.info("Pre-fetch: No documents returned from search")
|
||||
return None, None
|
||||
self.retrieved_docs = docs
|
||||
|
||||
docs_with_filenames = []
|
||||
for doc in docs:
|
||||
filename = doc.get("filename") or doc.get("title") or doc.get("source")
|
||||
if filename:
|
||||
chunk_header = str(filename)
|
||||
docs_with_filenames.append(f"{chunk_header}\n{doc['text']}")
|
||||
else:
|
||||
docs_with_filenames.append(doc["text"])
|
||||
docs_together = "\n\n".join(docs_with_filenames)
|
||||
|
||||
logger.info(f"Pre-fetch docs_together size: {len(docs_together)} chars")
|
||||
|
||||
return docs_together, docs
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to pre-fetch docs: {str(e)}", exc_info=True)
|
||||
return None, None
|
||||
|
||||
def pre_fetch_tools(self) -> Optional[Dict[str, Any]]:
|
||||
"""Pre-fetch tool data for template rendering before agent creation"""
|
||||
if not settings.ENABLE_TOOL_PREFETCH:
|
||||
logger.info(
|
||||
"Tool pre-fetching disabled globally via ENABLE_TOOL_PREFETCH setting"
|
||||
)
|
||||
return None
|
||||
|
||||
if self.data.get("disable_tool_prefetch", False):
|
||||
logger.info("Tool pre-fetching disabled for this request")
|
||||
return None
|
||||
|
||||
required_tool_actions = self._get_required_tool_actions()
|
||||
filtering_enabled = required_tool_actions is not None
|
||||
|
||||
try:
|
||||
user_tools_collection = self.db["user_tools"]
|
||||
user_id = self.initial_user_id or "local"
|
||||
|
||||
user_tools = list(
|
||||
user_tools_collection.find({"user": user_id, "status": True})
|
||||
)
|
||||
|
||||
if not user_tools:
|
||||
return None
|
||||
|
||||
tools_data = {}
|
||||
|
||||
for tool_doc in user_tools:
|
||||
tool_name = tool_doc.get("name")
|
||||
tool_id = str(tool_doc.get("_id"))
|
||||
|
||||
if filtering_enabled:
|
||||
required_actions_by_name = required_tool_actions.get(
|
||||
tool_name, set()
|
||||
)
|
||||
required_actions_by_id = required_tool_actions.get(tool_id, set())
|
||||
|
||||
required_actions = required_actions_by_name | required_actions_by_id
|
||||
|
||||
if not required_actions:
|
||||
continue
|
||||
else:
|
||||
required_actions = None
|
||||
|
||||
tool_data = self._fetch_tool_data(tool_doc, required_actions)
|
||||
if tool_data:
|
||||
tools_data[tool_name] = tool_data
|
||||
tools_data[tool_id] = tool_data
|
||||
|
||||
return tools_data if tools_data else None
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to pre-fetch tools: {type(e).__name__}")
|
||||
return None
|
||||
|
||||
def _fetch_tool_data(
|
||||
self,
|
||||
tool_doc: Dict[str, Any],
|
||||
required_actions: Optional[Set[Optional[str]]],
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Fetch and execute tool actions with saved parameters"""
|
||||
try:
|
||||
from application.agents.tools.tool_manager import ToolManager
|
||||
|
||||
tool_name = tool_doc.get("name")
|
||||
tool_config = tool_doc.get("config", {}).copy()
|
||||
tool_config["tool_id"] = str(tool_doc["_id"])
|
||||
|
||||
tool_manager = ToolManager(config={tool_name: tool_config})
|
||||
user_id = self.initial_user_id or "local"
|
||||
tool = tool_manager.load_tool(tool_name, tool_config, user_id=user_id)
|
||||
|
||||
if not tool:
|
||||
logger.debug(f"Tool '{tool_name}' failed to load")
|
||||
return None
|
||||
|
||||
tool_actions = tool.get_actions_metadata()
|
||||
if not tool_actions:
|
||||
logger.debug(f"Tool '{tool_name}' has no actions")
|
||||
return None
|
||||
|
||||
saved_actions = tool_doc.get("actions", [])
|
||||
|
||||
include_all_actions = required_actions is None or (
|
||||
required_actions and None in required_actions
|
||||
)
|
||||
allowed_actions: Set[str] = (
|
||||
{action for action in required_actions if isinstance(action, str)}
|
||||
if required_actions
|
||||
else set()
|
||||
)
|
||||
|
||||
action_results = {}
|
||||
for action_meta in tool_actions:
|
||||
action_name = action_meta.get("name")
|
||||
if action_name is None:
|
||||
continue
|
||||
if (
|
||||
not include_all_actions
|
||||
and allowed_actions
|
||||
and action_name not in allowed_actions
|
||||
):
|
||||
continue
|
||||
|
||||
try:
|
||||
saved_action = None
|
||||
for sa in saved_actions:
|
||||
if sa.get("name") == action_name:
|
||||
saved_action = sa
|
||||
break
|
||||
|
||||
action_params = action_meta.get("parameters", {})
|
||||
properties = action_params.get("properties", {})
|
||||
|
||||
kwargs = {}
|
||||
for param_name, param_spec in properties.items():
|
||||
if saved_action:
|
||||
saved_props = saved_action.get("parameters", {}).get(
|
||||
"properties", {}
|
||||
)
|
||||
if param_name in saved_props:
|
||||
param_value = saved_props[param_name].get("value")
|
||||
if param_value is not None:
|
||||
kwargs[param_name] = param_value
|
||||
continue
|
||||
|
||||
if param_name in tool_config:
|
||||
kwargs[param_name] = tool_config[param_name]
|
||||
elif "default" in param_spec:
|
||||
kwargs[param_name] = param_spec["default"]
|
||||
|
||||
result = tool.execute_action(action_name, **kwargs)
|
||||
action_results[action_name] = result
|
||||
except Exception as e:
|
||||
logger.debug(
|
||||
f"Action '{action_name}' execution failed: {type(e).__name__}"
|
||||
)
|
||||
continue
|
||||
|
||||
return action_results if action_results else None
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Tool pre-fetch failed for '{tool_name}': {type(e).__name__}")
|
||||
return None
|
||||
|
||||
def _get_prompt_content(self) -> Optional[str]:
|
||||
"""Retrieve and cache the raw prompt content for the current agent configuration."""
|
||||
if self._prompt_content is not None:
|
||||
return self._prompt_content
|
||||
prompt_id = (
|
||||
self.agent_config.get("prompt_id")
|
||||
if isinstance(self.agent_config, dict)
|
||||
else None
|
||||
)
|
||||
if not prompt_id:
|
||||
return None
|
||||
try:
|
||||
self._prompt_content = get_prompt(prompt_id, self.prompts_collection)
|
||||
except ValueError as e:
|
||||
logger.debug(f"Invalid prompt ID '{prompt_id}': {str(e)}")
|
||||
self._prompt_content = None
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to fetch prompt '{prompt_id}': {type(e).__name__}")
|
||||
self._prompt_content = None
|
||||
return self._prompt_content
|
||||
|
||||
def _get_required_tool_actions(self) -> Optional[Dict[str, Set[Optional[str]]]]:
|
||||
"""Determine which tool actions are referenced in the prompt template"""
|
||||
if self._required_tool_actions is not None:
|
||||
return self._required_tool_actions
|
||||
|
||||
prompt_content = self._get_prompt_content()
|
||||
if prompt_content is None:
|
||||
return None
|
||||
|
||||
if "{{" not in prompt_content or "}}" not in prompt_content:
|
||||
self._required_tool_actions = {}
|
||||
return self._required_tool_actions
|
||||
|
||||
try:
|
||||
from application.templates.template_engine import TemplateEngine
|
||||
|
||||
template_engine = TemplateEngine()
|
||||
usages = template_engine.extract_tool_usages(prompt_content)
|
||||
self._required_tool_actions = usages
|
||||
return self._required_tool_actions
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to extract tool usages: {type(e).__name__}")
|
||||
self._required_tool_actions = {}
|
||||
return self._required_tool_actions
|
||||
|
||||
def _fetch_memory_tool_data(
|
||||
self, tool_doc: Dict[str, Any]
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Fetch memory tool data for pre-injection into prompt"""
|
||||
try:
|
||||
tool_config = tool_doc.get("config", {}).copy()
|
||||
tool_config["tool_id"] = str(tool_doc["_id"])
|
||||
|
||||
from application.agents.tools.memory import MemoryTool
|
||||
|
||||
memory_tool = MemoryTool(tool_config, self.initial_user_id)
|
||||
|
||||
root_view = memory_tool.execute_action("view", path="/")
|
||||
|
||||
if "Error:" in root_view or not root_view.strip():
|
||||
return None
|
||||
|
||||
return {"root": root_view, "available": True}
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch memory tool data: {str(e)}")
|
||||
return None
|
||||
|
||||
def create_agent(
|
||||
self,
|
||||
docs_together: Optional[str] = None,
|
||||
docs: Optional[list] = None,
|
||||
tools_data: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
"""Create and return the configured agent with rendered prompt"""
|
||||
agent_type = self.agent_config["agent_type"]
|
||||
|
||||
# For agentic agents, swap standard presets for their agentic
|
||||
# counterparts (which include search tool instructions instead of
|
||||
# {summaries}). Custom / user-provided prompts pass through as-is.
|
||||
raw_prompt = self._get_prompt_content()
|
||||
if raw_prompt is None:
|
||||
prompt_id = self.agent_config.get("prompt_id", "default")
|
||||
agentic_presets = {"default", "creative", "strict"}
|
||||
if agent_type in ("agentic", "research") and prompt_id in agentic_presets:
|
||||
raw_prompt = get_prompt(
|
||||
f"agentic_{prompt_id}", self.prompts_collection
|
||||
)
|
||||
else:
|
||||
raw_prompt = get_prompt(prompt_id, self.prompts_collection)
|
||||
self._prompt_content = raw_prompt
|
||||
|
||||
rendered_prompt = self.prompt_renderer.render_prompt(
|
||||
prompt_content=raw_prompt,
|
||||
user_id=self.initial_user_id,
|
||||
request_id=self.data.get("request_id"),
|
||||
passthrough_data=self.data.get("passthrough"),
|
||||
docs=docs,
|
||||
docs_together=docs_together,
|
||||
tools_data=tools_data,
|
||||
)
|
||||
|
||||
provider = (
|
||||
get_provider_from_model_id(self.model_id)
|
||||
if self.model_id
|
||||
else settings.LLM_PROVIDER
|
||||
)
|
||||
system_api_key = get_api_key_for_provider(provider or settings.LLM_PROVIDER)
|
||||
|
||||
# Create LLM and handler (dependency injection)
|
||||
from application.llm.llm_creator import LLMCreator
|
||||
from application.llm.handlers.handler_creator import LLMHandlerCreator
|
||||
from application.agents.tool_executor import ToolExecutor
|
||||
|
||||
# Compute backup models: agent's configured models minus the active one
|
||||
agent_models = self.agent_config.get("models", [])
|
||||
backup_models = [m for m in agent_models if m != self.model_id]
|
||||
|
||||
llm = LLMCreator.create_llm(
|
||||
provider or settings.LLM_PROVIDER,
|
||||
api_key=system_api_key,
|
||||
token_limit=self.retriever_config["token_limit"],
|
||||
gpt_model=self.gpt_model,
|
||||
user_api_key=self.agent_config["user_api_key"],
|
||||
decoded_token=self.decoded_token,
|
||||
model_id=self.model_id,
|
||||
agent_id=self.agent_id,
|
||||
backup_models=backup_models,
|
||||
)
|
||||
llm_handler = LLMHandlerCreator.create_handler(
|
||||
provider if provider else "default"
|
||||
)
|
||||
|
||||
user = self.decoded_token.get("sub") if self.decoded_token else None
|
||||
tool_executor = ToolExecutor(
|
||||
user_api_key=self.agent_config["user_api_key"],
|
||||
user=user,
|
||||
decoded_token=self.decoded_token,
|
||||
)
|
||||
tool_executor.conversation_id = self.conversation_id
|
||||
|
||||
# Base agent kwargs
|
||||
agent_kwargs = {
|
||||
"endpoint": "stream",
|
||||
"llm_name": provider or settings.LLM_PROVIDER,
|
||||
"model_id": self.model_id,
|
||||
"api_key": system_api_key,
|
||||
"agent_id": self.agent_id,
|
||||
"user_api_key": self.agent_config["user_api_key"],
|
||||
"prompt": rendered_prompt,
|
||||
"chat_history": self.history,
|
||||
"retrieved_docs": self.retrieved_docs,
|
||||
"decoded_token": self.decoded_token,
|
||||
"attachments": self.attachments,
|
||||
"json_schema": self.agent_config.get("json_schema"),
|
||||
"compressed_summary": self.compressed_summary,
|
||||
"llm": llm,
|
||||
"llm_handler": llm_handler,
|
||||
"tool_executor": tool_executor,
|
||||
}
|
||||
|
||||
# Type-specific kwargs
|
||||
if agent_type in ("agentic", "research"):
|
||||
agent_kwargs["retriever_config"] = {
|
||||
"source": self.source,
|
||||
"retriever_name": self.retriever_config.get(
|
||||
"retriever_name", "classic"
|
||||
),
|
||||
"chunks": self.retriever_config.get("chunks", 2),
|
||||
"doc_token_limit": self.retriever_config.get(
|
||||
"doc_token_limit", 50000
|
||||
),
|
||||
"model_id": self.model_id,
|
||||
"user_api_key": self.agent_config["user_api_key"],
|
||||
"agent_id": self.agent_id,
|
||||
"llm_name": provider or settings.LLM_PROVIDER,
|
||||
"api_key": system_api_key,
|
||||
"decoded_token": self.decoded_token,
|
||||
}
|
||||
|
||||
elif agent_type == "workflow":
|
||||
workflow_config = self.agent_config.get("workflow")
|
||||
if isinstance(workflow_config, str):
|
||||
agent_kwargs["workflow_id"] = workflow_config
|
||||
elif isinstance(workflow_config, dict):
|
||||
agent_kwargs["workflow"] = workflow_config
|
||||
workflow_owner = self.agent_config.get("workflow_owner")
|
||||
if workflow_owner:
|
||||
agent_kwargs["workflow_owner"] = workflow_owner
|
||||
|
||||
agent = AgentCreator.create_agent(agent_type, **agent_kwargs)
|
||||
|
||||
agent.conversation_id = self.conversation_id
|
||||
agent.initial_user_id = self.initial_user_id
|
||||
|
||||
return agent
|
||||
|
||||
@@ -1,9 +1,6 @@
|
||||
import base64
|
||||
import datetime
|
||||
import html
|
||||
import json
|
||||
import uuid
|
||||
from urllib.parse import urlencode
|
||||
import logging
|
||||
|
||||
|
||||
from bson.objectid import ObjectId
|
||||
@@ -17,6 +14,8 @@ from flask import (
|
||||
from flask_restx import fields, Namespace, Resource
|
||||
|
||||
|
||||
|
||||
|
||||
from application.api.user.tasks import (
|
||||
ingest_connector_task,
|
||||
)
|
||||
@@ -25,9 +24,15 @@ from application.core.settings import settings
|
||||
from application.api import api
|
||||
|
||||
|
||||
from application.utils import (
|
||||
check_required_fields
|
||||
)
|
||||
|
||||
|
||||
from application.parser.connectors.connector_creator import ConnectorCreator
|
||||
|
||||
|
||||
|
||||
mongo = MongoDB.get_client()
|
||||
db = mongo[settings.MONGO_DB_NAME]
|
||||
sources_collection = db["sources"]
|
||||
@@ -37,18 +42,185 @@ connector = Blueprint("connector", __name__)
|
||||
connectors_ns = Namespace("connectors", description="Connector operations", path="/")
|
||||
api.add_namespace(connectors_ns)
|
||||
|
||||
# Fixed callback status path to prevent open redirect
|
||||
CALLBACK_STATUS_PATH = "/api/connectors/callback-status"
|
||||
|
||||
|
||||
def build_callback_redirect(params: dict) -> str:
|
||||
"""Build a safe redirect URL to the callback status page.
|
||||
@connectors_ns.route("/api/connectors/upload")
|
||||
class UploadConnector(Resource):
|
||||
@api.expect(
|
||||
api.model(
|
||||
"ConnectorUploadModel",
|
||||
{
|
||||
"user": fields.String(required=True, description="User ID"),
|
||||
"source": fields.String(
|
||||
required=True, description="Source type (google_drive, github, etc.)"
|
||||
),
|
||||
"name": fields.String(required=True, description="Job name"),
|
||||
"data": fields.String(required=True, description="Configuration data"),
|
||||
"repo_url": fields.String(description="GitHub repository URL"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.doc(
|
||||
description="Uploads connector source for vectorization",
|
||||
)
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
data = request.form
|
||||
required_fields = ["user", "source", "name", "data"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
try:
|
||||
config = json.loads(data["data"])
|
||||
source_data = None
|
||||
sync_frequency = config.get("sync_frequency", "never")
|
||||
|
||||
Uses a fixed path and properly URL-encodes all parameters
|
||||
to prevent URL injection and open redirect vulnerabilities.
|
||||
"""
|
||||
return f"{CALLBACK_STATUS_PATH}?{urlencode(params)}"
|
||||
if data["source"] == "github":
|
||||
source_data = config.get("repo_url")
|
||||
elif data["source"] in ["crawler", "url"]:
|
||||
source_data = config.get("url")
|
||||
elif data["source"] == "reddit":
|
||||
source_data = config
|
||||
elif data["source"] in ConnectorCreator.get_supported_connectors():
|
||||
session_token = config.get("session_token")
|
||||
if not session_token:
|
||||
return make_response(jsonify({
|
||||
"success": False,
|
||||
"error": f"Missing session_token in {data['source']} configuration"
|
||||
}), 400)
|
||||
|
||||
file_ids = config.get("file_ids", [])
|
||||
if isinstance(file_ids, str):
|
||||
file_ids = [id.strip() for id in file_ids.split(',') if id.strip()]
|
||||
elif not isinstance(file_ids, list):
|
||||
file_ids = []
|
||||
|
||||
folder_ids = config.get("folder_ids", [])
|
||||
if isinstance(folder_ids, str):
|
||||
folder_ids = [id.strip() for id in folder_ids.split(',') if id.strip()]
|
||||
elif not isinstance(folder_ids, list):
|
||||
folder_ids = []
|
||||
|
||||
config["file_ids"] = file_ids
|
||||
config["folder_ids"] = folder_ids
|
||||
|
||||
task = ingest_connector_task.delay(
|
||||
job_name=data["name"],
|
||||
user=decoded_token.get("sub"),
|
||||
source_type=data["source"],
|
||||
session_token=session_token,
|
||||
file_ids=file_ids,
|
||||
folder_ids=folder_ids,
|
||||
recursive=config.get("recursive", False),
|
||||
retriever=config.get("retriever", "classic"),
|
||||
sync_frequency=sync_frequency
|
||||
)
|
||||
return make_response(jsonify({"success": True, "task_id": task.id}), 200)
|
||||
task = ingest_connector_task.delay(
|
||||
source_data=source_data,
|
||||
job_name=data["name"],
|
||||
user=decoded_token.get("sub"),
|
||||
loader=data["source"],
|
||||
sync_frequency=sync_frequency
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error uploading connector source: {err}", exc_info=True
|
||||
)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify({"success": True, "task_id": task.id}), 200)
|
||||
|
||||
|
||||
@connectors_ns.route("/api/connectors/task_status")
|
||||
class ConnectorTaskStatus(Resource):
|
||||
task_status_model = api.model(
|
||||
"ConnectorTaskStatusModel",
|
||||
{"task_id": fields.String(required=True, description="Task ID")},
|
||||
)
|
||||
|
||||
@api.expect(task_status_model)
|
||||
@api.doc(description="Get connector task status")
|
||||
def get(self):
|
||||
task_id = request.args.get("task_id")
|
||||
if not task_id:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Task ID is required"}), 400
|
||||
)
|
||||
try:
|
||||
from application.celery_init import celery
|
||||
|
||||
task = celery.AsyncResult(task_id)
|
||||
task_meta = task.info
|
||||
print(f"Task status: {task.status}")
|
||||
if not isinstance(
|
||||
task_meta, (dict, list, str, int, float, bool, type(None))
|
||||
):
|
||||
task_meta = str(task_meta)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error getting task status: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify({"status": task.status, "result": task_meta}), 200)
|
||||
|
||||
|
||||
@connectors_ns.route("/api/connectors/sources")
|
||||
class ConnectorSources(Resource):
|
||||
@api.doc(description="Get connector sources")
|
||||
def get(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
try:
|
||||
sources = sources_collection.find({"user": user, "type": "connector"}).sort("date", -1)
|
||||
connector_sources = []
|
||||
for source in sources:
|
||||
connector_sources.append({
|
||||
"id": str(source["_id"]),
|
||||
"name": source.get("name"),
|
||||
"date": source.get("date"),
|
||||
"type": source.get("type"),
|
||||
"source": source.get("source"),
|
||||
"tokens": source.get("tokens", ""),
|
||||
"retriever": source.get("retriever", "classic"),
|
||||
"syncFrequency": source.get("sync_frequency", ""),
|
||||
})
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error retrieving connector sources: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify(connector_sources), 200)
|
||||
|
||||
|
||||
@connectors_ns.route("/api/connectors/delete")
|
||||
class DeleteConnectorSource(Resource):
|
||||
@api.doc(
|
||||
description="Delete a connector source",
|
||||
params={"source_id": "The source ID to delete"},
|
||||
)
|
||||
def delete(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
source_id = request.args.get("source_id")
|
||||
if not source_id:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "source_id is required"}), 400
|
||||
)
|
||||
try:
|
||||
result = sources_collection.delete_one(
|
||||
{"_id": ObjectId(source_id), "user": decoded_token.get("sub")}
|
||||
)
|
||||
if result.deleted_count == 0:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Source not found"}), 404
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error deleting connector source: {err}", exc_info=True
|
||||
)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify({"success": True}), 200)
|
||||
|
||||
|
||||
@connectors_ns.route("/api/connectors/auth")
|
||||
@@ -63,24 +235,8 @@ class ConnectorAuth(Resource):
|
||||
if not ConnectorCreator.is_supported(provider):
|
||||
return make_response(jsonify({"success": False, "error": f"Unsupported provider: {provider}"}), 400)
|
||||
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False, "error": "Unauthorized"}), 401)
|
||||
user_id = decoded_token.get('sub')
|
||||
|
||||
now = datetime.datetime.now(datetime.timezone.utc)
|
||||
result = sessions_collection.insert_one({
|
||||
"provider": provider,
|
||||
"user": user_id,
|
||||
"status": "pending",
|
||||
"created_at": now
|
||||
})
|
||||
state_dict = {
|
||||
"provider": provider,
|
||||
"object_id": str(result.inserted_id)
|
||||
}
|
||||
state = base64.urlsafe_b64encode(json.dumps(state_dict).encode()).decode()
|
||||
|
||||
import uuid
|
||||
state = str(uuid.uuid4())
|
||||
auth = ConnectorCreator.create_auth(provider)
|
||||
authorization_url = auth.get_authorization_url(state=state)
|
||||
return make_response(jsonify({
|
||||
@@ -89,8 +245,8 @@ class ConnectorAuth(Resource):
|
||||
"state": state
|
||||
}), 200)
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error generating connector auth URL: {e}", exc_info=True)
|
||||
return make_response(jsonify({"success": False, "error": "Failed to generate authorization URL"}), 500)
|
||||
current_app.logger.error(f"Error generating connector auth URL: {e}")
|
||||
return make_response(jsonify({"success": False, "error": str(e)}), 500)
|
||||
|
||||
|
||||
@connectors_ns.route("/api/connectors/callback")
|
||||
@@ -101,123 +257,102 @@ class ConnectorsCallback(Resource):
|
||||
try:
|
||||
from application.parser.connectors.connector_creator import ConnectorCreator
|
||||
from flask import request, redirect
|
||||
import uuid
|
||||
|
||||
provider = request.args.get('provider', 'google_drive')
|
||||
authorization_code = request.args.get('code')
|
||||
state = request.args.get('state')
|
||||
_ = request.args.get('state')
|
||||
error = request.args.get('error')
|
||||
|
||||
state_dict = json.loads(base64.urlsafe_b64decode(state.encode()).decode())
|
||||
provider = state_dict.get("provider")
|
||||
state_object_id = state_dict.get("object_id")
|
||||
|
||||
# Validate provider
|
||||
if not provider or not isinstance(provider, str) or not ConnectorCreator.is_supported(provider):
|
||||
return redirect(build_callback_redirect({
|
||||
"status": "error",
|
||||
"message": "Invalid provider"
|
||||
}))
|
||||
|
||||
if error:
|
||||
if error == "access_denied":
|
||||
return redirect(build_callback_redirect({
|
||||
"status": "cancelled",
|
||||
"message": "Authentication was cancelled. You can try again if you'd like to connect your account.",
|
||||
"provider": provider
|
||||
}))
|
||||
else:
|
||||
current_app.logger.warning(f"OAuth error in callback: {error}")
|
||||
return redirect(build_callback_redirect({
|
||||
"status": "error",
|
||||
"message": "Authentication failed. Please try again and make sure to grant all requested permissions.",
|
||||
"provider": provider
|
||||
}))
|
||||
return redirect(f"/api/connectors/callback-status?status=error&message=OAuth+error:+{error}.+Please+try+again+and+make+sure+to+grant+all+requested+permissions,+including+offline+access.&provider={provider}")
|
||||
|
||||
if not authorization_code:
|
||||
return redirect(build_callback_redirect({
|
||||
"status": "error",
|
||||
"message": "Authentication failed. Please try again and make sure to grant all requested permissions.",
|
||||
"provider": provider
|
||||
}))
|
||||
return redirect(f"/api/connectors/callback-status?status=error&message=Authorization+code+not+provided.+Please+complete+the+authorization+process+and+make+sure+to+grant+offline+access.&provider={provider}")
|
||||
|
||||
try:
|
||||
auth = ConnectorCreator.create_auth(provider)
|
||||
token_info = auth.exchange_code_for_tokens(authorization_code)
|
||||
|
||||
session_token = str(uuid.uuid4())
|
||||
|
||||
|
||||
try:
|
||||
if provider == "google_drive":
|
||||
credentials = auth.create_credentials_from_token_info(token_info)
|
||||
service = auth.build_drive_service(credentials)
|
||||
user_info = service.about().get(fields="user").execute()
|
||||
user_email = user_info.get('user', {}).get('emailAddress', 'Connected User')
|
||||
else:
|
||||
user_email = token_info.get('user_info', {}).get('email', 'Connected User')
|
||||
|
||||
credentials = auth.create_credentials_from_token_info(token_info)
|
||||
service = auth.build_drive_service(credentials)
|
||||
user_info = service.about().get(fields="user").execute()
|
||||
user_email = user_info.get('user', {}).get('emailAddress', 'Connected User')
|
||||
except Exception as e:
|
||||
current_app.logger.warning(f"Could not get user info: {e}")
|
||||
user_email = 'Connected User'
|
||||
|
||||
sanitized_token_info = auth.sanitize_token_info(token_info)
|
||||
sanitized_token_info = {
|
||||
"access_token": token_info.get("access_token"),
|
||||
"refresh_token": token_info.get("refresh_token"),
|
||||
"token_uri": token_info.get("token_uri"),
|
||||
"expiry": token_info.get("expiry"),
|
||||
"scopes": token_info.get("scopes")
|
||||
}
|
||||
|
||||
sessions_collection.find_one_and_update(
|
||||
{"_id": ObjectId(state_object_id), "provider": provider},
|
||||
{
|
||||
"$set": {
|
||||
"session_token": session_token,
|
||||
"token_info": sanitized_token_info,
|
||||
"user_email": user_email,
|
||||
"status": "authorized"
|
||||
}
|
||||
}
|
||||
)
|
||||
user_id = request.decoded_token.get("sub") if getattr(request, "decoded_token", None) else None
|
||||
sessions_collection.insert_one({
|
||||
"session_token": session_token,
|
||||
"user": user_id,
|
||||
"token_info": sanitized_token_info,
|
||||
"created_at": datetime.datetime.now(datetime.timezone.utc),
|
||||
"user_email": user_email,
|
||||
"provider": provider
|
||||
})
|
||||
|
||||
# Redirect to success page with session token and user email
|
||||
return redirect(build_callback_redirect({
|
||||
"status": "success",
|
||||
"message": "Authentication successful",
|
||||
"provider": provider,
|
||||
"session_token": session_token,
|
||||
"user_email": user_email
|
||||
}))
|
||||
return redirect(f"/api/connectors/callback-status?status=success&message=Authentication+successful&provider={provider}&session_token={session_token}&user_email={user_email}")
|
||||
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error exchanging code for tokens: {str(e)}", exc_info=True)
|
||||
return redirect(build_callback_redirect({
|
||||
"status": "error",
|
||||
"message": "Authentication failed. Please try again and make sure to grant all requested permissions.",
|
||||
"provider": provider
|
||||
}))
|
||||
return redirect(f"/api/connectors/callback-status?status=error&message=Failed+to+exchange+authorization+code+for+tokens:+{str(e)}&provider={provider}")
|
||||
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error handling connector callback: {e}")
|
||||
return redirect(build_callback_redirect({
|
||||
"status": "error",
|
||||
"message": "Authentication failed. Please try again and make sure to grant all requested permissions."
|
||||
}))
|
||||
return redirect(f"/api/connectors/callback-status?status=error&message=Failed+to+complete+connector+authentication:+{str(e)}.+Please+try+again+and+make+sure+to+grant+all+requested+permissions,+including+offline+access.")
|
||||
|
||||
|
||||
@connectors_ns.route("/api/connectors/refresh")
|
||||
class ConnectorRefresh(Resource):
|
||||
@api.expect(api.model("ConnectorRefreshModel", {"provider": fields.String(required=True), "refresh_token": fields.String(required=True)}))
|
||||
@api.doc(description="Refresh connector access token")
|
||||
def post(self):
|
||||
try:
|
||||
data = request.get_json()
|
||||
provider = data.get('provider')
|
||||
refresh_token = data.get('refresh_token')
|
||||
|
||||
if not provider or not refresh_token:
|
||||
return make_response(jsonify({"success": False, "error": "provider and refresh_token are required"}), 400)
|
||||
|
||||
auth = ConnectorCreator.create_auth(provider)
|
||||
token_info = auth.refresh_access_token(refresh_token)
|
||||
return make_response(jsonify({"success": True, "token_info": token_info}), 200)
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error refreshing token for connector: {e}")
|
||||
return make_response(jsonify({"success": False, "error": str(e)}), 500)
|
||||
|
||||
|
||||
@connectors_ns.route("/api/connectors/files")
|
||||
class ConnectorFiles(Resource):
|
||||
@api.expect(api.model("ConnectorFilesModel", {
|
||||
"provider": fields.String(required=True),
|
||||
"session_token": fields.String(required=True),
|
||||
"folder_id": fields.String(required=False),
|
||||
"limit": fields.Integer(required=False),
|
||||
"page_token": fields.String(required=False),
|
||||
"search_query": fields.String(required=False),
|
||||
}))
|
||||
@api.doc(description="List files from a connector provider (supports pagination and search)")
|
||||
@api.expect(api.model("ConnectorFilesModel", {"provider": fields.String(required=True), "session_token": fields.String(required=True), "folder_id": fields.String(required=False), "limit": fields.Integer(required=False), "page_token": fields.String(required=False)}))
|
||||
@api.doc(description="List files from a connector provider (supports pagination)")
|
||||
def post(self):
|
||||
try:
|
||||
data = request.get_json()
|
||||
provider = data.get('provider')
|
||||
session_token = data.get('session_token')
|
||||
folder_id = data.get('folder_id')
|
||||
limit = data.get('limit', 10)
|
||||
|
||||
page_token = data.get('page_token')
|
||||
if not provider or not session_token:
|
||||
return make_response(jsonify({"success": False, "error": "provider and session_token are required"}), 400)
|
||||
|
||||
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False, "error": "Unauthorized"}), 401)
|
||||
@@ -227,14 +362,13 @@ class ConnectorFiles(Resource):
|
||||
return make_response(jsonify({"success": False, "error": "Invalid or unauthorized session"}), 401)
|
||||
|
||||
loader = ConnectorCreator.create_connector(provider, session_token)
|
||||
|
||||
generic_keys = {'provider', 'session_token'}
|
||||
input_config = {
|
||||
k: v for k, v in data.items() if k not in generic_keys
|
||||
}
|
||||
input_config['list_only'] = True
|
||||
|
||||
documents = loader.load_data(input_config)
|
||||
documents = loader.load_data({
|
||||
'limit': limit,
|
||||
'list_only': True,
|
||||
'session_token': session_token,
|
||||
'folder_id': folder_id,
|
||||
'page_token': page_token
|
||||
})
|
||||
|
||||
files = []
|
||||
for doc in documents[:limit]:
|
||||
@@ -252,29 +386,22 @@ class ConnectorFiles(Resource):
|
||||
'name': metadata.get('file_name', 'Unknown File'),
|
||||
'type': metadata.get('mime_type', 'unknown'),
|
||||
'size': metadata.get('size', None),
|
||||
'modifiedTime': formatted_time,
|
||||
'isFolder': metadata.get('is_folder', False)
|
||||
'modifiedTime': formatted_time
|
||||
})
|
||||
|
||||
next_token = getattr(loader, 'next_page_token', None)
|
||||
has_more = bool(next_token)
|
||||
|
||||
return make_response(jsonify({
|
||||
"success": True,
|
||||
"files": files,
|
||||
"total": len(files),
|
||||
"next_page_token": next_token,
|
||||
"has_more": has_more
|
||||
}), 200)
|
||||
return make_response(jsonify({"success": True, "files": files, "total": len(files), "next_page_token": next_token, "has_more": has_more}), 200)
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error loading connector files: {e}", exc_info=True)
|
||||
return make_response(jsonify({"success": False, "error": "Failed to load files"}), 500)
|
||||
current_app.logger.error(f"Error loading connector files: {e}")
|
||||
return make_response(jsonify({"success": False, "error": f"Failed to load files: {str(e)}"}), 500)
|
||||
|
||||
|
||||
@connectors_ns.route("/api/connectors/validate-session")
|
||||
class ConnectorValidateSession(Resource):
|
||||
@api.expect(api.model("ConnectorValidateSessionModel", {"provider": fields.String(required=True), "session_token": fields.String(required=True)}))
|
||||
@api.doc(description="Validate connector session token and return user info and access token")
|
||||
@api.doc(description="Validate connector session token and return user info")
|
||||
def post(self):
|
||||
try:
|
||||
data = request.get_json()
|
||||
@@ -283,6 +410,7 @@ class ConnectorValidateSession(Resource):
|
||||
if not provider or not session_token:
|
||||
return make_response(jsonify({"success": False, "error": "provider and session_token are required"}), 400)
|
||||
|
||||
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False, "error": "Unauthorized"}), 401)
|
||||
@@ -296,41 +424,14 @@ class ConnectorValidateSession(Resource):
|
||||
auth = ConnectorCreator.create_auth(provider)
|
||||
is_expired = auth.is_token_expired(token_info)
|
||||
|
||||
if is_expired and token_info.get('refresh_token'):
|
||||
try:
|
||||
refreshed_token_info = auth.refresh_access_token(token_info.get('refresh_token'))
|
||||
sanitized_token_info = auth.sanitize_token_info(refreshed_token_info)
|
||||
sessions_collection.update_one(
|
||||
{"session_token": session_token},
|
||||
{"$set": {"token_info": sanitized_token_info}}
|
||||
)
|
||||
token_info = sanitized_token_info
|
||||
is_expired = False
|
||||
except Exception as refresh_error:
|
||||
current_app.logger.error(f"Failed to refresh token: {refresh_error}")
|
||||
|
||||
if is_expired:
|
||||
return make_response(jsonify({
|
||||
"success": False,
|
||||
"expired": True,
|
||||
"error": "Session token has expired. Please reconnect."
|
||||
}), 401)
|
||||
|
||||
_base_fields = {"access_token", "refresh_token", "token_uri", "expiry"}
|
||||
provider_extras = {k: v for k, v in token_info.items() if k not in _base_fields}
|
||||
|
||||
response_data = {
|
||||
return make_response(jsonify({
|
||||
"success": True,
|
||||
"expired": False,
|
||||
"user_email": session.get('user_email', 'Connected User'),
|
||||
"access_token": token_info.get('access_token'),
|
||||
**provider_extras,
|
||||
}
|
||||
|
||||
return make_response(jsonify(response_data), 200)
|
||||
"expired": is_expired,
|
||||
"user_email": session.get('user_email', 'Connected User')
|
||||
}), 200)
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error validating connector session: {e}", exc_info=True)
|
||||
return make_response(jsonify({"success": False, "error": "Failed to validate session"}), 500)
|
||||
current_app.logger.error(f"Error validating connector session: {e}")
|
||||
return make_response(jsonify({"success": False, "error": str(e)}), 500)
|
||||
|
||||
|
||||
@connectors_ns.route("/api/connectors/disconnect")
|
||||
@@ -351,8 +452,8 @@ class ConnectorDisconnect(Resource):
|
||||
|
||||
return make_response(jsonify({"success": True}), 200)
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error disconnecting connector session: {e}", exc_info=True)
|
||||
return make_response(jsonify({"success": False, "error": "Failed to disconnect session"}), 500)
|
||||
current_app.logger.error(f"Error disconnecting connector session: {e}")
|
||||
return make_response(jsonify({"success": False, "error": str(e)}), 500)
|
||||
|
||||
|
||||
@connectors_ns.route("/api/connectors/sync")
|
||||
@@ -458,8 +559,8 @@ class ConnectorSync(Resource):
|
||||
return make_response(
|
||||
jsonify({
|
||||
"success": False,
|
||||
"error": "Failed to sync connector source"
|
||||
}),
|
||||
"error": str(err)
|
||||
}),
|
||||
400
|
||||
)
|
||||
|
||||
@@ -470,55 +571,36 @@ class ConnectorCallbackStatus(Resource):
|
||||
def get(self):
|
||||
"""Return HTML page with connector authentication status"""
|
||||
try:
|
||||
# Validate and sanitize status to a known value
|
||||
status_raw = request.args.get('status', 'error')
|
||||
status = status_raw if status_raw in ('success', 'error', 'cancelled') else 'error'
|
||||
|
||||
# Escape all user-controlled values for HTML context
|
||||
message = html.escape(request.args.get('message', ''))
|
||||
provider_raw = request.args.get('provider', 'connector')
|
||||
provider = html.escape(provider_raw.replace('_', ' ').title())
|
||||
status = request.args.get('status', 'error')
|
||||
message = request.args.get('message', '')
|
||||
provider = request.args.get('provider', 'connector')
|
||||
session_token = request.args.get('session_token', '')
|
||||
user_email = html.escape(request.args.get('user_email', ''))
|
||||
|
||||
def safe_js_string(value: str) -> str:
|
||||
"""Safely encode a string for embedding in inline JavaScript."""
|
||||
js_encoded = json.dumps(value)
|
||||
return js_encoded.replace('</', '<\\/').replace('<!--', '<\\!--')
|
||||
|
||||
js_status = safe_js_string(status)
|
||||
js_session_token = safe_js_string(session_token)
|
||||
js_user_email = safe_js_string(user_email)
|
||||
js_provider_type = safe_js_string(provider_raw)
|
||||
|
||||
user_email = request.args.get('user_email', '')
|
||||
|
||||
html_content = f"""
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<title>{provider} Authentication</title>
|
||||
<title>{provider.replace('_', ' ').title()} Authentication</title>
|
||||
<style>
|
||||
body {{ font-family: Arial, sans-serif; text-align: center; padding: 40px; }}
|
||||
.container {{ max-width: 600px; margin: 0 auto; }}
|
||||
.success {{ color: #4CAF50; }}
|
||||
.error {{ color: #F44336; }}
|
||||
.cancelled {{ color: #FF9800; }}
|
||||
</style>
|
||||
<script>
|
||||
window.onload = function() {{
|
||||
const status = {js_status};
|
||||
const sessionToken = {js_session_token};
|
||||
const userEmail = {js_user_email};
|
||||
const providerType = {js_provider_type};
|
||||
|
||||
const status = "{status}";
|
||||
const sessionToken = "{session_token}";
|
||||
const userEmail = "{user_email}";
|
||||
|
||||
if (status === "success" && window.opener) {{
|
||||
window.opener.postMessage({{
|
||||
type: providerType + '_auth_success',
|
||||
type: '{provider}_auth_success',
|
||||
session_token: sessionToken,
|
||||
user_email: userEmail
|
||||
}}, '*');
|
||||
|
||||
setTimeout(() => window.close(), 3000);
|
||||
}} else if (status === "cancelled" || status === "error") {{
|
||||
|
||||
setTimeout(() => window.close(), 3000);
|
||||
}}
|
||||
}};
|
||||
@@ -526,17 +608,17 @@ class ConnectorCallbackStatus(Resource):
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<h2>{provider} Authentication</h2>
|
||||
<h2>{provider.replace('_', ' ').title()} Authentication</h2>
|
||||
<div class="{status}">
|
||||
<p>{message}</p>
|
||||
{f'<p>Connected as: {user_email}</p>' if status == 'success' else ''}
|
||||
</div>
|
||||
<p><small>You can close this window. {f"Your {provider} is now connected and ready to use." if status == 'success' else "Feel free to close this window."}</small></p>
|
||||
<p><small>You can close this window. {f"Your {provider.replace('_', ' ').title()} is now connected and ready to use." if status == 'success' else ''}</small></p>
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
|
||||
|
||||
return make_response(html_content, 200, {'Content-Type': 'text/html'})
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error rendering callback status page: {e}")
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import os
|
||||
import datetime
|
||||
import json
|
||||
from flask import Blueprint, request, send_from_directory, jsonify
|
||||
from flask import Blueprint, request, send_from_directory
|
||||
from werkzeug.utils import secure_filename
|
||||
from bson.objectid import ObjectId
|
||||
import logging
|
||||
@@ -24,16 +24,6 @@ current_dir = os.path.dirname(
|
||||
internal = Blueprint("internal", __name__)
|
||||
|
||||
|
||||
@internal.before_request
|
||||
def verify_internal_key():
|
||||
"""Verify INTERNAL_KEY for all internal endpoint requests."""
|
||||
if settings.INTERNAL_KEY:
|
||||
internal_key = request.headers.get("X-Internal-Key")
|
||||
if not internal_key or internal_key != settings.INTERNAL_KEY:
|
||||
logger.warning(f"Unauthorized internal API access attempt from {request.remote_addr}")
|
||||
return jsonify({"error": "Unauthorized", "message": "Invalid or missing internal key"}), 401
|
||||
|
||||
|
||||
@internal.route("/api/download", methods=["get"])
|
||||
def download_file():
|
||||
user = secure_filename(request.args.get("user"))
|
||||
@@ -61,7 +51,6 @@ def upload_index_files():
|
||||
|
||||
file_path = request.form.get("file_path")
|
||||
directory_structure = request.form.get("directory_structure")
|
||||
file_name_map = request.form.get("file_name_map")
|
||||
|
||||
if directory_structure:
|
||||
try:
|
||||
@@ -71,14 +60,6 @@ def upload_index_files():
|
||||
directory_structure = {}
|
||||
else:
|
||||
directory_structure = {}
|
||||
if file_name_map:
|
||||
try:
|
||||
file_name_map = json.loads(file_name_map)
|
||||
except Exception:
|
||||
logger.error("Error parsing file_name_map")
|
||||
file_name_map = None
|
||||
else:
|
||||
file_name_map = None
|
||||
|
||||
storage = StorageCreator.get_storage()
|
||||
index_base_path = f"indexes/{id}"
|
||||
@@ -106,43 +87,41 @@ def upload_index_files():
|
||||
|
||||
existing_entry = sources_collection.find_one({"_id": ObjectId(id)})
|
||||
if existing_entry:
|
||||
update_fields = {
|
||||
"user": user,
|
||||
"name": job_name,
|
||||
"language": job_name,
|
||||
"date": datetime.datetime.now(),
|
||||
"model": settings.EMBEDDINGS_NAME,
|
||||
"type": type,
|
||||
"tokens": tokens,
|
||||
"retriever": retriever,
|
||||
"remote_data": remote_data,
|
||||
"sync_frequency": sync_frequency,
|
||||
"file_path": file_path,
|
||||
"directory_structure": directory_structure,
|
||||
}
|
||||
if file_name_map is not None:
|
||||
update_fields["file_name_map"] = file_name_map
|
||||
sources_collection.update_one(
|
||||
{"_id": ObjectId(id)},
|
||||
{"$set": update_fields},
|
||||
{
|
||||
"$set": {
|
||||
"user": user,
|
||||
"name": job_name,
|
||||
"language": job_name,
|
||||
"date": datetime.datetime.now(),
|
||||
"model": settings.EMBEDDINGS_NAME,
|
||||
"type": type,
|
||||
"tokens": tokens,
|
||||
"retriever": retriever,
|
||||
"remote_data": remote_data,
|
||||
"sync_frequency": sync_frequency,
|
||||
"file_path": file_path,
|
||||
"directory_structure": directory_structure,
|
||||
}
|
||||
},
|
||||
)
|
||||
else:
|
||||
insert_doc = {
|
||||
"_id": ObjectId(id),
|
||||
"user": user,
|
||||
"name": job_name,
|
||||
"language": job_name,
|
||||
"date": datetime.datetime.now(),
|
||||
"model": settings.EMBEDDINGS_NAME,
|
||||
"type": type,
|
||||
"tokens": tokens,
|
||||
"retriever": retriever,
|
||||
"remote_data": remote_data,
|
||||
"sync_frequency": sync_frequency,
|
||||
"file_path": file_path,
|
||||
"directory_structure": directory_structure,
|
||||
}
|
||||
if file_name_map is not None:
|
||||
insert_doc["file_name_map"] = file_name_map
|
||||
sources_collection.insert_one(insert_doc)
|
||||
sources_collection.insert_one(
|
||||
{
|
||||
"_id": ObjectId(id),
|
||||
"user": user,
|
||||
"name": job_name,
|
||||
"language": job_name,
|
||||
"date": datetime.datetime.now(),
|
||||
"model": settings.EMBEDDINGS_NAME,
|
||||
"type": type,
|
||||
"tokens": tokens,
|
||||
"retriever": retriever,
|
||||
"remote_data": remote_data,
|
||||
"sync_frequency": sync_frequency,
|
||||
"file_path": file_path,
|
||||
"directory_structure": directory_structure,
|
||||
}
|
||||
)
|
||||
return {"status": "ok"}
|
||||
|
||||
@@ -1,5 +0,0 @@
|
||||
"""User API module - provides all user-related API endpoints"""
|
||||
|
||||
from .routes import user
|
||||
|
||||
__all__ = ["user"]
|
||||
|
||||
@@ -1,8 +0,0 @@
|
||||
"""Agents module."""
|
||||
|
||||
from .routes import agents_ns
|
||||
from .sharing import agents_sharing_ns
|
||||
from .webhooks import agents_webhooks_ns
|
||||
from .folders import agents_folders_ns
|
||||
|
||||
__all__ = ["agents_ns", "agents_sharing_ns", "agents_webhooks_ns", "agents_folders_ns"]
|
||||
@@ -1,266 +0,0 @@
|
||||
"""
|
||||
Agent folders management routes.
|
||||
Provides virtual folder organization for agents (Google Drive-like structure).
|
||||
"""
|
||||
|
||||
import datetime
|
||||
from bson.objectid import ObjectId
|
||||
from flask import current_app, jsonify, make_response, request
|
||||
from flask_restx import Namespace, Resource, fields
|
||||
|
||||
from application.api import api
|
||||
from application.api.user.base import (
|
||||
agent_folders_collection,
|
||||
agents_collection,
|
||||
)
|
||||
|
||||
agents_folders_ns = Namespace(
|
||||
"agents_folders", description="Agent folder management", path="/api/agents/folders"
|
||||
)
|
||||
|
||||
|
||||
def _folder_error_response(message: str, err: Exception):
|
||||
current_app.logger.error(f"{message}: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False, "message": message}), 400)
|
||||
|
||||
|
||||
@agents_folders_ns.route("/")
|
||||
class AgentFolders(Resource):
|
||||
@api.doc(description="Get all folders for the user")
|
||||
def get(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
try:
|
||||
folders = list(agent_folders_collection.find({"user": user}))
|
||||
result = [
|
||||
{
|
||||
"id": str(f["_id"]),
|
||||
"name": f["name"],
|
||||
"parent_id": f.get("parent_id"),
|
||||
"created_at": f.get("created_at", "").isoformat() if f.get("created_at") else None,
|
||||
"updated_at": f.get("updated_at", "").isoformat() if f.get("updated_at") else None,
|
||||
}
|
||||
for f in folders
|
||||
]
|
||||
return make_response(jsonify({"folders": result}), 200)
|
||||
except Exception as err:
|
||||
return _folder_error_response("Failed to fetch folders", err)
|
||||
|
||||
@api.doc(description="Create a new folder")
|
||||
@api.expect(
|
||||
api.model(
|
||||
"CreateFolder",
|
||||
{
|
||||
"name": fields.String(required=True, description="Folder name"),
|
||||
"parent_id": fields.String(required=False, description="Parent folder ID"),
|
||||
},
|
||||
)
|
||||
)
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
data = request.get_json()
|
||||
if not data or not data.get("name"):
|
||||
return make_response(jsonify({"success": False, "message": "Folder name is required"}), 400)
|
||||
|
||||
parent_id = data.get("parent_id")
|
||||
if parent_id:
|
||||
parent = agent_folders_collection.find_one({"_id": ObjectId(parent_id), "user": user})
|
||||
if not parent:
|
||||
return make_response(jsonify({"success": False, "message": "Parent folder not found"}), 404)
|
||||
|
||||
try:
|
||||
now = datetime.datetime.now(datetime.timezone.utc)
|
||||
folder = {
|
||||
"user": user,
|
||||
"name": data["name"],
|
||||
"parent_id": parent_id,
|
||||
"created_at": now,
|
||||
"updated_at": now,
|
||||
}
|
||||
result = agent_folders_collection.insert_one(folder)
|
||||
return make_response(
|
||||
jsonify({"id": str(result.inserted_id), "name": data["name"], "parent_id": parent_id}),
|
||||
201,
|
||||
)
|
||||
except Exception as err:
|
||||
return _folder_error_response("Failed to create folder", err)
|
||||
|
||||
|
||||
@agents_folders_ns.route("/<string:folder_id>")
|
||||
class AgentFolder(Resource):
|
||||
@api.doc(description="Get a specific folder with its agents")
|
||||
def get(self, folder_id):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
try:
|
||||
folder = agent_folders_collection.find_one({"_id": ObjectId(folder_id), "user": user})
|
||||
if not folder:
|
||||
return make_response(jsonify({"success": False, "message": "Folder not found"}), 404)
|
||||
|
||||
agents = list(agents_collection.find({"user": user, "folder_id": folder_id}))
|
||||
agents_list = [
|
||||
{"id": str(a["_id"]), "name": a["name"], "description": a.get("description", "")}
|
||||
for a in agents
|
||||
]
|
||||
subfolders = list(agent_folders_collection.find({"user": user, "parent_id": folder_id}))
|
||||
subfolders_list = [{"id": str(sf["_id"]), "name": sf["name"]} for sf in subfolders]
|
||||
|
||||
return make_response(
|
||||
jsonify({
|
||||
"id": str(folder["_id"]),
|
||||
"name": folder["name"],
|
||||
"parent_id": folder.get("parent_id"),
|
||||
"agents": agents_list,
|
||||
"subfolders": subfolders_list,
|
||||
}),
|
||||
200,
|
||||
)
|
||||
except Exception as err:
|
||||
return _folder_error_response("Failed to fetch folder", err)
|
||||
|
||||
@api.doc(description="Update a folder")
|
||||
def put(self, folder_id):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
data = request.get_json()
|
||||
if not data:
|
||||
return make_response(jsonify({"success": False, "message": "No data provided"}), 400)
|
||||
|
||||
try:
|
||||
update_fields = {"updated_at": datetime.datetime.now(datetime.timezone.utc)}
|
||||
if "name" in data:
|
||||
update_fields["name"] = data["name"]
|
||||
if "parent_id" in data:
|
||||
if data["parent_id"] == folder_id:
|
||||
return make_response(jsonify({"success": False, "message": "Cannot set folder as its own parent"}), 400)
|
||||
update_fields["parent_id"] = data["parent_id"]
|
||||
|
||||
result = agent_folders_collection.update_one(
|
||||
{"_id": ObjectId(folder_id), "user": user}, {"$set": update_fields}
|
||||
)
|
||||
if result.matched_count == 0:
|
||||
return make_response(jsonify({"success": False, "message": "Folder not found"}), 404)
|
||||
return make_response(jsonify({"success": True}), 200)
|
||||
except Exception as err:
|
||||
return _folder_error_response("Failed to update folder", err)
|
||||
|
||||
@api.doc(description="Delete a folder")
|
||||
def delete(self, folder_id):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
try:
|
||||
agents_collection.update_many(
|
||||
{"user": user, "folder_id": folder_id}, {"$unset": {"folder_id": ""}}
|
||||
)
|
||||
agent_folders_collection.update_many(
|
||||
{"user": user, "parent_id": folder_id}, {"$unset": {"parent_id": ""}}
|
||||
)
|
||||
result = agent_folders_collection.delete_one({"_id": ObjectId(folder_id), "user": user})
|
||||
if result.deleted_count == 0:
|
||||
return make_response(jsonify({"success": False, "message": "Folder not found"}), 404)
|
||||
return make_response(jsonify({"success": True}), 200)
|
||||
except Exception as err:
|
||||
return _folder_error_response("Failed to delete folder", err)
|
||||
|
||||
|
||||
@agents_folders_ns.route("/move_agent")
|
||||
class MoveAgentToFolder(Resource):
|
||||
@api.doc(description="Move an agent to a folder or remove from folder")
|
||||
@api.expect(
|
||||
api.model(
|
||||
"MoveAgent",
|
||||
{
|
||||
"agent_id": fields.String(required=True, description="Agent ID to move"),
|
||||
"folder_id": fields.String(required=False, description="Target folder ID (null to remove from folder)"),
|
||||
},
|
||||
)
|
||||
)
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
data = request.get_json()
|
||||
if not data or not data.get("agent_id"):
|
||||
return make_response(jsonify({"success": False, "message": "Agent ID is required"}), 400)
|
||||
|
||||
agent_id = data["agent_id"]
|
||||
folder_id = data.get("folder_id")
|
||||
|
||||
try:
|
||||
agent = agents_collection.find_one({"_id": ObjectId(agent_id), "user": user})
|
||||
if not agent:
|
||||
return make_response(jsonify({"success": False, "message": "Agent not found"}), 404)
|
||||
|
||||
if folder_id:
|
||||
folder = agent_folders_collection.find_one({"_id": ObjectId(folder_id), "user": user})
|
||||
if not folder:
|
||||
return make_response(jsonify({"success": False, "message": "Folder not found"}), 404)
|
||||
agents_collection.update_one(
|
||||
{"_id": ObjectId(agent_id)}, {"$set": {"folder_id": folder_id}}
|
||||
)
|
||||
else:
|
||||
agents_collection.update_one(
|
||||
{"_id": ObjectId(agent_id)}, {"$unset": {"folder_id": ""}}
|
||||
)
|
||||
|
||||
return make_response(jsonify({"success": True}), 200)
|
||||
except Exception as err:
|
||||
return _folder_error_response("Failed to move agent", err)
|
||||
|
||||
|
||||
@agents_folders_ns.route("/bulk_move")
|
||||
class BulkMoveAgents(Resource):
|
||||
@api.doc(description="Move multiple agents to a folder")
|
||||
@api.expect(
|
||||
api.model(
|
||||
"BulkMoveAgents",
|
||||
{
|
||||
"agent_ids": fields.List(fields.String, required=True, description="List of agent IDs"),
|
||||
"folder_id": fields.String(required=False, description="Target folder ID"),
|
||||
},
|
||||
)
|
||||
)
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
data = request.get_json()
|
||||
if not data or not data.get("agent_ids"):
|
||||
return make_response(jsonify({"success": False, "message": "Agent IDs are required"}), 400)
|
||||
|
||||
agent_ids = data["agent_ids"]
|
||||
folder_id = data.get("folder_id")
|
||||
|
||||
try:
|
||||
if folder_id:
|
||||
folder = agent_folders_collection.find_one({"_id": ObjectId(folder_id), "user": user})
|
||||
if not folder:
|
||||
return make_response(jsonify({"success": False, "message": "Folder not found"}), 404)
|
||||
|
||||
object_ids = [ObjectId(aid) for aid in agent_ids]
|
||||
if folder_id:
|
||||
agents_collection.update_many(
|
||||
{"_id": {"$in": object_ids}, "user": user},
|
||||
{"$set": {"folder_id": folder_id}},
|
||||
)
|
||||
else:
|
||||
agents_collection.update_many(
|
||||
{"_id": {"$in": object_ids}, "user": user},
|
||||
{"$unset": {"folder_id": ""}},
|
||||
)
|
||||
return make_response(jsonify({"success": True}), 200)
|
||||
except Exception as err:
|
||||
return _folder_error_response("Failed to move agents", err)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,263 +0,0 @@
|
||||
"""Agent management sharing functionality."""
|
||||
|
||||
import datetime
|
||||
import secrets
|
||||
|
||||
from bson import DBRef
|
||||
from bson.objectid import ObjectId
|
||||
from flask import current_app, jsonify, make_response, request
|
||||
from flask_restx import fields, Namespace, Resource
|
||||
|
||||
from application.api import api
|
||||
from application.core.settings import settings
|
||||
from application.api.user.base import (
|
||||
agents_collection,
|
||||
db,
|
||||
ensure_user_doc,
|
||||
resolve_tool_details,
|
||||
user_tools_collection,
|
||||
users_collection,
|
||||
)
|
||||
from application.utils import generate_image_url
|
||||
|
||||
agents_sharing_ns = Namespace(
|
||||
"agents", description="Agent management operations", path="/api"
|
||||
)
|
||||
|
||||
|
||||
@agents_sharing_ns.route("/shared_agent")
|
||||
class SharedAgent(Resource):
|
||||
@api.doc(
|
||||
params={
|
||||
"token": "Shared token of the agent",
|
||||
},
|
||||
description="Get a shared agent by token or ID",
|
||||
)
|
||||
def get(self):
|
||||
shared_token = request.args.get("token")
|
||||
|
||||
if not shared_token:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Token or ID is required"}), 400
|
||||
)
|
||||
try:
|
||||
query = {
|
||||
"shared_publicly": True,
|
||||
"shared_token": shared_token,
|
||||
}
|
||||
shared_agent = agents_collection.find_one(query)
|
||||
if not shared_agent:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Shared agent not found"}),
|
||||
404,
|
||||
)
|
||||
agent_id = str(shared_agent["_id"])
|
||||
data = {
|
||||
"id": agent_id,
|
||||
"user": shared_agent.get("user", ""),
|
||||
"name": shared_agent.get("name", ""),
|
||||
"image": (
|
||||
generate_image_url(shared_agent["image"])
|
||||
if shared_agent.get("image")
|
||||
else ""
|
||||
),
|
||||
"description": shared_agent.get("description", ""),
|
||||
"source": (
|
||||
str(source_doc["_id"])
|
||||
if isinstance(shared_agent.get("source"), DBRef)
|
||||
and (source_doc := db.dereference(shared_agent.get("source")))
|
||||
else ""
|
||||
),
|
||||
"chunks": shared_agent.get("chunks", "0"),
|
||||
"retriever": shared_agent.get("retriever", "classic"),
|
||||
"prompt_id": shared_agent.get("prompt_id", "default"),
|
||||
"tools": shared_agent.get("tools", []),
|
||||
"tool_details": resolve_tool_details(shared_agent.get("tools", [])),
|
||||
"agent_type": shared_agent.get("agent_type", ""),
|
||||
"status": shared_agent.get("status", ""),
|
||||
"json_schema": shared_agent.get("json_schema"),
|
||||
"limited_token_mode": shared_agent.get("limited_token_mode", False),
|
||||
"token_limit": shared_agent.get("token_limit", settings.DEFAULT_AGENT_LIMITS["token_limit"]),
|
||||
"limited_request_mode": shared_agent.get("limited_request_mode", False),
|
||||
"request_limit": shared_agent.get("request_limit", settings.DEFAULT_AGENT_LIMITS["request_limit"]),
|
||||
"created_at": shared_agent.get("createdAt", ""),
|
||||
"updated_at": shared_agent.get("updatedAt", ""),
|
||||
"shared": shared_agent.get("shared_publicly", False),
|
||||
"shared_token": shared_agent.get("shared_token", ""),
|
||||
"shared_metadata": shared_agent.get("shared_metadata", {}),
|
||||
}
|
||||
|
||||
if data["tools"]:
|
||||
enriched_tools = []
|
||||
for tool in data["tools"]:
|
||||
tool_data = user_tools_collection.find_one({"_id": ObjectId(tool)})
|
||||
if tool_data:
|
||||
enriched_tools.append(tool_data.get("name", ""))
|
||||
data["tools"] = enriched_tools
|
||||
decoded_token = getattr(request, "decoded_token", None)
|
||||
if decoded_token:
|
||||
user_id = decoded_token.get("sub")
|
||||
owner_id = shared_agent.get("user")
|
||||
|
||||
if user_id != owner_id:
|
||||
ensure_user_doc(user_id)
|
||||
users_collection.update_one(
|
||||
{"user_id": user_id},
|
||||
{"$addToSet": {"agent_preferences.shared_with_me": agent_id}},
|
||||
)
|
||||
return make_response(jsonify(data), 200)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error retrieving shared agent: {err}")
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
|
||||
|
||||
@agents_sharing_ns.route("/shared_agents")
|
||||
class SharedAgents(Resource):
|
||||
@api.doc(description="Get shared agents explicitly shared with the user")
|
||||
def get(self):
|
||||
try:
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user_id = decoded_token.get("sub")
|
||||
|
||||
user_doc = ensure_user_doc(user_id)
|
||||
shared_with_ids = user_doc.get("agent_preferences", {}).get(
|
||||
"shared_with_me", []
|
||||
)
|
||||
shared_object_ids = [ObjectId(id) for id in shared_with_ids]
|
||||
|
||||
shared_agents_cursor = agents_collection.find(
|
||||
{"_id": {"$in": shared_object_ids}, "shared_publicly": True}
|
||||
)
|
||||
shared_agents = list(shared_agents_cursor)
|
||||
|
||||
found_ids_set = {str(agent["_id"]) for agent in shared_agents}
|
||||
stale_ids = [id for id in shared_with_ids if id not in found_ids_set]
|
||||
if stale_ids:
|
||||
users_collection.update_one(
|
||||
{"user_id": user_id},
|
||||
{"$pullAll": {"agent_preferences.shared_with_me": stale_ids}},
|
||||
)
|
||||
pinned_ids = set(user_doc.get("agent_preferences", {}).get("pinned", []))
|
||||
|
||||
list_shared_agents = [
|
||||
{
|
||||
"id": str(agent["_id"]),
|
||||
"name": agent.get("name", ""),
|
||||
"description": agent.get("description", ""),
|
||||
"image": (
|
||||
generate_image_url(agent["image"]) if agent.get("image") else ""
|
||||
),
|
||||
"tools": agent.get("tools", []),
|
||||
"tool_details": resolve_tool_details(agent.get("tools", [])),
|
||||
"agent_type": agent.get("agent_type", ""),
|
||||
"status": agent.get("status", ""),
|
||||
"json_schema": agent.get("json_schema"),
|
||||
"limited_token_mode": agent.get("limited_token_mode", False),
|
||||
"token_limit": agent.get("token_limit", settings.DEFAULT_AGENT_LIMITS["token_limit"]),
|
||||
"limited_request_mode": agent.get("limited_request_mode", False),
|
||||
"request_limit": agent.get("request_limit", settings.DEFAULT_AGENT_LIMITS["request_limit"]),
|
||||
"created_at": agent.get("createdAt", ""),
|
||||
"updated_at": agent.get("updatedAt", ""),
|
||||
"pinned": str(agent["_id"]) in pinned_ids,
|
||||
"shared": agent.get("shared_publicly", False),
|
||||
"shared_token": agent.get("shared_token", ""),
|
||||
"shared_metadata": agent.get("shared_metadata", {}),
|
||||
}
|
||||
for agent in shared_agents
|
||||
]
|
||||
|
||||
return make_response(jsonify(list_shared_agents), 200)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error retrieving shared agents: {err}")
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
|
||||
|
||||
@agents_sharing_ns.route("/share_agent")
|
||||
class ShareAgent(Resource):
|
||||
@api.expect(
|
||||
api.model(
|
||||
"ShareAgentModel",
|
||||
{
|
||||
"id": fields.String(required=True, description="ID of the agent"),
|
||||
"shared": fields.Boolean(
|
||||
required=True, description="Share or unshare the agent"
|
||||
),
|
||||
"username": fields.String(
|
||||
required=False, description="Name of the user"
|
||||
),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.doc(description="Share or unshare an agent")
|
||||
def put(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
|
||||
data = request.get_json()
|
||||
if not data:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Missing JSON body"}), 400
|
||||
)
|
||||
agent_id = data.get("id")
|
||||
shared = data.get("shared")
|
||||
username = data.get("username", "")
|
||||
|
||||
if not agent_id:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "ID is required"}), 400
|
||||
)
|
||||
if shared is None:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": "Shared parameter is required and must be true or false",
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
try:
|
||||
try:
|
||||
agent_oid = ObjectId(agent_id)
|
||||
except Exception:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Invalid agent ID"}), 400
|
||||
)
|
||||
agent = agents_collection.find_one({"_id": agent_oid, "user": user})
|
||||
if not agent:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Agent not found"}), 404
|
||||
)
|
||||
if shared:
|
||||
shared_metadata = {
|
||||
"shared_by": username,
|
||||
"shared_at": datetime.datetime.now(datetime.timezone.utc),
|
||||
}
|
||||
shared_token = secrets.token_urlsafe(32)
|
||||
agents_collection.update_one(
|
||||
{"_id": agent_oid, "user": user},
|
||||
{
|
||||
"$set": {
|
||||
"shared_publicly": shared,
|
||||
"shared_metadata": shared_metadata,
|
||||
"shared_token": shared_token,
|
||||
}
|
||||
},
|
||||
)
|
||||
else:
|
||||
agents_collection.update_one(
|
||||
{"_id": agent_oid, "user": user},
|
||||
{"$set": {"shared_publicly": shared, "shared_token": None}},
|
||||
{"$unset": {"shared_metadata": ""}},
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error sharing/unsharing agent: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False, "error": "Failed to update agent sharing status"}), 400)
|
||||
shared_token = shared_token if shared else None
|
||||
return make_response(
|
||||
jsonify({"success": True, "shared_token": shared_token}), 200
|
||||
)
|
||||
@@ -1,119 +0,0 @@
|
||||
"""Agent management webhook handlers."""
|
||||
|
||||
import secrets
|
||||
|
||||
from bson.objectid import ObjectId
|
||||
from flask import current_app, jsonify, make_response, request
|
||||
from flask_restx import Namespace, Resource
|
||||
|
||||
from application.api import api
|
||||
from application.api.user.base import agents_collection, require_agent
|
||||
from application.api.user.tasks import process_agent_webhook
|
||||
from application.core.settings import settings
|
||||
|
||||
|
||||
agents_webhooks_ns = Namespace(
|
||||
"agents", description="Agent management operations", path="/api"
|
||||
)
|
||||
|
||||
|
||||
@agents_webhooks_ns.route("/agent_webhook")
|
||||
class AgentWebhook(Resource):
|
||||
@api.doc(
|
||||
params={"id": "ID of the agent"},
|
||||
description="Generate webhook URL for the agent",
|
||||
)
|
||||
def get(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
agent_id = request.args.get("id")
|
||||
if not agent_id:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "ID is required"}), 400
|
||||
)
|
||||
try:
|
||||
agent = agents_collection.find_one(
|
||||
{"_id": ObjectId(agent_id), "user": user}
|
||||
)
|
||||
if not agent:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Agent not found"}), 404
|
||||
)
|
||||
webhook_token = agent.get("incoming_webhook_token")
|
||||
if not webhook_token:
|
||||
webhook_token = secrets.token_urlsafe(32)
|
||||
agents_collection.update_one(
|
||||
{"_id": ObjectId(agent_id), "user": user},
|
||||
{"$set": {"incoming_webhook_token": webhook_token}},
|
||||
)
|
||||
base_url = settings.API_URL.rstrip("/")
|
||||
full_webhook_url = f"{base_url}/api/webhooks/agents/{webhook_token}"
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error generating webhook URL: {err}", exc_info=True
|
||||
)
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Error generating webhook URL"}),
|
||||
400,
|
||||
)
|
||||
return make_response(
|
||||
jsonify({"success": True, "webhook_url": full_webhook_url}), 200
|
||||
)
|
||||
|
||||
|
||||
@agents_webhooks_ns.route("/webhooks/agents/<string:webhook_token>")
|
||||
class AgentWebhookListener(Resource):
|
||||
method_decorators = [require_agent]
|
||||
|
||||
def _enqueue_webhook_task(self, agent_id_str, payload, source_method):
|
||||
if not payload:
|
||||
current_app.logger.warning(
|
||||
f"Webhook ({source_method}) received for agent {agent_id_str} with empty payload."
|
||||
)
|
||||
current_app.logger.info(
|
||||
f"Incoming {source_method} webhook for agent {agent_id_str}. Enqueuing task with payload: {payload}"
|
||||
)
|
||||
|
||||
try:
|
||||
task = process_agent_webhook.delay(
|
||||
agent_id=agent_id_str,
|
||||
payload=payload,
|
||||
)
|
||||
current_app.logger.info(
|
||||
f"Task {task.id} enqueued for agent {agent_id_str} ({source_method})."
|
||||
)
|
||||
return make_response(jsonify({"success": True, "task_id": task.id}), 200)
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error enqueuing webhook task ({source_method}) for agent {agent_id_str}: {err}",
|
||||
exc_info=True,
|
||||
)
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Error processing webhook"}), 500
|
||||
)
|
||||
|
||||
@api.doc(
|
||||
description="Webhook listener for agent events (POST). Expects JSON payload, which is used to trigger processing.",
|
||||
)
|
||||
def post(self, webhook_token, agent, agent_id_str):
|
||||
payload = request.get_json()
|
||||
if payload is None:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": "Invalid or missing JSON data in request body",
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
return self._enqueue_webhook_task(agent_id_str, payload, source_method="POST")
|
||||
|
||||
@api.doc(
|
||||
description="Webhook listener for agent events (GET). Uses URL query parameters as payload to trigger processing.",
|
||||
)
|
||||
def get(self, webhook_token, agent, agent_id_str):
|
||||
payload = request.args.to_dict(flat=True)
|
||||
return self._enqueue_webhook_task(agent_id_str, payload, source_method="GET")
|
||||
@@ -1,5 +0,0 @@
|
||||
"""Analytics module."""
|
||||
|
||||
from .routes import analytics_ns
|
||||
|
||||
__all__ = ["analytics_ns"]
|
||||
@@ -1,540 +0,0 @@
|
||||
"""Analytics and reporting routes."""
|
||||
|
||||
import datetime
|
||||
|
||||
from bson.objectid import ObjectId
|
||||
from flask import current_app, jsonify, make_response, request
|
||||
from flask_restx import fields, Namespace, Resource
|
||||
|
||||
from application.api import api
|
||||
from application.api.user.base import (
|
||||
agents_collection,
|
||||
conversations_collection,
|
||||
generate_date_range,
|
||||
generate_hourly_range,
|
||||
generate_minute_range,
|
||||
token_usage_collection,
|
||||
user_logs_collection,
|
||||
)
|
||||
|
||||
analytics_ns = Namespace(
|
||||
"analytics", description="Analytics and reporting operations", path="/api"
|
||||
)
|
||||
|
||||
|
||||
@analytics_ns.route("/get_message_analytics")
|
||||
class GetMessageAnalytics(Resource):
|
||||
get_message_analytics_model = api.model(
|
||||
"GetMessageAnalyticsModel",
|
||||
{
|
||||
"api_key_id": fields.String(required=False, description="API Key ID"),
|
||||
"filter_option": fields.String(
|
||||
required=False,
|
||||
description="Filter option for analytics",
|
||||
default="last_30_days",
|
||||
enum=[
|
||||
"last_hour",
|
||||
"last_24_hour",
|
||||
"last_7_days",
|
||||
"last_15_days",
|
||||
"last_30_days",
|
||||
],
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
@api.expect(get_message_analytics_model)
|
||||
@api.doc(description="Get message analytics based on filter option")
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
data = request.get_json()
|
||||
api_key_id = data.get("api_key_id")
|
||||
filter_option = data.get("filter_option", "last_30_days")
|
||||
|
||||
try:
|
||||
api_key = (
|
||||
agents_collection.find_one({"_id": ObjectId(api_key_id), "user": user})[
|
||||
"key"
|
||||
]
|
||||
if api_key_id
|
||||
else None
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error getting API key: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
end_date = datetime.datetime.now(datetime.timezone.utc)
|
||||
|
||||
if filter_option == "last_hour":
|
||||
start_date = end_date - datetime.timedelta(hours=1)
|
||||
group_format = "%Y-%m-%d %H:%M:00"
|
||||
elif filter_option == "last_24_hour":
|
||||
start_date = end_date - datetime.timedelta(hours=24)
|
||||
group_format = "%Y-%m-%d %H:00"
|
||||
else:
|
||||
if filter_option in ["last_7_days", "last_15_days", "last_30_days"]:
|
||||
filter_days = (
|
||||
6
|
||||
if filter_option == "last_7_days"
|
||||
else 14 if filter_option == "last_15_days" else 29
|
||||
)
|
||||
else:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Invalid option"}), 400
|
||||
)
|
||||
start_date = end_date - datetime.timedelta(days=filter_days)
|
||||
start_date = start_date.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
end_date = end_date.replace(
|
||||
hour=23, minute=59, second=59, microsecond=999999
|
||||
)
|
||||
group_format = "%Y-%m-%d"
|
||||
try:
|
||||
match_stage = {
|
||||
"$match": {
|
||||
"user": user,
|
||||
}
|
||||
}
|
||||
if api_key:
|
||||
match_stage["$match"]["api_key"] = api_key
|
||||
pipeline = [
|
||||
match_stage,
|
||||
{"$unwind": "$queries"},
|
||||
{
|
||||
"$match": {
|
||||
"queries.timestamp": {"$gte": start_date, "$lte": end_date}
|
||||
}
|
||||
},
|
||||
{
|
||||
"$group": {
|
||||
"_id": {
|
||||
"$dateToString": {
|
||||
"format": group_format,
|
||||
"date": "$queries.timestamp",
|
||||
}
|
||||
},
|
||||
"count": {"$sum": 1},
|
||||
}
|
||||
},
|
||||
{"$sort": {"_id": 1}},
|
||||
]
|
||||
|
||||
message_data = conversations_collection.aggregate(pipeline)
|
||||
|
||||
if filter_option == "last_hour":
|
||||
intervals = generate_minute_range(start_date, end_date)
|
||||
elif filter_option == "last_24_hour":
|
||||
intervals = generate_hourly_range(start_date, end_date)
|
||||
else:
|
||||
intervals = generate_date_range(start_date, end_date)
|
||||
daily_messages = {interval: 0 for interval in intervals}
|
||||
|
||||
for entry in message_data:
|
||||
daily_messages[entry["_id"]] = entry["count"]
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error getting message analytics: {err}", exc_info=True
|
||||
)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(
|
||||
jsonify({"success": True, "messages": daily_messages}), 200
|
||||
)
|
||||
|
||||
|
||||
@analytics_ns.route("/get_token_analytics")
|
||||
class GetTokenAnalytics(Resource):
|
||||
get_token_analytics_model = api.model(
|
||||
"GetTokenAnalyticsModel",
|
||||
{
|
||||
"api_key_id": fields.String(required=False, description="API Key ID"),
|
||||
"filter_option": fields.String(
|
||||
required=False,
|
||||
description="Filter option for analytics",
|
||||
default="last_30_days",
|
||||
enum=[
|
||||
"last_hour",
|
||||
"last_24_hour",
|
||||
"last_7_days",
|
||||
"last_15_days",
|
||||
"last_30_days",
|
||||
],
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
@api.expect(get_token_analytics_model)
|
||||
@api.doc(description="Get token analytics data")
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
data = request.get_json()
|
||||
api_key_id = data.get("api_key_id")
|
||||
filter_option = data.get("filter_option", "last_30_days")
|
||||
|
||||
try:
|
||||
api_key = (
|
||||
agents_collection.find_one({"_id": ObjectId(api_key_id), "user": user})[
|
||||
"key"
|
||||
]
|
||||
if api_key_id
|
||||
else None
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error getting API key: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
end_date = datetime.datetime.now(datetime.timezone.utc)
|
||||
|
||||
if filter_option == "last_hour":
|
||||
start_date = end_date - datetime.timedelta(hours=1)
|
||||
group_format = "%Y-%m-%d %H:%M:00"
|
||||
group_stage = {
|
||||
"$group": {
|
||||
"_id": {
|
||||
"minute": {
|
||||
"$dateToString": {
|
||||
"format": group_format,
|
||||
"date": "$timestamp",
|
||||
}
|
||||
}
|
||||
},
|
||||
"total_tokens": {
|
||||
"$sum": {"$add": ["$prompt_tokens", "$generated_tokens"]}
|
||||
},
|
||||
}
|
||||
}
|
||||
elif filter_option == "last_24_hour":
|
||||
start_date = end_date - datetime.timedelta(hours=24)
|
||||
group_format = "%Y-%m-%d %H:00"
|
||||
group_stage = {
|
||||
"$group": {
|
||||
"_id": {
|
||||
"hour": {
|
||||
"$dateToString": {
|
||||
"format": group_format,
|
||||
"date": "$timestamp",
|
||||
}
|
||||
}
|
||||
},
|
||||
"total_tokens": {
|
||||
"$sum": {"$add": ["$prompt_tokens", "$generated_tokens"]}
|
||||
},
|
||||
}
|
||||
}
|
||||
else:
|
||||
if filter_option in ["last_7_days", "last_15_days", "last_30_days"]:
|
||||
filter_days = (
|
||||
6
|
||||
if filter_option == "last_7_days"
|
||||
else (14 if filter_option == "last_15_days" else 29)
|
||||
)
|
||||
else:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Invalid option"}), 400
|
||||
)
|
||||
start_date = end_date - datetime.timedelta(days=filter_days)
|
||||
start_date = start_date.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
end_date = end_date.replace(
|
||||
hour=23, minute=59, second=59, microsecond=999999
|
||||
)
|
||||
group_format = "%Y-%m-%d"
|
||||
group_stage = {
|
||||
"$group": {
|
||||
"_id": {
|
||||
"day": {
|
||||
"$dateToString": {
|
||||
"format": group_format,
|
||||
"date": "$timestamp",
|
||||
}
|
||||
}
|
||||
},
|
||||
"total_tokens": {
|
||||
"$sum": {"$add": ["$prompt_tokens", "$generated_tokens"]}
|
||||
},
|
||||
}
|
||||
}
|
||||
try:
|
||||
match_stage = {
|
||||
"$match": {
|
||||
"user_id": user,
|
||||
"timestamp": {"$gte": start_date, "$lte": end_date},
|
||||
}
|
||||
}
|
||||
if api_key:
|
||||
match_stage["$match"]["api_key"] = api_key
|
||||
token_usage_data = token_usage_collection.aggregate(
|
||||
[
|
||||
match_stage,
|
||||
group_stage,
|
||||
{"$sort": {"_id": 1}},
|
||||
]
|
||||
)
|
||||
|
||||
if filter_option == "last_hour":
|
||||
intervals = generate_minute_range(start_date, end_date)
|
||||
elif filter_option == "last_24_hour":
|
||||
intervals = generate_hourly_range(start_date, end_date)
|
||||
else:
|
||||
intervals = generate_date_range(start_date, end_date)
|
||||
daily_token_usage = {interval: 0 for interval in intervals}
|
||||
|
||||
for entry in token_usage_data:
|
||||
if filter_option == "last_hour":
|
||||
daily_token_usage[entry["_id"]["minute"]] = entry["total_tokens"]
|
||||
elif filter_option == "last_24_hour":
|
||||
daily_token_usage[entry["_id"]["hour"]] = entry["total_tokens"]
|
||||
else:
|
||||
daily_token_usage[entry["_id"]["day"]] = entry["total_tokens"]
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error getting token analytics: {err}", exc_info=True
|
||||
)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(
|
||||
jsonify({"success": True, "token_usage": daily_token_usage}), 200
|
||||
)
|
||||
|
||||
|
||||
@analytics_ns.route("/get_feedback_analytics")
|
||||
class GetFeedbackAnalytics(Resource):
|
||||
get_feedback_analytics_model = api.model(
|
||||
"GetFeedbackAnalyticsModel",
|
||||
{
|
||||
"api_key_id": fields.String(required=False, description="API Key ID"),
|
||||
"filter_option": fields.String(
|
||||
required=False,
|
||||
description="Filter option for analytics",
|
||||
default="last_30_days",
|
||||
enum=[
|
||||
"last_hour",
|
||||
"last_24_hour",
|
||||
"last_7_days",
|
||||
"last_15_days",
|
||||
"last_30_days",
|
||||
],
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
@api.expect(get_feedback_analytics_model)
|
||||
@api.doc(description="Get feedback analytics data")
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
data = request.get_json()
|
||||
api_key_id = data.get("api_key_id")
|
||||
filter_option = data.get("filter_option", "last_30_days")
|
||||
|
||||
try:
|
||||
api_key = (
|
||||
agents_collection.find_one({"_id": ObjectId(api_key_id), "user": user})[
|
||||
"key"
|
||||
]
|
||||
if api_key_id
|
||||
else None
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error getting API key: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
end_date = datetime.datetime.now(datetime.timezone.utc)
|
||||
|
||||
if filter_option == "last_hour":
|
||||
start_date = end_date - datetime.timedelta(hours=1)
|
||||
group_format = "%Y-%m-%d %H:%M:00"
|
||||
date_field = {
|
||||
"$dateToString": {
|
||||
"format": group_format,
|
||||
"date": "$queries.feedback_timestamp",
|
||||
}
|
||||
}
|
||||
elif filter_option == "last_24_hour":
|
||||
start_date = end_date - datetime.timedelta(hours=24)
|
||||
group_format = "%Y-%m-%d %H:00"
|
||||
date_field = {
|
||||
"$dateToString": {
|
||||
"format": group_format,
|
||||
"date": "$queries.feedback_timestamp",
|
||||
}
|
||||
}
|
||||
else:
|
||||
if filter_option in ["last_7_days", "last_15_days", "last_30_days"]:
|
||||
filter_days = (
|
||||
6
|
||||
if filter_option == "last_7_days"
|
||||
else (14 if filter_option == "last_15_days" else 29)
|
||||
)
|
||||
else:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Invalid option"}), 400
|
||||
)
|
||||
start_date = end_date - datetime.timedelta(days=filter_days)
|
||||
start_date = start_date.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
end_date = end_date.replace(
|
||||
hour=23, minute=59, second=59, microsecond=999999
|
||||
)
|
||||
group_format = "%Y-%m-%d"
|
||||
date_field = {
|
||||
"$dateToString": {
|
||||
"format": group_format,
|
||||
"date": "$queries.feedback_timestamp",
|
||||
}
|
||||
}
|
||||
try:
|
||||
match_stage = {
|
||||
"$match": {
|
||||
"queries.feedback_timestamp": {
|
||||
"$gte": start_date,
|
||||
"$lte": end_date,
|
||||
},
|
||||
"queries.feedback": {"$exists": True},
|
||||
}
|
||||
}
|
||||
if api_key:
|
||||
match_stage["$match"]["api_key"] = api_key
|
||||
pipeline = [
|
||||
match_stage,
|
||||
{"$unwind": "$queries"},
|
||||
{"$match": {"queries.feedback": {"$exists": True}}},
|
||||
{
|
||||
"$group": {
|
||||
"_id": {"time": date_field, "feedback": "$queries.feedback"},
|
||||
"count": {"$sum": 1},
|
||||
}
|
||||
},
|
||||
{
|
||||
"$group": {
|
||||
"_id": "$_id.time",
|
||||
"positive": {
|
||||
"$sum": {
|
||||
"$cond": [
|
||||
{"$eq": ["$_id.feedback", "LIKE"]},
|
||||
"$count",
|
||||
0,
|
||||
]
|
||||
}
|
||||
},
|
||||
"negative": {
|
||||
"$sum": {
|
||||
"$cond": [
|
||||
{"$eq": ["$_id.feedback", "DISLIKE"]},
|
||||
"$count",
|
||||
0,
|
||||
]
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
{"$sort": {"_id": 1}},
|
||||
]
|
||||
|
||||
feedback_data = conversations_collection.aggregate(pipeline)
|
||||
|
||||
if filter_option == "last_hour":
|
||||
intervals = generate_minute_range(start_date, end_date)
|
||||
elif filter_option == "last_24_hour":
|
||||
intervals = generate_hourly_range(start_date, end_date)
|
||||
else:
|
||||
intervals = generate_date_range(start_date, end_date)
|
||||
daily_feedback = {
|
||||
interval: {"positive": 0, "negative": 0} for interval in intervals
|
||||
}
|
||||
|
||||
for entry in feedback_data:
|
||||
daily_feedback[entry["_id"]] = {
|
||||
"positive": entry["positive"],
|
||||
"negative": entry["negative"],
|
||||
}
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error getting feedback analytics: {err}", exc_info=True
|
||||
)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(
|
||||
jsonify({"success": True, "feedback": daily_feedback}), 200
|
||||
)
|
||||
|
||||
|
||||
@analytics_ns.route("/get_user_logs")
|
||||
class GetUserLogs(Resource):
|
||||
get_user_logs_model = api.model(
|
||||
"GetUserLogsModel",
|
||||
{
|
||||
"page": fields.Integer(
|
||||
required=False,
|
||||
description="Page number for pagination",
|
||||
default=1,
|
||||
),
|
||||
"api_key_id": fields.String(required=False, description="API Key ID"),
|
||||
"page_size": fields.Integer(
|
||||
required=False,
|
||||
description="Number of logs per page",
|
||||
default=10,
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
@api.expect(get_user_logs_model)
|
||||
@api.doc(description="Get user logs with pagination")
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
data = request.get_json()
|
||||
page = int(data.get("page", 1))
|
||||
api_key_id = data.get("api_key_id")
|
||||
page_size = int(data.get("page_size", 10))
|
||||
skip = (page - 1) * page_size
|
||||
|
||||
try:
|
||||
api_key = (
|
||||
agents_collection.find_one({"_id": ObjectId(api_key_id)})["key"]
|
||||
if api_key_id
|
||||
else None
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error getting API key: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
query = {"user": user}
|
||||
if api_key:
|
||||
query = {"api_key": api_key}
|
||||
items_cursor = (
|
||||
user_logs_collection.find(query)
|
||||
.sort("timestamp", -1)
|
||||
.skip(skip)
|
||||
.limit(page_size + 1)
|
||||
)
|
||||
items = list(items_cursor)
|
||||
|
||||
results = [
|
||||
{
|
||||
"id": str(item.get("_id")),
|
||||
"action": item.get("action"),
|
||||
"level": item.get("level"),
|
||||
"user": item.get("user"),
|
||||
"question": item.get("question"),
|
||||
"sources": item.get("sources"),
|
||||
"retriever_params": item.get("retriever_params"),
|
||||
"timestamp": item.get("timestamp"),
|
||||
}
|
||||
for item in items[:page_size]
|
||||
]
|
||||
|
||||
has_more = len(items) > page_size
|
||||
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": True,
|
||||
"logs": results,
|
||||
"page": page,
|
||||
"page_size": page_size,
|
||||
"has_more": has_more,
|
||||
}
|
||||
),
|
||||
200,
|
||||
)
|
||||
@@ -1,5 +0,0 @@
|
||||
"""Attachments module."""
|
||||
|
||||
from .routes import attachments_ns
|
||||
|
||||
__all__ = ["attachments_ns"]
|
||||
@@ -1,670 +0,0 @@
|
||||
"""File attachments and media routes."""
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
from bson.objectid import ObjectId
|
||||
from flask import current_app, jsonify, make_response, request
|
||||
from flask_restx import fields, Namespace, Resource
|
||||
|
||||
from application.api import api
|
||||
from application.cache import get_redis_instance
|
||||
from application.core.settings import settings
|
||||
from application.stt.constants import (
|
||||
SUPPORTED_AUDIO_EXTENSIONS,
|
||||
SUPPORTED_AUDIO_MIME_TYPES,
|
||||
)
|
||||
from application.stt.upload_limits import (
|
||||
AudioFileTooLargeError,
|
||||
build_stt_file_size_limit_message,
|
||||
enforce_audio_file_size_limit,
|
||||
is_audio_filename,
|
||||
)
|
||||
from application.stt.live_session import (
|
||||
apply_live_stt_hypothesis,
|
||||
create_live_stt_session,
|
||||
delete_live_stt_session,
|
||||
finalize_live_stt_session,
|
||||
get_live_stt_transcript_text,
|
||||
load_live_stt_session,
|
||||
save_live_stt_session,
|
||||
)
|
||||
from application.stt.stt_creator import STTCreator
|
||||
from application.tts.tts_creator import TTSCreator
|
||||
from application.utils import safe_filename
|
||||
|
||||
|
||||
attachments_ns = Namespace(
|
||||
"attachments", description="File attachments and media operations", path="/api"
|
||||
)
|
||||
|
||||
|
||||
def _resolve_authenticated_user():
|
||||
decoded_token = getattr(request, "decoded_token", None)
|
||||
api_key = request.form.get("api_key") or request.args.get("api_key")
|
||||
|
||||
if decoded_token:
|
||||
return safe_filename(decoded_token.get("sub"))
|
||||
|
||||
if api_key:
|
||||
from application.api.user.base import agents_collection
|
||||
|
||||
agent = agents_collection.find_one({"key": api_key})
|
||||
if not agent:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Invalid API key"}), 401
|
||||
)
|
||||
return safe_filename(agent.get("user"))
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _get_uploaded_file_size(file) -> int:
|
||||
try:
|
||||
current_position = file.stream.tell()
|
||||
file.stream.seek(0, os.SEEK_END)
|
||||
size_bytes = file.stream.tell()
|
||||
file.stream.seek(current_position)
|
||||
return size_bytes
|
||||
except Exception:
|
||||
return 0
|
||||
|
||||
|
||||
def _is_supported_audio_mimetype(mimetype: str) -> bool:
|
||||
if not mimetype:
|
||||
return True
|
||||
normalized = mimetype.split(";")[0].strip().lower()
|
||||
return normalized.startswith("audio/") or normalized in SUPPORTED_AUDIO_MIME_TYPES
|
||||
|
||||
|
||||
def _enforce_uploaded_audio_size_limit(file, filename: str) -> None:
|
||||
if not is_audio_filename(filename):
|
||||
return
|
||||
size_bytes = _get_uploaded_file_size(file)
|
||||
if size_bytes:
|
||||
enforce_audio_file_size_limit(size_bytes)
|
||||
|
||||
|
||||
def _get_store_attachment_user_error(exc: Exception) -> str:
|
||||
if isinstance(exc, AudioFileTooLargeError):
|
||||
return build_stt_file_size_limit_message()
|
||||
return "Failed to process file"
|
||||
|
||||
|
||||
def _require_live_stt_redis():
|
||||
redis_client = get_redis_instance()
|
||||
if redis_client:
|
||||
return redis_client
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Live transcription is unavailable"}),
|
||||
503,
|
||||
)
|
||||
|
||||
|
||||
def _parse_bool_form_value(value: str | None) -> bool:
|
||||
if value is None:
|
||||
return False
|
||||
return value.strip().lower() in {"1", "true", "yes", "on"}
|
||||
|
||||
|
||||
@attachments_ns.route("/store_attachment")
|
||||
class StoreAttachment(Resource):
|
||||
@api.expect(
|
||||
api.model(
|
||||
"AttachmentModel",
|
||||
{
|
||||
"file": fields.Raw(required=True, description="File(s) to upload"),
|
||||
"api_key": fields.String(
|
||||
required=False, description="API key (optional)"
|
||||
),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.doc(
|
||||
description="Stores one or multiple attachments without vectorization or training. Supports user or API key authentication."
|
||||
)
|
||||
def post(self):
|
||||
auth_user = _resolve_authenticated_user()
|
||||
if hasattr(auth_user, "status_code"):
|
||||
return auth_user
|
||||
|
||||
files = request.files.getlist("file")
|
||||
if not files:
|
||||
single_file = request.files.get("file")
|
||||
if single_file:
|
||||
files = [single_file]
|
||||
|
||||
if not files or all(f.filename == "" for f in files):
|
||||
return make_response(
|
||||
jsonify({"status": "error", "message": "Missing file(s)"}),
|
||||
400,
|
||||
)
|
||||
|
||||
user = auth_user
|
||||
if not user:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Authentication required"}), 401
|
||||
)
|
||||
|
||||
try:
|
||||
from application.api.user.tasks import store_attachment
|
||||
from application.api.user.base import storage
|
||||
|
||||
tasks = []
|
||||
errors = []
|
||||
original_file_count = len(files)
|
||||
|
||||
for idx, file in enumerate(files):
|
||||
try:
|
||||
attachment_id = ObjectId()
|
||||
original_filename = safe_filename(os.path.basename(file.filename))
|
||||
_enforce_uploaded_audio_size_limit(file, original_filename)
|
||||
relative_path = f"{settings.UPLOAD_FOLDER}/{user}/attachments/{str(attachment_id)}/{original_filename}"
|
||||
|
||||
metadata = storage.save_file(file, relative_path)
|
||||
file_info = {
|
||||
"filename": original_filename,
|
||||
"attachment_id": str(attachment_id),
|
||||
"path": relative_path,
|
||||
"metadata": metadata,
|
||||
}
|
||||
|
||||
task = store_attachment.delay(file_info, user)
|
||||
tasks.append({
|
||||
"task_id": task.id,
|
||||
"filename": original_filename,
|
||||
"attachment_id": str(attachment_id),
|
||||
"upload_index": idx,
|
||||
})
|
||||
except Exception as file_err:
|
||||
current_app.logger.error(f"Error processing file {idx} ({file.filename}): {file_err}", exc_info=True)
|
||||
errors.append({
|
||||
"upload_index": idx,
|
||||
"filename": file.filename,
|
||||
"error": _get_store_attachment_user_error(file_err),
|
||||
})
|
||||
|
||||
if not tasks:
|
||||
if errors and all(
|
||||
error.get("error") == build_stt_file_size_limit_message()
|
||||
for error in errors
|
||||
):
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": build_stt_file_size_limit_message(),
|
||||
"errors": errors,
|
||||
}
|
||||
),
|
||||
413,
|
||||
)
|
||||
return make_response(
|
||||
jsonify({"status": "error", "message": "No valid files to upload"}),
|
||||
400,
|
||||
)
|
||||
|
||||
if original_file_count == 1 and len(tasks) == 1:
|
||||
current_app.logger.info("Returning single task_id response")
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": True,
|
||||
"task_id": tasks[0]["task_id"],
|
||||
"message": "File uploaded successfully. Processing started.",
|
||||
}
|
||||
),
|
||||
200,
|
||||
)
|
||||
else:
|
||||
response_data = {
|
||||
"success": True,
|
||||
"tasks": tasks,
|
||||
"message": f"{len(tasks)} file(s) uploaded successfully. Processing started.",
|
||||
}
|
||||
if errors:
|
||||
response_data["errors"] = errors
|
||||
response_data["message"] += f" {len(errors)} file(s) failed."
|
||||
|
||||
return make_response(
|
||||
jsonify(response_data),
|
||||
200,
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error storing attachment: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False, "error": "Failed to store attachment"}), 400)
|
||||
|
||||
|
||||
@attachments_ns.route("/stt")
|
||||
class SpeechToText(Resource):
|
||||
@api.expect(
|
||||
api.model(
|
||||
"SpeechToTextModel",
|
||||
{
|
||||
"file": fields.Raw(required=True, description="Audio file"),
|
||||
"language": fields.String(
|
||||
required=False, description="Optional transcription language hint"
|
||||
),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.doc(description="Transcribe an uploaded audio file")
|
||||
def post(self):
|
||||
auth_user = _resolve_authenticated_user()
|
||||
if hasattr(auth_user, "status_code"):
|
||||
return auth_user
|
||||
if not auth_user:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Authentication required"}),
|
||||
401,
|
||||
)
|
||||
|
||||
file = request.files.get("file")
|
||||
if not file or file.filename == "":
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Missing file"}),
|
||||
400,
|
||||
)
|
||||
|
||||
filename = safe_filename(os.path.basename(file.filename))
|
||||
suffix = Path(filename).suffix.lower()
|
||||
if suffix not in SUPPORTED_AUDIO_EXTENSIONS:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Unsupported audio format"}),
|
||||
400,
|
||||
)
|
||||
|
||||
if not _is_supported_audio_mimetype(file.mimetype or ""):
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Unsupported audio MIME type"}),
|
||||
400,
|
||||
)
|
||||
|
||||
try:
|
||||
_enforce_uploaded_audio_size_limit(file, filename)
|
||||
except AudioFileTooLargeError:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": build_stt_file_size_limit_message(),
|
||||
}
|
||||
),
|
||||
413,
|
||||
)
|
||||
|
||||
temp_path = None
|
||||
try:
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as temp_file:
|
||||
file.save(temp_file.name)
|
||||
temp_path = Path(temp_file.name)
|
||||
|
||||
stt_instance = STTCreator.create_stt(settings.STT_PROVIDER)
|
||||
transcript = stt_instance.transcribe(
|
||||
temp_path,
|
||||
language=request.form.get("language") or settings.STT_LANGUAGE,
|
||||
timestamps=settings.STT_ENABLE_TIMESTAMPS,
|
||||
diarize=settings.STT_ENABLE_DIARIZATION,
|
||||
)
|
||||
return make_response(jsonify({"success": True, **transcript}), 200)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error transcribing audio: {err}", exc_info=True)
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Failed to transcribe audio"}),
|
||||
400,
|
||||
)
|
||||
finally:
|
||||
if temp_path and temp_path.exists():
|
||||
temp_path.unlink()
|
||||
|
||||
|
||||
@attachments_ns.route("/stt/live/start")
|
||||
class LiveSpeechToTextStart(Resource):
|
||||
@api.doc(description="Start a live speech-to-text session")
|
||||
def post(self):
|
||||
auth_user = _resolve_authenticated_user()
|
||||
if hasattr(auth_user, "status_code"):
|
||||
return auth_user
|
||||
if not auth_user:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Authentication required"}),
|
||||
401,
|
||||
)
|
||||
|
||||
redis_client = _require_live_stt_redis()
|
||||
if hasattr(redis_client, "status_code"):
|
||||
return redis_client
|
||||
|
||||
payload = request.get_json(silent=True) or {}
|
||||
session_state = create_live_stt_session(
|
||||
user=auth_user,
|
||||
language=payload.get("language") or settings.STT_LANGUAGE,
|
||||
)
|
||||
save_live_stt_session(redis_client, session_state)
|
||||
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": True,
|
||||
"session_id": session_state["session_id"],
|
||||
"language": session_state.get("language"),
|
||||
"committed_text": "",
|
||||
"mutable_text": "",
|
||||
"previous_hypothesis": "",
|
||||
"latest_hypothesis": "",
|
||||
"finalized_text": "",
|
||||
"pending_text": "",
|
||||
"transcript_text": "",
|
||||
}
|
||||
),
|
||||
200,
|
||||
)
|
||||
|
||||
|
||||
@attachments_ns.route("/stt/live/chunk")
|
||||
class LiveSpeechToTextChunk(Resource):
|
||||
@api.expect(
|
||||
api.model(
|
||||
"LiveSpeechToTextChunkModel",
|
||||
{
|
||||
"session_id": fields.String(
|
||||
required=True, description="Live transcription session ID"
|
||||
),
|
||||
"chunk_index": fields.Integer(
|
||||
required=True, description="Sequential chunk index"
|
||||
),
|
||||
"is_silence": fields.Boolean(
|
||||
required=False,
|
||||
description="Whether the latest capture window was mostly silence",
|
||||
),
|
||||
"file": fields.Raw(required=True, description="Audio chunk"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.doc(description="Transcribe a chunk for a live speech-to-text session")
|
||||
def post(self):
|
||||
auth_user = _resolve_authenticated_user()
|
||||
if hasattr(auth_user, "status_code"):
|
||||
return auth_user
|
||||
if not auth_user:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Authentication required"}),
|
||||
401,
|
||||
)
|
||||
|
||||
redis_client = _require_live_stt_redis()
|
||||
if hasattr(redis_client, "status_code"):
|
||||
return redis_client
|
||||
|
||||
session_id = request.form.get("session_id", "").strip()
|
||||
if not session_id:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Missing session_id"}),
|
||||
400,
|
||||
)
|
||||
|
||||
session_state = load_live_stt_session(redis_client, session_id)
|
||||
if not session_state:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": "Live transcription session not found",
|
||||
}
|
||||
),
|
||||
404,
|
||||
)
|
||||
|
||||
if safe_filename(str(session_state.get("user", ""))) != auth_user:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Forbidden"}),
|
||||
403,
|
||||
)
|
||||
|
||||
chunk_index_raw = request.form.get("chunk_index", "").strip()
|
||||
if chunk_index_raw == "":
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Missing chunk_index"}),
|
||||
400,
|
||||
)
|
||||
|
||||
try:
|
||||
chunk_index = int(chunk_index_raw)
|
||||
except ValueError:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Invalid chunk_index"}),
|
||||
400,
|
||||
)
|
||||
is_silence = _parse_bool_form_value(request.form.get("is_silence"))
|
||||
|
||||
file = request.files.get("file")
|
||||
if not file or file.filename == "":
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Missing file"}),
|
||||
400,
|
||||
)
|
||||
|
||||
filename = safe_filename(os.path.basename(file.filename))
|
||||
suffix = Path(filename).suffix.lower()
|
||||
if suffix not in SUPPORTED_AUDIO_EXTENSIONS:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Unsupported audio format"}),
|
||||
400,
|
||||
)
|
||||
|
||||
if not _is_supported_audio_mimetype(file.mimetype or ""):
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Unsupported audio MIME type"}),
|
||||
400,
|
||||
)
|
||||
|
||||
try:
|
||||
_enforce_uploaded_audio_size_limit(file, filename)
|
||||
except AudioFileTooLargeError:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": build_stt_file_size_limit_message(),
|
||||
}
|
||||
),
|
||||
413,
|
||||
)
|
||||
|
||||
temp_path = None
|
||||
try:
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as temp_file:
|
||||
file.save(temp_file.name)
|
||||
temp_path = Path(temp_file.name)
|
||||
|
||||
session_language = session_state.get("language") or settings.STT_LANGUAGE
|
||||
stt_instance = STTCreator.create_stt(settings.STT_PROVIDER)
|
||||
transcript = stt_instance.transcribe(
|
||||
temp_path,
|
||||
language=session_language,
|
||||
timestamps=False,
|
||||
diarize=False,
|
||||
)
|
||||
if not session_state.get("language") and transcript.get("language"):
|
||||
session_state["language"] = transcript["language"]
|
||||
|
||||
try:
|
||||
apply_live_stt_hypothesis(
|
||||
session_state,
|
||||
str(transcript.get("text", "")),
|
||||
chunk_index,
|
||||
is_silence=is_silence,
|
||||
)
|
||||
except ValueError:
|
||||
current_app.logger.warning(
|
||||
"Invalid live transcription chunk",
|
||||
exc_info=True,
|
||||
)
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": "Invalid live transcription chunk",
|
||||
}
|
||||
),
|
||||
409,
|
||||
)
|
||||
save_live_stt_session(redis_client, session_state)
|
||||
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": True,
|
||||
"session_id": session_id,
|
||||
"chunk_index": chunk_index,
|
||||
"chunk_text": transcript.get("text", ""),
|
||||
"is_silence": is_silence,
|
||||
"language": session_state.get("language"),
|
||||
"committed_text": session_state.get("committed_text", ""),
|
||||
"mutable_text": session_state.get("mutable_text", ""),
|
||||
"previous_hypothesis": session_state.get(
|
||||
"previous_hypothesis", ""
|
||||
),
|
||||
"latest_hypothesis": session_state.get(
|
||||
"latest_hypothesis", ""
|
||||
),
|
||||
"finalized_text": session_state.get("committed_text", ""),
|
||||
"pending_text": session_state.get("mutable_text", ""),
|
||||
"transcript_text": get_live_stt_transcript_text(session_state),
|
||||
}
|
||||
),
|
||||
200,
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error transcribing live audio chunk: {err}", exc_info=True
|
||||
)
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Failed to transcribe audio"}),
|
||||
400,
|
||||
)
|
||||
finally:
|
||||
if temp_path and temp_path.exists():
|
||||
temp_path.unlink()
|
||||
|
||||
|
||||
@attachments_ns.route("/stt/live/finish")
|
||||
class LiveSpeechToTextFinish(Resource):
|
||||
@api.doc(description="Finish a live speech-to-text session")
|
||||
def post(self):
|
||||
auth_user = _resolve_authenticated_user()
|
||||
if hasattr(auth_user, "status_code"):
|
||||
return auth_user
|
||||
if not auth_user:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Authentication required"}),
|
||||
401,
|
||||
)
|
||||
|
||||
redis_client = _require_live_stt_redis()
|
||||
if hasattr(redis_client, "status_code"):
|
||||
return redis_client
|
||||
|
||||
payload = request.get_json(silent=True) or {}
|
||||
session_id = str(payload.get("session_id", "")).strip()
|
||||
if not session_id:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Missing session_id"}),
|
||||
400,
|
||||
)
|
||||
|
||||
session_state = load_live_stt_session(redis_client, session_id)
|
||||
if not session_state:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": "Live transcription session not found",
|
||||
}
|
||||
),
|
||||
404,
|
||||
)
|
||||
|
||||
if safe_filename(str(session_state.get("user", ""))) != auth_user:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Forbidden"}),
|
||||
403,
|
||||
)
|
||||
|
||||
final_text = finalize_live_stt_session(session_state)
|
||||
delete_live_stt_session(redis_client, session_id)
|
||||
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": True,
|
||||
"session_id": session_id,
|
||||
"language": session_state.get("language"),
|
||||
"text": final_text,
|
||||
}
|
||||
),
|
||||
200,
|
||||
)
|
||||
|
||||
|
||||
@attachments_ns.route("/images/<path:image_path>")
|
||||
class ServeImage(Resource):
|
||||
@api.doc(description="Serve an image from storage")
|
||||
def get(self, image_path):
|
||||
try:
|
||||
from application.api.user.base import storage
|
||||
|
||||
file_obj = storage.get_file(image_path)
|
||||
extension = image_path.split(".")[-1].lower()
|
||||
content_type = f"image/{extension}"
|
||||
if extension == "jpg":
|
||||
content_type = "image/jpeg"
|
||||
response = make_response(file_obj.read())
|
||||
response.headers.set("Content-Type", content_type)
|
||||
response.headers.set("Cache-Control", "max-age=86400")
|
||||
|
||||
return response
|
||||
except FileNotFoundError:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Image not found"}), 404
|
||||
)
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error serving image: {e}")
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Error retrieving image"}), 500
|
||||
)
|
||||
|
||||
|
||||
@attachments_ns.route("/tts")
|
||||
class TextToSpeech(Resource):
|
||||
tts_model = api.model(
|
||||
"TextToSpeechModel",
|
||||
{
|
||||
"text": fields.String(
|
||||
required=True, description="Text to be synthesized as audio"
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
@api.expect(tts_model)
|
||||
@api.doc(description="Synthesize audio speech from text")
|
||||
def post(self):
|
||||
data = request.get_json()
|
||||
text = data["text"]
|
||||
try:
|
||||
tts_instance = TTSCreator.create_tts(settings.TTS_PROVIDER)
|
||||
audio_base64, detected_language = tts_instance.text_to_speech(text)
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": True,
|
||||
"audio_base64": audio_base64,
|
||||
"lang": detected_language,
|
||||
}
|
||||
),
|
||||
200,
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error synthesizing audio: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
@@ -1,254 +0,0 @@
|
||||
"""
|
||||
Shared utilities, database connections, and helper functions for user API routes.
|
||||
"""
|
||||
|
||||
import datetime
|
||||
import os
|
||||
import uuid
|
||||
from functools import wraps
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from bson.objectid import ObjectId
|
||||
from flask import current_app, jsonify, make_response, Response
|
||||
from pymongo import ReturnDocument
|
||||
from werkzeug.utils import secure_filename
|
||||
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.core.settings import settings
|
||||
from application.storage.storage_creator import StorageCreator
|
||||
from application.vectorstore.vector_creator import VectorCreator
|
||||
|
||||
|
||||
storage = StorageCreator.get_storage()
|
||||
|
||||
|
||||
mongo = MongoDB.get_client()
|
||||
db = mongo[settings.MONGO_DB_NAME]
|
||||
|
||||
|
||||
conversations_collection = db["conversations"]
|
||||
sources_collection = db["sources"]
|
||||
prompts_collection = db["prompts"]
|
||||
feedback_collection = db["feedback"]
|
||||
agents_collection = db["agents"]
|
||||
agent_folders_collection = db["agent_folders"]
|
||||
token_usage_collection = db["token_usage"]
|
||||
shared_conversations_collections = db["shared_conversations"]
|
||||
users_collection = db["users"]
|
||||
user_logs_collection = db["user_logs"]
|
||||
user_tools_collection = db["user_tools"]
|
||||
attachments_collection = db["attachments"]
|
||||
workflow_runs_collection = db["workflow_runs"]
|
||||
workflows_collection = db["workflows"]
|
||||
workflow_nodes_collection = db["workflow_nodes"]
|
||||
workflow_edges_collection = db["workflow_edges"]
|
||||
|
||||
|
||||
try:
|
||||
agents_collection.create_index(
|
||||
[("shared", 1)],
|
||||
name="shared_index",
|
||||
background=True,
|
||||
)
|
||||
users_collection.create_index("user_id", unique=True)
|
||||
workflows_collection.create_index(
|
||||
[("user", 1)], name="workflow_user_index", background=True
|
||||
)
|
||||
workflow_nodes_collection.create_index(
|
||||
[("workflow_id", 1)], name="node_workflow_index", background=True
|
||||
)
|
||||
workflow_nodes_collection.create_index(
|
||||
[("workflow_id", 1), ("graph_version", 1)],
|
||||
name="node_workflow_graph_version_index",
|
||||
background=True,
|
||||
)
|
||||
workflow_edges_collection.create_index(
|
||||
[("workflow_id", 1)], name="edge_workflow_index", background=True
|
||||
)
|
||||
workflow_edges_collection.create_index(
|
||||
[("workflow_id", 1), ("graph_version", 1)],
|
||||
name="edge_workflow_graph_version_index",
|
||||
background=True,
|
||||
)
|
||||
except Exception as e:
|
||||
print("Error creating indexes:", e)
|
||||
current_dir = os.path.dirname(
|
||||
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
)
|
||||
|
||||
|
||||
def generate_minute_range(start_date, end_date):
|
||||
"""Generate a dictionary with minute-level time ranges."""
|
||||
return {
|
||||
(start_date + datetime.timedelta(minutes=i)).strftime("%Y-%m-%d %H:%M:00"): 0
|
||||
for i in range(int((end_date - start_date).total_seconds() // 60) + 1)
|
||||
}
|
||||
|
||||
|
||||
def generate_hourly_range(start_date, end_date):
|
||||
"""Generate a dictionary with hourly time ranges."""
|
||||
return {
|
||||
(start_date + datetime.timedelta(hours=i)).strftime("%Y-%m-%d %H:00"): 0
|
||||
for i in range(int((end_date - start_date).total_seconds() // 3600) + 1)
|
||||
}
|
||||
|
||||
|
||||
def generate_date_range(start_date, end_date):
|
||||
"""Generate a dictionary with daily date ranges."""
|
||||
return {
|
||||
(start_date + datetime.timedelta(days=i)).strftime("%Y-%m-%d"): 0
|
||||
for i in range((end_date - start_date).days + 1)
|
||||
}
|
||||
|
||||
|
||||
def ensure_user_doc(user_id):
|
||||
"""
|
||||
Ensure user document exists with proper agent preferences structure.
|
||||
|
||||
Args:
|
||||
user_id: The user ID to ensure
|
||||
|
||||
Returns:
|
||||
The user document
|
||||
"""
|
||||
default_prefs = {
|
||||
"pinned": [],
|
||||
"shared_with_me": [],
|
||||
}
|
||||
|
||||
user_doc = users_collection.find_one_and_update(
|
||||
{"user_id": user_id},
|
||||
{"$setOnInsert": {"agent_preferences": default_prefs}},
|
||||
upsert=True,
|
||||
return_document=ReturnDocument.AFTER,
|
||||
)
|
||||
|
||||
prefs = user_doc.get("agent_preferences", {})
|
||||
updates = {}
|
||||
if "pinned" not in prefs:
|
||||
updates["agent_preferences.pinned"] = []
|
||||
if "shared_with_me" not in prefs:
|
||||
updates["agent_preferences.shared_with_me"] = []
|
||||
if updates:
|
||||
users_collection.update_one({"user_id": user_id}, {"$set": updates})
|
||||
user_doc = users_collection.find_one({"user_id": user_id})
|
||||
return user_doc
|
||||
|
||||
|
||||
def resolve_tool_details(tool_ids):
|
||||
"""
|
||||
Resolve tool IDs to their details.
|
||||
|
||||
Args:
|
||||
tool_ids: List of tool IDs
|
||||
|
||||
Returns:
|
||||
List of tool details with id, name, and display_name
|
||||
"""
|
||||
valid_ids = []
|
||||
for tid in tool_ids:
|
||||
try:
|
||||
valid_ids.append(ObjectId(tid))
|
||||
except Exception:
|
||||
continue
|
||||
tools = user_tools_collection.find(
|
||||
{"_id": {"$in": valid_ids}}
|
||||
) if valid_ids else []
|
||||
return [
|
||||
{
|
||||
"id": str(tool["_id"]),
|
||||
"name": tool.get("name", ""),
|
||||
"display_name": tool.get("customName")
|
||||
or tool.get("displayName")
|
||||
or tool.get("name", ""),
|
||||
}
|
||||
for tool in tools
|
||||
]
|
||||
|
||||
|
||||
def get_vector_store(source_id):
|
||||
"""
|
||||
Get the Vector Store for a given source ID.
|
||||
|
||||
Args:
|
||||
source_id (str): source id of the document
|
||||
|
||||
Returns:
|
||||
Vector store instance
|
||||
"""
|
||||
store = VectorCreator.create_vectorstore(
|
||||
settings.VECTOR_STORE,
|
||||
source_id=source_id,
|
||||
embeddings_key=os.getenv("EMBEDDINGS_KEY"),
|
||||
)
|
||||
return store
|
||||
|
||||
|
||||
def handle_image_upload(
|
||||
request, existing_url: str, user: str, storage, base_path: str = "attachments/"
|
||||
) -> Tuple[str, Optional[Response]]:
|
||||
"""
|
||||
Handle image file upload from request.
|
||||
|
||||
Args:
|
||||
request: Flask request object
|
||||
existing_url: Existing image URL (fallback)
|
||||
user: User ID
|
||||
storage: Storage instance
|
||||
base_path: Base path for upload
|
||||
|
||||
Returns:
|
||||
Tuple of (image_url, error_response)
|
||||
"""
|
||||
image_url = existing_url
|
||||
|
||||
if "image" in request.files:
|
||||
file = request.files["image"]
|
||||
if file.filename != "":
|
||||
filename = secure_filename(file.filename)
|
||||
upload_path = f"{settings.UPLOAD_FOLDER.rstrip('/')}/{user}/{base_path.rstrip('/')}/{uuid.uuid4()}_{filename}"
|
||||
try:
|
||||
storage.save_file(file, upload_path, storage_class="STANDARD")
|
||||
image_url = upload_path
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error uploading image: {e}")
|
||||
return None, make_response(
|
||||
jsonify({"success": False, "message": "Image upload failed"}),
|
||||
400,
|
||||
)
|
||||
return image_url, None
|
||||
|
||||
|
||||
def require_agent(func):
|
||||
"""
|
||||
Decorator to require valid agent webhook token.
|
||||
|
||||
Args:
|
||||
func: Function to decorate
|
||||
|
||||
Returns:
|
||||
Wrapped function
|
||||
"""
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
webhook_token = kwargs.get("webhook_token")
|
||||
if not webhook_token:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Webhook token missing"}), 400
|
||||
)
|
||||
agent = agents_collection.find_one(
|
||||
{"incoming_webhook_token": webhook_token}, {"_id": 1}
|
||||
)
|
||||
if not agent:
|
||||
current_app.logger.warning(
|
||||
f"Webhook attempt with invalid token: {webhook_token}"
|
||||
)
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Agent not found"}), 404
|
||||
)
|
||||
kwargs["agent"] = agent
|
||||
kwargs["agent_id_str"] = str(agent["_id"])
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
@@ -1,5 +0,0 @@
|
||||
"""Conversation management module."""
|
||||
|
||||
from .routes import conversations_ns
|
||||
|
||||
__all__ = ["conversations_ns"]
|
||||
@@ -1,280 +0,0 @@
|
||||
"""Conversation management routes."""
|
||||
|
||||
import datetime
|
||||
|
||||
from bson.objectid import ObjectId
|
||||
from flask import current_app, jsonify, make_response, request
|
||||
from flask_restx import fields, Namespace, Resource
|
||||
|
||||
from application.api import api
|
||||
from application.api.user.base import attachments_collection, conversations_collection
|
||||
from application.utils import check_required_fields
|
||||
|
||||
conversations_ns = Namespace(
|
||||
"conversations", description="Conversation management operations", path="/api"
|
||||
)
|
||||
|
||||
|
||||
@conversations_ns.route("/delete_conversation")
|
||||
class DeleteConversation(Resource):
|
||||
@api.doc(
|
||||
description="Deletes a conversation by ID",
|
||||
params={"id": "The ID of the conversation to delete"},
|
||||
)
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
conversation_id = request.args.get("id")
|
||||
if not conversation_id:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "ID is required"}), 400
|
||||
)
|
||||
try:
|
||||
conversations_collection.delete_one(
|
||||
{"_id": ObjectId(conversation_id), "user": decoded_token["sub"]}
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error deleting conversation: {err}", exc_info=True
|
||||
)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify({"success": True}), 200)
|
||||
|
||||
|
||||
@conversations_ns.route("/delete_all_conversations")
|
||||
class DeleteAllConversations(Resource):
|
||||
@api.doc(
|
||||
description="Deletes all conversations for a specific user",
|
||||
)
|
||||
def get(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user_id = decoded_token.get("sub")
|
||||
try:
|
||||
conversations_collection.delete_many({"user": user_id})
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error deleting all conversations: {err}", exc_info=True
|
||||
)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify({"success": True}), 200)
|
||||
|
||||
|
||||
@conversations_ns.route("/get_conversations")
|
||||
class GetConversations(Resource):
|
||||
@api.doc(
|
||||
description="Retrieve a list of the latest 30 conversations (excluding API key conversations)",
|
||||
)
|
||||
def get(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
try:
|
||||
conversations = (
|
||||
conversations_collection.find(
|
||||
{
|
||||
"$or": [
|
||||
{"api_key": {"$exists": False}},
|
||||
{"agent_id": {"$exists": True}},
|
||||
],
|
||||
"user": decoded_token.get("sub"),
|
||||
}
|
||||
)
|
||||
.sort("date", -1)
|
||||
.limit(30)
|
||||
)
|
||||
|
||||
list_conversations = [
|
||||
{
|
||||
"id": str(conversation["_id"]),
|
||||
"name": conversation["name"],
|
||||
"agent_id": conversation.get("agent_id", None),
|
||||
"is_shared_usage": conversation.get("is_shared_usage", False),
|
||||
"shared_token": conversation.get("shared_token", None),
|
||||
}
|
||||
for conversation in conversations
|
||||
]
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error retrieving conversations: {err}", exc_info=True
|
||||
)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify(list_conversations), 200)
|
||||
|
||||
|
||||
@conversations_ns.route("/get_single_conversation")
|
||||
class GetSingleConversation(Resource):
|
||||
@api.doc(
|
||||
description="Retrieve a single conversation by ID",
|
||||
params={"id": "The conversation ID"},
|
||||
)
|
||||
def get(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
conversation_id = request.args.get("id")
|
||||
if not conversation_id:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "ID is required"}), 400
|
||||
)
|
||||
try:
|
||||
conversation = conversations_collection.find_one(
|
||||
{"_id": ObjectId(conversation_id), "user": decoded_token.get("sub")}
|
||||
)
|
||||
if not conversation:
|
||||
return make_response(jsonify({"status": "not found"}), 404)
|
||||
# Process queries to include attachment names
|
||||
|
||||
queries = conversation["queries"]
|
||||
for query in queries:
|
||||
if "attachments" in query and query["attachments"]:
|
||||
attachment_details = []
|
||||
for attachment_id in query["attachments"]:
|
||||
try:
|
||||
attachment = attachments_collection.find_one(
|
||||
{"_id": ObjectId(attachment_id)}
|
||||
)
|
||||
if attachment:
|
||||
attachment_details.append(
|
||||
{
|
||||
"id": str(attachment["_id"]),
|
||||
"fileName": attachment.get(
|
||||
"filename", "Unknown file"
|
||||
),
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
current_app.logger.error(
|
||||
f"Error retrieving attachment {attachment_id}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
query["attachments"] = attachment_details
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error retrieving conversation: {err}", exc_info=True
|
||||
)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
data = {
|
||||
"queries": queries,
|
||||
"agent_id": conversation.get("agent_id"),
|
||||
"is_shared_usage": conversation.get("is_shared_usage", False),
|
||||
"shared_token": conversation.get("shared_token", None),
|
||||
}
|
||||
return make_response(jsonify(data), 200)
|
||||
|
||||
|
||||
@conversations_ns.route("/update_conversation_name")
|
||||
class UpdateConversationName(Resource):
|
||||
@api.expect(
|
||||
api.model(
|
||||
"UpdateConversationModel",
|
||||
{
|
||||
"id": fields.String(required=True, description="Conversation ID"),
|
||||
"name": fields.String(
|
||||
required=True, description="New name of the conversation"
|
||||
),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.doc(
|
||||
description="Updates the name of a conversation",
|
||||
)
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
data = request.get_json()
|
||||
required_fields = ["id", "name"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
try:
|
||||
conversations_collection.update_one(
|
||||
{"_id": ObjectId(data["id"]), "user": decoded_token.get("sub")},
|
||||
{"$set": {"name": data["name"]}},
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error updating conversation name: {err}", exc_info=True
|
||||
)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify({"success": True}), 200)
|
||||
|
||||
|
||||
@conversations_ns.route("/feedback")
|
||||
class SubmitFeedback(Resource):
|
||||
@api.expect(
|
||||
api.model(
|
||||
"FeedbackModel",
|
||||
{
|
||||
"question": fields.String(
|
||||
required=False, description="The user question"
|
||||
),
|
||||
"answer": fields.String(required=False, description="The AI answer"),
|
||||
"feedback": fields.String(required=True, description="User feedback"),
|
||||
"question_index": fields.Integer(
|
||||
required=True,
|
||||
description="The question number in that particular conversation",
|
||||
),
|
||||
"conversation_id": fields.String(
|
||||
required=True, description="id of the particular conversation"
|
||||
),
|
||||
"api_key": fields.String(description="Optional API key"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.doc(
|
||||
description="Submit feedback for a conversation",
|
||||
)
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
data = request.get_json()
|
||||
required_fields = ["feedback", "conversation_id", "question_index"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
try:
|
||||
if data["feedback"] is None:
|
||||
# Remove feedback and feedback_timestamp if feedback is null
|
||||
|
||||
conversations_collection.update_one(
|
||||
{
|
||||
"_id": ObjectId(data["conversation_id"]),
|
||||
"user": decoded_token.get("sub"),
|
||||
f"queries.{data['question_index']}": {"$exists": True},
|
||||
},
|
||||
{
|
||||
"$unset": {
|
||||
f"queries.{data['question_index']}.feedback": "",
|
||||
f"queries.{data['question_index']}.feedback_timestamp": "",
|
||||
}
|
||||
},
|
||||
)
|
||||
else:
|
||||
# Set feedback and feedback_timestamp if feedback has a value
|
||||
|
||||
conversations_collection.update_one(
|
||||
{
|
||||
"_id": ObjectId(data["conversation_id"]),
|
||||
"user": decoded_token.get("sub"),
|
||||
f"queries.{data['question_index']}": {"$exists": True},
|
||||
},
|
||||
{
|
||||
"$set": {
|
||||
f"queries.{data['question_index']}.feedback": data[
|
||||
"feedback"
|
||||
],
|
||||
f"queries.{data['question_index']}.feedback_timestamp": datetime.datetime.now(
|
||||
datetime.timezone.utc
|
||||
),
|
||||
}
|
||||
},
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error submitting feedback: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify({"success": True}), 200)
|
||||
@@ -1,3 +0,0 @@
|
||||
from .routes import models_ns
|
||||
|
||||
__all__ = ["models_ns"]
|
||||
@@ -1,25 +0,0 @@
|
||||
from flask import current_app, jsonify, make_response
|
||||
from flask_restx import Namespace, Resource
|
||||
|
||||
from application.core.model_settings import ModelRegistry
|
||||
|
||||
models_ns = Namespace("models", description="Available models", path="/api")
|
||||
|
||||
|
||||
@models_ns.route("/models")
|
||||
class ModelsListResource(Resource):
|
||||
def get(self):
|
||||
"""Get list of available models with their capabilities."""
|
||||
try:
|
||||
registry = ModelRegistry.get_instance()
|
||||
models = registry.get_enabled_models()
|
||||
|
||||
response = {
|
||||
"models": [model.to_dict() for model in models],
|
||||
"default_model_id": registry.default_model_id,
|
||||
"count": len(models),
|
||||
}
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error fetching models: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 500)
|
||||
return make_response(jsonify(response), 200)
|
||||
@@ -1,5 +0,0 @@
|
||||
"""Prompts module."""
|
||||
|
||||
from .routes import prompts_ns
|
||||
|
||||
__all__ = ["prompts_ns"]
|
||||
@@ -1,191 +0,0 @@
|
||||
"""Prompt management routes."""
|
||||
|
||||
import os
|
||||
|
||||
from bson.objectid import ObjectId
|
||||
from flask import current_app, jsonify, make_response, request
|
||||
from flask_restx import fields, Namespace, Resource
|
||||
|
||||
from application.api import api
|
||||
from application.api.user.base import current_dir, prompts_collection
|
||||
from application.utils import check_required_fields
|
||||
|
||||
prompts_ns = Namespace(
|
||||
"prompts", description="Prompt management operations", path="/api"
|
||||
)
|
||||
|
||||
|
||||
@prompts_ns.route("/create_prompt")
|
||||
class CreatePrompt(Resource):
|
||||
create_prompt_model = api.model(
|
||||
"CreatePromptModel",
|
||||
{
|
||||
"content": fields.String(
|
||||
required=True, description="Content of the prompt"
|
||||
),
|
||||
"name": fields.String(required=True, description="Name of the prompt"),
|
||||
},
|
||||
)
|
||||
|
||||
@api.expect(create_prompt_model)
|
||||
@api.doc(description="Create a new prompt")
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
data = request.get_json()
|
||||
required_fields = ["content", "name"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
user = decoded_token.get("sub")
|
||||
try:
|
||||
|
||||
resp = prompts_collection.insert_one(
|
||||
{
|
||||
"name": data["name"],
|
||||
"content": data["content"],
|
||||
"user": user,
|
||||
}
|
||||
)
|
||||
new_id = str(resp.inserted_id)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error creating prompt: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify({"id": new_id}), 200)
|
||||
|
||||
|
||||
@prompts_ns.route("/get_prompts")
|
||||
class GetPrompts(Resource):
|
||||
@api.doc(description="Get all prompts for the user")
|
||||
def get(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
try:
|
||||
prompts = prompts_collection.find({"user": user})
|
||||
list_prompts = [
|
||||
{"id": "default", "name": "default", "type": "public"},
|
||||
{"id": "creative", "name": "creative", "type": "public"},
|
||||
{"id": "strict", "name": "strict", "type": "public"},
|
||||
]
|
||||
|
||||
for prompt in prompts:
|
||||
list_prompts.append(
|
||||
{
|
||||
"id": str(prompt["_id"]),
|
||||
"name": prompt["name"],
|
||||
"type": "private",
|
||||
}
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error retrieving prompts: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify(list_prompts), 200)
|
||||
|
||||
|
||||
@prompts_ns.route("/get_single_prompt")
|
||||
class GetSinglePrompt(Resource):
|
||||
@api.doc(params={"id": "ID of the prompt"}, description="Get a single prompt by ID")
|
||||
def get(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
prompt_id = request.args.get("id")
|
||||
if not prompt_id:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "ID is required"}), 400
|
||||
)
|
||||
try:
|
||||
if prompt_id == "default":
|
||||
with open(
|
||||
os.path.join(current_dir, "prompts", "chat_combine_default.txt"),
|
||||
"r",
|
||||
) as f:
|
||||
chat_combine_template = f.read()
|
||||
return make_response(jsonify({"content": chat_combine_template}), 200)
|
||||
elif prompt_id == "creative":
|
||||
with open(
|
||||
os.path.join(current_dir, "prompts", "chat_combine_creative.txt"),
|
||||
"r",
|
||||
) as f:
|
||||
chat_reduce_creative = f.read()
|
||||
return make_response(jsonify({"content": chat_reduce_creative}), 200)
|
||||
elif prompt_id == "strict":
|
||||
with open(
|
||||
os.path.join(current_dir, "prompts", "chat_combine_strict.txt"), "r"
|
||||
) as f:
|
||||
chat_reduce_strict = f.read()
|
||||
return make_response(jsonify({"content": chat_reduce_strict}), 200)
|
||||
prompt = prompts_collection.find_one(
|
||||
{"_id": ObjectId(prompt_id), "user": user}
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error retrieving prompt: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify({"content": prompt["content"]}), 200)
|
||||
|
||||
|
||||
@prompts_ns.route("/delete_prompt")
|
||||
class DeletePrompt(Resource):
|
||||
delete_prompt_model = api.model(
|
||||
"DeletePromptModel",
|
||||
{"id": fields.String(required=True, description="Prompt ID to delete")},
|
||||
)
|
||||
|
||||
@api.expect(delete_prompt_model)
|
||||
@api.doc(description="Delete a prompt by ID")
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
data = request.get_json()
|
||||
required_fields = ["id"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
try:
|
||||
prompts_collection.delete_one({"_id": ObjectId(data["id"]), "user": user})
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error deleting prompt: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify({"success": True}), 200)
|
||||
|
||||
|
||||
@prompts_ns.route("/update_prompt")
|
||||
class UpdatePrompt(Resource):
|
||||
update_prompt_model = api.model(
|
||||
"UpdatePromptModel",
|
||||
{
|
||||
"id": fields.String(required=True, description="Prompt ID to update"),
|
||||
"name": fields.String(required=True, description="New name of the prompt"),
|
||||
"content": fields.String(
|
||||
required=True, description="New content of the prompt"
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
@api.expect(update_prompt_model)
|
||||
@api.doc(description="Update an existing prompt")
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
data = request.get_json()
|
||||
required_fields = ["id", "name", "content"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
try:
|
||||
prompts_collection.update_one(
|
||||
{"_id": ObjectId(data["id"]), "user": user},
|
||||
{"$set": {"name": data["name"], "content": data["content"]}},
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error updating prompt: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify({"success": True}), 200)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,5 +0,0 @@
|
||||
"""Sharing module."""
|
||||
|
||||
from .routes import sharing_ns
|
||||
|
||||
__all__ = ["sharing_ns"]
|
||||
@@ -1,304 +0,0 @@
|
||||
"""Conversation sharing routes."""
|
||||
|
||||
import uuid
|
||||
|
||||
from bson.binary import Binary, UuidRepresentation
|
||||
from bson.dbref import DBRef
|
||||
from bson.objectid import ObjectId
|
||||
from flask import current_app, jsonify, make_response, request
|
||||
from flask_restx import fields, inputs, Namespace, Resource
|
||||
|
||||
from application.api import api
|
||||
from application.api.user.base import (
|
||||
agents_collection,
|
||||
attachments_collection,
|
||||
conversations_collection,
|
||||
shared_conversations_collections,
|
||||
)
|
||||
from application.utils import check_required_fields
|
||||
|
||||
sharing_ns = Namespace(
|
||||
"sharing", description="Conversation sharing operations", path="/api"
|
||||
)
|
||||
|
||||
|
||||
@sharing_ns.route("/share")
|
||||
class ShareConversation(Resource):
|
||||
share_conversation_model = api.model(
|
||||
"ShareConversationModel",
|
||||
{
|
||||
"conversation_id": fields.String(
|
||||
required=True, description="Conversation ID"
|
||||
),
|
||||
"user": fields.String(description="User ID (optional)"),
|
||||
"prompt_id": fields.String(description="Prompt ID (optional)"),
|
||||
"chunks": fields.Integer(description="Chunks count (optional)"),
|
||||
},
|
||||
)
|
||||
|
||||
@api.expect(share_conversation_model)
|
||||
@api.doc(description="Share a conversation")
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
data = request.get_json()
|
||||
required_fields = ["conversation_id"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
is_promptable = request.args.get("isPromptable", type=inputs.boolean)
|
||||
if is_promptable is None:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "isPromptable is required"}), 400
|
||||
)
|
||||
conversation_id = data["conversation_id"]
|
||||
|
||||
try:
|
||||
conversation = conversations_collection.find_one(
|
||||
{"_id": ObjectId(conversation_id)}
|
||||
)
|
||||
if conversation is None:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"status": "error",
|
||||
"message": "Conversation does not exist",
|
||||
}
|
||||
),
|
||||
404,
|
||||
)
|
||||
current_n_queries = len(conversation["queries"])
|
||||
explicit_binary = Binary.from_uuid(
|
||||
uuid.uuid4(), UuidRepresentation.STANDARD
|
||||
)
|
||||
|
||||
if is_promptable:
|
||||
prompt_id = data.get("prompt_id", "default")
|
||||
chunks = data.get("chunks", "2")
|
||||
|
||||
name = conversation["name"] + "(shared)"
|
||||
new_api_key_data = {
|
||||
"prompt_id": prompt_id,
|
||||
"chunks": chunks,
|
||||
"user": user,
|
||||
}
|
||||
|
||||
if "source" in data and ObjectId.is_valid(data["source"]):
|
||||
new_api_key_data["source"] = DBRef(
|
||||
"sources", ObjectId(data["source"])
|
||||
)
|
||||
if "retriever" in data:
|
||||
new_api_key_data["retriever"] = data["retriever"]
|
||||
pre_existing_api_document = agents_collection.find_one(new_api_key_data)
|
||||
if pre_existing_api_document:
|
||||
api_uuid = pre_existing_api_document["key"]
|
||||
pre_existing = shared_conversations_collections.find_one(
|
||||
{
|
||||
"conversation_id": ObjectId(conversation_id),
|
||||
"isPromptable": is_promptable,
|
||||
"first_n_queries": current_n_queries,
|
||||
"user": user,
|
||||
"api_key": api_uuid,
|
||||
}
|
||||
)
|
||||
if pre_existing is not None:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": True,
|
||||
"identifier": str(pre_existing["uuid"].as_uuid()),
|
||||
}
|
||||
),
|
||||
200,
|
||||
)
|
||||
else:
|
||||
shared_conversations_collections.insert_one(
|
||||
{
|
||||
"uuid": explicit_binary,
|
||||
"conversation_id": ObjectId(conversation_id),
|
||||
"isPromptable": is_promptable,
|
||||
"first_n_queries": current_n_queries,
|
||||
"user": user,
|
||||
"api_key": api_uuid,
|
||||
}
|
||||
)
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": True,
|
||||
"identifier": str(explicit_binary.as_uuid()),
|
||||
}
|
||||
),
|
||||
201,
|
||||
)
|
||||
else:
|
||||
api_uuid = str(uuid.uuid4())
|
||||
new_api_key_data["key"] = api_uuid
|
||||
new_api_key_data["name"] = name
|
||||
|
||||
if "source" in data and ObjectId.is_valid(data["source"]):
|
||||
new_api_key_data["source"] = DBRef(
|
||||
"sources", ObjectId(data["source"])
|
||||
)
|
||||
if "retriever" in data:
|
||||
new_api_key_data["retriever"] = data["retriever"]
|
||||
agents_collection.insert_one(new_api_key_data)
|
||||
shared_conversations_collections.insert_one(
|
||||
{
|
||||
"uuid": explicit_binary,
|
||||
"conversation_id": ObjectId(conversation_id),
|
||||
"isPromptable": is_promptable,
|
||||
"first_n_queries": current_n_queries,
|
||||
"user": user,
|
||||
"api_key": api_uuid,
|
||||
}
|
||||
)
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": True,
|
||||
"identifier": str(explicit_binary.as_uuid()),
|
||||
}
|
||||
),
|
||||
201,
|
||||
)
|
||||
pre_existing = shared_conversations_collections.find_one(
|
||||
{
|
||||
"conversation_id": ObjectId(conversation_id),
|
||||
"isPromptable": is_promptable,
|
||||
"first_n_queries": current_n_queries,
|
||||
"user": user,
|
||||
}
|
||||
)
|
||||
if pre_existing is not None:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": True,
|
||||
"identifier": str(pre_existing["uuid"].as_uuid()),
|
||||
}
|
||||
),
|
||||
200,
|
||||
)
|
||||
else:
|
||||
shared_conversations_collections.insert_one(
|
||||
{
|
||||
"uuid": explicit_binary,
|
||||
"conversation_id": ObjectId(conversation_id),
|
||||
"isPromptable": is_promptable,
|
||||
"first_n_queries": current_n_queries,
|
||||
"user": user,
|
||||
}
|
||||
)
|
||||
return make_response(
|
||||
jsonify(
|
||||
{"success": True, "identifier": str(explicit_binary.as_uuid())}
|
||||
),
|
||||
201,
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error sharing conversation: {err}", exc_info=True
|
||||
)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
|
||||
|
||||
@sharing_ns.route("/shared_conversation/<string:identifier>")
|
||||
class GetPubliclySharedConversations(Resource):
|
||||
@api.doc(description="Get publicly shared conversations by identifier")
|
||||
def get(self, identifier: str):
|
||||
try:
|
||||
query_uuid = Binary.from_uuid(
|
||||
uuid.UUID(identifier), UuidRepresentation.STANDARD
|
||||
)
|
||||
shared = shared_conversations_collections.find_one({"uuid": query_uuid})
|
||||
conversation_queries = []
|
||||
|
||||
if (
|
||||
shared
|
||||
and "conversation_id" in shared
|
||||
):
|
||||
# Handle DBRef (legacy), ObjectId, dict, and string formats for conversation_id
|
||||
conversation_id = shared["conversation_id"]
|
||||
if isinstance(conversation_id, DBRef):
|
||||
conversation_id = conversation_id.id
|
||||
elif isinstance(conversation_id, dict):
|
||||
# Handle dict representation of DBRef (e.g., {"$ref": "...", "$id": "..."})
|
||||
if "$id" in conversation_id:
|
||||
conv_id = conversation_id["$id"]
|
||||
# $id might be a dict like {"$oid": "..."} or a string
|
||||
if isinstance(conv_id, dict) and "$oid" in conv_id:
|
||||
conversation_id = ObjectId(conv_id["$oid"])
|
||||
else:
|
||||
conversation_id = ObjectId(conv_id)
|
||||
elif "_id" in conversation_id:
|
||||
conversation_id = ObjectId(conversation_id["_id"])
|
||||
elif isinstance(conversation_id, str):
|
||||
conversation_id = ObjectId(conversation_id)
|
||||
conversation = conversations_collection.find_one(
|
||||
{"_id": conversation_id}
|
||||
)
|
||||
if conversation is None:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"error": "might have broken url or the conversation does not exist",
|
||||
}
|
||||
),
|
||||
404,
|
||||
)
|
||||
conversation_queries = conversation["queries"][
|
||||
: (shared["first_n_queries"])
|
||||
]
|
||||
|
||||
for query in conversation_queries:
|
||||
if "attachments" in query and query["attachments"]:
|
||||
attachment_details = []
|
||||
for attachment_id in query["attachments"]:
|
||||
try:
|
||||
attachment = attachments_collection.find_one(
|
||||
{"_id": ObjectId(attachment_id)}
|
||||
)
|
||||
if attachment:
|
||||
attachment_details.append(
|
||||
{
|
||||
"id": str(attachment["_id"]),
|
||||
"fileName": attachment.get(
|
||||
"filename", "Unknown file"
|
||||
),
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
current_app.logger.error(
|
||||
f"Error retrieving attachment {attachment_id}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
query["attachments"] = attachment_details
|
||||
else:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"error": "might have broken url or the conversation does not exist",
|
||||
}
|
||||
),
|
||||
404,
|
||||
)
|
||||
date = conversation["_id"].generation_time.isoformat()
|
||||
res = {
|
||||
"success": True,
|
||||
"queries": conversation_queries,
|
||||
"title": conversation["name"],
|
||||
"timestamp": date,
|
||||
}
|
||||
if shared["isPromptable"] and "api_key" in shared:
|
||||
res["api_key"] = shared["api_key"]
|
||||
return make_response(jsonify(res), 200)
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error getting shared conversation: {err}", exc_info=True
|
||||
)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
@@ -1,7 +0,0 @@
|
||||
"""Sources module."""
|
||||
|
||||
from .chunks import sources_chunks_ns
|
||||
from .routes import sources_ns
|
||||
from .upload import sources_upload_ns
|
||||
|
||||
__all__ = ["sources_ns", "sources_chunks_ns", "sources_upload_ns"]
|
||||
@@ -1,283 +0,0 @@
|
||||
"""Source document management chunk management."""
|
||||
|
||||
from bson.objectid import ObjectId
|
||||
from flask import current_app, jsonify, make_response, request
|
||||
from flask_restx import fields, Namespace, Resource
|
||||
|
||||
from application.api import api
|
||||
from application.api.user.base import get_vector_store, sources_collection
|
||||
from application.utils import check_required_fields, num_tokens_from_string
|
||||
|
||||
sources_chunks_ns = Namespace(
|
||||
"sources", description="Source document management operations", path="/api"
|
||||
)
|
||||
|
||||
|
||||
@sources_chunks_ns.route("/get_chunks")
|
||||
class GetChunks(Resource):
|
||||
@api.doc(
|
||||
description="Retrieves chunks from a document, optionally filtered by file path and search term",
|
||||
params={
|
||||
"id": "The document ID",
|
||||
"page": "Page number for pagination",
|
||||
"per_page": "Number of chunks per page",
|
||||
"path": "Optional: Filter chunks by relative file path",
|
||||
"search": "Optional: Search term to filter chunks by title or content",
|
||||
},
|
||||
)
|
||||
def get(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
doc_id = request.args.get("id")
|
||||
page = int(request.args.get("page", 1))
|
||||
per_page = int(request.args.get("per_page", 10))
|
||||
path = request.args.get("path")
|
||||
search_term = request.args.get("search", "").strip().lower()
|
||||
|
||||
if not ObjectId.is_valid(doc_id):
|
||||
return make_response(jsonify({"error": "Invalid doc_id"}), 400)
|
||||
doc = sources_collection.find_one({"_id": ObjectId(doc_id), "user": user})
|
||||
if not doc:
|
||||
return make_response(
|
||||
jsonify({"error": "Document not found or access denied"}), 404
|
||||
)
|
||||
try:
|
||||
store = get_vector_store(doc_id)
|
||||
chunks = store.get_chunks()
|
||||
|
||||
filtered_chunks = []
|
||||
for chunk in chunks:
|
||||
metadata = chunk.get("metadata", {})
|
||||
|
||||
# Filter by path if provided
|
||||
|
||||
if path:
|
||||
chunk_source = metadata.get("source", "")
|
||||
chunk_file_path = metadata.get("file_path", "")
|
||||
# Check if the chunk matches the requested path
|
||||
# For file uploads: source ends with path (e.g., "inputs/.../file.pdf" ends with "file.pdf")
|
||||
# For crawlers: file_path ends with path (e.g., "guides/setup.md" ends with "setup.md")
|
||||
source_match = chunk_source and chunk_source.endswith(path)
|
||||
file_path_match = chunk_file_path and chunk_file_path.endswith(path)
|
||||
|
||||
if not (source_match or file_path_match):
|
||||
continue
|
||||
# Filter by search term if provided
|
||||
|
||||
if search_term:
|
||||
text_match = search_term in chunk.get("text", "").lower()
|
||||
title_match = search_term in metadata.get("title", "").lower()
|
||||
|
||||
if not (text_match or title_match):
|
||||
continue
|
||||
filtered_chunks.append(chunk)
|
||||
chunks = filtered_chunks
|
||||
|
||||
total_chunks = len(chunks)
|
||||
start = (page - 1) * per_page
|
||||
end = start + per_page
|
||||
paginated_chunks = chunks[start:end]
|
||||
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"page": page,
|
||||
"per_page": per_page,
|
||||
"total": total_chunks,
|
||||
"chunks": paginated_chunks,
|
||||
"path": path if path else None,
|
||||
"search": search_term if search_term else None,
|
||||
}
|
||||
),
|
||||
200,
|
||||
)
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error getting chunks: {e}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 500)
|
||||
|
||||
|
||||
@sources_chunks_ns.route("/add_chunk")
|
||||
class AddChunk(Resource):
|
||||
@api.expect(
|
||||
api.model(
|
||||
"AddChunkModel",
|
||||
{
|
||||
"id": fields.String(required=True, description="Document ID"),
|
||||
"text": fields.String(required=True, description="Text of the chunk"),
|
||||
"metadata": fields.Raw(
|
||||
required=False,
|
||||
description="Metadata associated with the chunk",
|
||||
),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.doc(
|
||||
description="Adds a new chunk to the document",
|
||||
)
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
data = request.get_json()
|
||||
required_fields = ["id", "text"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
doc_id = data.get("id")
|
||||
text = data.get("text")
|
||||
metadata = data.get("metadata", {})
|
||||
token_count = num_tokens_from_string(text)
|
||||
metadata["token_count"] = token_count
|
||||
|
||||
if not ObjectId.is_valid(doc_id):
|
||||
return make_response(jsonify({"error": "Invalid doc_id"}), 400)
|
||||
doc = sources_collection.find_one({"_id": ObjectId(doc_id), "user": user})
|
||||
if not doc:
|
||||
return make_response(
|
||||
jsonify({"error": "Document not found or access denied"}), 404
|
||||
)
|
||||
try:
|
||||
store = get_vector_store(doc_id)
|
||||
chunk_id = store.add_chunk(text, metadata)
|
||||
return make_response(
|
||||
jsonify({"message": "Chunk added successfully", "chunk_id": chunk_id}),
|
||||
201,
|
||||
)
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error adding chunk: {e}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 500)
|
||||
|
||||
|
||||
@sources_chunks_ns.route("/delete_chunk")
|
||||
class DeleteChunk(Resource):
|
||||
@api.doc(
|
||||
description="Deletes a specific chunk from the document.",
|
||||
params={"id": "The document ID", "chunk_id": "The ID of the chunk to delete"},
|
||||
)
|
||||
def delete(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
doc_id = request.args.get("id")
|
||||
chunk_id = request.args.get("chunk_id")
|
||||
|
||||
if not ObjectId.is_valid(doc_id):
|
||||
return make_response(jsonify({"error": "Invalid doc_id"}), 400)
|
||||
doc = sources_collection.find_one({"_id": ObjectId(doc_id), "user": user})
|
||||
if not doc:
|
||||
return make_response(
|
||||
jsonify({"error": "Document not found or access denied"}), 404
|
||||
)
|
||||
try:
|
||||
store = get_vector_store(doc_id)
|
||||
deleted = store.delete_chunk(chunk_id)
|
||||
if deleted:
|
||||
return make_response(
|
||||
jsonify({"message": "Chunk deleted successfully"}), 200
|
||||
)
|
||||
else:
|
||||
return make_response(
|
||||
jsonify({"message": "Chunk not found or could not be deleted"}),
|
||||
404,
|
||||
)
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error deleting chunk: {e}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 500)
|
||||
|
||||
|
||||
@sources_chunks_ns.route("/update_chunk")
|
||||
class UpdateChunk(Resource):
|
||||
@api.expect(
|
||||
api.model(
|
||||
"UpdateChunkModel",
|
||||
{
|
||||
"id": fields.String(required=True, description="Document ID"),
|
||||
"chunk_id": fields.String(
|
||||
required=True, description="Chunk ID to update"
|
||||
),
|
||||
"text": fields.String(
|
||||
required=False, description="New text of the chunk"
|
||||
),
|
||||
"metadata": fields.Raw(
|
||||
required=False,
|
||||
description="Updated metadata associated with the chunk",
|
||||
),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.doc(
|
||||
description="Updates an existing chunk in the document.",
|
||||
)
|
||||
def put(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
data = request.get_json()
|
||||
required_fields = ["id", "chunk_id"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
doc_id = data.get("id")
|
||||
chunk_id = data.get("chunk_id")
|
||||
text = data.get("text")
|
||||
metadata = data.get("metadata")
|
||||
|
||||
if text is not None:
|
||||
token_count = num_tokens_from_string(text)
|
||||
if metadata is None:
|
||||
metadata = {}
|
||||
metadata["token_count"] = token_count
|
||||
if not ObjectId.is_valid(doc_id):
|
||||
return make_response(jsonify({"error": "Invalid doc_id"}), 400)
|
||||
doc = sources_collection.find_one({"_id": ObjectId(doc_id), "user": user})
|
||||
if not doc:
|
||||
return make_response(
|
||||
jsonify({"error": "Document not found or access denied"}), 404
|
||||
)
|
||||
try:
|
||||
store = get_vector_store(doc_id)
|
||||
|
||||
chunks = store.get_chunks()
|
||||
existing_chunk = next((c for c in chunks if c["doc_id"] == chunk_id), None)
|
||||
if not existing_chunk:
|
||||
return make_response(jsonify({"error": "Chunk not found"}), 404)
|
||||
new_text = text if text is not None else existing_chunk["text"]
|
||||
|
||||
if metadata is not None:
|
||||
new_metadata = existing_chunk["metadata"].copy()
|
||||
new_metadata.update(metadata)
|
||||
else:
|
||||
new_metadata = existing_chunk["metadata"].copy()
|
||||
if text is not None:
|
||||
new_metadata["token_count"] = num_tokens_from_string(new_text)
|
||||
try:
|
||||
new_chunk_id = store.add_chunk(new_text, new_metadata)
|
||||
|
||||
deleted = store.delete_chunk(chunk_id)
|
||||
if not deleted:
|
||||
current_app.logger.warning(
|
||||
f"Failed to delete old chunk {chunk_id}, but new chunk {new_chunk_id} was created"
|
||||
)
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"message": "Chunk updated successfully",
|
||||
"chunk_id": new_chunk_id,
|
||||
"original_chunk_id": chunk_id,
|
||||
}
|
||||
),
|
||||
200,
|
||||
)
|
||||
except Exception as add_error:
|
||||
current_app.logger.error(f"Failed to add updated chunk: {add_error}")
|
||||
return make_response(
|
||||
jsonify({"error": "Failed to update chunk - addition failed"}), 500
|
||||
)
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error updating chunk: {e}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 500)
|
||||
@@ -1,409 +0,0 @@
|
||||
"""Source document management routes."""
|
||||
|
||||
import json
|
||||
import math
|
||||
|
||||
from bson.objectid import ObjectId
|
||||
from flask import current_app, jsonify, make_response, redirect, request
|
||||
from flask_restx import fields, Namespace, Resource
|
||||
|
||||
from application.api import api
|
||||
from application.api.user.base import sources_collection
|
||||
from application.api.user.tasks import sync_source
|
||||
from application.core.settings import settings
|
||||
from application.storage.storage_creator import StorageCreator
|
||||
from application.utils import check_required_fields
|
||||
from application.vectorstore.vector_creator import VectorCreator
|
||||
|
||||
|
||||
sources_ns = Namespace(
|
||||
"sources", description="Source document management operations", path="/api"
|
||||
)
|
||||
|
||||
|
||||
def _get_provider_from_remote_data(remote_data):
|
||||
if not remote_data:
|
||||
return None
|
||||
if isinstance(remote_data, dict):
|
||||
return remote_data.get("provider")
|
||||
if isinstance(remote_data, str):
|
||||
try:
|
||||
remote_data_obj = json.loads(remote_data)
|
||||
except Exception:
|
||||
return None
|
||||
if isinstance(remote_data_obj, dict):
|
||||
return remote_data_obj.get("provider")
|
||||
return None
|
||||
|
||||
|
||||
@sources_ns.route("/sources")
|
||||
class CombinedJson(Resource):
|
||||
@api.doc(description="Provide JSON file with combined available indexes")
|
||||
def get(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
data = [
|
||||
{
|
||||
"name": "Default",
|
||||
"date": "default",
|
||||
"model": settings.EMBEDDINGS_NAME,
|
||||
"location": "remote",
|
||||
"tokens": "",
|
||||
"retriever": "classic",
|
||||
}
|
||||
]
|
||||
|
||||
try:
|
||||
for index in sources_collection.find({"user": user}).sort("date", -1):
|
||||
provider = _get_provider_from_remote_data(index.get("remote_data"))
|
||||
data.append(
|
||||
{
|
||||
"id": str(index["_id"]),
|
||||
"name": index.get("name"),
|
||||
"date": index.get("date"),
|
||||
"model": settings.EMBEDDINGS_NAME,
|
||||
"location": "local",
|
||||
"tokens": index.get("tokens", ""),
|
||||
"retriever": index.get("retriever", "classic"),
|
||||
"syncFrequency": index.get("sync_frequency", ""),
|
||||
"provider": provider,
|
||||
"is_nested": bool(index.get("directory_structure")),
|
||||
"type": index.get(
|
||||
"type", "file"
|
||||
), # Add type field with default "file"
|
||||
}
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error retrieving sources: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify(data), 200)
|
||||
|
||||
|
||||
@sources_ns.route("/sources/paginated")
|
||||
class PaginatedSources(Resource):
|
||||
@api.doc(description="Get document with pagination, sorting and filtering")
|
||||
def get(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
sort_field = request.args.get("sort", "date") # Default to 'date'
|
||||
sort_order = request.args.get("order", "desc") # Default to 'desc'
|
||||
page = int(request.args.get("page", 1)) # Default to 1
|
||||
rows_per_page = int(request.args.get("rows", 10)) # Default to 10
|
||||
# add .strip() to remove leading and trailing whitespaces
|
||||
|
||||
search_term = request.args.get(
|
||||
"search", ""
|
||||
).strip() # add search for filter documents
|
||||
|
||||
# Prepare query for filtering
|
||||
|
||||
query = {"user": user}
|
||||
if search_term:
|
||||
query["name"] = {
|
||||
"$regex": search_term,
|
||||
"$options": "i", # using case-insensitive search
|
||||
}
|
||||
total_documents = sources_collection.count_documents(query)
|
||||
total_pages = max(1, math.ceil(total_documents / rows_per_page))
|
||||
page = min(
|
||||
max(1, page), total_pages
|
||||
) # add this to make sure page inbound is within the range
|
||||
sort_order = 1 if sort_order == "asc" else -1
|
||||
skip = (page - 1) * rows_per_page
|
||||
|
||||
try:
|
||||
documents = (
|
||||
sources_collection.find(query)
|
||||
.sort(sort_field, sort_order)
|
||||
.skip(skip)
|
||||
.limit(rows_per_page)
|
||||
)
|
||||
|
||||
paginated_docs = []
|
||||
for doc in documents:
|
||||
provider = _get_provider_from_remote_data(doc.get("remote_data"))
|
||||
doc_data = {
|
||||
"id": str(doc["_id"]),
|
||||
"name": doc.get("name", ""),
|
||||
"date": doc.get("date", ""),
|
||||
"model": settings.EMBEDDINGS_NAME,
|
||||
"location": "local",
|
||||
"tokens": doc.get("tokens", ""),
|
||||
"retriever": doc.get("retriever", "classic"),
|
||||
"syncFrequency": doc.get("sync_frequency", ""),
|
||||
"provider": provider,
|
||||
"isNested": bool(doc.get("directory_structure")),
|
||||
"type": doc.get("type", "file"),
|
||||
}
|
||||
paginated_docs.append(doc_data)
|
||||
response = {
|
||||
"total": total_documents,
|
||||
"totalPages": total_pages,
|
||||
"currentPage": page,
|
||||
"paginated": paginated_docs,
|
||||
}
|
||||
return make_response(jsonify(response), 200)
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error retrieving paginated sources: {err}", exc_info=True
|
||||
)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
|
||||
|
||||
@sources_ns.route("/delete_by_ids")
|
||||
class DeleteByIds(Resource):
|
||||
@api.doc(
|
||||
description="Deletes documents from the vector store by IDs",
|
||||
params={"path": "Comma-separated list of IDs"},
|
||||
)
|
||||
def get(self):
|
||||
ids = request.args.get("path")
|
||||
if not ids:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Missing required fields"}), 400
|
||||
)
|
||||
try:
|
||||
result = sources_collection.delete_index(ids=ids)
|
||||
if result:
|
||||
return make_response(jsonify({"success": True}), 200)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error deleting indexes: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
|
||||
|
||||
@sources_ns.route("/delete_old")
|
||||
class DeleteOldIndexes(Resource):
|
||||
@api.doc(
|
||||
description="Deletes old indexes and associated files",
|
||||
params={"source_id": "The source ID to delete"},
|
||||
)
|
||||
def get(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
source_id = request.args.get("source_id")
|
||||
if not source_id:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Missing required fields"}), 400
|
||||
)
|
||||
doc = sources_collection.find_one(
|
||||
{"_id": ObjectId(source_id), "user": decoded_token.get("sub")}
|
||||
)
|
||||
if not doc:
|
||||
return make_response(jsonify({"status": "not found"}), 404)
|
||||
storage = StorageCreator.get_storage()
|
||||
|
||||
try:
|
||||
# Delete vector index
|
||||
|
||||
if settings.VECTOR_STORE == "faiss":
|
||||
index_path = f"indexes/{str(doc['_id'])}"
|
||||
if storage.file_exists(f"{index_path}/index.faiss"):
|
||||
storage.delete_file(f"{index_path}/index.faiss")
|
||||
if storage.file_exists(f"{index_path}/index.pkl"):
|
||||
storage.delete_file(f"{index_path}/index.pkl")
|
||||
else:
|
||||
vectorstore = VectorCreator.create_vectorstore(
|
||||
settings.VECTOR_STORE, source_id=str(doc["_id"])
|
||||
)
|
||||
vectorstore.delete_index()
|
||||
if "file_path" in doc and doc["file_path"]:
|
||||
file_path = doc["file_path"]
|
||||
if storage.is_directory(file_path):
|
||||
files = storage.list_files(file_path)
|
||||
for f in files:
|
||||
storage.delete_file(f)
|
||||
else:
|
||||
storage.delete_file(file_path)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error deleting files and indexes: {err}", exc_info=True
|
||||
)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
sources_collection.delete_one({"_id": ObjectId(source_id)})
|
||||
return make_response(jsonify({"success": True}), 200)
|
||||
|
||||
|
||||
@sources_ns.route("/combine")
|
||||
class RedirectToSources(Resource):
|
||||
@api.doc(
|
||||
description="Redirects /api/combine to /api/sources for backward compatibility"
|
||||
)
|
||||
def get(self):
|
||||
return redirect("/api/sources", code=301)
|
||||
|
||||
|
||||
@sources_ns.route("/manage_sync")
|
||||
class ManageSync(Resource):
|
||||
manage_sync_model = api.model(
|
||||
"ManageSyncModel",
|
||||
{
|
||||
"source_id": fields.String(required=True, description="Source ID"),
|
||||
"sync_frequency": fields.String(
|
||||
required=True,
|
||||
description="Sync frequency (never, daily, weekly, monthly)",
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
@api.expect(manage_sync_model)
|
||||
@api.doc(description="Manage sync frequency for sources")
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
data = request.get_json() or {}
|
||||
required_fields = ["source_id", "sync_frequency"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
source_id = data["source_id"]
|
||||
sync_frequency = data["sync_frequency"]
|
||||
|
||||
if sync_frequency not in ["never", "daily", "weekly", "monthly"]:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Invalid frequency"}), 400
|
||||
)
|
||||
update_data = {"$set": {"sync_frequency": sync_frequency}}
|
||||
try:
|
||||
sources_collection.update_one(
|
||||
{
|
||||
"_id": ObjectId(source_id),
|
||||
"user": user,
|
||||
},
|
||||
update_data,
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error updating sync frequency: {err}", exc_info=True
|
||||
)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify({"success": True}), 200)
|
||||
|
||||
|
||||
@sources_ns.route("/sync_source")
|
||||
class SyncSource(Resource):
|
||||
sync_source_model = api.model(
|
||||
"SyncSourceModel",
|
||||
{"source_id": fields.String(required=True, description="Source ID")},
|
||||
)
|
||||
|
||||
@api.expect(sync_source_model)
|
||||
@api.doc(description="Trigger an immediate sync for a source")
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
data = request.get_json()
|
||||
required_fields = ["source_id"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
source_id = data["source_id"]
|
||||
if not ObjectId.is_valid(source_id):
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Invalid source ID"}), 400
|
||||
)
|
||||
doc = sources_collection.find_one(
|
||||
{"_id": ObjectId(source_id), "user": user}
|
||||
)
|
||||
if not doc:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Source not found"}), 404
|
||||
)
|
||||
source_type = doc.get("type", "")
|
||||
if source_type.startswith("connector"):
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": "Connector sources must be synced via /api/connectors/sync",
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
source_data = doc.get("remote_data")
|
||||
if not source_data:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Source is not syncable"}), 400
|
||||
)
|
||||
try:
|
||||
task = sync_source.delay(
|
||||
source_data=source_data,
|
||||
job_name=doc.get("name", ""),
|
||||
user=user,
|
||||
loader=source_type,
|
||||
sync_frequency=doc.get("sync_frequency", "never"),
|
||||
retriever=doc.get("retriever", "classic"),
|
||||
doc_id=source_id,
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error starting sync for source {source_id}: {err}",
|
||||
exc_info=True,
|
||||
)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify({"success": True, "task_id": task.id}), 200)
|
||||
|
||||
|
||||
@sources_ns.route("/directory_structure")
|
||||
class DirectoryStructure(Resource):
|
||||
@api.doc(
|
||||
description="Get the directory structure for a document",
|
||||
params={"id": "The document ID"},
|
||||
)
|
||||
def get(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
doc_id = request.args.get("id")
|
||||
|
||||
if not doc_id:
|
||||
return make_response(jsonify({"error": "Document ID is required"}), 400)
|
||||
if not ObjectId.is_valid(doc_id):
|
||||
return make_response(jsonify({"error": "Invalid document ID"}), 400)
|
||||
try:
|
||||
doc = sources_collection.find_one({"_id": ObjectId(doc_id), "user": user})
|
||||
if not doc:
|
||||
return make_response(
|
||||
jsonify({"error": "Document not found or access denied"}), 404
|
||||
)
|
||||
directory_structure = doc.get("directory_structure", {})
|
||||
base_path = doc.get("file_path", "")
|
||||
|
||||
provider = None
|
||||
remote_data = doc.get("remote_data")
|
||||
try:
|
||||
if isinstance(remote_data, str) and remote_data:
|
||||
remote_data_obj = json.loads(remote_data)
|
||||
provider = remote_data_obj.get("provider")
|
||||
except Exception as e:
|
||||
current_app.logger.warning(
|
||||
f"Failed to parse remote_data for doc {doc_id}: {e}"
|
||||
)
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": True,
|
||||
"directory_structure": directory_structure,
|
||||
"base_path": base_path,
|
||||
"provider": provider,
|
||||
}
|
||||
),
|
||||
200,
|
||||
)
|
||||
except Exception as e:
|
||||
current_app.logger.error(
|
||||
f"Error retrieving directory structure: {e}", exc_info=True
|
||||
)
|
||||
return make_response(jsonify({"success": False, "error": "Failed to retrieve directory structure"}), 500)
|
||||
@@ -1,655 +0,0 @@
|
||||
"""Source document management upload functionality."""
|
||||
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
import zipfile
|
||||
|
||||
from bson.objectid import ObjectId
|
||||
from flask import current_app, jsonify, make_response, request
|
||||
from flask_restx import fields, Namespace, Resource
|
||||
|
||||
from application.api import api
|
||||
from application.api.user.base import sources_collection
|
||||
from application.api.user.tasks import ingest, ingest_connector_task, ingest_remote
|
||||
from application.core.settings import settings
|
||||
from application.parser.connectors.connector_creator import ConnectorCreator
|
||||
from application.parser.file.constants import SUPPORTED_SOURCE_EXTENSIONS
|
||||
from application.storage.storage_creator import StorageCreator
|
||||
from application.stt.upload_limits import (
|
||||
AudioFileTooLargeError,
|
||||
build_stt_file_size_limit_message,
|
||||
enforce_audio_file_size_limit,
|
||||
is_audio_filename,
|
||||
)
|
||||
from application.utils import check_required_fields, safe_filename
|
||||
|
||||
|
||||
sources_upload_ns = Namespace(
|
||||
"sources", description="Source document management operations", path="/api"
|
||||
)
|
||||
|
||||
|
||||
def _enforce_audio_path_size_limit(file_path: str, filename: str) -> None:
|
||||
if not is_audio_filename(filename):
|
||||
return
|
||||
enforce_audio_file_size_limit(os.path.getsize(file_path))
|
||||
|
||||
|
||||
@sources_upload_ns.route("/upload")
|
||||
class UploadFile(Resource):
|
||||
@api.expect(
|
||||
api.model(
|
||||
"UploadModel",
|
||||
{
|
||||
"user": fields.String(required=True, description="User ID"),
|
||||
"name": fields.String(required=True, description="Job name"),
|
||||
"file": fields.Raw(required=True, description="File(s) to upload"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.doc(
|
||||
description="Uploads a file to be vectorized and indexed",
|
||||
)
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
data = request.form
|
||||
files = request.files.getlist("file")
|
||||
required_fields = ["user", "name"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields or not files or all(file.filename == "" for file in files):
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"status": "error",
|
||||
"message": "Missing required fields or files",
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
user = decoded_token.get("sub")
|
||||
job_name = request.form["name"]
|
||||
|
||||
# Create safe versions for filesystem operations
|
||||
|
||||
safe_user = safe_filename(user)
|
||||
dir_name = safe_filename(job_name)
|
||||
base_path = f"{settings.UPLOAD_FOLDER}/{safe_user}/{dir_name}"
|
||||
file_name_map = {}
|
||||
|
||||
try:
|
||||
storage = StorageCreator.get_storage()
|
||||
|
||||
for file in files:
|
||||
original_filename = os.path.basename(file.filename)
|
||||
safe_file = safe_filename(original_filename)
|
||||
if original_filename:
|
||||
file_name_map[safe_file] = original_filename
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
temp_file_path = os.path.join(temp_dir, safe_file)
|
||||
file.save(temp_file_path)
|
||||
_enforce_audio_path_size_limit(temp_file_path, safe_file)
|
||||
|
||||
# Only extract actual .zip files, not Office formats (.docx, .xlsx, .pptx)
|
||||
# which are technically zip archives but should be processed as-is
|
||||
is_office_format = safe_file.lower().endswith(
|
||||
(".docx", ".xlsx", ".pptx", ".odt", ".ods", ".odp", ".epub")
|
||||
)
|
||||
if zipfile.is_zipfile(temp_file_path) and not is_office_format:
|
||||
try:
|
||||
with zipfile.ZipFile(temp_file_path, "r") as zip_ref:
|
||||
zip_ref.extractall(path=temp_dir)
|
||||
|
||||
# Walk through extracted files and upload them
|
||||
|
||||
for root, _, files in os.walk(temp_dir):
|
||||
for extracted_file in files:
|
||||
if (
|
||||
os.path.join(root, extracted_file)
|
||||
== temp_file_path
|
||||
):
|
||||
continue
|
||||
rel_path = os.path.relpath(
|
||||
os.path.join(root, extracted_file), temp_dir
|
||||
)
|
||||
storage_path = f"{base_path}/{rel_path}"
|
||||
_enforce_audio_path_size_limit(
|
||||
os.path.join(root, extracted_file),
|
||||
extracted_file,
|
||||
)
|
||||
|
||||
with open(
|
||||
os.path.join(root, extracted_file), "rb"
|
||||
) as f:
|
||||
storage.save_file(f, storage_path)
|
||||
except Exception as e:
|
||||
current_app.logger.error(
|
||||
f"Error extracting zip: {e}", exc_info=True
|
||||
)
|
||||
# If zip extraction fails, save the original zip file
|
||||
|
||||
file_path = f"{base_path}/{safe_file}"
|
||||
with open(temp_file_path, "rb") as f:
|
||||
storage.save_file(f, file_path)
|
||||
else:
|
||||
# For non-zip files, save directly
|
||||
|
||||
file_path = f"{base_path}/{safe_file}"
|
||||
with open(temp_file_path, "rb") as f:
|
||||
storage.save_file(f, file_path)
|
||||
task = ingest.delay(
|
||||
settings.UPLOAD_FOLDER,
|
||||
list(SUPPORTED_SOURCE_EXTENSIONS),
|
||||
job_name,
|
||||
user,
|
||||
file_path=base_path,
|
||||
filename=dir_name,
|
||||
file_name_map=file_name_map,
|
||||
)
|
||||
except AudioFileTooLargeError:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": build_stt_file_size_limit_message(),
|
||||
}
|
||||
),
|
||||
413,
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error uploading file: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify({"success": True, "task_id": task.id}), 200)
|
||||
|
||||
|
||||
@sources_upload_ns.route("/remote")
|
||||
class UploadRemote(Resource):
|
||||
@api.expect(
|
||||
api.model(
|
||||
"RemoteUploadModel",
|
||||
{
|
||||
"user": fields.String(required=True, description="User ID"),
|
||||
"source": fields.String(
|
||||
required=True, description="Source of the data"
|
||||
),
|
||||
"name": fields.String(required=True, description="Job name"),
|
||||
"data": fields.String(required=True, description="Data to process"),
|
||||
"repo_url": fields.String(description="GitHub repository URL"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.doc(
|
||||
description="Uploads remote source for vectorization",
|
||||
)
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
data = request.form
|
||||
required_fields = ["user", "source", "name", "data"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
try:
|
||||
config = json.loads(data["data"])
|
||||
source_data = None
|
||||
|
||||
if data["source"] == "github":
|
||||
source_data = config.get("repo_url")
|
||||
elif data["source"] in ["crawler", "url"]:
|
||||
source_data = config.get("url")
|
||||
elif data["source"] == "reddit":
|
||||
source_data = config
|
||||
elif data["source"] == "s3":
|
||||
source_data = config
|
||||
elif data["source"] in ConnectorCreator.get_supported_connectors():
|
||||
session_token = config.get("session_token")
|
||||
if not session_token:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"error": f"Missing session_token in {data['source']} configuration",
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
# Process file_ids
|
||||
|
||||
file_ids = config.get("file_ids", [])
|
||||
if isinstance(file_ids, str):
|
||||
file_ids = [id.strip() for id in file_ids.split(",") if id.strip()]
|
||||
elif not isinstance(file_ids, list):
|
||||
file_ids = []
|
||||
# Process folder_ids
|
||||
|
||||
folder_ids = config.get("folder_ids", [])
|
||||
if isinstance(folder_ids, str):
|
||||
folder_ids = [
|
||||
id.strip() for id in folder_ids.split(",") if id.strip()
|
||||
]
|
||||
elif not isinstance(folder_ids, list):
|
||||
folder_ids = []
|
||||
config["file_ids"] = file_ids
|
||||
config["folder_ids"] = folder_ids
|
||||
|
||||
task = ingest_connector_task.delay(
|
||||
job_name=data["name"],
|
||||
user=decoded_token.get("sub"),
|
||||
source_type=data["source"],
|
||||
session_token=session_token,
|
||||
file_ids=file_ids,
|
||||
folder_ids=folder_ids,
|
||||
recursive=config.get("recursive", False),
|
||||
retriever=config.get("retriever", "classic"),
|
||||
)
|
||||
return make_response(
|
||||
jsonify({"success": True, "task_id": task.id}), 200
|
||||
)
|
||||
task = ingest_remote.delay(
|
||||
source_data=source_data,
|
||||
job_name=data["name"],
|
||||
user=decoded_token.get("sub"),
|
||||
loader=data["source"],
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error uploading remote source: {err}", exc_info=True
|
||||
)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify({"success": True, "task_id": task.id}), 200)
|
||||
|
||||
|
||||
@sources_upload_ns.route("/manage_source_files")
|
||||
class ManageSourceFiles(Resource):
|
||||
@api.expect(
|
||||
api.model(
|
||||
"ManageSourceFilesModel",
|
||||
{
|
||||
"source_id": fields.String(
|
||||
required=True, description="Source ID to modify"
|
||||
),
|
||||
"operation": fields.String(
|
||||
required=True,
|
||||
description="Operation: 'add', 'remove', or 'remove_directory'",
|
||||
),
|
||||
"file_paths": fields.List(
|
||||
fields.String,
|
||||
required=False,
|
||||
description="File paths to remove (for remove operation)",
|
||||
),
|
||||
"directory_path": fields.String(
|
||||
required=False,
|
||||
description="Directory path to remove (for remove_directory operation)",
|
||||
),
|
||||
"file": fields.Raw(
|
||||
required=False, description="Files to add (for add operation)"
|
||||
),
|
||||
"parent_dir": fields.String(
|
||||
required=False,
|
||||
description="Parent directory path relative to source root",
|
||||
),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.doc(
|
||||
description="Add files, remove files, or remove directories from an existing source",
|
||||
)
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Unauthorized"}), 401
|
||||
)
|
||||
user = decoded_token.get("sub")
|
||||
source_id = request.form.get("source_id")
|
||||
operation = request.form.get("operation")
|
||||
|
||||
if not source_id or not operation:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": "source_id and operation are required",
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
if operation not in ["add", "remove", "remove_directory"]:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": "operation must be 'add', 'remove', or 'remove_directory'",
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
try:
|
||||
ObjectId(source_id)
|
||||
except Exception:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Invalid source ID format"}), 400
|
||||
)
|
||||
try:
|
||||
source = sources_collection.find_one(
|
||||
{"_id": ObjectId(source_id), "user": user}
|
||||
)
|
||||
if not source:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": "Source not found or access denied",
|
||||
}
|
||||
),
|
||||
404,
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error finding source: {err}", exc_info=True)
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Database error"}), 500
|
||||
)
|
||||
try:
|
||||
storage = StorageCreator.get_storage()
|
||||
source_file_path = source.get("file_path", "")
|
||||
parent_dir = request.form.get("parent_dir", "")
|
||||
file_name_map = source.get("file_name_map") or {}
|
||||
if isinstance(file_name_map, str):
|
||||
try:
|
||||
file_name_map = json.loads(file_name_map)
|
||||
except Exception:
|
||||
file_name_map = {}
|
||||
if not isinstance(file_name_map, dict):
|
||||
file_name_map = {}
|
||||
|
||||
if parent_dir and (parent_dir.startswith("/") or ".." in parent_dir):
|
||||
return make_response(
|
||||
jsonify(
|
||||
{"success": False, "message": "Invalid parent directory path"}
|
||||
),
|
||||
400,
|
||||
)
|
||||
if operation == "add":
|
||||
files = request.files.getlist("file")
|
||||
if not files or all(file.filename == "" for file in files):
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": "No files provided for add operation",
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
added_files = []
|
||||
map_updated = False
|
||||
|
||||
target_dir = source_file_path
|
||||
if parent_dir:
|
||||
target_dir = f"{source_file_path}/{parent_dir}"
|
||||
for file in files:
|
||||
if file.filename:
|
||||
original_filename = os.path.basename(file.filename)
|
||||
safe_filename_str = safe_filename(original_filename)
|
||||
file_path = f"{target_dir}/{safe_filename_str}"
|
||||
|
||||
# Save file to storage
|
||||
|
||||
storage.save_file(file, file_path)
|
||||
added_files.append(safe_filename_str)
|
||||
if original_filename:
|
||||
relative_key = (
|
||||
f"{parent_dir}/{safe_filename_str}"
|
||||
if parent_dir
|
||||
else safe_filename_str
|
||||
)
|
||||
file_name_map[relative_key] = original_filename
|
||||
map_updated = True
|
||||
|
||||
if map_updated:
|
||||
sources_collection.update_one(
|
||||
{"_id": ObjectId(source_id)},
|
||||
{"$set": {"file_name_map": file_name_map}},
|
||||
)
|
||||
# Trigger re-ingestion pipeline
|
||||
|
||||
from application.api.user.tasks import reingest_source_task
|
||||
|
||||
task = reingest_source_task.delay(source_id=source_id, user=user)
|
||||
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": True,
|
||||
"message": f"Added {len(added_files)} files",
|
||||
"added_files": added_files,
|
||||
"parent_dir": parent_dir,
|
||||
"reingest_task_id": task.id,
|
||||
}
|
||||
),
|
||||
200,
|
||||
)
|
||||
elif operation == "remove":
|
||||
file_paths_str = request.form.get("file_paths")
|
||||
if not file_paths_str:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": "file_paths required for remove operation",
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
try:
|
||||
file_paths = (
|
||||
json.loads(file_paths_str)
|
||||
if isinstance(file_paths_str, str)
|
||||
else file_paths_str
|
||||
)
|
||||
except Exception:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{"success": False, "message": "Invalid file_paths format"}
|
||||
),
|
||||
400,
|
||||
)
|
||||
# Remove files from storage and directory structure
|
||||
|
||||
removed_files = []
|
||||
map_updated = False
|
||||
for file_path in file_paths:
|
||||
full_path = f"{source_file_path}/{file_path}"
|
||||
|
||||
# Remove from storage
|
||||
|
||||
if storage.file_exists(full_path):
|
||||
storage.delete_file(full_path)
|
||||
removed_files.append(file_path)
|
||||
if file_path in file_name_map:
|
||||
file_name_map.pop(file_path, None)
|
||||
map_updated = True
|
||||
|
||||
if map_updated and isinstance(file_name_map, dict):
|
||||
sources_collection.update_one(
|
||||
{"_id": ObjectId(source_id)},
|
||||
{"$set": {"file_name_map": file_name_map}},
|
||||
)
|
||||
# Trigger re-ingestion pipeline
|
||||
|
||||
from application.api.user.tasks import reingest_source_task
|
||||
|
||||
task = reingest_source_task.delay(source_id=source_id, user=user)
|
||||
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": True,
|
||||
"message": f"Removed {len(removed_files)} files",
|
||||
"removed_files": removed_files,
|
||||
"reingest_task_id": task.id,
|
||||
}
|
||||
),
|
||||
200,
|
||||
)
|
||||
elif operation == "remove_directory":
|
||||
directory_path = request.form.get("directory_path")
|
||||
if not directory_path:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": "directory_path required for remove_directory operation",
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
# Validate directory path (prevent path traversal)
|
||||
|
||||
if directory_path.startswith("/") or ".." in directory_path:
|
||||
current_app.logger.warning(
|
||||
f"Invalid directory path attempted for removal. "
|
||||
f"User: {user}, Source ID: {source_id}, Directory path: {directory_path}"
|
||||
)
|
||||
return make_response(
|
||||
jsonify(
|
||||
{"success": False, "message": "Invalid directory path"}
|
||||
),
|
||||
400,
|
||||
)
|
||||
full_directory_path = (
|
||||
f"{source_file_path}/{directory_path}"
|
||||
if directory_path
|
||||
else source_file_path
|
||||
)
|
||||
|
||||
if not storage.is_directory(full_directory_path):
|
||||
current_app.logger.warning(
|
||||
f"Directory not found or is not a directory for removal. "
|
||||
f"User: {user}, Source ID: {source_id}, Directory path: {directory_path}, "
|
||||
f"Full path: {full_directory_path}"
|
||||
)
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": "Directory not found or is not a directory",
|
||||
}
|
||||
),
|
||||
404,
|
||||
)
|
||||
success = storage.remove_directory(full_directory_path)
|
||||
|
||||
if not success:
|
||||
current_app.logger.error(
|
||||
f"Failed to remove directory from storage. "
|
||||
f"User: {user}, Source ID: {source_id}, Directory path: {directory_path}, "
|
||||
f"Full path: {full_directory_path}"
|
||||
)
|
||||
return make_response(
|
||||
jsonify(
|
||||
{"success": False, "message": "Failed to remove directory"}
|
||||
),
|
||||
500,
|
||||
)
|
||||
current_app.logger.info(
|
||||
f"Successfully removed directory. "
|
||||
f"User: {user}, Source ID: {source_id}, Directory path: {directory_path}, "
|
||||
f"Full path: {full_directory_path}"
|
||||
)
|
||||
if directory_path and file_name_map:
|
||||
prefix = f"{directory_path.rstrip('/')}/"
|
||||
keys_to_remove = [
|
||||
key
|
||||
for key in file_name_map.keys()
|
||||
if key == directory_path or key.startswith(prefix)
|
||||
]
|
||||
if keys_to_remove:
|
||||
for key in keys_to_remove:
|
||||
file_name_map.pop(key, None)
|
||||
sources_collection.update_one(
|
||||
{"_id": ObjectId(source_id)},
|
||||
{"$set": {"file_name_map": file_name_map}},
|
||||
)
|
||||
|
||||
# Trigger re-ingestion pipeline
|
||||
|
||||
from application.api.user.tasks import reingest_source_task
|
||||
|
||||
task = reingest_source_task.delay(source_id=source_id, user=user)
|
||||
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": True,
|
||||
"message": f"Successfully removed directory: {directory_path}",
|
||||
"removed_directory": directory_path,
|
||||
"reingest_task_id": task.id,
|
||||
}
|
||||
),
|
||||
200,
|
||||
)
|
||||
except Exception as err:
|
||||
error_context = f"operation={operation}, user={user}, source_id={source_id}"
|
||||
if operation == "remove_directory":
|
||||
directory_path = request.form.get("directory_path", "")
|
||||
error_context += f", directory_path={directory_path}"
|
||||
elif operation == "remove":
|
||||
file_paths_str = request.form.get("file_paths", "")
|
||||
error_context += f", file_paths={file_paths_str}"
|
||||
elif operation == "add":
|
||||
parent_dir = request.form.get("parent_dir", "")
|
||||
error_context += f", parent_dir={parent_dir}"
|
||||
current_app.logger.error(
|
||||
f"Error managing source files: {err} ({error_context})", exc_info=True
|
||||
)
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Operation failed"}), 500
|
||||
)
|
||||
|
||||
|
||||
@sources_upload_ns.route("/task_status")
|
||||
class TaskStatus(Resource):
|
||||
task_status_model = api.model(
|
||||
"TaskStatusModel",
|
||||
{"task_id": fields.String(required=True, description="Task ID")},
|
||||
)
|
||||
|
||||
@api.expect(task_status_model)
|
||||
@api.doc(description="Get celery job status")
|
||||
def get(self):
|
||||
task_id = request.args.get("task_id")
|
||||
if not task_id:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Task ID is required"}), 400
|
||||
)
|
||||
try:
|
||||
from application.celery_init import celery
|
||||
|
||||
task = celery.AsyncResult(task_id)
|
||||
task_meta = task.info
|
||||
print(f"Task status: {task.status}")
|
||||
|
||||
if task.status == "PENDING":
|
||||
inspect = celery.control.inspect()
|
||||
active_workers = inspect.ping()
|
||||
if not active_workers:
|
||||
raise ConnectionError("Service unavailable")
|
||||
|
||||
if not isinstance(
|
||||
task_meta, (dict, list, str, int, float, bool, type(None))
|
||||
):
|
||||
task_meta = str(task_meta) # Convert to a string representation
|
||||
except ConnectionError as err:
|
||||
current_app.logger.error(f"Connection error getting task status: {err}")
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Service unavailable"}), 503
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error getting task status: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify({"status": task.status, "result": task_meta}), 200)
|
||||
@@ -5,28 +5,14 @@ from application.worker import (
|
||||
agent_webhook_worker,
|
||||
attachment_worker,
|
||||
ingest_worker,
|
||||
mcp_oauth,
|
||||
mcp_oauth_status,
|
||||
remote_worker,
|
||||
sync,
|
||||
sync_worker,
|
||||
)
|
||||
|
||||
|
||||
@celery.task(bind=True)
|
||||
def ingest(
|
||||
self, directory, formats, job_name, user, file_path, filename, file_name_map=None
|
||||
):
|
||||
resp = ingest_worker(
|
||||
self,
|
||||
directory,
|
||||
formats,
|
||||
job_name,
|
||||
file_path,
|
||||
filename,
|
||||
user,
|
||||
file_name_map=file_name_map,
|
||||
)
|
||||
def ingest(self, directory, formats, job_name, user, file_path, filename):
|
||||
resp = ingest_worker(self, directory, formats, job_name, file_path, filename, user)
|
||||
return resp
|
||||
|
||||
|
||||
@@ -39,7 +25,6 @@ def ingest_remote(self, source_data, job_name, user, loader):
|
||||
@celery.task(bind=True)
|
||||
def reingest_source_task(self, source_id, user):
|
||||
from application.worker import reingest_source_worker
|
||||
|
||||
resp = reingest_source_worker(self, source_id, user)
|
||||
return resp
|
||||
|
||||
@@ -50,30 +35,6 @@ def schedule_syncs(self, frequency):
|
||||
return resp
|
||||
|
||||
|
||||
@celery.task(bind=True)
|
||||
def sync_source(
|
||||
self,
|
||||
source_data,
|
||||
job_name,
|
||||
user,
|
||||
loader,
|
||||
sync_frequency,
|
||||
retriever,
|
||||
doc_id,
|
||||
):
|
||||
resp = sync(
|
||||
self,
|
||||
source_data,
|
||||
job_name,
|
||||
user,
|
||||
loader,
|
||||
sync_frequency,
|
||||
retriever,
|
||||
doc_id,
|
||||
)
|
||||
return resp
|
||||
|
||||
|
||||
@celery.task(bind=True)
|
||||
def store_attachment(self, file_info, user):
|
||||
resp = attachment_worker(self, file_info, user)
|
||||
@@ -99,10 +60,9 @@ def ingest_connector_task(
|
||||
retriever="classic",
|
||||
operation_mode="upload",
|
||||
doc_id=None,
|
||||
sync_frequency="never",
|
||||
sync_frequency="never"
|
||||
):
|
||||
from application.worker import ingest_connector
|
||||
|
||||
resp = ingest_connector(
|
||||
self,
|
||||
job_name,
|
||||
@@ -115,7 +75,7 @@ def ingest_connector_task(
|
||||
retriever=retriever,
|
||||
operation_mode=operation_mode,
|
||||
doc_id=doc_id,
|
||||
sync_frequency=sync_frequency,
|
||||
sync_frequency=sync_frequency
|
||||
)
|
||||
return resp
|
||||
|
||||
@@ -134,15 +94,3 @@ def setup_periodic_tasks(sender, **kwargs):
|
||||
timedelta(days=30),
|
||||
schedule_syncs.s("monthly"),
|
||||
)
|
||||
|
||||
|
||||
@celery.task(bind=True)
|
||||
def mcp_oauth_task(self, config, user):
|
||||
resp = mcp_oauth(self, config, user)
|
||||
return resp
|
||||
|
||||
|
||||
@celery.task(bind=True)
|
||||
def mcp_oauth_status_task(self, task_id):
|
||||
resp = mcp_oauth_status(self, task_id)
|
||||
return resp
|
||||
|
||||
@@ -1,6 +0,0 @@
|
||||
"""Tools module."""
|
||||
|
||||
from .mcp import tools_mcp_ns
|
||||
from .routes import tools_ns
|
||||
|
||||
__all__ = ["tools_ns", "tools_mcp_ns"]
|
||||
@@ -1,462 +0,0 @@
|
||||
"""Tool management MCP server integration."""
|
||||
|
||||
import json
|
||||
from urllib.parse import urlencode, urlparse
|
||||
|
||||
from bson.objectid import ObjectId
|
||||
from flask import current_app, jsonify, make_response, redirect, request
|
||||
from flask_restx import Namespace, Resource, fields
|
||||
|
||||
from application.agents.tools.mcp_tool import MCPOAuthManager, MCPTool
|
||||
from application.api import api
|
||||
from application.api.user.base import user_tools_collection
|
||||
from application.api.user.tools.routes import transform_actions
|
||||
from application.cache import get_redis_instance
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.core.settings import settings
|
||||
from application.security.encryption import decrypt_credentials, encrypt_credentials
|
||||
from application.utils import check_required_fields
|
||||
|
||||
tools_mcp_ns = Namespace("tools", description="Tool management operations", path="/api")
|
||||
|
||||
_mongo = MongoDB.get_client()
|
||||
_db = _mongo[settings.MONGO_DB_NAME]
|
||||
_connector_sessions = _db["connector_sessions"]
|
||||
|
||||
_ALLOWED_TRANSPORTS = {"auto", "sse", "http"}
|
||||
|
||||
|
||||
def _sanitize_mcp_transport(config):
|
||||
"""Normalise and validate the transport_type field.
|
||||
|
||||
Strips ``command`` / ``args`` keys that are only valid for local STDIO
|
||||
transports and returns the cleaned transport type string.
|
||||
"""
|
||||
transport_type = (config.get("transport_type") or "auto").lower()
|
||||
if transport_type not in _ALLOWED_TRANSPORTS:
|
||||
raise ValueError(f"Unsupported transport_type: {transport_type}")
|
||||
config.pop("command", None)
|
||||
config.pop("args", None)
|
||||
config["transport_type"] = transport_type
|
||||
return transport_type
|
||||
|
||||
|
||||
def _extract_auth_credentials(config):
|
||||
"""Build an ``auth_credentials`` dict from the raw MCP config."""
|
||||
auth_credentials = {}
|
||||
auth_type = config.get("auth_type", "none")
|
||||
|
||||
if auth_type == "api_key":
|
||||
if config.get("api_key"):
|
||||
auth_credentials["api_key"] = config["api_key"]
|
||||
if config.get("api_key_header"):
|
||||
auth_credentials["api_key_header"] = config["api_key_header"]
|
||||
elif auth_type == "bearer":
|
||||
if config.get("bearer_token"):
|
||||
auth_credentials["bearer_token"] = config["bearer_token"]
|
||||
elif auth_type == "basic":
|
||||
if config.get("username"):
|
||||
auth_credentials["username"] = config["username"]
|
||||
if config.get("password"):
|
||||
auth_credentials["password"] = config["password"]
|
||||
|
||||
return auth_credentials
|
||||
|
||||
|
||||
@tools_mcp_ns.route("/mcp_server/test")
|
||||
class TestMCPServerConfig(Resource):
|
||||
@api.expect(
|
||||
api.model(
|
||||
"MCPServerTestModel",
|
||||
{
|
||||
"config": fields.Raw(
|
||||
required=True, description="MCP server configuration to test"
|
||||
),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.doc(description="Test MCP server connection with provided configuration")
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
data = request.get_json()
|
||||
|
||||
required_fields = ["config"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
try:
|
||||
config = data["config"]
|
||||
try:
|
||||
_sanitize_mcp_transport(config)
|
||||
except ValueError:
|
||||
return make_response(
|
||||
jsonify({"success": False, "error": "Unsupported transport_type"}),
|
||||
400,
|
||||
)
|
||||
|
||||
auth_credentials = _extract_auth_credentials(config)
|
||||
test_config = config.copy()
|
||||
test_config["auth_credentials"] = auth_credentials
|
||||
|
||||
mcp_tool = MCPTool(config=test_config, user_id=user)
|
||||
result = mcp_tool.test_connection()
|
||||
|
||||
if result.get("requires_oauth"):
|
||||
return make_response(jsonify(result), 200)
|
||||
|
||||
if not result.get("success") and "message" in result:
|
||||
current_app.logger.error(
|
||||
f"MCP connection test failed: {result.get('message')}"
|
||||
)
|
||||
result["message"] = "Connection test failed"
|
||||
|
||||
return make_response(jsonify(result), 200)
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error testing MCP server: {e}", exc_info=True)
|
||||
return make_response(
|
||||
jsonify({"success": False, "error": "Connection test failed"}),
|
||||
500,
|
||||
)
|
||||
|
||||
|
||||
@tools_mcp_ns.route("/mcp_server/save")
|
||||
class MCPServerSave(Resource):
|
||||
@api.expect(
|
||||
api.model(
|
||||
"MCPServerSaveModel",
|
||||
{
|
||||
"id": fields.String(
|
||||
required=False, description="Tool ID for updates (optional)"
|
||||
),
|
||||
"displayName": fields.String(
|
||||
required=True, description="Display name for the MCP server"
|
||||
),
|
||||
"config": fields.Raw(
|
||||
required=True, description="MCP server configuration"
|
||||
),
|
||||
"status": fields.Boolean(
|
||||
required=False, default=True, description="Tool status"
|
||||
),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.doc(description="Create or update MCP server with automatic tool discovery")
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
data = request.get_json()
|
||||
|
||||
required_fields = ["displayName", "config"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
try:
|
||||
config = data["config"]
|
||||
try:
|
||||
_sanitize_mcp_transport(config)
|
||||
except ValueError:
|
||||
return make_response(
|
||||
jsonify({"success": False, "error": "Unsupported transport_type"}),
|
||||
400,
|
||||
)
|
||||
|
||||
auth_credentials = _extract_auth_credentials(config)
|
||||
auth_type = config.get("auth_type", "none")
|
||||
mcp_config = config.copy()
|
||||
mcp_config["auth_credentials"] = auth_credentials
|
||||
|
||||
if auth_type == "oauth":
|
||||
if not config.get("oauth_task_id"):
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"error": "Connection not authorized. Please complete the OAuth authorization first.",
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
redis_client = get_redis_instance()
|
||||
manager = MCPOAuthManager(redis_client)
|
||||
result = manager.get_oauth_status(config["oauth_task_id"])
|
||||
if not result.get("status") == "completed":
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"error": "OAuth failed or not completed. Please try authorizing again.",
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
actions_metadata = result.get("tools", [])
|
||||
elif auth_type == "none" or auth_credentials:
|
||||
mcp_tool = MCPTool(config=mcp_config, user_id=user)
|
||||
mcp_tool.discover_tools()
|
||||
actions_metadata = mcp_tool.get_actions_metadata()
|
||||
else:
|
||||
raise Exception(
|
||||
"No valid credentials provided for the selected authentication type"
|
||||
)
|
||||
storage_config = config.copy()
|
||||
|
||||
tool_id = data.get("id")
|
||||
existing_encrypted = None
|
||||
if tool_id:
|
||||
existing_doc = user_tools_collection.find_one(
|
||||
{"_id": ObjectId(tool_id), "user": user, "name": "mcp_tool"}
|
||||
)
|
||||
if existing_doc:
|
||||
existing_encrypted = existing_doc.get("config", {}).get(
|
||||
"encrypted_credentials"
|
||||
)
|
||||
|
||||
if auth_credentials:
|
||||
if existing_encrypted:
|
||||
existing_secrets = decrypt_credentials(existing_encrypted, user)
|
||||
existing_secrets.update(auth_credentials)
|
||||
auth_credentials = existing_secrets
|
||||
storage_config["encrypted_credentials"] = encrypt_credentials(
|
||||
auth_credentials, user
|
||||
)
|
||||
elif existing_encrypted:
|
||||
storage_config["encrypted_credentials"] = existing_encrypted
|
||||
|
||||
for field in [
|
||||
"api_key",
|
||||
"bearer_token",
|
||||
"username",
|
||||
"password",
|
||||
"api_key_header",
|
||||
"redirect_uri",
|
||||
]:
|
||||
storage_config.pop(field, None)
|
||||
transformed_actions = transform_actions(actions_metadata)
|
||||
tool_data = {
|
||||
"name": "mcp_tool",
|
||||
"displayName": data["displayName"],
|
||||
"customName": data["displayName"],
|
||||
"description": f"MCP Server: {storage_config.get('server_url', 'Unknown')}",
|
||||
"config": storage_config,
|
||||
"actions": transformed_actions,
|
||||
"status": data.get("status", True),
|
||||
"user": user,
|
||||
}
|
||||
|
||||
if tool_id:
|
||||
result = user_tools_collection.update_one(
|
||||
{"_id": ObjectId(tool_id), "user": user, "name": "mcp_tool"},
|
||||
{"$set": {k: v for k, v in tool_data.items() if k != "user"}},
|
||||
)
|
||||
if result.matched_count == 0:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"error": "Tool not found or access denied",
|
||||
}
|
||||
),
|
||||
404,
|
||||
)
|
||||
response_data = {
|
||||
"success": True,
|
||||
"id": tool_id,
|
||||
"message": f"MCP server updated successfully! Discovered {len(transformed_actions)} tools.",
|
||||
"tools_count": len(transformed_actions),
|
||||
}
|
||||
else:
|
||||
result = user_tools_collection.insert_one(tool_data)
|
||||
tool_id = str(result.inserted_id)
|
||||
response_data = {
|
||||
"success": True,
|
||||
"id": tool_id,
|
||||
"message": f"MCP server created successfully! Discovered {len(transformed_actions)} tools.",
|
||||
"tools_count": len(transformed_actions),
|
||||
}
|
||||
return make_response(jsonify(response_data), 200)
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error saving MCP server: {e}", exc_info=True)
|
||||
return make_response(
|
||||
jsonify({"success": False, "error": "Failed to save MCP server"}),
|
||||
500,
|
||||
)
|
||||
|
||||
|
||||
@tools_mcp_ns.route("/mcp_server/callback")
|
||||
class MCPOAuthCallback(Resource):
|
||||
@api.expect(
|
||||
api.model(
|
||||
"MCPServerCallbackModel",
|
||||
{
|
||||
"code": fields.String(required=True, description="Authorization code"),
|
||||
"state": fields.String(required=True, description="State parameter"),
|
||||
"error": fields.String(
|
||||
required=False, description="Error message (if any)"
|
||||
),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.doc(
|
||||
description="Handle OAuth callback by providing the authorization code and state"
|
||||
)
|
||||
def get(self):
|
||||
code = request.args.get("code")
|
||||
state = request.args.get("state")
|
||||
error = request.args.get("error")
|
||||
|
||||
if error:
|
||||
params = {
|
||||
"status": "error",
|
||||
"message": f"OAuth error: {error}. Please try again and make sure to grant all requested permissions, including offline access.",
|
||||
"provider": "mcp_tool",
|
||||
}
|
||||
return redirect(f"/api/connectors/callback-status?{urlencode(params)}")
|
||||
if not code or not state:
|
||||
return redirect(
|
||||
"/api/connectors/callback-status?status=error&message=Authorization+code+or+state+not+provided.+Please+complete+the+authorization+process+and+make+sure+to+grant+offline+access.&provider=mcp_tool"
|
||||
)
|
||||
try:
|
||||
redis_client = get_redis_instance()
|
||||
if not redis_client:
|
||||
return redirect(
|
||||
"/api/connectors/callback-status?status=error&message=Internal+server+error:+Redis+not+available.&provider=mcp_tool"
|
||||
)
|
||||
manager = MCPOAuthManager(redis_client)
|
||||
success = manager.handle_oauth_callback(state, code, error)
|
||||
if success:
|
||||
return redirect(
|
||||
"/api/connectors/callback-status?status=success&message=Authorization+code+received+successfully.+You+can+close+this+window.&provider=mcp_tool"
|
||||
)
|
||||
else:
|
||||
return redirect(
|
||||
"/api/connectors/callback-status?status=error&message=OAuth+callback+failed.&provider=mcp_tool"
|
||||
)
|
||||
except Exception as e:
|
||||
current_app.logger.error(
|
||||
f"Error handling MCP OAuth callback: {str(e)}", exc_info=True
|
||||
)
|
||||
return redirect(
|
||||
"/api/connectors/callback-status?status=error&message=Internal+server+error.&provider=mcp_tool"
|
||||
)
|
||||
|
||||
|
||||
@tools_mcp_ns.route("/mcp_server/oauth_status/<string:task_id>")
|
||||
class MCPOAuthStatus(Resource):
|
||||
def get(self, task_id):
|
||||
try:
|
||||
redis_client = get_redis_instance()
|
||||
status_key = f"mcp_oauth_status:{task_id}"
|
||||
status_data = redis_client.get(status_key)
|
||||
|
||||
if status_data:
|
||||
status = json.loads(status_data)
|
||||
if "tools" in status and isinstance(status["tools"], list):
|
||||
status["tools"] = [
|
||||
{
|
||||
"name": t.get("name", "unknown"),
|
||||
"description": t.get("description", ""),
|
||||
}
|
||||
for t in status["tools"]
|
||||
]
|
||||
return make_response(
|
||||
jsonify({"success": True, "task_id": task_id, **status})
|
||||
)
|
||||
else:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": True,
|
||||
"task_id": task_id,
|
||||
"status": "pending",
|
||||
"message": "Waiting for OAuth to start...",
|
||||
}
|
||||
),
|
||||
200,
|
||||
)
|
||||
except Exception as e:
|
||||
current_app.logger.error(
|
||||
f"Error getting OAuth status for task {task_id}: {str(e)}",
|
||||
exc_info=True,
|
||||
)
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"error": "Failed to get OAuth status",
|
||||
"task_id": task_id,
|
||||
}
|
||||
),
|
||||
500,
|
||||
)
|
||||
|
||||
|
||||
@tools_mcp_ns.route("/mcp_server/auth_status")
|
||||
class MCPAuthStatus(Resource):
|
||||
@api.doc(
|
||||
description="Batch check auth status for all MCP tools. "
|
||||
"Lightweight DB-only check — no network calls to MCP servers."
|
||||
)
|
||||
def get(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
try:
|
||||
mcp_tools = list(
|
||||
user_tools_collection.find(
|
||||
{"user": user, "name": "mcp_tool"},
|
||||
{"_id": 1, "config": 1},
|
||||
)
|
||||
)
|
||||
if not mcp_tools:
|
||||
return make_response(jsonify({"success": True, "statuses": {}}), 200)
|
||||
|
||||
oauth_server_urls = {}
|
||||
statuses = {}
|
||||
for tool in mcp_tools:
|
||||
tool_id = str(tool["_id"])
|
||||
config = tool.get("config", {})
|
||||
auth_type = config.get("auth_type", "none")
|
||||
if auth_type == "oauth":
|
||||
server_url = config.get("server_url", "")
|
||||
if server_url:
|
||||
parsed = urlparse(server_url)
|
||||
base_url = f"{parsed.scheme}://{parsed.netloc}"
|
||||
oauth_server_urls[tool_id] = base_url
|
||||
else:
|
||||
statuses[tool_id] = "needs_auth"
|
||||
else:
|
||||
statuses[tool_id] = "configured"
|
||||
|
||||
if oauth_server_urls:
|
||||
unique_urls = list(set(oauth_server_urls.values()))
|
||||
sessions = list(
|
||||
_connector_sessions.find(
|
||||
{"user_id": user, "server_url": {"$in": unique_urls}},
|
||||
{"server_url": 1, "tokens": 1},
|
||||
)
|
||||
)
|
||||
url_has_tokens = {
|
||||
doc["server_url"]: bool(doc.get("tokens", {}).get("access_token"))
|
||||
for doc in sessions
|
||||
}
|
||||
for tool_id, base_url in oauth_server_urls.items():
|
||||
if url_has_tokens.get(base_url):
|
||||
statuses[tool_id] = "connected"
|
||||
else:
|
||||
statuses[tool_id] = "needs_auth"
|
||||
|
||||
return make_response(jsonify({"success": True, "statuses": statuses}), 200)
|
||||
except Exception as e:
|
||||
current_app.logger.error(
|
||||
"Error checking MCP auth status: %s", e, exc_info=True
|
||||
)
|
||||
return make_response(
|
||||
jsonify({"success": False, "error": "Failed to check auth status"}),
|
||||
500,
|
||||
)
|
||||
@@ -1,700 +0,0 @@
|
||||
"""Tool management routes."""
|
||||
|
||||
from bson.objectid import ObjectId
|
||||
from flask import current_app, jsonify, make_response, request
|
||||
from flask_restx import fields, Namespace, Resource
|
||||
|
||||
from application.agents.tools.spec_parser import parse_spec
|
||||
from application.agents.tools.tool_manager import ToolManager
|
||||
from application.api import api
|
||||
from application.api.user.base import user_tools_collection
|
||||
from application.security.encryption import decrypt_credentials, encrypt_credentials
|
||||
from application.utils import check_required_fields, validate_function_name
|
||||
|
||||
tool_config = {}
|
||||
tool_manager = ToolManager(config=tool_config)
|
||||
|
||||
|
||||
def _encrypt_secret_fields(config, config_requirements, user_id):
|
||||
secret_keys = [
|
||||
key for key, spec in config_requirements.items()
|
||||
if spec.get("secret") and key in config and config[key]
|
||||
]
|
||||
if not secret_keys:
|
||||
return config
|
||||
|
||||
storage_config = config.copy()
|
||||
secret_values = {k: config[k] for k in secret_keys}
|
||||
storage_config["encrypted_credentials"] = encrypt_credentials(secret_values, user_id)
|
||||
for key in secret_keys:
|
||||
storage_config.pop(key, None)
|
||||
return storage_config
|
||||
|
||||
|
||||
def _validate_config(config, config_requirements, has_existing_secrets=False):
|
||||
errors = {}
|
||||
for key, spec in config_requirements.items():
|
||||
depends_on = spec.get("depends_on")
|
||||
if depends_on:
|
||||
if not all(config.get(dk) == dv for dk, dv in depends_on.items()):
|
||||
continue
|
||||
if spec.get("required") and not config.get(key):
|
||||
if has_existing_secrets and spec.get("secret"):
|
||||
continue
|
||||
errors[key] = f"{spec.get('label', key)} is required"
|
||||
value = config.get(key)
|
||||
if value is not None and value != "":
|
||||
if spec.get("type") == "number":
|
||||
try:
|
||||
num = float(value)
|
||||
if key == "timeout" and (num < 1 or num > 300):
|
||||
errors[key] = "Timeout must be between 1 and 300"
|
||||
except (ValueError, TypeError):
|
||||
errors[key] = f"{spec.get('label', key)} must be a number"
|
||||
if spec.get("enum") and value not in spec["enum"]:
|
||||
errors[key] = f"Invalid value for {spec.get('label', key)}"
|
||||
return errors
|
||||
|
||||
|
||||
def _merge_secrets_on_update(new_config, existing_config, config_requirements, user_id):
|
||||
"""Merge incoming config with existing encrypted secrets and re-encrypt.
|
||||
|
||||
For updates, the client may omit unchanged secret values. This helper
|
||||
decrypts any previously stored secrets, overlays whatever the client *did*
|
||||
send, strips plain-text secrets from the stored config, and re-encrypts
|
||||
the merged result.
|
||||
|
||||
Returns the final ``config`` dict ready for persistence.
|
||||
"""
|
||||
secret_keys = [
|
||||
key for key, spec in config_requirements.items()
|
||||
if spec.get("secret")
|
||||
]
|
||||
|
||||
if not secret_keys:
|
||||
return new_config
|
||||
|
||||
existing_secrets = {}
|
||||
if "encrypted_credentials" in existing_config:
|
||||
existing_secrets = decrypt_credentials(
|
||||
existing_config["encrypted_credentials"], user_id
|
||||
)
|
||||
|
||||
merged_secrets = existing_secrets.copy()
|
||||
for key in secret_keys:
|
||||
if key in new_config and new_config[key]:
|
||||
merged_secrets[key] = new_config[key]
|
||||
|
||||
# Start from existing non-secret values, then overlay incoming non-secrets
|
||||
storage_config = {
|
||||
k: v for k, v in existing_config.items()
|
||||
if k not in secret_keys and k != "encrypted_credentials"
|
||||
}
|
||||
storage_config.update(
|
||||
{k: v for k, v in new_config.items() if k not in secret_keys}
|
||||
)
|
||||
|
||||
if merged_secrets:
|
||||
storage_config["encrypted_credentials"] = encrypt_credentials(
|
||||
merged_secrets, user_id
|
||||
)
|
||||
else:
|
||||
storage_config.pop("encrypted_credentials", None)
|
||||
|
||||
storage_config.pop("has_encrypted_credentials", None)
|
||||
return storage_config
|
||||
|
||||
|
||||
def transform_actions(actions_metadata):
|
||||
"""Set default flags on action metadata for storage.
|
||||
|
||||
Marks each action as active, sets ``filled_by_llm`` and ``value`` on every
|
||||
parameter property. Used by both the generic create_tool and MCP save routes.
|
||||
"""
|
||||
transformed = []
|
||||
for action in actions_metadata:
|
||||
action["active"] = True
|
||||
if "parameters" in action:
|
||||
props = action["parameters"].get("properties", {})
|
||||
for param_details in props.values():
|
||||
param_details["filled_by_llm"] = True
|
||||
param_details["value"] = ""
|
||||
transformed.append(action)
|
||||
return transformed
|
||||
|
||||
|
||||
tools_ns = Namespace("tools", description="Tool management operations", path="/api")
|
||||
|
||||
|
||||
@tools_ns.route("/available_tools")
|
||||
class AvailableTools(Resource):
|
||||
@api.doc(description="Get available tools for a user")
|
||||
def get(self):
|
||||
try:
|
||||
tools_metadata = []
|
||||
for tool_name, tool_instance in tool_manager.tools.items():
|
||||
doc = tool_instance.__doc__.strip()
|
||||
lines = doc.split("\n", 1)
|
||||
name = lines[0].strip()
|
||||
description = lines[1].strip() if len(lines) > 1 else ""
|
||||
config_req = tool_instance.get_config_requirements()
|
||||
actions = tool_instance.get_actions_metadata()
|
||||
tools_metadata.append(
|
||||
{
|
||||
"name": tool_name,
|
||||
"displayName": name,
|
||||
"description": description,
|
||||
"configRequirements": config_req,
|
||||
"actions": actions,
|
||||
}
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error getting available tools: {err}", exc_info=True
|
||||
)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify({"success": True, "data": tools_metadata}), 200)
|
||||
|
||||
|
||||
@tools_ns.route("/get_tools")
|
||||
class GetTools(Resource):
|
||||
@api.doc(description="Get tools created by a user")
|
||||
def get(self):
|
||||
try:
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
tools = user_tools_collection.find({"user": user})
|
||||
user_tools = []
|
||||
for tool in tools:
|
||||
tool_copy = {**tool}
|
||||
tool_copy["id"] = str(tool["_id"])
|
||||
tool_copy.pop("_id", None)
|
||||
|
||||
config_req = tool_copy.get("configRequirements", {})
|
||||
if not config_req:
|
||||
tool_instance = tool_manager.tools.get(tool_copy.get("name"))
|
||||
if tool_instance:
|
||||
config_req = tool_instance.get_config_requirements()
|
||||
tool_copy["configRequirements"] = config_req
|
||||
|
||||
has_secrets = any(
|
||||
spec.get("secret") for spec in config_req.values()
|
||||
) if config_req else False
|
||||
if has_secrets and "encrypted_credentials" in tool_copy.get("config", {}):
|
||||
tool_copy["config"]["has_encrypted_credentials"] = True
|
||||
tool_copy["config"].pop("encrypted_credentials", None)
|
||||
|
||||
user_tools.append(tool_copy)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error getting user tools: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify({"success": True, "tools": user_tools}), 200)
|
||||
|
||||
|
||||
@tools_ns.route("/create_tool")
|
||||
class CreateTool(Resource):
|
||||
@api.expect(
|
||||
api.model(
|
||||
"CreateToolModel",
|
||||
{
|
||||
"name": fields.String(required=True, description="Name of the tool"),
|
||||
"displayName": fields.String(
|
||||
required=True, description="Display name for the tool"
|
||||
),
|
||||
"description": fields.String(
|
||||
required=True, description="Tool description"
|
||||
),
|
||||
"config": fields.Raw(
|
||||
required=True, description="Configuration of the tool"
|
||||
),
|
||||
"customName": fields.String(
|
||||
required=False, description="Custom name for the tool"
|
||||
),
|
||||
"status": fields.Boolean(
|
||||
required=True, description="Status of the tool"
|
||||
),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.doc(description="Create a new tool")
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
data = request.get_json()
|
||||
required_fields = [
|
||||
"name",
|
||||
"displayName",
|
||||
"description",
|
||||
"config",
|
||||
"status",
|
||||
]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
try:
|
||||
tool_instance = tool_manager.tools.get(data["name"])
|
||||
if not tool_instance:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Tool not found"}), 404
|
||||
)
|
||||
actions_metadata = tool_instance.get_actions_metadata()
|
||||
transformed_actions = transform_actions(actions_metadata)
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error getting tool actions: {err}", exc_info=True
|
||||
)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
try:
|
||||
config_requirements = tool_instance.get_config_requirements()
|
||||
if config_requirements:
|
||||
validation_errors = _validate_config(
|
||||
data["config"], config_requirements
|
||||
)
|
||||
if validation_errors:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": "Validation failed",
|
||||
"errors": validation_errors,
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
storage_config = _encrypt_secret_fields(
|
||||
data["config"], config_requirements, user
|
||||
)
|
||||
new_tool = {
|
||||
"user": user,
|
||||
"name": data["name"],
|
||||
"displayName": data["displayName"],
|
||||
"description": data["description"],
|
||||
"customName": data.get("customName", ""),
|
||||
"actions": transformed_actions,
|
||||
"config": storage_config,
|
||||
"configRequirements": config_requirements,
|
||||
"status": data["status"],
|
||||
}
|
||||
resp = user_tools_collection.insert_one(new_tool)
|
||||
new_id = str(resp.inserted_id)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error creating tool: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify({"id": new_id}), 200)
|
||||
|
||||
|
||||
@tools_ns.route("/update_tool")
|
||||
class UpdateTool(Resource):
|
||||
@api.expect(
|
||||
api.model(
|
||||
"UpdateToolModel",
|
||||
{
|
||||
"id": fields.String(required=True, description="Tool ID"),
|
||||
"name": fields.String(description="Name of the tool"),
|
||||
"displayName": fields.String(description="Display name for the tool"),
|
||||
"customName": fields.String(description="Custom name for the tool"),
|
||||
"description": fields.String(description="Tool description"),
|
||||
"config": fields.Raw(description="Configuration of the tool"),
|
||||
"actions": fields.List(
|
||||
fields.Raw, description="Actions the tool can perform"
|
||||
),
|
||||
"status": fields.Boolean(description="Status of the tool"),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.doc(description="Update a tool by ID")
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
data = request.get_json()
|
||||
required_fields = ["id"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
try:
|
||||
update_data = {}
|
||||
if "name" in data:
|
||||
update_data["name"] = data["name"]
|
||||
if "displayName" in data:
|
||||
update_data["displayName"] = data["displayName"]
|
||||
if "customName" in data:
|
||||
update_data["customName"] = data["customName"]
|
||||
if "description" in data:
|
||||
update_data["description"] = data["description"]
|
||||
if "actions" in data:
|
||||
update_data["actions"] = data["actions"]
|
||||
if "config" in data:
|
||||
if "actions" in data["config"]:
|
||||
for action_name in list(data["config"]["actions"].keys()):
|
||||
if not validate_function_name(action_name):
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": f"Invalid function name '{action_name}'. Function names must match pattern '^[a-zA-Z0-9_-]+$'.",
|
||||
"param": "tools[].function.name",
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
tool_doc = user_tools_collection.find_one(
|
||||
{"_id": ObjectId(data["id"]), "user": user}
|
||||
)
|
||||
if not tool_doc:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Tool not found"}),
|
||||
404,
|
||||
)
|
||||
tool_name = tool_doc.get("name", data.get("name"))
|
||||
tool_instance = tool_manager.tools.get(tool_name)
|
||||
config_requirements = (
|
||||
tool_instance.get_config_requirements() if tool_instance else {}
|
||||
)
|
||||
existing_config = tool_doc.get("config", {})
|
||||
has_existing_secrets = "encrypted_credentials" in existing_config
|
||||
|
||||
if config_requirements:
|
||||
validation_errors = _validate_config(
|
||||
data["config"], config_requirements,
|
||||
has_existing_secrets=has_existing_secrets,
|
||||
)
|
||||
if validation_errors:
|
||||
return make_response(
|
||||
jsonify({
|
||||
"success": False,
|
||||
"message": "Validation failed",
|
||||
"errors": validation_errors,
|
||||
}),
|
||||
400,
|
||||
)
|
||||
|
||||
update_data["config"] = _merge_secrets_on_update(
|
||||
data["config"], existing_config, config_requirements, user
|
||||
)
|
||||
if "status" in data:
|
||||
update_data["status"] = data["status"]
|
||||
user_tools_collection.update_one(
|
||||
{"_id": ObjectId(data["id"]), "user": user},
|
||||
{"$set": update_data},
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error updating tool: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify({"success": True}), 200)
|
||||
|
||||
|
||||
@tools_ns.route("/update_tool_config")
|
||||
class UpdateToolConfig(Resource):
|
||||
@api.expect(
|
||||
api.model(
|
||||
"UpdateToolConfigModel",
|
||||
{
|
||||
"id": fields.String(required=True, description="Tool ID"),
|
||||
"config": fields.Raw(
|
||||
required=True, description="Configuration of the tool"
|
||||
),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.doc(description="Update the configuration of a tool")
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
data = request.get_json()
|
||||
required_fields = ["id", "config"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
try:
|
||||
tool_doc = user_tools_collection.find_one(
|
||||
{"_id": ObjectId(data["id"]), "user": user}
|
||||
)
|
||||
if not tool_doc:
|
||||
return make_response(jsonify({"success": False}), 404)
|
||||
|
||||
tool_name = tool_doc.get("name")
|
||||
tool_instance = tool_manager.tools.get(tool_name)
|
||||
config_requirements = (
|
||||
tool_instance.get_config_requirements() if tool_instance else {}
|
||||
)
|
||||
existing_config = tool_doc.get("config", {})
|
||||
has_existing_secrets = "encrypted_credentials" in existing_config
|
||||
|
||||
if config_requirements:
|
||||
validation_errors = _validate_config(
|
||||
data["config"], config_requirements,
|
||||
has_existing_secrets=has_existing_secrets,
|
||||
)
|
||||
if validation_errors:
|
||||
return make_response(
|
||||
jsonify({
|
||||
"success": False,
|
||||
"message": "Validation failed",
|
||||
"errors": validation_errors,
|
||||
}),
|
||||
400,
|
||||
)
|
||||
|
||||
final_config = _merge_secrets_on_update(
|
||||
data["config"], existing_config, config_requirements, user
|
||||
)
|
||||
|
||||
user_tools_collection.update_one(
|
||||
{"_id": ObjectId(data["id"]), "user": user},
|
||||
{"$set": {"config": final_config}},
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error updating tool config: {err}", exc_info=True
|
||||
)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify({"success": True}), 200)
|
||||
|
||||
|
||||
@tools_ns.route("/update_tool_actions")
|
||||
class UpdateToolActions(Resource):
|
||||
@api.expect(
|
||||
api.model(
|
||||
"UpdateToolActionsModel",
|
||||
{
|
||||
"id": fields.String(required=True, description="Tool ID"),
|
||||
"actions": fields.List(
|
||||
fields.Raw,
|
||||
required=True,
|
||||
description="Actions the tool can perform",
|
||||
),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.doc(description="Update the actions of a tool")
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
data = request.get_json()
|
||||
required_fields = ["id", "actions"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
try:
|
||||
user_tools_collection.update_one(
|
||||
{"_id": ObjectId(data["id"]), "user": user},
|
||||
{"$set": {"actions": data["actions"]}},
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error updating tool actions: {err}", exc_info=True
|
||||
)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify({"success": True}), 200)
|
||||
|
||||
|
||||
@tools_ns.route("/update_tool_status")
|
||||
class UpdateToolStatus(Resource):
|
||||
@api.expect(
|
||||
api.model(
|
||||
"UpdateToolStatusModel",
|
||||
{
|
||||
"id": fields.String(required=True, description="Tool ID"),
|
||||
"status": fields.Boolean(
|
||||
required=True, description="Status of the tool"
|
||||
),
|
||||
},
|
||||
)
|
||||
)
|
||||
@api.doc(description="Update the status of a tool")
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
data = request.get_json()
|
||||
required_fields = ["id", "status"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
try:
|
||||
user_tools_collection.update_one(
|
||||
{"_id": ObjectId(data["id"]), "user": user},
|
||||
{"$set": {"status": data["status"]}},
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error updating tool status: {err}", exc_info=True
|
||||
)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify({"success": True}), 200)
|
||||
|
||||
|
||||
@tools_ns.route("/delete_tool")
|
||||
class DeleteTool(Resource):
|
||||
@api.expect(
|
||||
api.model(
|
||||
"DeleteToolModel",
|
||||
{"id": fields.String(required=True, description="Tool ID")},
|
||||
)
|
||||
)
|
||||
@api.doc(description="Delete a tool by ID")
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
data = request.get_json()
|
||||
required_fields = ["id"]
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
try:
|
||||
result = user_tools_collection.delete_one(
|
||||
{"_id": ObjectId(data["id"]), "user": user}
|
||||
)
|
||||
if result.deleted_count == 0:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Tool not found"}), 404
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error deleting tool: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify({"success": True}), 200)
|
||||
|
||||
|
||||
@tools_ns.route("/parse_spec")
|
||||
class ParseSpec(Resource):
|
||||
@api.doc(
|
||||
description="Parse an API specification (OpenAPI 3.x or Swagger 2.0) and return actions"
|
||||
)
|
||||
def post(self):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
if "file" in request.files:
|
||||
file = request.files["file"]
|
||||
if not file.filename:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "No file selected"}), 400
|
||||
)
|
||||
try:
|
||||
spec_content = file.read().decode("utf-8")
|
||||
except UnicodeDecodeError:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Invalid file encoding"}), 400
|
||||
)
|
||||
elif request.is_json:
|
||||
data = request.get_json()
|
||||
spec_content = data.get("spec_content", "")
|
||||
else:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "No spec provided"}), 400
|
||||
)
|
||||
if not spec_content or not spec_content.strip():
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Empty spec content"}), 400
|
||||
)
|
||||
try:
|
||||
metadata, actions = parse_spec(spec_content)
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": True,
|
||||
"metadata": metadata,
|
||||
"actions": actions,
|
||||
}
|
||||
),
|
||||
200,
|
||||
)
|
||||
except ValueError as e:
|
||||
current_app.logger.error(f"Spec validation error: {e}")
|
||||
return make_response(jsonify({"success": False, "error": "Invalid specification format"}), 400)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error parsing spec: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False, "error": "Failed to parse specification"}), 500)
|
||||
|
||||
|
||||
@tools_ns.route("/artifact/<artifact_id>")
|
||||
class GetArtifact(Resource):
|
||||
@api.doc(description="Get artifact data by artifact ID. Returns all todos for the tool when fetching a todo artifact.")
|
||||
def get(self, artifact_id: str):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user_id = decoded_token.get("sub")
|
||||
|
||||
try:
|
||||
obj_id = ObjectId(artifact_id)
|
||||
except Exception:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Invalid artifact ID"}), 400
|
||||
)
|
||||
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.core.settings import settings
|
||||
|
||||
db = MongoDB.get_client()[settings.MONGO_DB_NAME]
|
||||
|
||||
note_doc = db["notes"].find_one({"_id": obj_id, "user_id": user_id})
|
||||
if note_doc:
|
||||
content = note_doc.get("note", "")
|
||||
line_count = len(content.split("\n")) if content else 0
|
||||
artifact = {
|
||||
"artifact_type": "note",
|
||||
"data": {
|
||||
"content": content,
|
||||
"line_count": line_count,
|
||||
"updated_at": (
|
||||
note_doc["updated_at"].isoformat()
|
||||
if note_doc.get("updated_at")
|
||||
else None
|
||||
),
|
||||
},
|
||||
}
|
||||
return make_response(jsonify({"success": True, "artifact": artifact}), 200)
|
||||
|
||||
todo_doc = db["todos"].find_one({"_id": obj_id, "user_id": user_id})
|
||||
if todo_doc:
|
||||
tool_id = todo_doc.get("tool_id")
|
||||
query = {"user_id": user_id, "tool_id": tool_id}
|
||||
all_todos = list(db["todos"].find(query))
|
||||
items = []
|
||||
open_count = 0
|
||||
completed_count = 0
|
||||
for t in all_todos:
|
||||
status = t.get("status", "open")
|
||||
if status == "open":
|
||||
open_count += 1
|
||||
elif status == "completed":
|
||||
completed_count += 1
|
||||
items.append({
|
||||
"todo_id": t.get("todo_id"),
|
||||
"title": t.get("title", ""),
|
||||
"status": status,
|
||||
"created_at": (
|
||||
t["created_at"].isoformat() if t.get("created_at") else None
|
||||
),
|
||||
"updated_at": (
|
||||
t["updated_at"].isoformat() if t.get("updated_at") else None
|
||||
),
|
||||
})
|
||||
artifact = {
|
||||
"artifact_type": "todo_list",
|
||||
"data": {
|
||||
"items": items,
|
||||
"total_count": len(items),
|
||||
"open_count": open_count,
|
||||
"completed_count": completed_count,
|
||||
},
|
||||
}
|
||||
return make_response(jsonify({"success": True, "artifact": artifact}), 200)
|
||||
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Artifact not found"}), 404
|
||||
)
|
||||
@@ -1,387 +0,0 @@
|
||||
"""Centralized utilities for API routes."""
|
||||
|
||||
from functools import wraps
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
from bson.errors import InvalidId
|
||||
from bson.objectid import ObjectId
|
||||
from flask import (
|
||||
Response,
|
||||
current_app,
|
||||
has_app_context,
|
||||
jsonify,
|
||||
make_response,
|
||||
request,
|
||||
)
|
||||
from pymongo.collection import Collection
|
||||
|
||||
|
||||
def get_user_id() -> Optional[str]:
|
||||
"""
|
||||
Extract user ID from decoded JWT token.
|
||||
|
||||
Returns:
|
||||
User ID string or None if not authenticated
|
||||
"""
|
||||
decoded_token = getattr(request, "decoded_token", None)
|
||||
return decoded_token.get("sub") if decoded_token else None
|
||||
|
||||
|
||||
def require_auth(func: Callable) -> Callable:
|
||||
"""
|
||||
Decorator to require authentication for route handlers.
|
||||
|
||||
Usage:
|
||||
@require_auth
|
||||
def get(self):
|
||||
user_id = get_user_id()
|
||||
...
|
||||
"""
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
user_id = get_user_id()
|
||||
if not user_id:
|
||||
return error_response("Unauthorized", 401)
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def success_response(
|
||||
data: Optional[Dict[str, Any]] = None, status: int = 200
|
||||
) -> Response:
|
||||
"""
|
||||
Create a standardized success response.
|
||||
|
||||
Args:
|
||||
data: Optional data dictionary to include in response
|
||||
status: HTTP status code (default: 200)
|
||||
|
||||
Returns:
|
||||
Flask Response object
|
||||
|
||||
Example:
|
||||
return success_response({"users": [...], "total": 10})
|
||||
"""
|
||||
response = {"success": True}
|
||||
if data:
|
||||
response.update(data)
|
||||
return make_response(jsonify(response), status)
|
||||
|
||||
|
||||
def error_response(message: str, status: int = 400, **kwargs) -> Response:
|
||||
"""
|
||||
Create a standardized error response.
|
||||
|
||||
Args:
|
||||
message: Error message string
|
||||
status: HTTP status code (default: 400)
|
||||
**kwargs: Additional fields to include in response
|
||||
|
||||
Returns:
|
||||
Flask Response object
|
||||
|
||||
Example:
|
||||
return error_response("Resource not found", 404)
|
||||
return error_response("Invalid input", 400, errors=["field1", "field2"])
|
||||
"""
|
||||
response = {"success": False, "message": message}
|
||||
response.update(kwargs)
|
||||
return make_response(jsonify(response), status)
|
||||
|
||||
|
||||
def validate_object_id(
|
||||
id_string: str, resource_name: str = "Resource"
|
||||
) -> Tuple[Optional[ObjectId], Optional[Response]]:
|
||||
"""
|
||||
Validate and convert string to ObjectId.
|
||||
|
||||
Args:
|
||||
id_string: String to convert
|
||||
resource_name: Name of resource for error message
|
||||
|
||||
Returns:
|
||||
Tuple of (ObjectId or None, error_response or None)
|
||||
|
||||
Example:
|
||||
obj_id, error = validate_object_id(workflow_id, "Workflow")
|
||||
if error:
|
||||
return error
|
||||
"""
|
||||
try:
|
||||
return ObjectId(id_string), None
|
||||
except (InvalidId, TypeError):
|
||||
return None, error_response(f"Invalid {resource_name} ID format")
|
||||
|
||||
|
||||
def validate_pagination(
|
||||
default_limit: int = 20, max_limit: int = 100
|
||||
) -> Tuple[int, int, Optional[Response]]:
|
||||
"""
|
||||
Extract and validate pagination parameters from request.
|
||||
|
||||
Args:
|
||||
default_limit: Default items per page
|
||||
max_limit: Maximum allowed items per page
|
||||
|
||||
Returns:
|
||||
Tuple of (limit, skip, error_response or None)
|
||||
|
||||
Example:
|
||||
limit, skip, error = validate_pagination()
|
||||
if error:
|
||||
return error
|
||||
"""
|
||||
try:
|
||||
limit = min(int(request.args.get("limit", default_limit)), max_limit)
|
||||
skip = int(request.args.get("skip", 0))
|
||||
if limit < 1 or skip < 0:
|
||||
return 0, 0, error_response("Invalid pagination parameters")
|
||||
return limit, skip, None
|
||||
except ValueError:
|
||||
return 0, 0, error_response("Invalid pagination parameters")
|
||||
|
||||
|
||||
def check_resource_ownership(
|
||||
collection: Collection,
|
||||
resource_id: ObjectId,
|
||||
user_id: str,
|
||||
resource_name: str = "Resource",
|
||||
) -> Tuple[Optional[Dict], Optional[Response]]:
|
||||
"""
|
||||
Check if resource exists and belongs to user.
|
||||
|
||||
Args:
|
||||
collection: MongoDB collection
|
||||
resource_id: Resource ObjectId
|
||||
user_id: User ID string
|
||||
resource_name: Name of resource for error messages
|
||||
|
||||
Returns:
|
||||
Tuple of (resource_dict or None, error_response or None)
|
||||
|
||||
Example:
|
||||
workflow, error = check_resource_ownership(
|
||||
workflows_collection,
|
||||
workflow_id,
|
||||
user_id,
|
||||
"Workflow"
|
||||
)
|
||||
if error:
|
||||
return error
|
||||
"""
|
||||
resource = collection.find_one({"_id": resource_id, "user": user_id})
|
||||
if not resource:
|
||||
return None, error_response(f"{resource_name} not found", 404)
|
||||
return resource, None
|
||||
|
||||
|
||||
def serialize_object_id(
|
||||
obj: Dict[str, Any], id_field: str = "_id", new_field: str = "id"
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Convert ObjectId to string in a dictionary.
|
||||
|
||||
Args:
|
||||
obj: Dictionary containing ObjectId
|
||||
id_field: Field name containing ObjectId
|
||||
new_field: New field name for string ID
|
||||
|
||||
Returns:
|
||||
Modified dictionary
|
||||
|
||||
Example:
|
||||
user = serialize_object_id(user_doc)
|
||||
# user["id"] = "507f1f77bcf86cd799439011"
|
||||
"""
|
||||
if id_field in obj:
|
||||
obj[new_field] = str(obj[id_field])
|
||||
if id_field != new_field:
|
||||
obj.pop(id_field, None)
|
||||
return obj
|
||||
|
||||
|
||||
def serialize_list(items: List[Dict], serializer: Callable[[Dict], Dict]) -> List[Dict]:
|
||||
"""
|
||||
Apply serializer function to list of items.
|
||||
|
||||
Args:
|
||||
items: List of dictionaries
|
||||
serializer: Function to apply to each item
|
||||
|
||||
Returns:
|
||||
List of serialized items
|
||||
|
||||
Example:
|
||||
workflows = serialize_list(workflow_docs, serialize_workflow)
|
||||
"""
|
||||
return [serializer(item) for item in items]
|
||||
|
||||
|
||||
def paginated_response(
|
||||
collection: Collection,
|
||||
query: Dict[str, Any],
|
||||
serializer: Callable[[Dict], Dict],
|
||||
limit: int,
|
||||
skip: int,
|
||||
sort_field: str = "created_at",
|
||||
sort_order: int = -1,
|
||||
response_key: str = "items",
|
||||
) -> Response:
|
||||
"""
|
||||
Create paginated response for collection query.
|
||||
|
||||
Args:
|
||||
collection: MongoDB collection
|
||||
query: Query dictionary
|
||||
serializer: Function to serialize each item
|
||||
limit: Items per page
|
||||
skip: Number of items to skip
|
||||
sort_field: Field to sort by
|
||||
sort_order: Sort order (1=asc, -1=desc)
|
||||
response_key: Key name for items in response
|
||||
|
||||
Returns:
|
||||
Flask Response with paginated data
|
||||
|
||||
Example:
|
||||
return paginated_response(
|
||||
workflows_collection,
|
||||
{"user": user_id},
|
||||
serialize_workflow,
|
||||
limit, skip,
|
||||
response_key="workflows"
|
||||
)
|
||||
"""
|
||||
items = list(
|
||||
collection.find(query).sort(sort_field, sort_order).skip(skip).limit(limit)
|
||||
)
|
||||
total = collection.count_documents(query)
|
||||
|
||||
return success_response(
|
||||
{
|
||||
response_key: serialize_list(items, serializer),
|
||||
"total": total,
|
||||
"limit": limit,
|
||||
"skip": skip,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def require_fields(required: List[str]) -> Callable:
|
||||
"""
|
||||
Decorator to validate required fields in request JSON.
|
||||
|
||||
Args:
|
||||
required: List of required field names
|
||||
|
||||
Returns:
|
||||
Decorator function
|
||||
|
||||
Example:
|
||||
@require_fields(["name", "description"])
|
||||
def post(self):
|
||||
data = request.get_json()
|
||||
...
|
||||
"""
|
||||
|
||||
def decorator(func: Callable) -> Callable:
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
data = request.get_json()
|
||||
if not data:
|
||||
return error_response("Request body required")
|
||||
missing = [field for field in required if not data.get(field)]
|
||||
if missing:
|
||||
return error_response(f"Missing required fields: {', '.join(missing)}")
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def safe_db_operation(
|
||||
operation: Callable, error_message: str = "Database operation failed"
|
||||
) -> Tuple[Any, Optional[Response]]:
|
||||
"""
|
||||
Safely execute database operation with error handling.
|
||||
|
||||
Args:
|
||||
operation: Function to execute
|
||||
error_message: Error message if operation fails
|
||||
|
||||
Returns:
|
||||
Tuple of (result or None, error_response or None)
|
||||
|
||||
Example:
|
||||
result, error = safe_db_operation(
|
||||
lambda: collection.insert_one(doc),
|
||||
"Failed to create resource"
|
||||
)
|
||||
if error:
|
||||
return error
|
||||
"""
|
||||
try:
|
||||
result = operation()
|
||||
return result, None
|
||||
except Exception as err:
|
||||
if has_app_context():
|
||||
current_app.logger.error(f"{error_message}: {err}", exc_info=True)
|
||||
return None, error_response(error_message)
|
||||
|
||||
|
||||
def validate_enum(
|
||||
value: Any, allowed: List[Any], field_name: str
|
||||
) -> Optional[Response]:
|
||||
"""
|
||||
Validate that value is in allowed list.
|
||||
|
||||
Args:
|
||||
value: Value to validate
|
||||
allowed: List of allowed values
|
||||
field_name: Field name for error message
|
||||
|
||||
Returns:
|
||||
error_response if invalid, None if valid
|
||||
|
||||
Example:
|
||||
error = validate_enum(status, ["draft", "published"], "status")
|
||||
if error:
|
||||
return error
|
||||
"""
|
||||
if value not in allowed:
|
||||
allowed_str = ", ".join(f"'{v}'" for v in allowed)
|
||||
return error_response(f"Invalid {field_name}. Must be one of: {allowed_str}")
|
||||
return None
|
||||
|
||||
|
||||
def extract_sort_params(
|
||||
default_field: str = "created_at",
|
||||
default_order: str = "desc",
|
||||
allowed_fields: Optional[List[str]] = None,
|
||||
) -> Tuple[str, int]:
|
||||
"""
|
||||
Extract and validate sort parameters from request.
|
||||
|
||||
Args:
|
||||
default_field: Default sort field
|
||||
default_order: Default sort order ("asc" or "desc")
|
||||
allowed_fields: List of allowed sort fields (None = no validation)
|
||||
|
||||
Returns:
|
||||
Tuple of (sort_field, sort_order)
|
||||
|
||||
Example:
|
||||
sort_field, sort_order = extract_sort_params(
|
||||
allowed_fields=["name", "date", "status"]
|
||||
)
|
||||
"""
|
||||
sort_field = request.args.get("sort", default_field)
|
||||
sort_order_str = request.args.get("order", default_order).lower()
|
||||
|
||||
if allowed_fields and sort_field not in allowed_fields:
|
||||
sort_field = default_field
|
||||
sort_order = -1 if sort_order_str == "desc" else 1
|
||||
return sort_field, sort_order
|
||||
@@ -1,3 +0,0 @@
|
||||
from .routes import workflows_ns
|
||||
|
||||
__all__ = ["workflows_ns"]
|
||||
@@ -1,546 +0,0 @@
|
||||
"""Workflow management routes."""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Dict, List, Optional, Set
|
||||
|
||||
from flask import current_app, request
|
||||
from flask_restx import Namespace, Resource
|
||||
|
||||
from application.api.user.base import (
|
||||
workflow_edges_collection,
|
||||
workflow_nodes_collection,
|
||||
workflows_collection,
|
||||
)
|
||||
from application.core.json_schema_utils import (
|
||||
JsonSchemaValidationError,
|
||||
normalize_json_schema_payload,
|
||||
)
|
||||
from application.core.model_utils import get_model_capabilities
|
||||
from application.api.user.utils import (
|
||||
check_resource_ownership,
|
||||
error_response,
|
||||
get_user_id,
|
||||
require_auth,
|
||||
require_fields,
|
||||
safe_db_operation,
|
||||
success_response,
|
||||
validate_object_id,
|
||||
)
|
||||
|
||||
workflows_ns = Namespace("workflows", path="/api")
|
||||
|
||||
|
||||
def _workflow_error_response(message: str, err: Exception):
|
||||
current_app.logger.error(f"{message}: {err}", exc_info=True)
|
||||
return error_response(message)
|
||||
|
||||
|
||||
def serialize_workflow(w: Dict) -> Dict:
|
||||
"""Serialize workflow document to API response format."""
|
||||
return {
|
||||
"id": str(w["_id"]),
|
||||
"name": w.get("name"),
|
||||
"description": w.get("description"),
|
||||
"created_at": w["created_at"].isoformat() if w.get("created_at") else None,
|
||||
"updated_at": w["updated_at"].isoformat() if w.get("updated_at") else None,
|
||||
}
|
||||
|
||||
|
||||
def serialize_node(n: Dict) -> Dict:
|
||||
"""Serialize workflow node document to API response format."""
|
||||
return {
|
||||
"id": n["id"],
|
||||
"type": n["type"],
|
||||
"title": n.get("title"),
|
||||
"description": n.get("description"),
|
||||
"position": n.get("position"),
|
||||
"data": n.get("config", {}),
|
||||
}
|
||||
|
||||
|
||||
def serialize_edge(e: Dict) -> Dict:
|
||||
"""Serialize workflow edge document to API response format."""
|
||||
return {
|
||||
"id": e["id"],
|
||||
"source": e.get("source_id"),
|
||||
"target": e.get("target_id"),
|
||||
"sourceHandle": e.get("source_handle"),
|
||||
"targetHandle": e.get("target_handle"),
|
||||
}
|
||||
|
||||
|
||||
def get_workflow_graph_version(workflow: Dict) -> int:
|
||||
"""Get current graph version with legacy fallback."""
|
||||
raw_version = workflow.get("current_graph_version", 1)
|
||||
try:
|
||||
version = int(raw_version)
|
||||
return version if version > 0 else 1
|
||||
except (ValueError, TypeError):
|
||||
return 1
|
||||
|
||||
|
||||
def fetch_graph_documents(collection, workflow_id: str, graph_version: int) -> List[Dict]:
|
||||
"""Fetch graph docs for active version, with fallback for legacy unversioned data."""
|
||||
docs = list(
|
||||
collection.find({"workflow_id": workflow_id, "graph_version": graph_version})
|
||||
)
|
||||
if docs:
|
||||
return docs
|
||||
if graph_version == 1:
|
||||
return list(
|
||||
collection.find(
|
||||
{"workflow_id": workflow_id, "graph_version": {"$exists": False}}
|
||||
)
|
||||
)
|
||||
return docs
|
||||
|
||||
|
||||
def validate_json_schema_payload(
|
||||
json_schema: Any,
|
||||
) -> tuple[Optional[Dict[str, Any]], Optional[str]]:
|
||||
"""Validate and normalize optional JSON schema payload for structured output."""
|
||||
if json_schema is None:
|
||||
return None, None
|
||||
try:
|
||||
return normalize_json_schema_payload(json_schema), None
|
||||
except JsonSchemaValidationError as exc:
|
||||
return None, str(exc)
|
||||
|
||||
|
||||
def normalize_agent_node_json_schemas(nodes: List[Dict]) -> List[Dict]:
|
||||
"""Normalize agent-node JSON schema payloads before persistence."""
|
||||
normalized_nodes: List[Dict] = []
|
||||
for node in nodes:
|
||||
if not isinstance(node, dict):
|
||||
normalized_nodes.append(node)
|
||||
continue
|
||||
|
||||
normalized_node = dict(node)
|
||||
if normalized_node.get("type") != "agent":
|
||||
normalized_nodes.append(normalized_node)
|
||||
continue
|
||||
|
||||
raw_config = normalized_node.get("data")
|
||||
if not isinstance(raw_config, dict) or "json_schema" not in raw_config:
|
||||
normalized_nodes.append(normalized_node)
|
||||
continue
|
||||
|
||||
normalized_config = dict(raw_config)
|
||||
try:
|
||||
normalized_config["json_schema"] = normalize_json_schema_payload(
|
||||
raw_config.get("json_schema")
|
||||
)
|
||||
except JsonSchemaValidationError:
|
||||
# Validation runs before normalization; keep original on unexpected shape.
|
||||
normalized_config["json_schema"] = raw_config.get("json_schema")
|
||||
normalized_node["data"] = normalized_config
|
||||
normalized_nodes.append(normalized_node)
|
||||
|
||||
return normalized_nodes
|
||||
|
||||
|
||||
def validate_workflow_structure(nodes: List[Dict], edges: List[Dict]) -> List[str]:
|
||||
"""Validate workflow graph structure."""
|
||||
errors = []
|
||||
|
||||
if not nodes:
|
||||
errors.append("Workflow must have at least one node")
|
||||
return errors
|
||||
|
||||
start_nodes = [n for n in nodes if n.get("type") == "start"]
|
||||
if len(start_nodes) != 1:
|
||||
errors.append("Workflow must have exactly one start node")
|
||||
|
||||
end_nodes = [n for n in nodes if n.get("type") == "end"]
|
||||
if not end_nodes:
|
||||
errors.append("Workflow must have at least one end node")
|
||||
|
||||
node_ids = {n.get("id") for n in nodes}
|
||||
node_map = {n.get("id"): n for n in nodes}
|
||||
end_ids = {n.get("id") for n in end_nodes}
|
||||
|
||||
for edge in edges:
|
||||
source_id = edge.get("source")
|
||||
target_id = edge.get("target")
|
||||
if source_id not in node_ids:
|
||||
errors.append(f"Edge references non-existent source: {source_id}")
|
||||
if target_id not in node_ids:
|
||||
errors.append(f"Edge references non-existent target: {target_id}")
|
||||
|
||||
if start_nodes:
|
||||
start_id = start_nodes[0].get("id")
|
||||
if not any(e.get("source") == start_id for e in edges):
|
||||
errors.append("Start node must have at least one outgoing edge")
|
||||
|
||||
condition_nodes = [n for n in nodes if n.get("type") == "condition"]
|
||||
for cnode in condition_nodes:
|
||||
cnode_id = cnode.get("id")
|
||||
cnode_title = cnode.get("title", cnode_id)
|
||||
outgoing = [e for e in edges if e.get("source") == cnode_id]
|
||||
if len(outgoing) < 2:
|
||||
errors.append(
|
||||
f"Condition node '{cnode_title}' must have at least 2 outgoing edges"
|
||||
)
|
||||
node_data = cnode.get("data", {}) or {}
|
||||
cases = node_data.get("cases", [])
|
||||
if not isinstance(cases, list):
|
||||
cases = []
|
||||
if not cases or not any(
|
||||
isinstance(c, dict) and str(c.get("expression", "")).strip() for c in cases
|
||||
):
|
||||
errors.append(
|
||||
f"Condition node '{cnode_title}' must have at least one case with an expression"
|
||||
)
|
||||
|
||||
case_handles: Set[str] = set()
|
||||
duplicate_case_handles: Set[str] = set()
|
||||
for case in cases:
|
||||
if not isinstance(case, dict):
|
||||
continue
|
||||
raw_handle = case.get("sourceHandle", "")
|
||||
handle = raw_handle.strip() if isinstance(raw_handle, str) else ""
|
||||
if not handle:
|
||||
errors.append(
|
||||
f"Condition node '{cnode_title}' has a case without a branch handle"
|
||||
)
|
||||
continue
|
||||
if handle in case_handles:
|
||||
duplicate_case_handles.add(handle)
|
||||
case_handles.add(handle)
|
||||
|
||||
for handle in duplicate_case_handles:
|
||||
errors.append(
|
||||
f"Condition node '{cnode_title}' has duplicate case handle '{handle}'"
|
||||
)
|
||||
|
||||
outgoing_by_handle: Dict[str, List[Dict]] = {}
|
||||
for out_edge in outgoing:
|
||||
raw_handle = out_edge.get("sourceHandle", "")
|
||||
handle = raw_handle.strip() if isinstance(raw_handle, str) else ""
|
||||
outgoing_by_handle.setdefault(handle, []).append(out_edge)
|
||||
|
||||
for handle, handle_edges in outgoing_by_handle.items():
|
||||
if not handle:
|
||||
errors.append(
|
||||
f"Condition node '{cnode_title}' has an outgoing edge without sourceHandle"
|
||||
)
|
||||
continue
|
||||
if handle != "else" and handle not in case_handles:
|
||||
errors.append(
|
||||
f"Condition node '{cnode_title}' has a connection from unknown branch '{handle}'"
|
||||
)
|
||||
if len(handle_edges) > 1:
|
||||
errors.append(
|
||||
f"Condition node '{cnode_title}' has multiple outgoing edges from branch '{handle}'"
|
||||
)
|
||||
|
||||
if "else" not in outgoing_by_handle:
|
||||
errors.append(f"Condition node '{cnode_title}' must have an 'else' branch")
|
||||
|
||||
for case in cases:
|
||||
if not isinstance(case, dict):
|
||||
continue
|
||||
raw_handle = case.get("sourceHandle", "")
|
||||
handle = raw_handle.strip() if isinstance(raw_handle, str) else ""
|
||||
if not handle:
|
||||
continue
|
||||
|
||||
raw_expression = case.get("expression", "")
|
||||
has_expression = isinstance(raw_expression, str) and bool(
|
||||
raw_expression.strip()
|
||||
)
|
||||
has_outgoing = bool(outgoing_by_handle.get(handle))
|
||||
if has_expression and not has_outgoing:
|
||||
errors.append(
|
||||
f"Condition node '{cnode_title}' case '{handle}' has an expression but no outgoing edge"
|
||||
)
|
||||
if not has_expression and has_outgoing:
|
||||
errors.append(
|
||||
f"Condition node '{cnode_title}' case '{handle}' has an outgoing edge but no expression"
|
||||
)
|
||||
|
||||
for handle, handle_edges in outgoing_by_handle.items():
|
||||
if not handle:
|
||||
continue
|
||||
for out_edge in handle_edges:
|
||||
target = out_edge.get("target")
|
||||
if target and not _can_reach_end(target, edges, node_map, end_ids):
|
||||
errors.append(
|
||||
f"Branch '{handle}' of condition '{cnode_title}' "
|
||||
f"must eventually reach an end node"
|
||||
)
|
||||
|
||||
agent_nodes = [n for n in nodes if n.get("type") == "agent"]
|
||||
for agent_node in agent_nodes:
|
||||
agent_title = agent_node.get("title", agent_node.get("id", "unknown"))
|
||||
raw_config = agent_node.get("data", {}) or {}
|
||||
if not isinstance(raw_config, dict):
|
||||
errors.append(f"Agent node '{agent_title}' has invalid configuration")
|
||||
continue
|
||||
normalized_schema, schema_error = validate_json_schema_payload(
|
||||
raw_config.get("json_schema")
|
||||
)
|
||||
has_json_schema = normalized_schema is not None
|
||||
|
||||
model_id = raw_config.get("model_id")
|
||||
if has_json_schema and isinstance(model_id, str) and model_id.strip():
|
||||
capabilities = get_model_capabilities(model_id.strip())
|
||||
if capabilities and not capabilities.get("supports_structured_output", False):
|
||||
errors.append(
|
||||
f"Agent node '{agent_title}' selected model does not support structured output"
|
||||
)
|
||||
if schema_error:
|
||||
errors.append(f"Agent node '{agent_title}' JSON schema {schema_error}")
|
||||
|
||||
for node in nodes:
|
||||
if not node.get("id"):
|
||||
errors.append("All nodes must have an id")
|
||||
if not node.get("type"):
|
||||
errors.append(f"Node {node.get('id', 'unknown')} must have a type")
|
||||
|
||||
return errors
|
||||
|
||||
|
||||
def _can_reach_end(
|
||||
node_id: str, edges: List[Dict], node_map: Dict, end_ids: set, visited: set = None
|
||||
) -> bool:
|
||||
if visited is None:
|
||||
visited = set()
|
||||
if node_id in end_ids:
|
||||
return True
|
||||
if node_id in visited or node_id not in node_map:
|
||||
return False
|
||||
visited.add(node_id)
|
||||
outgoing = [e.get("target") for e in edges if e.get("source") == node_id]
|
||||
return any(_can_reach_end(t, edges, node_map, end_ids, visited) for t in outgoing if t)
|
||||
|
||||
|
||||
def create_workflow_nodes(
|
||||
workflow_id: str, nodes_data: List[Dict], graph_version: int
|
||||
) -> None:
|
||||
"""Insert workflow nodes into database."""
|
||||
if nodes_data:
|
||||
workflow_nodes_collection.insert_many(
|
||||
[
|
||||
{
|
||||
"id": n["id"],
|
||||
"workflow_id": workflow_id,
|
||||
"graph_version": graph_version,
|
||||
"type": n["type"],
|
||||
"title": n.get("title", ""),
|
||||
"description": n.get("description", ""),
|
||||
"position": n.get("position", {"x": 0, "y": 0}),
|
||||
"config": n.get("data", {}),
|
||||
}
|
||||
for n in nodes_data
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def create_workflow_edges(
|
||||
workflow_id: str, edges_data: List[Dict], graph_version: int
|
||||
) -> None:
|
||||
"""Insert workflow edges into database."""
|
||||
if edges_data:
|
||||
workflow_edges_collection.insert_many(
|
||||
[
|
||||
{
|
||||
"id": e["id"],
|
||||
"workflow_id": workflow_id,
|
||||
"graph_version": graph_version,
|
||||
"source_id": e.get("source"),
|
||||
"target_id": e.get("target"),
|
||||
"source_handle": e.get("sourceHandle"),
|
||||
"target_handle": e.get("targetHandle"),
|
||||
}
|
||||
for e in edges_data
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@workflows_ns.route("/workflows")
|
||||
class WorkflowList(Resource):
|
||||
|
||||
@require_auth
|
||||
@require_fields(["name"])
|
||||
def post(self):
|
||||
"""Create a new workflow with nodes and edges."""
|
||||
user_id = get_user_id()
|
||||
data = request.get_json()
|
||||
|
||||
name = data.get("name", "").strip()
|
||||
nodes_data = data.get("nodes", [])
|
||||
edges_data = data.get("edges", [])
|
||||
|
||||
validation_errors = validate_workflow_structure(nodes_data, edges_data)
|
||||
if validation_errors:
|
||||
return error_response(
|
||||
"Workflow validation failed", errors=validation_errors
|
||||
)
|
||||
nodes_data = normalize_agent_node_json_schemas(nodes_data)
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
workflow_doc = {
|
||||
"name": name,
|
||||
"description": data.get("description", ""),
|
||||
"user": user_id,
|
||||
"created_at": now,
|
||||
"updated_at": now,
|
||||
"current_graph_version": 1,
|
||||
}
|
||||
|
||||
result, error = safe_db_operation(
|
||||
lambda: workflows_collection.insert_one(workflow_doc),
|
||||
"Failed to create workflow",
|
||||
)
|
||||
if error:
|
||||
return error
|
||||
|
||||
workflow_id = str(result.inserted_id)
|
||||
|
||||
try:
|
||||
create_workflow_nodes(workflow_id, nodes_data, 1)
|
||||
create_workflow_edges(workflow_id, edges_data, 1)
|
||||
except Exception as err:
|
||||
workflow_nodes_collection.delete_many({"workflow_id": workflow_id})
|
||||
workflow_edges_collection.delete_many({"workflow_id": workflow_id})
|
||||
workflows_collection.delete_one({"_id": result.inserted_id})
|
||||
return _workflow_error_response("Failed to create workflow structure", err)
|
||||
|
||||
return success_response({"id": workflow_id}, 201)
|
||||
|
||||
|
||||
@workflows_ns.route("/workflows/<string:workflow_id>")
|
||||
class WorkflowDetail(Resource):
|
||||
|
||||
@require_auth
|
||||
def get(self, workflow_id: str):
|
||||
"""Get workflow details with nodes and edges."""
|
||||
user_id = get_user_id()
|
||||
obj_id, error = validate_object_id(workflow_id, "Workflow")
|
||||
if error:
|
||||
return error
|
||||
|
||||
workflow, error = check_resource_ownership(
|
||||
workflows_collection, obj_id, user_id, "Workflow"
|
||||
)
|
||||
if error:
|
||||
return error
|
||||
|
||||
graph_version = get_workflow_graph_version(workflow)
|
||||
nodes = fetch_graph_documents(
|
||||
workflow_nodes_collection, workflow_id, graph_version
|
||||
)
|
||||
edges = fetch_graph_documents(
|
||||
workflow_edges_collection, workflow_id, graph_version
|
||||
)
|
||||
|
||||
return success_response(
|
||||
{
|
||||
"workflow": serialize_workflow(workflow),
|
||||
"nodes": [serialize_node(n) for n in nodes],
|
||||
"edges": [serialize_edge(e) for e in edges],
|
||||
}
|
||||
)
|
||||
|
||||
@require_auth
|
||||
@require_fields(["name"])
|
||||
def put(self, workflow_id: str):
|
||||
"""Update workflow and replace nodes/edges."""
|
||||
user_id = get_user_id()
|
||||
obj_id, error = validate_object_id(workflow_id, "Workflow")
|
||||
if error:
|
||||
return error
|
||||
|
||||
workflow, error = check_resource_ownership(
|
||||
workflows_collection, obj_id, user_id, "Workflow"
|
||||
)
|
||||
if error:
|
||||
return error
|
||||
|
||||
data = request.get_json()
|
||||
name = data.get("name", "").strip()
|
||||
nodes_data = data.get("nodes", [])
|
||||
edges_data = data.get("edges", [])
|
||||
|
||||
validation_errors = validate_workflow_structure(nodes_data, edges_data)
|
||||
if validation_errors:
|
||||
return error_response(
|
||||
"Workflow validation failed", errors=validation_errors
|
||||
)
|
||||
nodes_data = normalize_agent_node_json_schemas(nodes_data)
|
||||
|
||||
current_graph_version = get_workflow_graph_version(workflow)
|
||||
next_graph_version = current_graph_version + 1
|
||||
try:
|
||||
create_workflow_nodes(workflow_id, nodes_data, next_graph_version)
|
||||
create_workflow_edges(workflow_id, edges_data, next_graph_version)
|
||||
except Exception as err:
|
||||
workflow_nodes_collection.delete_many(
|
||||
{"workflow_id": workflow_id, "graph_version": next_graph_version}
|
||||
)
|
||||
workflow_edges_collection.delete_many(
|
||||
{"workflow_id": workflow_id, "graph_version": next_graph_version}
|
||||
)
|
||||
return _workflow_error_response("Failed to update workflow structure", err)
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
_, error = safe_db_operation(
|
||||
lambda: workflows_collection.update_one(
|
||||
{"_id": obj_id},
|
||||
{
|
||||
"$set": {
|
||||
"name": name,
|
||||
"description": data.get("description", ""),
|
||||
"updated_at": now,
|
||||
"current_graph_version": next_graph_version,
|
||||
}
|
||||
},
|
||||
),
|
||||
"Failed to update workflow",
|
||||
)
|
||||
if error:
|
||||
workflow_nodes_collection.delete_many(
|
||||
{"workflow_id": workflow_id, "graph_version": next_graph_version}
|
||||
)
|
||||
workflow_edges_collection.delete_many(
|
||||
{"workflow_id": workflow_id, "graph_version": next_graph_version}
|
||||
)
|
||||
return error
|
||||
|
||||
try:
|
||||
workflow_nodes_collection.delete_many(
|
||||
{"workflow_id": workflow_id, "graph_version": {"$ne": next_graph_version}}
|
||||
)
|
||||
workflow_edges_collection.delete_many(
|
||||
{"workflow_id": workflow_id, "graph_version": {"$ne": next_graph_version}}
|
||||
)
|
||||
except Exception as cleanup_err:
|
||||
current_app.logger.warning(
|
||||
f"Failed to clean old workflow graph versions for {workflow_id}: {cleanup_err}"
|
||||
)
|
||||
|
||||
return success_response()
|
||||
|
||||
@require_auth
|
||||
def delete(self, workflow_id: str):
|
||||
"""Delete workflow and its graph."""
|
||||
user_id = get_user_id()
|
||||
obj_id, error = validate_object_id(workflow_id, "Workflow")
|
||||
if error:
|
||||
return error
|
||||
|
||||
workflow, error = check_resource_ownership(
|
||||
workflows_collection, obj_id, user_id, "Workflow"
|
||||
)
|
||||
if error:
|
||||
return error
|
||||
|
||||
try:
|
||||
workflow_nodes_collection.delete_many({"workflow_id": workflow_id})
|
||||
workflow_edges_collection.delete_many({"workflow_id": workflow_id})
|
||||
workflows_collection.delete_one({"_id": workflow["_id"], "user": user_id})
|
||||
except Exception as err:
|
||||
return _workflow_error_response("Failed to delete workflow", err)
|
||||
|
||||
return success_response()
|
||||
@@ -19,10 +19,6 @@ from application.api.user.routes import user # noqa: E402
|
||||
from application.api.connector.routes import connector # noqa: E402
|
||||
from application.celery_init import celery # noqa: E402
|
||||
from application.core.settings import settings # noqa: E402
|
||||
from application.stt.upload_limits import ( # noqa: E402
|
||||
build_stt_file_size_limit_message,
|
||||
should_reject_stt_request,
|
||||
)
|
||||
|
||||
|
||||
if platform.system() == "Windows":
|
||||
@@ -72,11 +68,6 @@ def home():
|
||||
return "Welcome to DocsGPT Backend!"
|
||||
|
||||
|
||||
@app.route("/api/health")
|
||||
def health():
|
||||
return jsonify({"status": "ok"})
|
||||
|
||||
|
||||
@app.route("/api/config")
|
||||
def get_config():
|
||||
response = {
|
||||
@@ -97,23 +88,6 @@ def generate_token():
|
||||
return jsonify({"error": "Token generation not allowed in current auth mode"}), 400
|
||||
|
||||
|
||||
@app.before_request
|
||||
def enforce_stt_request_size_limits():
|
||||
if request.method == "OPTIONS":
|
||||
return None
|
||||
if should_reject_stt_request(request.path, request.content_length):
|
||||
return (
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": build_stt_file_size_limit_message(),
|
||||
}
|
||||
),
|
||||
413,
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
@app.before_request
|
||||
def authenticate_request():
|
||||
if request.method == "OPTIONS":
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user