Compare commits

..

1 Commits

Author SHA1 Message Date
Pavel
9548364e05 answer endpoint attachments 2025-07-12 15:53:53 +02:00
1026 changed files with 42010 additions and 937093 deletions

View File

@@ -1,39 +1,9 @@
API_KEY=<LLM api key (for example, open ai key)>
LLM_NAME=docsgpt
VITE_API_STREAMING=true
INTERNAL_KEY=<internal key for worker-to-backend authentication>
# Provider-specific API keys (optional - use these to enable multiple providers)
# OPENAI_API_KEY=<your-openai-api-key>
# ANTHROPIC_API_KEY=<your-anthropic-api-key>
# GOOGLE_API_KEY=<your-google-api-key>
# GROQ_API_KEY=<your-groq-api-key>
# NOVITA_API_KEY=<your-novita-api-key>
# OPEN_ROUTER_API_KEY=<your-openrouter-api-key>
# Remote Embeddings (Optional - for using a remote embeddings API instead of local SentenceTransformer)
# When set, the app will use the remote API and won't load SentenceTransformer (saves RAM)
EMBEDDINGS_BASE_URL=
EMBEDDINGS_KEY=
#For Azure (you can delete it if you don't use Azure)
OPENAI_API_BASE=
OPENAI_API_VERSION=
AZURE_DEPLOYMENT_NAME=
AZURE_EMBEDDINGS_DEPLOYMENT_NAME=
#Azure AD Application (client) ID
MICROSOFT_CLIENT_ID=your-azure-ad-client-id
#Azure AD Application client secret
MICROSOFT_CLIENT_SECRET=your-azure-ad-client-secret
#Azure AD Tenant ID (or 'common' for multi-tenant)
MICROSOFT_TENANT_ID=your-azure-ad-tenant-id
#If you are using a Microsoft Entra ID tenant,
#configure the AUTHORITY variable as
#"https://login.microsoftonline.com/TENANT_GUID"
#or "https://login.microsoftonline.com/contoso.onmicrosoft.com".
#Alternatively, use "https://login.microsoftonline.com/common" for multi-tenant app.
MICROSOFT_AUTHORITY=https://{tenantId}.ciamlogin.com/{tenantId}
# POSTGRES_URI=postgresql://docsgpt:docsgpt@localhost:5432/docsgpt
AZURE_EMBEDDINGS_DEPLOYMENT_NAME=

View File

@@ -1,99 +0,0 @@
# DocsGPT Incident Response Plan (IRP)
This playbook describes how maintainers respond to confirmed or suspected security incidents.
- Vulnerability reporting: [`SECURITY.md`](../SECURITY.md)
- Non-security bugs/features: [`CONTRIBUTING.md`](../CONTRIBUTING.md)
## Severity
| Severity | Definition | Typical examples |
|---|---|---|
| **Critical** | Active exploitation, supply-chain compromise, or confirmed data breach requiring immediate user action. | Compromised release artifact/image; remote execution. |
| **High** | Serious undisclosed vulnerability with no practical workaround, or CVSS >= 7.0. | key leakage; prompt injection enabling cross-tenant access. |
| **Medium** | Material impact but constrained by preconditions/scope, or a practical workaround exists. | Auth-required exploit; dependency CVE with limited reachability. |
| **Low** | Defense-in-depth or narrow availability impact with no confirmed data exposure. | Missing rate limiting; hardening gap without exploit evidence. |
## Response workflow
### 1) Triage (target: initial response within 48 hours)
1. Acknowledge report.
2. Validate on latest release and `main`.
3. Confirm in-scope security issue vs. hardening item (per `SECURITY.md`).
4. Assign severity and open a **draft GitHub Security Advisory (GHSA)** (no public issue).
5. Determine whether root cause is DocsGPT code or upstream dependency/provider.
### 2) Investigation
1. Identify affected components, versions, and deployment scope (self-hosted, cloud, or both).
2. For AI issues, explicitly evaluate prompt injection, document isolation, and output leakage.
3. Request a CVE through GHSA for **Medium+** issues.
### 3) Containment, fix, and disclosure
1. Implement and test fix in private security workflow (GHSA private fork/branch).
2. Merge fix to `main`, cut patched release, and verify published artifacts/images.
3. Patch managed cloud deployment (`app.docsgpt.cloud`) and other deployments as soon as validated.
4. Publish GHSA with CVE (if assigned), affected/fixed versions, CVSS, mitigations, and upgrade guidance.
5. **Critical/High:** coordinate disclosure timing with reporter (goal: <= 90 days) and publish a notice.
6. **Medium/Low:** include in next scheduled release unless risk requires immediate out-of-band patching.
### 4) Post-incident
1. Monitor support channels (GitHub/Discord) for regressions or exploitation reports.
2. Run a short retrospective (root cause, detection, response gaps, prevention work).
3. Track follow-up hardening actions with owners/dates.
4. Update this IRP and related runbooks as needed.
## Scenario playbooks
### Supply-chain compromise
1. Freeze releases and investigate blast radius.
2. Rotate credentials in order: Docker Hub -> GitHub tokens -> LLM provider keys -> DB credentials -> `JWT_SECRET_KEY` -> `ENCRYPTION_SECRET_KEY` -> `INTERNAL_KEY`.
3. Replace compromised artifacts/tags with clean releases and revoke/remove bad tags where possible.
4. Publish advisory with exact affected versions and required user actions.
### Data exposure
1. Determine scope (users, documents, keys, logs, time window).
2. Disable affected path or hotfix immediately for managed cloud.
3. Notify affected users with concrete remediation steps (for example, rotate keys).
4. Continue through standard fix/disclosure workflow.
### Critical regression with security impact
1. Identify introducing change (`git bisect` if needed).
2. Publish workaround within 24 hours (for example, pin to known-good version).
3. Ship patch release with regression test and close incident with public summary.
## AI-specific guidance
Treat confirmed AI-specific abuse as security incidents:
- Prompt injection causing sensitive data exfiltration (from tools that don't belong to the agent) -> **High**
- Cross-tenant retrieval/isolation failure -> **High**
- API key disclosure in output -> **High**
## Secret rotation quick reference
| Secret | Standard rotation action |
|---|---|
| Docker Hub credentials | Revoke/replace in Docker Hub; update CI/CD secrets |
| GitHub tokens/PATs | Revoke/replace in GitHub; update automation secrets |
| LLM provider API keys | Rotate in provider console; update runtime/deploy secrets |
| Database credentials | Rotate in DB platform; redeploy with new secrets |
| `JWT_SECRET_KEY` | Rotate and redeploy (invalidates all active user sessions/tokens) |
| `ENCRYPTION_SECRET_KEY` | Rotate and redeploy (re-encrypt stored data if possible; existing encrypted data may become inaccessible) |
| `INTERNAL_KEY` | Rotate and redeploy (invalidates worker-to-backend authentication) |
## Maintenance
Review this document:
- after every **Critical/High** incident, and
- at least annually.
Changes should be proposed via pull request to `main`.

View File

@@ -1,144 +0,0 @@
# DocsGPT Public Threat Model
**Classification:** Public
**Last updated:** 2026-04-15
**Applies to:** Open-source and self-hosted DocsGPT deployments
## 1) Overview
DocsGPT ingests content (files/URLs/connectors), indexes it, and answers queries via LLM-backed APIs and optional tools.
Core components:
- Backend API (`application/`)
- Workers/ingestion (`application/worker.py` and related modules)
- Datastores (MongoDB/Redis/vector stores)
- Frontend (`frontend/`)
- Optional extensions/integrations (`extensions/`)
## 2) Scope and assumptions
In scope:
- Application-level threats in this repository.
- Local and internet-exposed self-hosted deployments.
Assumptions:
- Internet-facing instances enable auth and use strong secrets.
- Datastores/internal services are not publicly exposed.
Out of scope:
- Cloud hardware/provider compromise.
- Security guarantees of external LLM vendors.
- Full security audits of third-party systems targeted by tools (external DBs/MCP servers/code-exec APIs).
## 3) Security objectives
- Protect document/conversation confidentiality.
- Preserve integrity of prompts, agents, tools, and indexed data.
- Maintain API/worker availability.
- Enforce tenant isolation in authenticated deployments.
## 4) Assets
- Documents, attachments, chunks/embeddings, summaries.
- Conversations, agents, workflows, prompt templates.
- Secrets (JWT secret, `INTERNAL_KEY`, provider/API/OAuth credentials).
- Operational capacity (worker throughput, queue depth, model quota/cost).
## 5) Trust boundaries and untrusted input
Trust boundaries:
- Internet ↔ Frontend
- Frontend ↔ Backend API
- Backend ↔ Workers/internal APIs
- Backend/workers ↔ Datastores
- Backend ↔ External LLM/connectors/remote URLs
Untrusted input includes API payloads, file uploads, remote URLs, OAuth/webhook data, retrieved content, and LLM/tool arguments.
## 6) Main attack surfaces
1. Auth/authz paths and sharing tokens.
2. File upload + parsing pipeline.
3. Remote URL fetching and connectors (SSRF risk).
4. Agent/tool execution from LLM output.
5. Template/workflow rendering.
6. Frontend rendering + token storage.
7. Internal service endpoints (`INTERNAL_KEY`).
8. High-impact integrations (SQL tool, generic API tool, remote MCP tools).
## 7) Key threats and expected mitigations
### A. Auth/authz misconfiguration
- Threat: weak/no auth or leaked tokens leads to broad data access.
- Mitigations: require auth for public deployments, short-lived tokens, rotation/revocation, least-privilege sharing.
### B. Untrusted file ingestion
- Threat: malicious files/archives trigger traversal, parser exploits, or resource exhaustion.
- Mitigations: strict path checks, archive safeguards, file limits, patched parser dependencies.
### C. SSRF/outbound abuse
- Threat: URL loaders/tools access private/internal/metadata endpoints.
- Mitigations: validate URLs + redirects, block private/link-local ranges, apply egress controls/allowlists.
### D. Prompt injection + tool abuse
- Threat: retrieved text manipulates model behavior and causes unsafe tool calls.
- Threat: never rely on the model to "choose correctly" under adversarial input.
- Mitigations: treat retrieved/model output as untrusted, enforce tool policies, only expose tools explicitly assigned by the user/admin to that agent, separate system instructions from retrieved content, audit tool calls.
### E. Dangerous tool capability chaining (SQL/API/MCP)
- Threat: write-capable SQL credentials allow destructive queries.
- Threat: API tool can trigger side effects (infra/payment/webhook/code-exec endpoints).
- Threat: remote MCP tools may expose privileged operations.
- Mitigations: read-only-by-default credentials, destination allowlists, explicit approval for write/exec actions, per-tool policy enforcement + logging.
### F. Frontend/XSS + token theft
- Threat: XSS can steal local tokens and call APIs.
- Mitigations: reduce unsafe rendering paths, strong CSP, scoped short-lived credentials.
### G. Internal endpoint exposure
- Threat: weak/unset `INTERNAL_KEY` enables internal API abuse.
- Mitigations: fail closed, require strong random keys, keep internal APIs private.
### H. DoS and cost abuse
- Threat: request floods, large ingestion jobs, expensive prompts/crawls.
- Mitigations: rate limits, quotas, timeouts, queue backpressure, usage budgets.
## 8) Example attacker stories
- Internet-exposed deployment runs with weak/no auth and receives unauthorized data access/abuse.
- Intranet deployment intentionally using weak/no auth is vulnerable to insider misuse and lateral-movement abuse.
- Crafted archive attempts path traversal during extraction.
- Malicious URL/redirect chain targets internal services.
- Poisoned document causes data exfiltration through tool calls.
- Over-privileged SQL/API/MCP tool performs destructive side effects.
## 9) Severity calibration
- **Critical:** unauthenticated public data access; prompt-injection-driven exfiltration; SSRF to sensitive internal endpoints.
- **High:** cross-tenant leakage, persistent token compromise, over-privileged destructive tools.
- **Medium:** DoS/cost amplification and non-critical information disclosure.
- **Low:** minor hardening gaps with limited impact.
## 10) Baseline controls for public deployments
1. Enforce authentication and secure defaults.
2. Set/rotate strong secrets (`JWT`, `INTERNAL_KEY`, encryption keys).
3. Restrict CORS and front API with a hardened proxy.
4. Add rate limiting/quotas for answer/upload/crawl/token endpoints.
5. Enforce URL+redirect SSRF protections and egress restrictions.
6. Apply upload/archive/parsing hardening.
7. Require least-privilege tool credentials and auditable tool execution.
8. Monitor auth failures, tool anomalies, ingestion spikes, and cost anomalies.
9. Keep dependencies/images patched and scanned.
10. Validate multi-tenant isolation with explicit tests.
## 11) Maintenance
Review this model after major auth, ingestion, connector, tool, or workflow changes.
## References
- [OWASP Top 10 for LLM Applications](https://owasp.org/www-project-top-10-for-large-language-model-applications/)
- [OWASP ASVS](https://owasp.org/www-project-application-security-verification-standard/)
- [STRIDE overview](https://learn.microsoft.com/azure/security/develop/threat-modeling-tool-threats)
- [DocsGPT SECURITY.md](../SECURITY.md)

View File

@@ -13,11 +13,7 @@ updates:
directory: "/frontend" # Location of package manifests
schedule:
interval: "daily"
- package-ecosystem: "npm"
directory: "/extensions/react-widget"
schedule:
interval: "daily"
- package-ecosystem: "github-actions"
directory: "/"
schedule:
interval: "daily"
interval: "daily"

View File

@@ -1,11 +0,0 @@
extends: spelling
level: warning
message: "Did you really mean '%s'?"
ignore:
- "**/node_modules/**"
- "**/dist/**"
- "**/build/**"
- "**/coverage/**"
- "**/public/**"
- "**/static/**"
vocab: DocsGPT

View File

@@ -1,80 +0,0 @@
Agentic
Anthropic's
api
APIs
Atlassian
automations
autoescaping
Autoescaping
backfill
backfills
bool
boolean
brave_web_search
chatbot
Chatwoot
config
configs
CSVs
dev
diarization
Docling
docsgpt
docstrings
Entra
env
enqueues
EOL
ESLint
feedbacks
Figma
GPUs
Groq
hardcode
hardcoding
Idempotency
JSONPath
kubectl
Lightsail
llama_cpp
llm
LLM
LLMs
LMDeploy
Milvus
Mixtral
namespace
namespaces
needs_auth
Nextra
Novita
npm
OAuth
Ollama
opencode
parsable
passthrough
PDFs
pgvector
Postgres
Premade
Pydantic
pytest
Qdrant
qdrant
Repo
repo
Sanitization
SDKs
SGLang
Shareability
Signup
Supabase
UIs
uncomment
URl
vectorstore
Vite
VSCode
VSCode's
widget's

View File

@@ -7,9 +7,6 @@ on:
pull_request:
types: [ opened, synchronize ]
permissions:
contents: read
jobs:
ruff:
runs-on: ubuntu-latest

View File

@@ -1,114 +0,0 @@
name: Publish npm libraries
on:
workflow_dispatch:
inputs:
version:
description: >
Version bump type (patch | minor | major) or explicit semver (e.g. 1.2.3).
Applies to both docsgpt and docsgpt-react.
required: true
default: patch
permissions:
contents: write
pull-requests: write
jobs:
publish:
runs-on: ubuntu-latest
environment: npm-release
defaults:
run:
working-directory: extensions/react-widget
steps:
- uses: actions/checkout@v4
- uses: actions/setup-node@v4
with:
node-version: 20
registry-url: https://registry.npmjs.org
- name: Install dependencies
run: npm ci
# ── docsgpt (HTML embedding bundle) ──────────────────────────────────
# Uses the `build` script (parcel build src/browser.tsx) and keeps
# the `targets` field so Parcel produces browser-optimised bundles.
- name: Set package name → docsgpt
run: jq --arg n "docsgpt" '.name=$n' package.json > _tmp.json && mv _tmp.json package.json
- name: Bump version (docsgpt)
id: version_docsgpt
run: |
VERSION="${{ github.event.inputs.version }}"
NEW_VER=$(npm version "${VERSION:-patch}" --no-git-tag-version)
echo "version=${NEW_VER#v}" >> "$GITHUB_OUTPUT"
- name: Build docsgpt
run: npm run build
- name: Publish docsgpt
run: npm publish --verbose
env:
NODE_AUTH_TOKEN: ${{ secrets.NPM_TOKEN }}
# ── docsgpt-react (React library bundle) ─────────────────────────────
# Uses `build:react` script (parcel build src/index.ts) and strips
# the `targets` field so Parcel treats the output as a plain library
# without browser-specific target resolution, producing a smaller bundle.
- name: Reset package.json from source control
run: git checkout -- package.json
- name: Set package name → docsgpt-react
run: jq --arg n "docsgpt-react" '.name=$n' package.json > _tmp.json && mv _tmp.json package.json
- name: Remove targets field (react library build)
run: jq 'del(.targets)' package.json > _tmp.json && mv _tmp.json package.json
- name: Bump version (docsgpt-react) to match docsgpt
run: npm version "${{ steps.version_docsgpt.outputs.version }}" --no-git-tag-version
- name: Clean dist before react build
run: rm -rf dist
- name: Build docsgpt-react
run: npm run build:react
- name: Publish docsgpt-react
run: npm publish --verbose
env:
NODE_AUTH_TOKEN: ${{ secrets.NPM_TOKEN }}
# ── Commit the bumped version back to the repository ─────────────────
- name: Reset package.json and write final version
run: |
git checkout -- package.json
jq --arg v "${{ steps.version_docsgpt.outputs.version }}" '.version=$v' \
package.json > _tmp.json && mv _tmp.json package.json
npm install --package-lock-only
- name: Commit version bump and create PR
run: |
git config user.name "github-actions[bot]"
git config user.email "github-actions[bot]@users.noreply.github.com"
BRANCH="chore/bump-npm-v${{ steps.version_docsgpt.outputs.version }}"
git checkout -b "$BRANCH"
git add package.json package-lock.json
git commit -m "chore: bump npm libraries to v${{ steps.version_docsgpt.outputs.version }}"
git push origin "$BRANCH"
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
- name: Create PR
run: |
gh pr create \
--title "chore: bump npm libraries to v${{ steps.version_docsgpt.outputs.version }}" \
--body "Automated version bump after npm publish." \
--base main
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}

View File

@@ -1,9 +1,5 @@
name: Run python tests with pytest
on: [push, pull_request]
permissions:
contents: read
jobs:
pytest_and_coverage:
name: Run tests and count coverage
@@ -20,15 +16,15 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install pytest pytest-cov
cd application
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
cd ../tests
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
- name: Test with pytest and generate coverage report
run: |
python -m pytest --cov=application --cov-report=xml --cov-report=term-missing
python -m pytest --cov=application --cov-report=xml
- name: Upload coverage reports to Codecov
if: github.event_name == 'pull_request' && matrix.python-version == '3.12'
uses: codecov/codecov-action@v5
env:
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}

View File

@@ -1,34 +0,0 @@
name: React Widget Build
on:
push:
paths:
- 'extensions/react-widget/**'
pull_request:
paths:
- 'extensions/react-widget/**'
permissions:
contents: read
jobs:
build:
runs-on: ubuntu-latest
defaults:
run:
working-directory: extensions/react-widget
steps:
- uses: actions/checkout@v4
- uses: actions/setup-node@v4
with:
node-version: 20
cache: npm
cache-dependency-path: extensions/react-widget/package-lock.json
- name: Install dependencies
run: npm ci
- name: Build
run: npm run build

View File

@@ -1,34 +0,0 @@
name: Vale Documentation Linter
on:
pull_request:
paths:
- 'docs/**/*.md'
- 'docs/**/*.mdx'
- '**/*.md'
- '.vale.ini'
- '.github/styles/**'
permissions:
contents: read
jobs:
vale:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Install Vale
run: |
curl -fsSL -o vale.tar.gz \
https://github.com/errata-ai/vale/releases/download/v3.0.5/vale_3.0.5_Linux_64-bit.tar.gz
tar -xzf vale.tar.gz
sudo mv vale /usr/local/bin/vale
vale --version
- name: Sync Vale packages
run: vale sync
- name: Run Vale
run: vale --minAlertLevel=error docs

View File

@@ -1,25 +0,0 @@
name: GitHub Actions Security Analysis
on:
push:
branches: ["master"]
pull_request:
branches: ["**"]
permissions: {}
jobs:
zizmor:
runs-on: ubuntu-latest
permissions:
security-events: write # Required for upload-sarif (used by zizmor-action) to upload SARIF files.
steps:
- name: Checkout repository
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- name: Run zizmor 🌈
uses: zizmorcore/zizmor-action@71321a20a9ded102f6e9ce5718a2fcec2c4f70d8 # v0.5.2

19
.gitignore vendored
View File

@@ -2,10 +2,7 @@
__pycache__/
*.py[cod]
*$py.class
results.txt
experiments/
experiments
# C extensions
*.so
*.next
@@ -72,7 +69,6 @@ instance/
# Sphinx documentation
docs/_build/
docs/public/_pagefind/
# PyBuilder
target/
@@ -108,8 +104,6 @@ celerybeat.pid
# Environments
.env
.venv
# Machine-specific Claude Code guidance (see CLAUDE.md preamble)
CLAUDE.md
env/
venv/
ENV/
@@ -151,10 +145,6 @@ frontend/yarn-error.log*
frontend/pnpm-debug.log*
frontend/lerna-debug.log*
# Keep frontend utility helpers tracked (overrides global lib/ ignore)
!frontend/src/lib/
!frontend/src/lib/**
frontend/node_modules
frontend/dist
frontend/dist-ssr
@@ -183,14 +173,5 @@ application/vectors/
node_modules/
.vscode/settings.json
.vscode/sftp.json
/models/
model/
# E2E test artifacts
.e2e-tmp/
/tmp/docsgpt-e2e/
tests/e2e/node_modules/
tests/e2e/playwright-report/
tests/e2e/test-results/
tests/e2e/.e2e-last-run.json

View File

@@ -1,6 +1,2 @@
# Allow lines to be as long as 120 characters.
line-length = 120
[lint.per-file-ignores]
# Integration tests use sys.path.insert() before imports for standalone execution
"tests/integration/*.py" = ["E402"]
line-length = 120

View File

@@ -1,7 +0,0 @@
MinAlertLevel = warning
StylesPath = .github/styles
Vocab = DocsGPT
[*.{md,mdx}]
BasedOnStyles = DocsGPT

33
.vscode/launch.json vendored
View File

@@ -2,11 +2,15 @@
"version": "0.2.0",
"configurations": [
{
"name": "Frontend Debug (npm)",
"type": "node-terminal",
"name": "Docker Debug Frontend",
"request": "launch",
"command": "npm run dev",
"cwd": "${workspaceFolder}/frontend"
"type": "chrome",
"preLaunchTask": "docker-compose: debug:frontend",
"url": "http://127.0.0.1:5173",
"webRoot": "${workspaceFolder}/frontend",
"skipFiles": [
"<node_internals>/**"
]
},
{
"name": "Flask Debugger",
@@ -45,27 +49,6 @@
"--pool=solo"
],
"cwd": "${workspaceFolder}"
},
{
"name": "Dev Containers (Mongo + Redis)",
"type": "node-terminal",
"request": "launch",
"command": "docker compose -f deployment/docker-compose-dev.yaml up --build",
"cwd": "${workspaceFolder}"
}
],
"compounds": [
{
"name": "DocsGPT: Full Stack",
"configurations": [
"Frontend Debug (npm)",
"Flask Debugger",
"Celery Debugger"
],
"presentation": {
"group": "DocsGPT",
"order": 1
}
}
]
}

21
.vscode/tasks.json vendored Normal file
View File

@@ -0,0 +1,21 @@
{
"version": "2.0.0",
"tasks": [
{
"type": "docker-compose",
"label": "docker-compose: debug:frontend",
"dockerCompose": {
"up": {
"detached": true,
"services": [
"frontend"
],
"build": true
},
"files": [
"${workspaceFolder}/docker-compose.yaml"
]
}
}
]
}

156
AGENTS.md
View File

@@ -1,156 +0,0 @@
# AGENTS.md
- Read `CONTRIBUTING.md` before making non-trivial changes.
- For day-to-day development and feature work, follow the development-environment workflow rather than defaulting to `setup.sh` / `setup.ps1`.
- Avoid using the setup scripts during normal feature work unless the user explicitly asks for them. Users configure `.env` usually.
- Try to follow red/green TDD
### Check existing dev prerequisites first
For feature work, do **not** assume the environment needs to be recreated.
- Check whether the user already has a Python virtual environment such as `venv/` or `.venv/`.
- Check whether Postgres is already running and reachable via `POSTGRES_URI` (the canonical user-data store).
- Check whether Redis is already running.
- Reuse what is already working. Do not stop or recreate Postgres, Redis, or the Python environment unless the task is environment setup or troubleshooting.
> MongoDB is **not** required for the default install. It is only needed if
> the user opts into the Mongo vector-store backend (`VECTOR_STORE=mongodb`)
> or is running the one-shot `scripts/db/backfill.py` to migrate existing
> user data from the legacy Mongo-based install. In those cases, `pymongo`
> is available as an optional extra, not a core dependency.
## Normal local development commands
Use these commands once the dev prerequisites above are satisfied.
### Backend
```bash
source .venv/bin/activate # macOS/Linux
uv pip install -r application/requirements.txt # or: pip install -r application/requirements.txt
```
Run the Flask API (if needed):
```bash
flask --app application/app.py run --host=0.0.0.0 --port=7091
```
That's the fast inner-loop option — quick startup, the Werkzeug interactive
debugger still works, and it hot-reloads on source changes. It serves the
Flask routes only (`/api/*`, `/stream`, etc.).
If you need to exercise the full ASGI stack — the `/mcp` FastMCP endpoint,
or to match the production runtime exactly — run the ASGI composition under
uvicorn instead:
```bash
uvicorn application.asgi:asgi_app --host 0.0.0.0 --port 7091 --reload
```
Production uses `gunicorn -k uvicorn_worker.UvicornWorker` against the same
`application.asgi:asgi_app` target; see `application/Dockerfile` for the
full flag set.
Run the Celery worker in a separate terminal (if needed):
```bash
celery -A application.app.celery worker -l INFO
```
On macOS, prefer the solo pool for Celery:
```bash
python -m celery -A application.app.celery worker -l INFO --pool=solo
```
### Frontend
Install dependencies only when needed, then run the dev server:
```bash
cd frontend
npm install --include=dev
npm run dev
```
### Docs site
```bash
cd docs
npm install
```
### Python / backend changes validation
```bash
ruff check .
python -m pytest
```
### Frontend changes
```bash
cd frontend && npm run lint
cd frontend && npm run build
```
### Documentation changes
```bash
cd docs && npm run build
```
If Vale is installed locally and you edited prose, also run:
```bash
vale .
```
## Repository map
- `application/`: Flask backend, API routes, agent logic, retrieval, parsing, security, storage, Celery worker, and WSGI entrypoints.
- `tests/`: backend unit/integration tests and test-only Python dependencies.
- `frontend/`: Vite + React + TypeScript application.
- `frontend/src/`: main UI code, including `components`, `conversation`, `hooks`, `locale`, `settings`, `upload`, and Redux store wiring in `store.ts`.
- `docs/`: separate documentation site built with Next.js/Nextra.
- `extensions/`: integrations and widgets — currently the Chatwoot webhook bridge and the React widget (published to npm as `docsgpt`). The Discord bot, Slack bot, and Chrome extension have been moved to their own repos under `arc53/`.
- `deployment/`: Docker Compose variants and Kubernetes manifests.
## Coding rules
### Backend
- Follow PEP 8 and keep Python line length at or under 120 characters.
- Use type hints for function arguments and return values.
- Add Google-style docstrings to new or substantially changed functions and classes.
- Add or update tests under `tests/` for backend behavior changes.
- Keep changes narrow in `api`, `auth`, `security`, `parser`, `retriever`, and `storage` areas.
### Backend Abstractions
- LLM providers implement a common interface in `application/llm/` (add new providers by extending the base class).
- Vector stores are abstracted in `application/vectorstore/`.
- Parsers live in `application/parser/` and handle different document formats in the ingestion stage.
- Agents and tools are in `application/agents/` and `application/agents/tools/`.
- Celery setup/config lives in `application/celery_init.py` and `application/celeryconfig.py`.
- Settings and env vars are managed via Pydantic in `application/core/settings.py`.
### Frontend
- Follow the existing ESLint + Prettier setup.
- Prefer small, reusable functional components and hooks.
- If shared state must be added, use Redux rather than introducing a new global state library.
- Avoid broad UI refactors unless the task explicitly asks for them.
- Do not re-create components if we already have some in the app.
## PR readiness
Before opening a PR:
- run the relevant validation commands above
- confirm backend changes still work end-to-end after ingesting sample data when applicable
- clearly summarize user-visible behavior changes
- mention any config, dependency, or deployment implications
- Ask your user to attach a screenshot or a video to it

View File

@@ -22,11 +22,6 @@ Thank you for choosing to contribute to DocsGPT! We are all very grateful!
- We have a frontend built on React (Vite) and a backend in Python.
> **Required for every PR:** Please attach screenshots or a short screen
> recording that shows the working version of your changes. This makes the
> requirement visible to reviewers and helps them quickly verify what you are
> submitting.
Before creating issues, please check out how the latest version of our app looks and works by launching it via [Quickstart](https://github.com/arc53/DocsGPT#quickstart) the version on our live demo is slightly modified with login. Your issues should relate to the version you can launch via [Quickstart](https://github.com/arc53/DocsGPT#quickstart).
@@ -130,7 +125,7 @@ Here's a step-by-step guide on how to contribute to DocsGPT:
```
9. **Submit a Pull Request (PR):**
- Create a Pull Request from your branch to the main repository. Make sure to include a detailed description of your changes, reference any related issues, and attach screenshots or a screen recording showing the working version.
- Create a Pull Request from your branch to the main repository. Make sure to include a detailed description of your changes and reference any related issues.
10. **Collaborate:**
- Be responsive to comments and feedback on your PR.
@@ -152,5 +147,5 @@ Here's a step-by-step guide on how to contribute to DocsGPT:
Thank you for considering contributing to DocsGPT! 🙏
## Questions/collaboration
Feel free to join our [Discord](https://discord.gg/vN7YFfdMpj). We're very friendly and welcoming to new contributors, so don't hesitate to reach out.
Feel free to join our [Discord](https://discord.gg/n5BX8dh8rU). We're very friendly and welcoming to new contributors, so don't hesitate to reach out.
# Thank you so much for considering to contributing DocsGPT!🙏

View File

@@ -1,39 +0,0 @@
# **🎉 Join the Hacktoberfest with DocsGPT and win a Free T-shirt for a meaningful PR! 🎉**
Welcome, contributors! We're excited to announce that DocsGPT is participating in Hacktoberfest. Get involved by submitting meaningful pull requests.
All Meaningful contributors with accepted PRs that were created for issues with the `hacktoberfest` label (set by our maintainer team: dartpain, siiddhantt, pabik, ManishMadan2882) will receive a cool T-shirt! 🤩.
<img width="1331" height="678" alt="hacktoberfest-mocks-preview" src="https://github.com/user-attachments/assets/633f6377-38db-48f5-b519-a8b3855a9eb4" />
Fill in [this form](https://forms.gle/Npaba4n9Epfyx56S8
) after your PR was merged please
If you are in doubt don't hesitate to ping us on discord, ping me - Alex (dartpain).
## 📜 Here's How to Contribute:
```text
🛠️ Code: This is the golden ticket! Make meaningful contributions through PRs.
🧩 API extension: Build an app utilising DocsGPT API. We prefer submissions that showcase original ideas and turn the API into an AI agent.
They can be a completely separate repos.
For example:
https://github.com/arc53/tg-bot-docsgpt-extenstion or
https://github.com/arc53/DocsGPT-cli
Non-Code Contributions:
📚 Wiki: Improve our documentation, create a guide.
🖥️ Design: Improve the UI/UX or design a new feature.
```
### 📝 Guidelines for Pull Requests:
- Familiarize yourself with the current contributions and our [Roadmap](https://github.com/orgs/arc53/projects/2).
- Before contributing check existing [issues](https://github.com/arc53/DocsGPT/issues) or [create](https://github.com/arc53/DocsGPT/issues/new/choose) an issue and wait to get assigned.
- Once you are finished with your contribution, please fill in this [form](https://forms.gle/Npaba4n9Epfyx56S8).
- Refer to the [Documentation](https://docs.docsgpt.cloud/).
- Feel free to join our [Discord](https://discord.gg/vN7YFfdMpj) server. We're here to help newcomers, so don't hesitate to jump in! Join us [here](https://discord.gg/vN7YFfdMpj).
Thank you very much for considering contributing to DocsGPT during Hacktoberfest! 🙏 Your contributions (not just simple typos) could earn you a stylish new t-shirt.
We will publish a t-shirt design later into the October.

View File

@@ -3,11 +3,11 @@
</h1>
<p align="center">
<strong>Private AI for agents, assistants and enterprise search</strong>
<strong>Open-Source RAG Assistant</strong>
</p>
<p align="left">
<strong><a href="https://www.docsgpt.cloud/">DocsGPT</a></strong> is an open-source AI platform for building intelligent agents and assistants. Features Agent Builder, deep research tools, document analysis (PDF, Office, web content, and audio), Multi-model support (choose your provider or run locally), and rich API connectivity for agents with actionable tools and integrations. Deploy anywhere with complete privacy control.
<strong><a href="https://www.docsgpt.cloud/">DocsGPT</a></strong> is an open-source genAI tool that helps users get reliable answers from any knowledge source, while avoiding hallucinations. It enables quick and reliable information retrieval, with tooling and agentic system capability built in.
</p>
<div align="center">
@@ -16,27 +16,23 @@
<a href="https://github.com/arc53/DocsGPT">![link to main GitHub showing Forks number](https://img.shields.io/github/forks/arc53/docsgpt?style=social)</a>
<a href="https://github.com/arc53/DocsGPT/blob/main/LICENSE">![link to license file](https://img.shields.io/github/license/arc53/docsgpt)</a>
<a href="https://www.bestpractices.dev/projects/9907"><img src="https://www.bestpractices.dev/projects/9907/badge"></a>
<a href="https://discord.gg/vN7YFfdMpj">![link to discord](https://img.shields.io/discord/1070046503302877216)</a>
<a href="https://x.com/docsgptai">![X (formerly Twitter) URL](https://img.shields.io/twitter/follow/docsgptai)</a>
<a href="https://discord.gg/n5BX8dh8rU">![link to discord](https://img.shields.io/discord/1070046503302877216)</a>
<a href="https://twitter.com/docsgptai">![X (formerly Twitter) URL](https://img.shields.io/twitter/follow/docsgptai)</a>
<a href="https://docs.docsgpt.cloud/quickstart">⚡️ Quickstart</a><a href="https://app.docsgpt.cloud/">☁️ Cloud Version</a><a href="https://discord.gg/vN7YFfdMpj">💬 Discord</a>
<br>
<a href="https://docs.docsgpt.cloud/">📖 Documentation</a><a href="https://github.com/arc53/DocsGPT/blob/main/CONTRIBUTING.md">👫 Contribute</a><a href="https://blog.docsgpt.cloud/">🗞 Blog</a>
<br>
<a href="https://docs.docsgpt.cloud/quickstart">⚡️ Quickstart</a><a href="https://app.docsgpt.cloud/">☁️ Cloud Version</a><a href="https://discord.gg/n5BX8dh8rU">💬 Discord</a>
<br>
<a href="https://docs.docsgpt.cloud/">📖 Documentation</a><a href="https://github.com/arc53/DocsGPT/blob/main/CONTRIBUTING.md">👫 Contribute</a><a href="https://blog.docsgpt.cloud/">🗞 Blog</a>
<br>
</div>
<div align="center">
<br>
<img src="https://d3dg1063dc54p9.cloudfront.net/videos/demo-26.gif" alt="video-example-of-docs-gpt" width="800" height="480">
<img src="https://d3dg1063dc54p9.cloudfront.net/videos/demov7.gif" alt="video-example-of-docs-gpt" width="800" height="450">
</div>
<h3 align="left">
<strong>Key Features:</strong>
</h3>
<ul align="left">
<li><strong>🗂️ Wide Format Support:</strong> Reads PDF, DOCX, CSV, XLSX, EPUB, MD, RST, HTML, MDX, JSON, PPTX, images, and audio files such as MP3, WAV, M4A, OGG, and WebM.</li>
<li><strong>🎙️ Speech Workflows:</strong> Record voice input into chat, transcribe audio on the backend, and ingest meeting recordings or voice notes as searchable knowledge.</li>
<li><strong>🗂️ Wide Format Support:</strong> Reads PDF, DOCX, CSV, XLSX, EPUB, MD, RST, HTML, MDX, JSON, PPTX, and images.</li>
<li><strong>🌐 Web & Data Integration:</strong> Ingests from URLs, sitemaps, Reddit, GitHub and web crawlers.</li>
<li><strong>✅ Reliable Answers:</strong> Get accurate, hallucination-free responses with source citations viewable in a clean UI.</li>
<li><strong>🔑 Streamlined API Keys:</strong> Generate keys linked to your settings, documents, and models, simplifying chatbot and integration setup.</li>
@@ -47,13 +43,21 @@
</ul>
## Roadmap
- [x] Agent Workflow Builder with conditional nodes ( February 2026 )
- [x] SharePoint & Confluence connectors ( March April 2026 )
- [x] Research mode ( March 2026 )
- [x] Postgres migration for user data ( April 2026 )
- [x] OpenTelemetry observability ( April 2026 )
- [x] Bring Your Own Model (BYOM) ( April 2026 )
- [ ] Agent scheduling (RedBeat-backed) ( Q2 2026 )
- [x] Full GoogleAI compatibility (Jan 2025)
- [x] Add tools (Jan 2025)
- [x] Manually updating chunks in the app UI (Feb 2025)
- [x] Devcontainer for easy development (Feb 2025)
- [x] ReACT agent (March 2025)
- [x] Chatbots menu re-design to handle tools, agent types, and more (April 2025)
- [x] New input box in the conversation menu (April 2025)
- [x] Add triggerable actions / tools (webhook) (April 2025)
- [x] Agent optimisations (May 2025)
- [ ] Filesystem sources update (July 2025)
- [ ] Anthropic Tool compatibility (July 2025)
- [ ] MCP support (July 2025)
- [ ] Add OAuth 2.0 authentication for tools and sources (August 2025)
- [ ] Agent scheduling
You can find our full roadmap [here](https://github.com/orgs/arc53/projects/2). Please don't hesitate to contribute or create issues, it helps us improve DocsGPT!
@@ -67,10 +71,11 @@ We're eager to provide personalized assistance when deploying your DocsGPT to a
## Join the Lighthouse Program 🌟
Calling all developers and GenAI innovators! The **DocsGPT Lighthouse Program** connects technical leaders actively deploying or extending DocsGPT in real-world scenarios. Collaborate directly with our team to shape the roadmap, access priority support, and build enterprise-ready solutions with exclusive community insights.
Calling all developers and GenAI innovators! The **DocsGPT Lighthouse Program** connects technical leaders actively deploying or extending DocsGPT in real-world scenarios. Collaborate directly with our team to shape the roadmap, access priority support, and build enterprise-ready solutions with exclusive community insights.
[Learn More & Apply →](https://docs.google.com/forms/d/1KAADiJinUJ8EMQyfTXUIGyFbqINNClNR3jBNWq7DgTE)
## QuickStart
> [!Note]
@@ -101,7 +106,7 @@ A more detailed [Quickstart](https://docs.docsgpt.cloud/quickstart) is available
PowerShell -ExecutionPolicy Bypass -File .\setup.ps1
```
Either script will guide you through setting up DocsGPT. Five options available: using the public API, running locally, connecting to a local inference engine, using a cloud API provider, or build the docker image locally. Scripts will automatically configure your `.env` file and handle necessary downloads and installations based on your chosen option.
Either script will guide you through setting up DocsGPT. Four options available: using the public API, running locally, connecting to a local inference engine, or using a cloud API provider. Scripts will automatically configure your `.env` file and handle necessary downloads and installations based on your chosen option.
**Navigate to http://localhost:5173/**
@@ -110,7 +115,6 @@ To stop DocsGPT, open a terminal in the `DocsGPT` directory and run:
```bash
docker compose -f deployment/docker-compose.yaml down
```
(or use the specific `docker compose down` command shown after running the setup script).
> [!Note]
@@ -138,6 +142,7 @@ Please refer to the [CONTRIBUTING.md](CONTRIBUTING.md) file for information abou
We as members, contributors, and leaders, pledge to make participation in our community a harassment-free experience for everyone, regardless of age, body size, visible or invisible disability, ethnicity, sex characteristics, gender identity and expression, level of experience, education, socio-economic status, nationality, personal appearance, race, religion, or sexual identity and orientation. Please refer to the [CODE_OF_CONDUCT.md](CODE_OF_CONDUCT.md) file for more information about contributing.
## Many Thanks To Our Contributors⚡
<a href="https://github.com/arc53/DocsGPT/graphs/contributors" alt="View Contributors">
@@ -148,16 +153,9 @@ We as members, contributors, and leaders, pledge to make participation in our co
The source code license is [MIT](https://opensource.org/license/mit/), as described in the [LICENSE](LICENSE) file.
## This project is supported by:
<p>This project is supported by:</p>
<p>
<a href="https://www.digitalocean.com/?utm_medium=opensource&utm_source=DocsGPT">
<img src="https://opensource.nyc3.cdn.digitaloceanspaces.com/attribution/assets/SVG/DO_Logo_horizontal_blue.svg" width="201px">
</a>
</p>
<p>
<a href="https://get.neon.com/docsgpt">
<img width="201" alt="color" src="https://github.com/user-attachments/assets/7d9813b7-0e6d-403f-b5af-68af066b326f" />
</a>
</p>

View File

@@ -2,21 +2,13 @@
## Supported Versions
Security patches target the latest release and the `main` branch. We recommend always running the most recent version.
Supported Versions:
Currently, we support security patches by committing changes and bumping the version published on Github.
## Reporting a Vulnerability
Preferred method: use GitHub's private vulnerability reporting flow:
https://github.com/arc53/DocsGPT/security
Found a vulnerability? Please email us:
Then click **Report a vulnerability**.
Alternatively, email us at: security@arc53.com
We aim to acknowledge reports within 48 hours.
## Incident Handling
For the public incident response process, see [`INCIDENT_RESPONSE.md`](./.github/INCIDENT_RESPONSE.md). If you believe an active exploit is occurring, include **URGENT** in your report subject line.
security@arc53.com

11
application/.env_sample Normal file
View File

@@ -0,0 +1,11 @@
API_KEY=your_api_key
EMBEDDINGS_KEY=your_api_key
API_URL=http://localhost:7091
FLASK_APP=application/app.py
FLASK_DEBUG=true
#For OPENAI on Azure
OPENAI_API_BASE=
OPENAI_API_VERSION=
AZURE_DEPLOYMENT_NAME=
AZURE_EMBEDDINGS_DEPLOYMENT_NAME=

View File

@@ -7,7 +7,7 @@ RUN apt-get update && \
apt-get install -y software-properties-common && \
add-apt-repository ppa:deadsnakes/ppa && \
apt-get update && \
apt-get install -y --no-install-recommends gcc g++ wget unzip libc6-dev python3.12 python3.12-venv python3.12-dev && \
apt-get install -y --no-install-recommends gcc wget unzip libc6-dev python3.12 python3.12-venv && \
rm -rf /var/lib/apt/lists/*
# Verify Python installation and setup symlink
@@ -48,12 +48,7 @@ FROM ubuntu:24.04 as final
RUN apt-get update && \
apt-get install -y software-properties-common && \
add-apt-repository ppa:deadsnakes/ppa && \
apt-get update && apt-get install -y --no-install-recommends \
python3.12 \
libgl1 \
libglib2.0-0 \
poppler-utils \
&& \
apt-get update && apt-get install -y --no-install-recommends python3.12 && \
ln -s /usr/bin/python3.12 /usr/bin/python && \
rm -rf /var/lib/apt/lists/*
@@ -88,15 +83,5 @@ EXPOSE 7091
# Switch to non-root user
USER appuser
CMD ["gunicorn", \
"-w", "1", \
"-k", "uvicorn_worker.UvicornWorker", \
"--bind", "0.0.0.0:7091", \
"--timeout", "180", \
"--graceful-timeout", "120", \
"--keep-alive", "5", \
"--worker-tmp-dir", "/dev/shm", \
"--max-requests", "1000", \
"--max-requests-jitter", "100", \
"--config", "application/gunicorn_conf.py", \
"application.asgi:asgi_app"]
# Start Gunicorn
CMD ["gunicorn", "-w", "1", "--timeout", "120", "--bind", "0.0.0.0:7091", "--preload", "application.wsgi:app"]

View File

@@ -1,20 +1,11 @@
import logging
from application.agents.agentic_agent import AgenticAgent
from application.agents.classic_agent import ClassicAgent
from application.agents.research_agent import ResearchAgent
from application.agents.workflow_agent import WorkflowAgent
logger = logging.getLogger(__name__)
from application.agents.react_agent import ReActAgent
class AgentCreator:
agents = {
"classic": ClassicAgent,
"react": ClassicAgent, # backwards compat: react falls back to classic
"agentic": AgenticAgent,
"research": ResearchAgent,
"workflow": WorkflowAgent,
"react": ReActAgent,
}
@classmethod

View File

@@ -1,63 +0,0 @@
import logging
from typing import Dict, Generator, Optional
from application.agents.base import BaseAgent
from application.agents.tools.internal_search import (
INTERNAL_TOOL_ID,
add_internal_search_tool,
)
from application.logging import LogContext
logger = logging.getLogger(__name__)
class AgenticAgent(BaseAgent):
"""Agent where the LLM controls retrieval via tools.
Unlike ClassicAgent which pre-fetches docs into the prompt,
AgenticAgent gives the LLM an internal_search tool so it can
decide when, what, and whether to search.
"""
def __init__(
self,
retriever_config: Optional[Dict] = None,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
self.retriever_config = retriever_config or {}
def _gen_inner(
self, query: str, log_context: LogContext
) -> Generator[Dict, None, None]:
tools_dict = self.tool_executor.get_tools()
add_internal_search_tool(tools_dict, self.retriever_config)
self._prepare_tools(tools_dict)
# 4. Build messages (prompt has NO pre-fetched docs)
messages = self._build_messages(self.prompt, query)
# 5. Call LLM — the handler manages the tool loop
llm_response = self._llm_gen(messages, log_context)
yield from self._handle_response(
llm_response, tools_dict, messages, log_context
)
# 6. Collect sources from internal search tool results
self._collect_internal_sources()
yield {"sources": self.retrieved_docs}
yield {"tool_calls": self._get_truncated_tool_calls()}
log_context.stacks.append(
{"component": "agent", "data": {"tool_calls": self.tool_calls.copy()}}
)
def _collect_internal_sources(self):
"""Collect retrieved docs from the cached InternalSearchTool instance."""
cache_key = f"internal_search:{INTERNAL_TOOL_ID}:{self.user or ''}"
tool = self.tool_executor._loaded_tools.get(cache_key)
if tool and hasattr(tool, "retrieved_docs") and tool.retrieved_docs:
self.retrieved_docs = tool.retrieved_docs

View File

@@ -1,21 +1,19 @@
import json
import logging
import uuid
from abc import ABC, abstractmethod
from typing import Any, Dict, Generator, List, Optional
from typing import Dict, Generator, List, Optional
from application.agents.tool_executor import ToolExecutor
from application.core.json_schema_utils import (
JsonSchemaValidationError,
normalize_json_schema_payload,
)
from bson.objectid import ObjectId
from application.agents.tools.tool_action_parser import ToolActionParser
from application.agents.tools.tool_manager import ToolManager
from application.core.mongo_db import MongoDB
from application.core.settings import settings
from application.llm.handlers.base import ToolCall
from application.llm.handlers.handler_creator import LLMHandlerCreator
from application.llm.llm_creator import LLMCreator
from application.logging import build_stack_data, log_activity, LogContext
logger = logging.getLogger(__name__)
from application.retriever.base import BaseRetriever
class BaseAgent(ABC):
@@ -23,504 +21,261 @@ class BaseAgent(ABC):
self,
endpoint: str,
llm_name: str,
model_id: str,
gpt_model: str,
api_key: str,
agent_id: Optional[str] = None,
user_api_key: Optional[str] = None,
prompt: str = "",
chat_history: Optional[List[Dict]] = None,
retrieved_docs: Optional[List[Dict]] = None,
decoded_token: Optional[Dict] = None,
attachments: Optional[List[Dict]] = None,
json_schema: Optional[Dict] = None,
limited_token_mode: Optional[bool] = False,
token_limit: Optional[int] = settings.DEFAULT_AGENT_LIMITS["token_limit"],
limited_request_mode: Optional[bool] = False,
request_limit: Optional[int] = settings.DEFAULT_AGENT_LIMITS["request_limit"],
compressed_summary: Optional[str] = None,
llm=None,
llm_handler=None,
tool_executor: Optional[ToolExecutor] = None,
backup_models: Optional[List[str]] = None,
model_user_id: Optional[str] = None,
):
self.endpoint = endpoint
self.llm_name = llm_name
self.model_id = model_id
self.gpt_model = gpt_model
self.api_key = api_key
self.agent_id = agent_id
self.user_api_key = user_api_key
self.prompt = prompt
self.decoded_token = decoded_token or {}
self.user: str = self.decoded_token.get("sub")
# BYOM-resolution scope: owner for shared agents, caller for
# caller-owned BYOM, None for built-ins. Falls back to self.user
# for worker/legacy callers that don't thread model_user_id.
self.model_user_id = model_user_id
self.user: str = decoded_token.get("sub")
self.tool_config: Dict = {}
self.tools: List[Dict] = []
self.tool_calls: List[Dict] = []
self.chat_history: List[Dict] = chat_history if chat_history is not None else []
if llm is not None:
self.llm = llm
else:
self.llm = LLMCreator.create_llm(
llm_name,
api_key=api_key,
user_api_key=user_api_key,
decoded_token=decoded_token,
model_id=model_id,
agent_id=agent_id,
backup_models=backup_models,
model_user_id=model_user_id,
)
# For BYOM, registry id (UUID) differs from upstream model id
# (e.g. ``mistral-large-latest``). LLMCreator resolved this onto
# the LLM instance; cache it for subsequent gen calls.
self.upstream_model_id = (
getattr(self.llm, "model_id", None) or model_id
self.llm = LLMCreator.create_llm(
llm_name,
api_key=api_key,
user_api_key=user_api_key,
decoded_token=decoded_token,
)
self.llm_handler = LLMHandlerCreator.create_handler(
llm_name if llm_name else "default"
)
self.retrieved_docs = retrieved_docs or []
if llm_handler is not None:
self.llm_handler = llm_handler
else:
self.llm_handler = LLMHandlerCreator.create_handler(
llm_name if llm_name else "default"
)
# Tool executor — injected or created
if tool_executor is not None:
self.tool_executor = tool_executor
else:
self.tool_executor = ToolExecutor(
user_api_key=user_api_key,
user=self.user,
decoded_token=decoded_token,
)
self.attachments = attachments or []
self.json_schema = None
if json_schema is not None:
try:
self.json_schema = normalize_json_schema_payload(json_schema)
except JsonSchemaValidationError as exc:
logger.warning("Ignoring invalid JSON schema payload: %s", exc)
self.limited_token_mode = limited_token_mode
self.token_limit = token_limit
self.limited_request_mode = limited_request_mode
self.request_limit = request_limit
self.compressed_summary = compressed_summary
self.current_token_count = 0
self.context_limit_reached = False
@log_activity()
def gen(
self, query: str, log_context: LogContext = None
self, query: str, retriever: BaseRetriever, log_context: LogContext = None
) -> Generator[Dict, None, None]:
yield from self._gen_inner(query, log_context)
yield from self._gen_inner(query, retriever, log_context)
@abstractmethod
def _gen_inner(
self, query: str, log_context: LogContext
self, query: str, retriever: BaseRetriever, log_context: LogContext
) -> Generator[Dict, None, None]:
pass
def gen_continuation(
self,
messages: List[Dict],
tools_dict: Dict,
pending_tool_calls: List[Dict],
tool_actions: List[Dict],
) -> Generator[Dict, None, None]:
"""Resume generation after tool actions are resolved.
Processes the client-provided *tool_actions* (approvals, denials,
or client-side results), appends the resulting messages, then
hands back to the LLM to continue the conversation.
Args:
messages: The saved messages array from the pause point.
tools_dict: The saved tools dictionary.
pending_tool_calls: The pending tool call descriptors from the pause.
tool_actions: Client-provided actions resolving the pending calls.
"""
self._prepare_tools(tools_dict)
actions_by_id = {a["call_id"]: a for a in tool_actions}
# Build a single assistant message containing all tool calls so
# the message history matches the format LLM providers expect
# (one assistant message with N tool_calls, followed by N tool results).
tc_objects: List[Dict[str, Any]] = []
for pending in pending_tool_calls:
call_id = pending["call_id"]
args = pending["arguments"]
args_str = (
json.dumps(args) if isinstance(args, dict) else (args or "{}")
)
tc_obj: Dict[str, Any] = {
"id": call_id,
"type": "function",
"function": {
"name": pending["name"],
"arguments": args_str,
},
}
if pending.get("thought_signature"):
tc_obj["thought_signature"] = pending["thought_signature"]
tc_objects.append(tc_obj)
messages.append({
"role": "assistant",
"content": None,
"tool_calls": tc_objects,
})
# Now process each pending call and append tool result messages
for pending in pending_tool_calls:
call_id = pending["call_id"]
args = pending["arguments"]
action = actions_by_id.get(call_id)
if not action:
action = {
"call_id": call_id,
"decision": "denied",
"comment": "No response provided",
}
if action.get("decision") == "approved":
# Execute the tool server-side
tc = ToolCall(
id=call_id,
name=pending["name"],
arguments=(
json.dumps(args) if isinstance(args, dict) else args
),
)
tool_gen = self._execute_tool_action(tools_dict, tc)
tool_response = None
while True:
try:
event = next(tool_gen)
yield event
except StopIteration as e:
tool_response, _ = e.value
break
messages.append(
self.llm_handler.create_tool_message(tc, tool_response)
)
elif action.get("decision") == "denied":
comment = action.get("comment", "")
denial = (
f"Tool execution denied by user. Reason: {comment}"
if comment
else "Tool execution denied by user."
)
tc = ToolCall(
id=call_id, name=pending["name"], arguments=args
)
messages.append(
self.llm_handler.create_tool_message(tc, denial)
)
yield {
"type": "tool_call",
"data": {
"tool_name": pending.get("tool_name", "unknown"),
"call_id": call_id,
"action_name": pending.get("llm_name", pending["name"]),
"arguments": args,
"status": "denied",
},
}
elif "result" in action:
result = action["result"]
result_str = (
json.dumps(result)
if not isinstance(result, str)
else result
)
tc = ToolCall(
id=call_id, name=pending["name"], arguments=args
)
messages.append(
self.llm_handler.create_tool_message(tc, result_str)
)
yield {
"type": "tool_call",
"data": {
"tool_name": pending.get("tool_name", "unknown"),
"call_id": call_id,
"action_name": pending.get("llm_name", pending["name"]),
"arguments": args,
"result": (
result_str[:50] + "..."
if len(result_str) > 50
else result_str
),
"status": "completed",
},
}
# Resume the LLM loop with the updated messages
llm_response = self._llm_gen(messages)
yield from self._handle_response(
llm_response, tools_dict, messages, None
)
yield {"sources": self.retrieved_docs}
yield {"tool_calls": self._get_truncated_tool_calls()}
# ---- Tool delegation (thin wrappers around ToolExecutor) ----
@property
def tool_calls(self) -> List[Dict]:
return self.tool_executor.tool_calls
@tool_calls.setter
def tool_calls(self, value: List[Dict]):
self.tool_executor.tool_calls = value
def _get_tools(self, api_key: str = None) -> Dict[str, Dict]:
return self.tool_executor._get_tools_by_api_key(api_key or self.user_api_key)
mongo = MongoDB.get_client()
db = mongo[settings.MONGO_DB_NAME]
agents_collection = db["agents"]
tools_collection = db["user_tools"]
agent_data = agents_collection.find_one({"key": api_key or self.user_api_key})
tool_ids = agent_data.get("tools", []) if agent_data else []
tools = (
tools_collection.find(
{"_id": {"$in": [ObjectId(tool_id) for tool_id in tool_ids]}}
)
if tool_ids
else []
)
tools = list(tools)
tools_by_id = {str(tool["_id"]): tool for tool in tools} if tools else {}
return tools_by_id
def _get_user_tools(self, user="local"):
return self.tool_executor._get_user_tools(user)
mongo = MongoDB.get_client()
db = mongo[settings.MONGO_DB_NAME]
user_tools_collection = db["user_tools"]
user_tools = user_tools_collection.find({"user": user, "status": True})
user_tools = list(user_tools)
return {str(i): tool for i, tool in enumerate(user_tools)}
def _build_tool_parameters(self, action):
return self.tool_executor._build_tool_parameters(action)
params = {"type": "object", "properties": {}, "required": []}
for param_type in ["query_params", "headers", "body", "parameters"]:
if param_type in action and action[param_type].get("properties"):
for k, v in action[param_type]["properties"].items():
if v.get("filled_by_llm", True):
params["properties"][k] = {
key: value
for key, value in v.items()
if key != "filled_by_llm" and key != "value"
}
params["required"].append(k)
return params
def _prepare_tools(self, tools_dict):
self.tools = self.tool_executor.prepare_tools_for_llm(tools_dict)
self.tools = [
{
"type": "function",
"function": {
"name": f"{action['name']}_{tool_id}",
"description": action["description"],
"parameters": self._build_tool_parameters(action),
},
}
for tool_id, tool in tools_dict.items()
if (
(tool["name"] == "api_tool" and "actions" in tool.get("config", {}))
or (tool["name"] != "api_tool" and "actions" in tool)
)
for action in (
tool["config"]["actions"].values()
if tool["name"] == "api_tool"
else tool["actions"]
)
if action.get("active", True)
]
def _execute_tool_action(self, tools_dict, call):
return self.tool_executor.execute(
tools_dict, call, self.llm.__class__.__name__
parser = ToolActionParser(self.llm.__class__.__name__)
tool_id, action_name, call_args = parser.parse_args(call)
call_id = getattr(call, "id", None) or str(uuid.uuid4())
tool_call_data = {
"tool_name": tools_dict[tool_id]["name"],
"call_id": call_id,
"action_name": f"{action_name}_{tool_id}",
"arguments": call_args,
}
yield {"type": "tool_call", "data": {**tool_call_data, "status": "pending"}}
tool_data = tools_dict[tool_id]
action_data = (
tool_data["config"]["actions"][action_name]
if tool_data["name"] == "api_tool"
else next(
action
for action in tool_data["actions"]
if action["name"] == action_name
)
)
query_params, headers, body, parameters = {}, {}, {}, {}
param_types = {
"query_params": query_params,
"headers": headers,
"body": body,
"parameters": parameters,
}
for param_type, target_dict in param_types.items():
if param_type in action_data and action_data[param_type].get("properties"):
for param, details in action_data[param_type]["properties"].items():
if param not in call_args and "value" in details:
target_dict[param] = details["value"]
for param, value in call_args.items():
for param_type, target_dict in param_types.items():
if param_type in action_data and param in action_data[param_type].get(
"properties", {}
):
target_dict[param] = value
tm = ToolManager(config={})
tool = tm.load_tool(
tool_data["name"],
tool_config=(
{
"url": tool_data["config"]["actions"][action_name]["url"],
"method": tool_data["config"]["actions"][action_name]["method"],
"headers": headers,
"query_params": query_params,
}
if tool_data["name"] == "api_tool"
else tool_data["config"]
),
)
if tool_data["name"] == "api_tool":
print(
f"Executing api: {action_name} with query_params: {query_params}, headers: {headers}, body: {body}"
)
result = tool.execute_action(action_name, **body)
else:
print(f"Executing tool: {action_name} with args: {call_args}")
result = tool.execute_action(action_name, **parameters)
tool_call_data["result"] = (
f"{str(result)[:50]}..." if len(str(result)) > 50 else result
)
yield {"type": "tool_call", "data": {**tool_call_data, "status": "completed"}}
self.tool_calls.append(tool_call_data)
return result, call_id
def _get_truncated_tool_calls(self):
return self.tool_executor.get_truncated_tool_calls()
# ---- Context / token management ----
def _calculate_current_context_tokens(self, messages: List[Dict]) -> int:
from application.api.answer.services.compression.token_counter import (
TokenCounter,
)
return TokenCounter.count_message_tokens(messages)
def _check_context_limit(self, messages: List[Dict]) -> bool:
from application.core.model_utils import get_token_limit
try:
current_tokens = self._calculate_current_context_tokens(messages)
self.current_token_count = current_tokens
context_limit = get_token_limit(
self.model_id, user_id=self.model_user_id or self.user
)
threshold = int(context_limit * settings.COMPRESSION_THRESHOLD_PERCENTAGE)
if current_tokens >= threshold:
logger.warning(
f"Context limit approaching: {current_tokens}/{context_limit} tokens "
f"({(current_tokens/context_limit)*100:.1f}%)"
)
return True
return False
except Exception as e:
logger.error(f"Error checking context limit: {str(e)}", exc_info=True)
return False
def _validate_context_size(self, messages: List[Dict]) -> None:
from application.core.model_utils import get_token_limit
current_tokens = self._calculate_current_context_tokens(messages)
self.current_token_count = current_tokens
context_limit = get_token_limit(
self.model_id, user_id=self.model_user_id or self.user
)
percentage = (current_tokens / context_limit) * 100
if current_tokens >= context_limit:
logger.warning(
f"Context at limit: {current_tokens:,}/{context_limit:,} tokens "
f"({percentage:.1f}%). Model: {self.model_id}"
)
elif current_tokens >= int(
context_limit * settings.COMPRESSION_THRESHOLD_PERCENTAGE
):
logger.info(
f"Context approaching limit: {current_tokens:,}/{context_limit:,} tokens "
f"({percentage:.1f}%)"
)
def _truncate_text_middle(self, text: str, max_tokens: int) -> str:
from application.utils import num_tokens_from_string
current_tokens = num_tokens_from_string(text)
if current_tokens <= max_tokens:
return text
chars_per_token = len(text) / current_tokens if current_tokens > 0 else 4
target_chars = int(max_tokens * chars_per_token * 0.95)
if target_chars <= 0:
return ""
start_chars = int(target_chars * 0.4)
end_chars = int(target_chars * 0.4)
truncation_marker = "\n\n[... content truncated to fit context limit ...]\n\n"
truncated = text[:start_chars] + truncation_marker + text[-end_chars:]
logger.info(
f"Truncated text from {current_tokens:,} to ~{max_tokens:,} tokens "
f"(removed middle section)"
)
return truncated
# ---- Message building ----
return [
{
**tool_call,
"result": (
f"{str(tool_call['result'])[:50]}..."
if len(str(tool_call["result"])) > 50
else tool_call["result"]
),
"status": "completed",
}
for tool_call in self.tool_calls
]
def _build_messages(
self,
system_prompt: str,
query: str,
retrieved_data: List[Dict],
) -> List[Dict]:
"""Build messages using pre-rendered system prompt"""
from application.core.model_utils import get_token_limit
from application.utils import num_tokens_from_string
docs_together = "\n".join([doc["text"] for doc in retrieved_data])
p_chat_combine = system_prompt.replace("{summaries}", docs_together)
messages_combine = [{"role": "system", "content": p_chat_combine}]
if self.compressed_summary:
compression_context = (
"\n\n---\n\n"
"This session is being continued from a previous conversation that "
"has been compressed to fit within context limits. "
"The conversation is summarized below:\n\n"
f"{self.compressed_summary}"
)
system_prompt = system_prompt + compression_context
context_limit = get_token_limit(
self.model_id, user_id=self.model_user_id or self.user
)
system_tokens = num_tokens_from_string(system_prompt)
safety_buffer = int(context_limit * 0.1)
available_after_system = context_limit - system_tokens - safety_buffer
max_query_tokens = int(available_after_system * 0.8)
query_tokens = num_tokens_from_string(query)
if query_tokens > max_query_tokens:
query = self._truncate_text_middle(query, max_query_tokens)
query_tokens = num_tokens_from_string(query)
available_for_history = max(available_after_system - query_tokens, 0)
working_history = self._truncate_history_to_fit(
self.chat_history,
available_for_history,
)
messages = [{"role": "system", "content": system_prompt}]
for i in working_history:
for i in self.chat_history:
if "prompt" in i and "response" in i:
messages.append({"role": "user", "content": i["prompt"]})
messages.append({"role": "assistant", "content": i["response"]})
messages_combine.append({"role": "user", "content": i["prompt"]})
messages_combine.append({"role": "assistant", "content": i["response"]})
if "tool_calls" in i:
for tool_call in i["tool_calls"]:
call_id = tool_call.get("call_id") or str(uuid.uuid4())
args = tool_call.get("arguments")
args_str = (
json.dumps(args)
if isinstance(args, dict)
else (args or "{}")
)
messages.append({
"role": "assistant",
"content": None,
"tool_calls": [{
"id": call_id,
"type": "function",
"function": {
"name": tool_call.get("action_name", ""),
"arguments": args_str,
},
}],
})
result = tool_call.get("result")
result_str = (
json.dumps(result)
if not isinstance(result, str)
else (result or "")
)
messages.append({
"role": "tool",
"tool_call_id": call_id,
"content": result_str,
})
messages.append({"role": "user", "content": query})
return messages
def _truncate_history_to_fit(
function_call_dict = {
"function_call": {
"name": tool_call.get("action_name"),
"args": tool_call.get("arguments"),
"call_id": call_id,
}
}
function_response_dict = {
"function_response": {
"name": tool_call.get("action_name"),
"response": {"result": tool_call.get("result")},
"call_id": call_id,
}
}
messages_combine.append(
{"role": "assistant", "content": [function_call_dict]}
)
messages_combine.append(
{"role": "tool", "content": [function_response_dict]}
)
messages_combine.append({"role": "user", "content": query})
return messages_combine
def _retriever_search(
self,
history: List[Dict],
max_tokens: int,
retriever: BaseRetriever,
query: str,
log_context: Optional[LogContext] = None,
) -> List[Dict]:
from application.utils import num_tokens_from_string
if not history or max_tokens <= 0:
return []
truncated = []
current_tokens = 0
for message in reversed(history):
message_tokens = 0
if "prompt" in message and "response" in message:
message_tokens += num_tokens_from_string(message["prompt"])
message_tokens += num_tokens_from_string(message["response"])
if "tool_calls" in message:
for tool_call in message["tool_calls"]:
tool_str = (
f"Tool: {tool_call.get('tool_name')} | "
f"Action: {tool_call.get('action_name')} | "
f"Args: {tool_call.get('arguments')} | "
f"Response: {tool_call.get('result')}"
)
message_tokens += num_tokens_from_string(tool_str)
if current_tokens + message_tokens <= max_tokens:
current_tokens += message_tokens
truncated.insert(0, message)
else:
break
if len(truncated) < len(history):
logger.info(
f"Truncated chat history from {len(history)} to {len(truncated)} messages "
f"to fit within {max_tokens:,} token budget"
)
return truncated
# ---- LLM generation ----
retrieved_data = retriever.search(query)
if log_context:
data = build_stack_data(retriever, exclude_attributes=["llm"])
log_context.stacks.append({"component": "retriever", "data": data})
return retrieved_data
def _llm_gen(self, messages: List[Dict], log_context: Optional[LogContext] = None):
self._validate_context_size(messages)
# Use the upstream id resolved by LLMCreator (see __init__).
# Built-in models: same as self.model_id. BYOM: the user's
# typed model name, not the internal UUID.
gen_kwargs = {"model": self.upstream_model_id, "messages": messages}
if self.attachments:
gen_kwargs["_usage_attachments"] = self.attachments
gen_kwargs = {"model": self.gpt_model, "messages": messages}
if (
hasattr(self.llm, "_supports_tools")
@@ -528,19 +283,6 @@ class BaseAgent(ABC):
and self.tools
):
gen_kwargs["tools"] = self.tools
if (
self.json_schema
and hasattr(self.llm, "_supports_structured_output")
and self.llm._supports_structured_output()
):
structured_format = self.llm.prepare_structured_output_format(
self.json_schema
)
if structured_format:
if self.llm_name == "openai":
gen_kwargs["response_format"] = structured_format
elif self.llm_name == "google":
gen_kwargs["response_schema"] = structured_format
resp = self.llm.gen_stream(**gen_kwargs)
if log_context:
@@ -565,25 +307,11 @@ class BaseAgent(ABC):
return resp
def _handle_response(self, response, tools_dict, messages, log_context):
is_structured_output = (
self.json_schema is not None
and hasattr(self.llm, "_supports_structured_output")
and self.llm._supports_structured_output()
)
if isinstance(response, str):
answer_data = {"answer": response}
if is_structured_output:
answer_data["structured"] = True
answer_data["schema"] = self.json_schema
yield answer_data
yield {"answer": response}
return
if hasattr(response, "message") and getattr(response.message, "content", None):
answer_data = {"answer": response.message.content}
if is_structured_output:
answer_data["structured"] = True
answer_data["schema"] = self.json_schema
yield answer_data
yield {"answer": response.message.content}
return
processed_response_gen = self._llm_handler(
response, tools_dict, messages, log_context, self.attachments
@@ -591,16 +319,8 @@ class BaseAgent(ABC):
for event in processed_response_gen:
if isinstance(event, str):
answer_data = {"answer": event}
if is_structured_output:
answer_data["structured"] = True
answer_data["schema"] = self.json_schema
yield answer_data
yield {"answer": event}
elif hasattr(event, "message") and getattr(event.message, "content", None):
answer_data = {"answer": event.message.content}
if is_structured_output:
answer_data["structured"] = True
answer_data["schema"] = self.json_schema
yield answer_data
yield {"answer": event.message.content}
elif isinstance(event, dict) and "type" in event:
yield event

View File

@@ -1,33 +1,53 @@
import logging
from typing import Dict, Generator
from application.agents.base import BaseAgent
from application.logging import LogContext
from application.retriever.base import BaseRetriever
import logging
logger = logging.getLogger(__name__)
class ClassicAgent(BaseAgent):
"""A simplified agent with clear execution flow"""
"""A simplified agent with clear execution flow.
Usage:
1. Processes a query through retrieval
2. Sets up available tools
3. Generates responses using LLM
4. Handles tool interactions if needed
5. Returns standardized outputs
Easy to extend by overriding specific steps.
"""
def _gen_inner(
self, query: str, log_context: LogContext
self, query: str, retriever: BaseRetriever, log_context: LogContext
) -> Generator[Dict, None, None]:
"""Core generator function for ClassicAgent execution flow"""
# Step 1: Retrieve relevant data
retrieved_data = self._retriever_search(retriever, query, log_context)
tools_dict = self.tool_executor.get_tools()
# Step 2: Prepare tools
tools_dict = (
self._get_user_tools(self.user)
if not self.user_api_key
else self._get_tools(self.user_api_key)
)
self._prepare_tools(tools_dict)
messages = self._build_messages(self.prompt, query)
# Step 3: Build and process messages
messages = self._build_messages(self.prompt, query, retrieved_data)
llm_response = self._llm_gen(messages, log_context)
# Step 4: Handle the response
yield from self._handle_response(
llm_response, tools_dict, messages, log_context
)
yield {"sources": self.retrieved_docs}
# Step 5: Return metadata
yield {"sources": retrieved_data}
yield {"tool_calls": self._get_truncated_tool_calls()}
# Log tool calls for debugging
log_context.stacks.append(
{"component": "agent", "data": {"tool_calls": self.tool_calls.copy()}}
)

View File

@@ -0,0 +1,229 @@
import os
from typing import Dict, Generator, List, Any
import logging
from application.agents.base import BaseAgent
from application.logging import build_stack_data, LogContext
from application.retriever.base import BaseRetriever
logger = logging.getLogger(__name__)
current_dir = os.path.dirname(
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
)
with open(
os.path.join(current_dir, "application/prompts", "react_planning_prompt.txt"), "r"
) as f:
planning_prompt_template = f.read()
with open(
os.path.join(current_dir, "application/prompts", "react_final_prompt.txt"),
"r",
) as f:
final_prompt_template = f.read()
MAX_ITERATIONS_REASONING = 10
class ReActAgent(BaseAgent):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.plan: str = ""
self.observations: List[str] = []
def _extract_content_from_llm_response(self, resp: Any) -> str:
"""
Helper to extract string content from various LLM response types.
Handles strings, message objects (OpenAI-like), and streams.
Adapt stream handling for your specific LLM client if not OpenAI.
"""
collected_content = []
if isinstance(resp, str):
collected_content.append(resp)
elif ( # OpenAI non-streaming or Anthropic non-streaming (older SDK style)
hasattr(resp, "message")
and hasattr(resp.message, "content")
and resp.message.content is not None
):
collected_content.append(resp.message.content)
elif ( # OpenAI non-streaming (Pydantic model), Anthropic new SDK non-streaming
hasattr(resp, "choices") and resp.choices and
hasattr(resp.choices[0], "message") and
hasattr(resp.choices[0].message, "content") and
resp.choices[0].message.content is not None
):
collected_content.append(resp.choices[0].message.content) # OpenAI
elif ( # Anthropic new SDK non-streaming content block
hasattr(resp, "content") and isinstance(resp.content, list) and resp.content and
hasattr(resp.content[0], "text")
):
collected_content.append(resp.content[0].text) # Anthropic
else:
# Assume resp is a stream if not a recognized object
try:
for chunk in resp: # This will fail if resp is not iterable (e.g. a non-streaming response object)
content_piece = ""
# OpenAI-like stream
if hasattr(chunk, 'choices') and len(chunk.choices) > 0 and \
hasattr(chunk.choices[0], 'delta') and \
hasattr(chunk.choices[0].delta, 'content') and \
chunk.choices[0].delta.content is not None:
content_piece = chunk.choices[0].delta.content
# Anthropic-like stream (ContentBlockDelta)
elif hasattr(chunk, 'type') and chunk.type == 'content_block_delta' and \
hasattr(chunk, 'delta') and hasattr(chunk.delta, 'text'):
content_piece = chunk.delta.text
elif isinstance(chunk, str): # Simplest case: stream of strings
content_piece = chunk
if content_piece:
collected_content.append(content_piece)
except TypeError: # If resp is not iterable (e.g. a final response object that wasn't caught above)
logger.debug(f"Response type {type(resp)} could not be iterated as a stream. It might be a non-streaming object not handled by specific checks.")
except Exception as e:
logger.error(f"Error processing potential stream chunk: {e}, chunk was: {getattr(chunk, '__dict__', chunk)}")
return "".join(collected_content)
def _gen_inner(
self, query: str, retriever: BaseRetriever, log_context: LogContext
) -> Generator[Dict, None, None]:
# Reset state for this generation call
self.plan = ""
self.observations = []
retrieved_data = self._retriever_search(retriever, query, log_context)
if self.user_api_key:
tools_dict = self._get_tools(self.user_api_key)
else:
tools_dict = self._get_user_tools(self.user)
self._prepare_tools(tools_dict)
docs_together = "\n".join([doc["text"] for doc in retrieved_data])
iterating_reasoning = 0
while iterating_reasoning < MAX_ITERATIONS_REASONING:
iterating_reasoning += 1
# 1. Create Plan
logger.info("ReActAgent: Creating plan...")
plan_stream = self._create_plan(query, docs_together, log_context)
current_plan_parts = []
yield {"thought": f"Reasoning... (iteration {iterating_reasoning})\n\n"}
for line_chunk in plan_stream:
current_plan_parts.append(line_chunk)
yield {"thought": line_chunk}
self.plan = "".join(current_plan_parts)
if self.plan:
self.observations.append(f"Plan: {self.plan} Iteration: {iterating_reasoning}")
max_obs_len = 20000
obs_str = "\n".join(self.observations)
if len(obs_str) > max_obs_len:
obs_str = obs_str[:max_obs_len] + "\n...[observations truncated]"
execution_prompt_str = (
(self.prompt or "")
+ f"\n\nFollow this plan:\n{self.plan}"
+ f"\n\nObservations:\n{obs_str}"
+ f"\n\nIf there is enough data to complete user query '{query}', Respond with 'SATISFIED' only. Otherwise, continue. Dont Menstion 'SATISFIED' in your response if you are not ready. "
)
messages = self._build_messages(execution_prompt_str, query, retrieved_data)
resp_from_llm_gen = self._llm_gen(messages, log_context)
initial_llm_thought_content = self._extract_content_from_llm_response(resp_from_llm_gen)
if initial_llm_thought_content:
self.observations.append(f"Initial thought/response: {initial_llm_thought_content}")
else:
logger.info("ReActAgent: Initial LLM response (before handler) had no textual content (might be only tool calls).")
resp_after_handler = self._llm_handler(resp_from_llm_gen, tools_dict, messages, log_context)
for tool_call_info in self.tool_calls: # Iterate over self.tool_calls populated by _llm_handler
observation_string = (
f"Executed Action: Tool '{tool_call_info.get('tool_name', 'N/A')}' "
f"with arguments '{tool_call_info.get('arguments', '{}')}'. Result: '{str(tool_call_info.get('result', ''))[:200]}...'"
)
self.observations.append(observation_string)
content_after_handler = self._extract_content_from_llm_response(resp_after_handler)
if content_after_handler:
self.observations.append(f"Response after tool execution: {content_after_handler}")
else:
logger.info("ReActAgent: LLM response after handler had no textual content.")
if log_context:
log_context.stacks.append(
{"component": "agent_tool_calls", "data": {"tool_calls": self.tool_calls.copy()}}
)
yield {"sources": retrieved_data}
display_tool_calls = []
for tc in self.tool_calls:
cleaned_tc = tc.copy()
if len(str(cleaned_tc.get("result", ""))) > 50:
cleaned_tc["result"] = str(cleaned_tc["result"])[:50] + "..."
display_tool_calls.append(cleaned_tc)
if display_tool_calls:
yield {"tool_calls": display_tool_calls}
if "SATISFIED" in content_after_handler:
logger.info("ReActAgent: LLM satisfied with the plan and data. Stopping reasoning.")
break
# 3. Create Final Answer based on all observations
final_answer_stream = self._create_final_answer(query, self.observations, log_context)
for answer_chunk in final_answer_stream:
yield {"answer": answer_chunk}
logger.info("ReActAgent: Finished generating final answer.")
def _create_plan(
self, query: str, docs_data: str, log_context: LogContext = None
) -> Generator[str, None, None]:
plan_prompt_filled = planning_prompt_template.replace("{query}", query)
if "{summaries}" in plan_prompt_filled:
summaries = docs_data if docs_data else "No documents retrieved."
plan_prompt_filled = plan_prompt_filled.replace("{summaries}", summaries)
plan_prompt_filled = plan_prompt_filled.replace("{prompt}", self.prompt or "")
plan_prompt_filled = plan_prompt_filled.replace("{observations}", "\n".join(self.observations))
messages = [{"role": "user", "content": plan_prompt_filled}]
plan_stream_from_llm = self.llm.gen_stream(
model=self.gpt_model, messages=messages, tools=getattr(self, 'tools', None) # Use self.tools
)
if log_context:
data = build_stack_data(self.llm)
log_context.stacks.append({"component": "planning_llm", "data": data})
for chunk in plan_stream_from_llm:
content_piece = self._extract_content_from_llm_response(chunk)
if content_piece:
yield content_piece
def _create_final_answer(
self, query: str, observations: List[str], log_context: LogContext = None
) -> Generator[str, None, None]:
observation_string = "\n".join(observations)
max_obs_len = 10000
if len(observation_string) > max_obs_len:
observation_string = observation_string[:max_obs_len] + "\n...[observations truncated]"
logger.warning("ReActAgent: Truncated observations for final answer prompt due to length.")
final_answer_prompt_filled = final_prompt_template.format(
query=query, observations=observation_string
)
messages = [{"role": "user", "content": final_answer_prompt_filled}]
# Final answer should synthesize, not call tools.
final_answer_stream_from_llm = self.llm.gen_stream(
model=self.gpt_model, messages=messages, tools=None
)
if log_context:
data = build_stack_data(self.llm)
log_context.stacks.append({"component": "final_answer_llm", "data": data})
for chunk in final_answer_stream_from_llm:
content_piece = self._extract_content_from_llm_response(chunk)
if content_piece:
yield content_piece

View File

@@ -1,698 +0,0 @@
import json
import logging
import os
import time
from typing import Dict, Generator, List, Optional
from application.agents.base import BaseAgent
from application.agents.tool_executor import ToolExecutor
from application.agents.tools.internal_search import (
INTERNAL_TOOL_ID,
add_internal_search_tool,
)
from application.agents.tools.think import THINK_TOOL_ENTRY, THINK_TOOL_ID
from application.logging import LogContext
logger = logging.getLogger(__name__)
# Defaults (can be overridden via constructor)
DEFAULT_MAX_STEPS = 6
DEFAULT_MAX_SUB_ITERATIONS = 5
DEFAULT_TIMEOUT_SECONDS = 300 # 5 minutes
DEFAULT_TOKEN_BUDGET = 100_000
DEFAULT_PARALLEL_WORKERS = 3
# Adaptive depth caps per complexity level
COMPLEXITY_CAPS = {
"simple": 2,
"moderate": 4,
"complex": 6,
}
_PROMPTS_DIR = os.path.join(
os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
"prompts",
"research",
)
def _load_prompt(name: str) -> str:
with open(os.path.join(_PROMPTS_DIR, name), "r") as f:
return f.read()
CLARIFICATION_PROMPT = _load_prompt("clarification.txt")
PLANNING_PROMPT = _load_prompt("planning.txt")
STEP_PROMPT = _load_prompt("step.txt")
SYNTHESIS_PROMPT = _load_prompt("synthesis.txt")
# ---------------------------------------------------------------------------
# CitationManager
# ---------------------------------------------------------------------------
class CitationManager:
"""Tracks and deduplicates citations across research steps."""
def __init__(self):
self.citations: Dict[int, Dict] = {}
self._counter = 0
def add(self, doc: Dict) -> int:
"""Register a source, return its citation number. Deduplicates by source."""
source = doc.get("source", "")
title = doc.get("title", "")
for num, existing in self.citations.items():
if existing.get("source") == source and existing.get("title") == title:
return num
self._counter += 1
self.citations[self._counter] = doc
return self._counter
def add_docs(self, docs: List[Dict]) -> str:
"""Register multiple docs, return formatted citation mapping text."""
mapping_lines = []
for doc in docs:
num = self.add(doc)
title = doc.get("title", "Untitled")
mapping_lines.append(f"[{num}] {title}")
return "\n".join(mapping_lines)
def format_references(self) -> str:
"""Generate [N] -> source mapping for report footer."""
if not self.citations:
return "No sources found."
lines = []
for num, doc in sorted(self.citations.items()):
title = doc.get("title", "Untitled")
source = doc.get("source", "Unknown")
filename = doc.get("filename", "")
display = filename or title
lines.append(f"[{num}] {display}{source}")
return "\n".join(lines)
def get_all_docs(self) -> List[Dict]:
return list(self.citations.values())
# ---------------------------------------------------------------------------
# ResearchAgent
# ---------------------------------------------------------------------------
class ResearchAgent(BaseAgent):
"""Multi-step research agent with parallel execution and budget controls.
Orchestrates: Plan -> Research (per step, optionally parallel) -> Synthesize.
"""
def __init__(
self,
retriever_config: Optional[Dict] = None,
max_steps: int = DEFAULT_MAX_STEPS,
max_sub_iterations: int = DEFAULT_MAX_SUB_ITERATIONS,
timeout_seconds: int = DEFAULT_TIMEOUT_SECONDS,
token_budget: int = DEFAULT_TOKEN_BUDGET,
parallel_workers: int = DEFAULT_PARALLEL_WORKERS,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
self.retriever_config = retriever_config or {}
self.max_steps = max_steps
self.max_sub_iterations = max_sub_iterations
self.timeout_seconds = timeout_seconds
self.token_budget = token_budget
self.parallel_workers = parallel_workers
self.citations = CitationManager()
self._start_time: float = 0
self._tokens_used: int = 0
self._last_token_snapshot: int = 0
# ------------------------------------------------------------------
# Budget & timeout helpers
# ------------------------------------------------------------------
def _is_timed_out(self) -> bool:
return (time.monotonic() - self._start_time) >= self.timeout_seconds
def _elapsed(self) -> float:
return round(time.monotonic() - self._start_time, 1)
def _track_tokens(self, count: int):
self._tokens_used += count
def _budget_remaining(self) -> int:
return max(self.token_budget - self._tokens_used, 0)
def _is_over_budget(self) -> bool:
return self._tokens_used >= self.token_budget
def _snapshot_llm_tokens(self) -> int:
"""Read current token usage from LLM and return delta since last snapshot."""
current = self.llm.token_usage.get("prompt_tokens", 0) + self.llm.token_usage.get("generated_tokens", 0)
delta = current - self._last_token_snapshot
self._last_token_snapshot = current
return delta
# ------------------------------------------------------------------
# Main orchestration
# ------------------------------------------------------------------
def _gen_inner(
self, query: str, log_context: LogContext
) -> Generator[Dict, None, None]:
self._start_time = time.monotonic()
tools_dict = self._setup_tools()
# Phase 0: Clarification (skip if user is responding to a prior clarification)
if not self._is_follow_up():
clarification = self._clarification_phase(query)
if clarification:
yield {"metadata": {"is_clarification": True}}
yield {"answer": clarification}
yield {"sources": []}
yield {"tool_calls": []}
log_context.stacks.append(
{"component": "agent", "data": {"clarification": True}}
)
return
# Phase 1: Planning (with adaptive depth)
yield {"type": "research_progress", "data": {"status": "planning"}}
plan, complexity = self._planning_phase(query)
if not plan:
logger.warning("ResearchAgent: Planning produced no steps, falling back")
plan = [{"query": query, "rationale": "Direct investigation"}]
complexity = "simple"
yield {
"type": "research_plan",
"data": {"steps": plan, "complexity": complexity},
}
# Phase 2: Research each step (yields progress events in real-time)
intermediate_reports = []
for i, step in enumerate(plan):
step_num = i + 1
step_query = step.get("query", query)
if self._is_timed_out():
logger.warning(
f"ResearchAgent: Timeout at step {step_num}/{len(plan)} "
f"({self._elapsed()}s)"
)
break
if self._is_over_budget():
logger.warning(
f"ResearchAgent: Token budget exhausted at step {step_num}/{len(plan)}"
)
break
yield {
"type": "research_progress",
"data": {
"step": step_num,
"total": len(plan),
"query": step_query,
"status": "researching",
},
}
report = self._research_step(step_query, tools_dict)
intermediate_reports.append({"step": step, "content": report})
yield {
"type": "research_progress",
"data": {
"step": step_num,
"total": len(plan),
"query": step_query,
"status": "complete",
},
}
# Phase 3: Synthesis (streaming)
if self._is_timed_out():
logger.warning(
f"ResearchAgent: Timeout ({self._elapsed()}s) before synthesis, "
f"synthesizing with {len(intermediate_reports)} reports"
)
yield {
"type": "research_progress",
"data": {
"status": "synthesizing",
"elapsed_seconds": self._elapsed(),
"tokens_used": self._tokens_used,
},
}
yield from self._synthesis_phase(
query, plan, intermediate_reports, tools_dict, log_context
)
# Sources and tool calls
self.retrieved_docs = self.citations.get_all_docs()
yield {"sources": self.retrieved_docs}
yield {"tool_calls": self._get_truncated_tool_calls()}
logger.info(
f"ResearchAgent completed: {len(intermediate_reports)}/{len(plan)} steps, "
f"{self._elapsed()}s, ~{self._tokens_used} tokens"
)
log_context.stacks.append(
{"component": "agent", "data": {"tool_calls": self.tool_calls.copy()}}
)
# ------------------------------------------------------------------
# Tool setup
# ------------------------------------------------------------------
def _setup_tools(self) -> Dict:
"""Build tools_dict with user tools + internal search + think."""
tools_dict = self.tool_executor.get_tools()
add_internal_search_tool(tools_dict, self.retriever_config)
think_entry = dict(THINK_TOOL_ENTRY)
think_entry["config"] = {}
tools_dict[THINK_TOOL_ID] = think_entry
self._prepare_tools(tools_dict)
return tools_dict
# ------------------------------------------------------------------
# Phase 0: Clarification
# ------------------------------------------------------------------
def _is_follow_up(self) -> bool:
"""Check if the user is responding to a prior clarification.
Uses the metadata flag stored in the conversation DB — no string matching.
Only skip clarification when the last query was explicitly flagged
as a clarification by this agent.
"""
if not self.chat_history:
return False
last = self.chat_history[-1]
meta = last.get("metadata", {})
return bool(meta.get("is_clarification"))
def _clarification_phase(self, question: str) -> Optional[str]:
"""Ask the LLM whether the question needs clarification.
Returns formatted clarification text if needed, or None to proceed.
Uses response_format to force valid JSON output.
"""
messages = [
{"role": "system", "content": CLARIFICATION_PROMPT},
{"role": "user", "content": question},
]
try:
response = self.llm.gen(
model=self.upstream_model_id,
messages=messages,
tools=None,
response_format={"type": "json_object"},
)
text = self._extract_text(response)
self._track_tokens(self._snapshot_llm_tokens())
logger.info(f"ResearchAgent clarification response: {text[:300]}")
data = self._parse_clarification_json(text)
if not data or not data.get("needs_clarification"):
return None
questions = data.get("questions", [])
if not questions:
return None
# Format as a friendly response
lines = [
"Before I begin researching, I'd like to clarify a few things:\n"
]
for i, q in enumerate(questions[:3], 1):
lines.append(f"{i}. {q}")
lines.append(
"\nPlease provide these details and I'll start the research."
)
return "\n".join(lines)
except Exception as e:
logger.error(f"Clarification phase failed: {e}", exc_info=True)
return None # proceed with research on failure
def _parse_clarification_json(self, text: str) -> Optional[Dict]:
"""Parse clarification JSON from LLM response."""
try:
return json.loads(text)
except json.JSONDecodeError:
pass
# Try extracting from code fences
for marker in ["```json", "```"]:
if marker in text:
start = text.index(marker) + len(marker)
end = text.index("```", start) if "```" in text[start:] else len(text)
try:
return json.loads(text[start:end].strip())
except (json.JSONDecodeError, ValueError):
pass
# Try finding JSON object
for i, ch in enumerate(text):
if ch == "{":
for j in range(len(text) - 1, i, -1):
if text[j] == "}":
try:
return json.loads(text[i : j + 1])
except json.JSONDecodeError:
continue
break
return None
# ------------------------------------------------------------------
# Phase 1: Planning (with adaptive depth)
# ------------------------------------------------------------------
def _planning_phase(self, question: str) -> tuple[List[Dict], str]:
"""Decompose the question into research steps via LLM.
Returns (steps, complexity) where complexity is simple/moderate/complex.
"""
messages = [
{"role": "system", "content": PLANNING_PROMPT},
{"role": "user", "content": question},
]
try:
response = self.llm.gen(
model=self.upstream_model_id,
messages=messages,
tools=None,
response_format={"type": "json_object"},
)
text = self._extract_text(response)
self._track_tokens(self._snapshot_llm_tokens())
logger.info(f"ResearchAgent planning LLM response: {text[:500]}")
plan_data = self._parse_plan_json(text)
if isinstance(plan_data, dict):
complexity = plan_data.get("complexity", "moderate")
steps = plan_data.get("steps", [])
else:
complexity = "moderate"
steps = plan_data
# Adaptive depth: cap steps based on assessed complexity
cap = COMPLEXITY_CAPS.get(complexity, self.max_steps)
cap = min(cap, self.max_steps)
steps = steps[:cap]
logger.info(
f"ResearchAgent plan: complexity={complexity}, "
f"steps={len(steps)} (cap={cap})"
)
return steps, complexity
except Exception as e:
logger.error(f"Planning phase failed: {e}", exc_info=True)
return (
[{"query": question, "rationale": "Direct investigation (planning failed)"}],
"simple",
)
def _parse_plan_json(self, text: str):
"""Extract JSON plan from LLM response. Returns dict or list."""
# Try direct parse
try:
data = json.loads(text)
if isinstance(data, dict) and "steps" in data:
return data
if isinstance(data, list):
return data
except json.JSONDecodeError:
pass
# Try extracting from markdown code fences
for marker in ["```json", "```"]:
if marker in text:
start = text.index(marker) + len(marker)
end = text.index("```", start) if "```" in text[start:] else len(text)
try:
data = json.loads(text[start:end].strip())
if isinstance(data, dict) and "steps" in data:
return data
if isinstance(data, list):
return data
except (json.JSONDecodeError, ValueError):
pass
# Try finding JSON object in text
for i, ch in enumerate(text):
if ch == "{":
for j in range(len(text) - 1, i, -1):
if text[j] == "}":
try:
data = json.loads(text[i : j + 1])
if isinstance(data, dict) and "steps" in data:
return data
except json.JSONDecodeError:
continue
break
logger.warning(f"Could not parse plan JSON from: {text[:200]}")
return []
# ------------------------------------------------------------------
# Phase 2: Research step (core loop)
# ------------------------------------------------------------------
def _research_step(self, step_query: str, tools_dict: Dict) -> str:
"""Run a focused research loop for one sub-question (sequential path)."""
report = self._research_step_with_executor(
step_query, tools_dict, self.tool_executor
)
self._collect_step_sources()
return report
def _research_step_with_executor(
self, step_query: str, tools_dict: Dict, executor: ToolExecutor
) -> str:
"""Core research loop. Works with any ToolExecutor instance."""
system_prompt = STEP_PROMPT.replace("{step_query}", step_query)
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": step_query},
]
last_search_empty = False
for iteration in range(self.max_sub_iterations):
# Check timeout and budget
if self._is_timed_out():
logger.info(
f"Research step '{step_query[:50]}' timed out at iteration {iteration}"
)
break
if self._is_over_budget():
logger.info(
f"Research step '{step_query[:50]}' hit token budget at iteration {iteration}"
)
break
try:
response = self.llm.gen(
model=self.upstream_model_id,
messages=messages,
tools=self.tools if self.tools else None,
)
self._track_tokens(self._snapshot_llm_tokens())
except Exception as e:
logger.error(
f"Research step LLM call failed (iteration {iteration}): {e}",
exc_info=True,
)
break
parsed = self.llm_handler.parse_response(response)
if not parsed.requires_tool_call:
return parsed.content or "No findings for this step."
# Execute tool calls
messages, last_search_empty = self._execute_step_tools_with_refinement(
parsed.tool_calls, tools_dict, messages, executor, last_search_empty
)
# Max iterations / timeout / budget — ask for summary
messages.append(
{
"role": "user",
"content": "Please summarize your findings so far based on the information gathered.",
}
)
try:
response = self.llm.gen(
model=self.upstream_model_id, messages=messages, tools=None
)
self._track_tokens(self._snapshot_llm_tokens())
text = self._extract_text(response)
return text or "Research step completed."
except Exception:
return "Research step completed."
def _execute_step_tools_with_refinement(
self,
tool_calls,
tools_dict: Dict,
messages: List[Dict],
executor: ToolExecutor,
last_search_empty: bool,
) -> tuple[List[Dict], bool]:
"""Execute tool calls with query refinement on empty results.
Returns (updated_messages, was_last_search_empty).
"""
search_returned_empty = False
for call in tool_calls:
gen = executor.execute(
tools_dict, call, self.llm.__class__.__name__
)
result = None
call_id = None
while True:
try:
event = next(gen)
# Log tool_call status events instead of discarding them
if isinstance(event, dict) and event.get("type") == "tool_call":
logger.debug(
"Tool %s status: %s",
event.get("data", {}).get("action_name", ""),
event.get("data", {}).get("status", ""),
)
except StopIteration as e:
result, call_id = e.value
break
# Detect empty search results for refinement
is_search = "search" in (call.name or "").lower()
result_str = str(result) if result else ""
if is_search and "No documents found" in result_str:
search_returned_empty = True
if last_search_empty:
# Two consecutive empty searches — inject refinement hint
result_str += (
"\n\nHint: Previous search also returned no results. "
"Try a very different query with different keywords, "
"or broaden your search terms."
)
result = result_str
import json as _json
args_str = (
_json.dumps(call.arguments)
if isinstance(call.arguments, dict)
else call.arguments
)
messages.append({
"role": "assistant",
"content": None,
"tool_calls": [{
"id": call_id,
"type": "function",
"function": {"name": call.name, "arguments": args_str},
}],
})
tool_message = self.llm_handler.create_tool_message(call, result)
messages.append(tool_message)
return messages, search_returned_empty
def _collect_step_sources(self):
"""Collect sources from InternalSearchTool and register with CitationManager."""
cache_key = f"internal_search:{INTERNAL_TOOL_ID}:{self.user or ''}"
tool = self.tool_executor._loaded_tools.get(cache_key)
if tool and hasattr(tool, "retrieved_docs"):
for doc in tool.retrieved_docs:
self.citations.add(doc)
# ------------------------------------------------------------------
# Phase 3: Synthesis
# ------------------------------------------------------------------
def _synthesis_phase(
self,
question: str,
plan: List[Dict],
intermediate_reports: List[Dict],
tools_dict: Dict,
log_context: LogContext,
) -> Generator[Dict, None, None]:
"""Compile all findings into a final cited report (streaming)."""
plan_lines = []
for i, step in enumerate(plan, 1):
plan_lines.append(
f"{i}. {step.get('query', 'Unknown')}{step.get('rationale', '')}"
)
plan_summary = "\n".join(plan_lines)
findings_parts = []
for i, report in enumerate(intermediate_reports, 1):
step_query = report["step"].get("query", "Unknown")
content = report["content"]
findings_parts.append(
f"--- Step {i}: {step_query} ---\n{content}"
)
findings = "\n\n".join(findings_parts)
references = self.citations.format_references()
synthesis_prompt = SYNTHESIS_PROMPT.replace("{question}", question)
synthesis_prompt = synthesis_prompt.replace("{plan_summary}", plan_summary)
synthesis_prompt = synthesis_prompt.replace("{findings}", findings)
synthesis_prompt = synthesis_prompt.replace("{references}", references)
messages = [
{"role": "system", "content": synthesis_prompt},
{"role": "user", "content": f"Please write the research report for: {question}"},
]
llm_response = self.llm.gen_stream(
model=self.upstream_model_id, messages=messages, tools=None
)
if log_context:
from application.logging import build_stack_data
log_context.stacks.append(
{"component": "synthesis_llm", "data": build_stack_data(self.llm)}
)
yield from self._handle_response(
llm_response, tools_dict, messages, log_context
)
# ------------------------------------------------------------------
# Helpers
# ------------------------------------------------------------------
def _extract_text(self, response) -> str:
"""Extract text content from a non-streaming LLM response."""
if isinstance(response, str):
return response
if hasattr(response, "message") and hasattr(response.message, "content"):
return response.message.content or ""
if hasattr(response, "choices") and response.choices:
choice = response.choices[0]
if hasattr(choice, "message") and hasattr(choice.message, "content"):
return choice.message.content or ""
if hasattr(response, "content") and isinstance(response.content, list):
if response.content and hasattr(response.content[0], "text"):
return response.content[0].text or ""
return str(response) if response else ""

View File

@@ -1,519 +0,0 @@
import logging
import uuid
from collections import Counter
from typing import Dict, List, Optional, Tuple
from application.agents.tools.tool_action_parser import ToolActionParser
from application.agents.tools.tool_manager import ToolManager
from application.security.encryption import decrypt_credentials
from application.storage.db.repositories.agents import AgentsRepository
from application.storage.db.repositories.user_tools import UserToolsRepository
from application.storage.db.session import db_readonly
logger = logging.getLogger(__name__)
class ToolExecutor:
"""Handles tool discovery, preparation, and execution.
Extracted from BaseAgent to separate concerns and enable tool caching.
"""
def __init__(
self,
user_api_key: Optional[str] = None,
user: Optional[str] = None,
decoded_token: Optional[Dict] = None,
):
self.user_api_key = user_api_key
self.user = user
self.decoded_token = decoded_token
self.tool_calls: List[Dict] = []
self._loaded_tools: Dict[str, object] = {}
self.conversation_id: Optional[str] = None
self.client_tools: Optional[List[Dict]] = None
self._name_to_tool: Dict[str, Tuple[str, str]] = {}
self._tool_to_name: Dict[Tuple[str, str], str] = {}
def get_tools(self) -> Dict[str, Dict]:
"""Load tool configs from DB based on user context.
If *client_tools* have been set on this executor, they are
automatically merged into the returned dict.
"""
if self.user_api_key:
tools = self._get_tools_by_api_key(self.user_api_key)
else:
tools = self._get_user_tools(self.user or "local")
if self.client_tools:
self.merge_client_tools(tools, self.client_tools)
return tools
def _get_tools_by_api_key(self, api_key: str) -> Dict[str, Dict]:
# Per-operation session: the answer pipeline spans a long-lived
# generator; wrapping it in a single connection would pin a PG
# conn for the whole stream. Open, fetch, close.
with db_readonly() as conn:
agent_data = AgentsRepository(conn).find_by_key(api_key)
tool_ids = agent_data.get("tools", []) if agent_data else []
if not tool_ids:
return {}
tools_repo = UserToolsRepository(conn)
tools: List[Dict] = []
owner = (agent_data.get("user_id") or agent_data.get("user")) if agent_data else None
for tid in tool_ids:
row = None
if owner:
row = tools_repo.get_any(str(tid), owner)
if row is not None:
tools.append(row)
return {str(tool["id"]): tool for tool in tools} if tools else {}
def _get_user_tools(self, user: str = "local") -> Dict[str, Dict]:
with db_readonly() as conn:
user_tools = UserToolsRepository(conn).list_active_for_user(user)
return {str(i): tool for i, tool in enumerate(user_tools)}
def merge_client_tools(
self, tools_dict: Dict, client_tools: List[Dict]
) -> Dict:
"""Merge client-provided tool definitions into tools_dict.
Client tools use the standard function-calling format::
[{"type": "function", "function": {"name": "get_weather",
"description": "...", "parameters": {...}}}]
They are stored in *tools_dict* with ``client_side: True`` so that
:meth:`check_pause` returns a pause signal instead of trying to
execute them server-side.
Args:
tools_dict: The mutable server tools dict (will be modified in place).
client_tools: List of tool definitions in function-calling format.
Returns:
The updated *tools_dict* (same reference, for convenience).
"""
for i, ct in enumerate(client_tools):
func = ct.get("function", ct) # tolerate bare {"name":..} too
name = func.get("name", f"clienttool{i}")
tool_id = f"ct{i}"
tools_dict[tool_id] = {
"name": name,
"client_side": True,
"actions": [
{
"name": name,
"description": func.get("description", ""),
"active": True,
"parameters": func.get("parameters", {}),
}
],
}
return tools_dict
def prepare_tools_for_llm(self, tools_dict: Dict) -> List[Dict]:
"""Convert tool configs to LLM function schemas.
Action names are kept clean for the LLM:
- Unique action names appear as-is (e.g. ``get_weather``).
- Duplicate action names get numbered suffixes (e.g. ``search_1``,
``search_2``).
A reverse mapping is stored in ``_name_to_tool`` so that tool calls
can be routed back to the correct ``(tool_id, action_name)`` without
brittle string splitting.
"""
# Pass 1: collect entries and count action name occurrences
entries: List[Tuple[str, str, Dict, bool]] = [] # (tool_id, action_name, action, is_client)
name_counts: Counter = Counter()
for tool_id, tool in tools_dict.items():
is_api = tool["name"] == "api_tool"
is_client = tool.get("client_side", False)
if is_api and "actions" not in tool.get("config", {}):
continue
if not is_api and "actions" not in tool:
continue
actions = (
tool["config"]["actions"].values()
if is_api
else tool["actions"]
)
for action in actions:
if not action.get("active", True):
continue
entries.append((tool_id, action["name"], action, is_client))
name_counts[action["name"]] += 1
# Pass 2: assign LLM-visible names and build mappings
self._name_to_tool = {}
self._tool_to_name = {}
collision_counters: Dict[str, int] = {}
all_llm_names: set = set()
result = []
for tool_id, action_name, action, is_client in entries:
if name_counts[action_name] == 1:
llm_name = action_name
else:
counter = collision_counters.get(action_name, 1)
candidate = f"{action_name}_{counter}"
# Skip if candidate collides with a unique action name
while candidate in all_llm_names or (
candidate in name_counts and name_counts[candidate] == 1
):
counter += 1
candidate = f"{action_name}_{counter}"
collision_counters[action_name] = counter + 1
llm_name = candidate
all_llm_names.add(llm_name)
self._name_to_tool[llm_name] = (tool_id, action_name)
self._tool_to_name[(tool_id, action_name)] = llm_name
if is_client:
params = action.get("parameters", {})
else:
params = self._build_tool_parameters(action)
result.append({
"type": "function",
"function": {
"name": llm_name,
"description": action.get("description", ""),
"parameters": params,
},
})
return result
def _build_tool_parameters(self, action: Dict) -> Dict:
params = {"type": "object", "properties": {}, "required": []}
for param_type in ["query_params", "headers", "body", "parameters"]:
if param_type in action and action[param_type].get("properties"):
for k, v in action[param_type]["properties"].items():
if v.get("filled_by_llm", True):
params["properties"][k] = {
key: value
for key, value in v.items()
if key not in ("filled_by_llm", "value", "required")
}
if v.get("required", False):
params["required"].append(k)
return params
def check_pause(
self, tools_dict: Dict, call, llm_class_name: str
) -> Optional[Dict]:
"""Check if a tool call requires pausing for approval or client execution.
Returns a dict describing the pending action if pause is needed, None otherwise.
"""
parser = ToolActionParser(llm_class_name, name_mapping=self._name_to_tool)
tool_id, action_name, call_args = parser.parse_args(call)
call_id = getattr(call, "id", None) or str(uuid.uuid4())
llm_name = getattr(call, "name", "")
if tool_id is None or action_name is None or tool_id not in tools_dict:
return None # Will be handled as error by execute()
tool_data = tools_dict[tool_id]
# Client-side tools
if tool_data.get("client_side"):
return {
"call_id": call_id,
"name": llm_name,
"tool_name": tool_data.get("name", "unknown"),
"tool_id": tool_id,
"action_name": action_name,
"llm_name": llm_name,
"arguments": call_args if isinstance(call_args, dict) else {},
"pause_type": "requires_client_execution",
"thought_signature": getattr(call, "thought_signature", None),
}
# Approval required
if tool_data["name"] == "api_tool":
action_data = tool_data.get("config", {}).get("actions", {}).get(
action_name, {}
)
else:
action_data = next(
(a for a in tool_data.get("actions", []) if a["name"] == action_name),
{},
)
if action_data.get("require_approval"):
return {
"call_id": call_id,
"name": llm_name,
"tool_name": tool_data.get("name", "unknown"),
"tool_id": tool_id,
"action_name": action_name,
"llm_name": llm_name,
"arguments": call_args if isinstance(call_args, dict) else {},
"pause_type": "awaiting_approval",
"thought_signature": getattr(call, "thought_signature", None),
}
return None
def execute(self, tools_dict: Dict, call, llm_class_name: str):
"""Execute a tool call. Yields status events, returns (result, call_id)."""
parser = ToolActionParser(llm_class_name, name_mapping=self._name_to_tool)
tool_id, action_name, call_args = parser.parse_args(call)
llm_name = getattr(call, "name", "unknown")
call_id = getattr(call, "id", None) or str(uuid.uuid4())
if tool_id is None or action_name is None:
error_message = f"Error: Failed to parse LLM tool call. Tool name: {llm_name}"
logger.error(
"tool_call_parse_failed",
extra={
"llm_class_name": llm_class_name,
"llm_tool_name": llm_name,
"call_id": call_id,
},
)
tool_call_data = {
"tool_name": "unknown",
"call_id": call_id,
"action_name": llm_name,
"arguments": call_args or {},
"result": f"Failed to parse tool call. Invalid tool name format: {llm_name}",
}
yield {"type": "tool_call", "data": {**tool_call_data, "status": "error"}}
self.tool_calls.append(tool_call_data)
return "Failed to parse tool call.", call_id
if tool_id not in tools_dict:
error_message = f"Error: Tool ID '{tool_id}' extracted from LLM call not found in available tools_dict. Available IDs: {list(tools_dict.keys())}"
logger.error(
"tool_id_not_found",
extra={
"tool_id": tool_id,
"llm_tool_name": llm_name,
"call_id": call_id,
"available_tool_count": len(tools_dict),
},
)
tool_call_data = {
"tool_name": "unknown",
"call_id": call_id,
"action_name": llm_name,
"arguments": call_args,
"result": f"Tool with ID {tool_id} not found. Available tools: {list(tools_dict.keys())}",
}
yield {"type": "tool_call", "data": {**tool_call_data, "status": "error"}}
self.tool_calls.append(tool_call_data)
return f"Tool with ID {tool_id} not found.", call_id
tool_call_data = {
"tool_name": tools_dict[tool_id]["name"],
"call_id": call_id,
"action_name": llm_name,
"arguments": call_args,
}
yield {"type": "tool_call", "data": {**tool_call_data, "status": "pending"}}
tool_data = tools_dict[tool_id]
action_data = (
tool_data["config"]["actions"][action_name]
if tool_data["name"] == "api_tool"
else next(
action
for action in tool_data["actions"]
if action["name"] == action_name
)
)
query_params, headers, body, parameters = {}, {}, {}, {}
param_types = {
"query_params": query_params,
"headers": headers,
"body": body,
"parameters": parameters,
}
for param_type, target_dict in param_types.items():
if param_type in action_data and action_data[param_type].get("properties"):
for param, details in action_data[param_type]["properties"].items():
if (
param not in call_args
and "value" in details
and details["value"]
):
target_dict[param] = details["value"]
for param, value in call_args.items():
for param_type, target_dict in param_types.items():
if param_type in action_data and param in action_data[param_type].get(
"properties", {}
):
target_dict[param] = value
# Load tool (with caching)
tool = self._get_or_load_tool(
tool_data, tool_id, action_name,
headers=headers, query_params=query_params,
)
if tool is None:
error_message = (
f"Failed to load tool '{tool_data.get('name')}' (tool_id key={tool_id}): "
"missing 'id' on tool row."
)
logger.error(
"tool_load_failed",
extra={
"tool_name": tool_data.get("name"),
"tool_id": tool_id,
"action_name": action_name,
"call_id": call_id,
},
)
tool_call_data["result"] = error_message
yield {"type": "tool_call", "data": {**tool_call_data, "status": "error"}}
self.tool_calls.append(tool_call_data)
return error_message, call_id
resolved_arguments = (
{"query_params": query_params, "headers": headers, "body": body}
if tool_data["name"] == "api_tool"
else parameters
)
if tool_data["name"] == "api_tool":
logger.debug(
f"Executing api: {action_name} with query_params: {query_params}, headers: {headers}, body: {body}"
)
result = tool.execute_action(action_name, **body)
else:
logger.debug(f"Executing tool: {action_name} with args: {call_args}")
result = tool.execute_action(action_name, **parameters)
get_artifact_id = (
getattr(tool, "get_artifact_id", None)
if tool_data["name"] != "api_tool"
else None
)
artifact_id = None
if callable(get_artifact_id):
try:
artifact_id = get_artifact_id(action_name, **parameters)
except Exception:
logger.exception(
"Failed to extract artifact_id from tool %s for action %s",
tool_data["name"],
action_name,
)
artifact_id = str(artifact_id).strip() if artifact_id is not None else ""
if artifact_id:
tool_call_data["artifact_id"] = artifact_id
result_full = str(result)
tool_call_data["resolved_arguments"] = resolved_arguments
tool_call_data["result_full"] = result_full
tool_call_data["result"] = (
f"{result_full[:50]}..." if len(result_full) > 50 else result_full
)
stream_tool_call_data = {
key: value
for key, value in tool_call_data.items()
if key not in {"result_full", "resolved_arguments"}
}
yield {"type": "tool_call", "data": {**stream_tool_call_data, "status": "completed"}}
self.tool_calls.append(tool_call_data)
return result, call_id
def _get_or_load_tool(
self, tool_data: Dict, tool_id: str, action_name: str,
headers: Optional[Dict] = None, query_params: Optional[Dict] = None,
):
"""Load a tool, using cache when possible."""
cache_key = f"{tool_data['name']}:{tool_id}:{self.user or ''}"
if cache_key in self._loaded_tools:
return self._loaded_tools[cache_key]
tm = ToolManager(config={})
if tool_data["name"] == "api_tool":
action_config = tool_data["config"]["actions"][action_name]
tool_config = {
"url": action_config["url"],
"method": action_config["method"],
"headers": headers or {},
"query_params": query_params or {},
}
if "body_content_type" in action_config:
tool_config["body_content_type"] = action_config.get(
"body_content_type", "application/json"
)
tool_config["body_encoding_rules"] = action_config.get(
"body_encoding_rules", {}
)
else:
tool_config = tool_data["config"].copy() if tool_data["config"] else {}
if tool_config.get("encrypted_credentials") and self.user:
decrypted = decrypt_credentials(
tool_config["encrypted_credentials"], self.user
)
tool_config.update(decrypted)
tool_config["auth_credentials"] = decrypted
tool_config.pop("encrypted_credentials", None)
row_id = tool_data.get("id")
if not row_id:
logger.error(
"tool_missing_row_id",
extra={
"tool_name": tool_data.get("name"),
"tool_id": tool_id,
"action_name": action_name,
},
)
return None
tool_config["tool_id"] = str(row_id)
if self.conversation_id:
tool_config["conversation_id"] = self.conversation_id
if tool_data["name"] == "mcp_tool":
tool_config["query_mode"] = True
tool = tm.load_tool(
tool_data["name"],
tool_config=tool_config,
user_id=self.user,
)
# Don't cache api_tool since config varies by action
if tool_data["name"] != "api_tool":
self._loaded_tools[cache_key] = tool
return tool
def get_truncated_tool_calls(self) -> List[Dict]:
return [
{
"tool_name": tool_call.get("tool_name"),
"call_id": tool_call.get("call_id"),
"action_name": tool_call.get("action_name"),
"arguments": tool_call.get("arguments"),
"artifact_id": tool_call.get("artifact_id"),
"result": (
f"{str(tool_call['result'])[:50]}..."
if len(str(tool_call["result"])) > 50
else tool_call["result"]
),
"status": "completed",
}
for tool_call in self.tool_calls
]

View File

@@ -1,323 +0,0 @@
import base64
import json
import logging
from enum import Enum
from typing import Any, Dict, Optional, Union
from urllib.parse import quote, urlencode
logger = logging.getLogger(__name__)
class ContentType(str, Enum):
"""Supported content types for request bodies."""
JSON = "application/json"
FORM_URLENCODED = "application/x-www-form-urlencoded"
MULTIPART_FORM_DATA = "multipart/form-data"
TEXT_PLAIN = "text/plain"
XML = "application/xml"
OCTET_STREAM = "application/octet-stream"
class RequestBodySerializer:
"""Serializes request bodies according to content-type and OpenAPI 3.1 spec."""
@staticmethod
def serialize(
body_data: Dict[str, Any],
content_type: str = ContentType.JSON,
encoding_rules: Optional[Dict[str, Dict[str, Any]]] = None,
) -> tuple[Union[str, bytes], Dict[str, str]]:
"""
Serialize body data to appropriate format.
Args:
body_data: Dictionary of body parameters
content_type: Content-Type header value
encoding_rules: OpenAPI Encoding Object rules per field
Returns:
Tuple of (serialized_body, updated_headers_dict)
Raises:
ValueError: If serialization fails
"""
if not body_data:
return None, {}
try:
content_type_lower = content_type.lower().split(";")[0].strip()
if content_type_lower == ContentType.JSON:
return RequestBodySerializer._serialize_json(body_data)
elif content_type_lower == ContentType.FORM_URLENCODED:
return RequestBodySerializer._serialize_form_urlencoded(
body_data, encoding_rules
)
elif content_type_lower == ContentType.MULTIPART_FORM_DATA:
return RequestBodySerializer._serialize_multipart_form_data(
body_data, encoding_rules
)
elif content_type_lower == ContentType.TEXT_PLAIN:
return RequestBodySerializer._serialize_text_plain(body_data)
elif content_type_lower == ContentType.XML:
return RequestBodySerializer._serialize_xml(body_data)
elif content_type_lower == ContentType.OCTET_STREAM:
return RequestBodySerializer._serialize_octet_stream(body_data)
else:
logger.warning(
f"Unknown content type: {content_type}, treating as JSON"
)
return RequestBodySerializer._serialize_json(body_data)
except Exception as e:
logger.error(f"Error serializing body: {str(e)}", exc_info=True)
raise ValueError(f"Failed to serialize request body: {str(e)}")
@staticmethod
def _serialize_json(body_data: Dict[str, Any]) -> tuple[str, Dict[str, str]]:
"""Serialize body as JSON per OpenAPI spec."""
try:
serialized = json.dumps(
body_data, separators=(",", ":"), ensure_ascii=False
)
headers = {"Content-Type": ContentType.JSON.value}
return serialized, headers
except (TypeError, ValueError) as e:
raise ValueError(f"Failed to serialize JSON body: {str(e)}")
@staticmethod
def _serialize_form_urlencoded(
body_data: Dict[str, Any],
encoding_rules: Optional[Dict[str, Dict[str, Any]]] = None,
) -> tuple[str, Dict[str, str]]:
"""Serialize body as application/x-www-form-urlencoded per RFC1866/RFC3986."""
encoding_rules = encoding_rules or {}
params = []
for key, value in body_data.items():
if value is None:
continue
rule = encoding_rules.get(key, {})
style = rule.get("style", "form")
explode = rule.get("explode", style == "form")
content_type = rule.get("contentType", "text/plain")
serialized_value = RequestBodySerializer._serialize_form_value(
value, style, explode, content_type, key
)
if isinstance(serialized_value, list):
for sv in serialized_value:
params.append((key, sv))
else:
params.append((key, serialized_value))
# Use standard urlencode (replaces space with +)
serialized = urlencode(params, safe="")
headers = {"Content-Type": ContentType.FORM_URLENCODED.value}
return serialized, headers
@staticmethod
def _serialize_form_value(
value: Any, style: str, explode: bool, content_type: str, key: str
) -> Union[str, list]:
"""Serialize individual form value with encoding rules."""
if isinstance(value, dict):
if content_type == "application/json":
return json.dumps(value, separators=(",", ":"))
elif content_type == "application/xml":
return RequestBodySerializer._dict_to_xml(value)
else:
if style == "deepObject" and explode:
return [
f"{RequestBodySerializer._percent_encode(str(v))}"
for v in value.values()
]
elif explode:
return [
f"{RequestBodySerializer._percent_encode(str(v))}"
for v in value.values()
]
else:
pairs = [f"{k},{v}" for k, v in value.items()]
return RequestBodySerializer._percent_encode(",".join(pairs))
elif isinstance(value, (list, tuple)):
if explode:
return [
RequestBodySerializer._percent_encode(str(item)) for item in value
]
else:
return RequestBodySerializer._percent_encode(
",".join(str(v) for v in value)
)
else:
return RequestBodySerializer._percent_encode(str(value))
@staticmethod
def _serialize_multipart_form_data(
body_data: Dict[str, Any],
encoding_rules: Optional[Dict[str, Dict[str, Any]]] = None,
) -> tuple[bytes, Dict[str, str]]:
"""
Serialize body as multipart/form-data per RFC7578.
Supports file uploads and encoding rules.
"""
import secrets
encoding_rules = encoding_rules or {}
boundary = f"----DocsGPT{secrets.token_hex(16)}"
parts = []
for key, value in body_data.items():
if value is None:
continue
rule = encoding_rules.get(key, {})
content_type = rule.get("contentType", "text/plain")
headers_rule = rule.get("headers", {})
part = RequestBodySerializer._create_multipart_part(
key, value, content_type, headers_rule
)
parts.append(part)
body_bytes = f"--{boundary}\r\n".encode("utf-8")
body_bytes += f"--{boundary}\r\n".join(parts).encode("utf-8")
body_bytes += f"\r\n--{boundary}--\r\n".encode("utf-8")
headers = {
"Content-Type": f"multipart/form-data; boundary={boundary}",
}
return body_bytes, headers
@staticmethod
def _create_multipart_part(
name: str, value: Any, content_type: str, headers_rule: Dict[str, Any]
) -> str:
"""Create a single multipart/form-data part."""
headers = [
f'Content-Disposition: form-data; name="{RequestBodySerializer._percent_encode(name)}"'
]
if isinstance(value, bytes):
if content_type == "application/octet-stream":
value_encoded = base64.b64encode(value).decode("utf-8")
else:
value_encoded = value.decode("utf-8", errors="replace")
headers.append(f"Content-Type: {content_type}")
headers.append("Content-Transfer-Encoding: base64")
elif isinstance(value, dict):
if content_type == "application/json":
value_encoded = json.dumps(value, separators=(",", ":"))
elif content_type == "application/xml":
value_encoded = RequestBodySerializer._dict_to_xml(value)
else:
value_encoded = str(value)
headers.append(f"Content-Type: {content_type}")
elif isinstance(value, str) and content_type != "text/plain":
try:
if content_type == "application/json":
json.loads(value)
value_encoded = value
elif content_type == "application/xml":
value_encoded = value
else:
value_encoded = str(value)
except json.JSONDecodeError:
value_encoded = str(value)
headers.append(f"Content-Type: {content_type}")
else:
value_encoded = str(value)
if content_type != "text/plain":
headers.append(f"Content-Type: {content_type}")
part = "\r\n".join(headers) + "\r\n\r\n" + value_encoded + "\r\n"
return part
@staticmethod
def _serialize_text_plain(body_data: Dict[str, Any]) -> tuple[str, Dict[str, str]]:
"""Serialize body as plain text."""
if len(body_data) == 1:
value = list(body_data.values())[0]
return str(value), {"Content-Type": ContentType.TEXT_PLAIN.value}
else:
text = "\n".join(f"{k}: {v}" for k, v in body_data.items())
return text, {"Content-Type": ContentType.TEXT_PLAIN.value}
@staticmethod
def _serialize_xml(body_data: Dict[str, Any]) -> tuple[str, Dict[str, str]]:
"""Serialize body as XML."""
xml_str = RequestBodySerializer._dict_to_xml(body_data)
return xml_str, {"Content-Type": ContentType.XML.value}
@staticmethod
def _serialize_octet_stream(
body_data: Dict[str, Any],
) -> tuple[bytes, Dict[str, str]]:
"""Serialize body as binary octet stream."""
if isinstance(body_data, bytes):
return body_data, {"Content-Type": ContentType.OCTET_STREAM.value}
elif isinstance(body_data, str):
return body_data.encode("utf-8"), {
"Content-Type": ContentType.OCTET_STREAM.value
}
else:
serialized = json.dumps(body_data)
return serialized.encode("utf-8"), {
"Content-Type": ContentType.OCTET_STREAM.value
}
@staticmethod
def _percent_encode(value: str, safe_chars: str = "") -> str:
"""
Percent-encode per RFC3986.
Args:
value: String to encode
safe_chars: Additional characters to not encode
"""
return quote(value, safe=safe_chars)
@staticmethod
def _dict_to_xml(data: Dict[str, Any], root_name: str = "root") -> str:
"""
Convert dict to simple XML format.
"""
def build_xml(obj: Any, name: str) -> str:
if isinstance(obj, dict):
inner = "".join(build_xml(v, k) for k, v in obj.items())
return f"<{name}>{inner}</{name}>"
elif isinstance(obj, (list, tuple)):
items = "".join(
build_xml(item, f"{name[:-1] if name.endswith('s') else name}")
for item in obj
)
return items
else:
return f"<{name}>{RequestBodySerializer._escape_xml(str(obj))}</{name}>"
root = build_xml(data, root_name)
return f'<?xml version="1.0" encoding="UTF-8"?>{root}'
@staticmethod
def _escape_xml(value: str) -> str:
"""Escape XML special characters."""
return (
value.replace("&", "&amp;")
.replace("<", "&lt;")
.replace(">", "&gt;")
.replace('"', "&quot;")
.replace("'", "&apos;")
)

View File

@@ -1,280 +1,72 @@
import json
import logging
import re
from typing import Any, Dict, Optional
from urllib.parse import urlencode
import requests
from application.agents.tools.api_body_serializer import (
ContentType,
RequestBodySerializer,
)
from application.agents.tools.base import Tool
from application.core.url_validation import validate_url, SSRFError
logger = logging.getLogger(__name__)
DEFAULT_TIMEOUT = 90 # seconds
class APITool(Tool):
"""
API Tool
A flexible tool for performing various API actions (e.g., sending messages, retrieving data) via custom user-specified APIs.
A flexible tool for performing various API actions (e.g., sending messages, retrieving data) via custom user-specified APIs
"""
def __init__(self, config):
self.config = config
self.url = config.get("url", "")
self.method = config.get("method", "GET")
self.headers = config.get("headers", {})
self.headers = config.get("headers", {"Content-Type": "application/json"})
self.query_params = config.get("query_params", {})
self.body_content_type = config.get("body_content_type", ContentType.JSON)
self.body_encoding_rules = config.get("body_encoding_rules", {})
def execute_action(self, action_name, **kwargs):
"""Execute an API action with the given arguments."""
return self._make_api_call(
self.url,
self.method,
self.headers,
self.query_params,
kwargs,
self.body_content_type,
self.body_encoding_rules,
self.url, self.method, self.headers, self.query_params, kwargs
)
def _make_api_call(
self,
url: str,
method: str,
headers: Dict[str, str],
query_params: Dict[str, Any],
body: Dict[str, Any],
content_type: str = ContentType.JSON,
encoding_rules: Optional[Dict[str, Dict[str, Any]]] = None,
) -> Dict[str, Any]:
"""
Make an API call with proper body serialization and error handling.
Args:
url: API endpoint URL
method: HTTP method (GET, POST, PUT, DELETE, PATCH, HEAD, OPTIONS)
headers: Request headers dict
query_params: URL query parameters
body: Request body as dict
content_type: Content-Type for serialization
encoding_rules: OpenAPI encoding rules
Returns:
Dict with status_code, data, and message
"""
request_url = url
request_headers = headers.copy() if headers else {}
response = None
# Validate URL to prevent SSRF attacks
def _make_api_call(self, url, method, headers, query_params, body):
if query_params:
url = f"{url}?{requests.compat.urlencode(query_params)}"
# if isinstance(body, dict):
# body = json.dumps(body)
try:
validate_url(request_url)
except SSRFError as e:
logger.error(f"URL validation failed: {e}")
return {
"status_code": None,
"message": f"URL validation error: {e}",
"data": None,
}
try:
path_params_used = set()
if query_params:
for match in re.finditer(r"\{([^}]+)\}", request_url):
param_name = match.group(1)
if param_name in query_params:
request_url = request_url.replace(
f"{{{param_name}}}", str(query_params[param_name])
)
path_params_used.add(param_name)
remaining_params = {
k: v for k, v in query_params.items() if k not in path_params_used
}
if remaining_params:
query_string = urlencode(remaining_params)
separator = "&" if "?" in request_url else "?"
request_url = f"{request_url}{separator}{query_string}"
# Re-validate URL after parameter substitution to prevent SSRF via path params
try:
validate_url(request_url)
except SSRFError as e:
logger.error(f"URL validation failed after parameter substitution: {e}")
return {
"status_code": None,
"message": f"URL validation error: {e}",
"data": None,
}
# Serialize body based on content type
if body and body != {}:
try:
serialized_body, body_headers = RequestBodySerializer.serialize(
body, content_type, encoding_rules
)
request_headers.update(body_headers)
except ValueError as e:
logger.error(f"Body serialization failed: {str(e)}")
return {
"status_code": None,
"message": f"Body serialization error: {str(e)}",
"data": None,
}
else:
serialized_body = None
if "Content-Type" not in request_headers and method not in [
"GET",
"HEAD",
"DELETE",
]:
request_headers["Content-Type"] = ContentType.JSON
logger.debug(
f"API Call: {method} {request_url} | Content-Type: {request_headers.get('Content-Type', 'N/A')}"
)
if method.upper() == "GET":
response = requests.get(
request_url, headers=request_headers, timeout=DEFAULT_TIMEOUT
)
elif method.upper() == "POST":
response = requests.post(
request_url,
data=serialized_body,
headers=request_headers,
timeout=DEFAULT_TIMEOUT,
)
elif method.upper() == "PUT":
response = requests.put(
request_url,
data=serialized_body,
headers=request_headers,
timeout=DEFAULT_TIMEOUT,
)
elif method.upper() == "DELETE":
response = requests.delete(
request_url, headers=request_headers, timeout=DEFAULT_TIMEOUT
)
elif method.upper() == "PATCH":
response = requests.patch(
request_url,
data=serialized_body,
headers=request_headers,
timeout=DEFAULT_TIMEOUT,
)
elif method.upper() == "HEAD":
response = requests.head(
request_url, headers=request_headers, timeout=DEFAULT_TIMEOUT
)
elif method.upper() == "OPTIONS":
response = requests.options(
request_url, headers=request_headers, timeout=DEFAULT_TIMEOUT
)
else:
return {
"status_code": None,
"message": f"Unsupported HTTP method: {method}",
"data": None,
}
print(f"Making API call: {method} {url} with body: {body}")
if body == "{}":
body = None
response = requests.request(method, url, headers=headers, data=body)
response.raise_for_status()
data = self._parse_response(response)
content_type = response.headers.get(
"Content-Type", "application/json"
).lower()
if "application/json" in content_type:
try:
data = response.json()
except json.JSONDecodeError as e:
print(f"Error decoding JSON: {e}. Raw response: {response.text}")
return {
"status_code": response.status_code,
"message": f"API call returned invalid JSON. Error: {e}",
"data": response.text,
}
elif "text/" in content_type or "application/xml" in content_type:
data = response.text
elif not response.content:
data = None
else:
print(f"Unsupported content type: {content_type}")
data = response.content
return {
"status_code": response.status_code,
"data": data,
"message": "API call successful.",
}
except requests.exceptions.Timeout:
logger.error(f"Request timeout for {request_url}")
return {
"status_code": None,
"message": f"Request timeout ({DEFAULT_TIMEOUT}s exceeded)",
"data": None,
}
except requests.exceptions.ConnectionError as e:
logger.error(f"Connection error: {str(e)}")
return {
"status_code": None,
"message": f"Connection error: {str(e)}",
"data": None,
}
except requests.exceptions.HTTPError as e:
logger.error(f"HTTP error {response.status_code}: {str(e)}")
try:
error_data = response.json()
except (json.JSONDecodeError, ValueError):
error_data = response.text
return {
"status_code": response.status_code,
"message": f"HTTP Error {response.status_code}",
"data": error_data,
}
except requests.exceptions.RequestException as e:
logger.error(f"Request failed: {str(e)}")
return {
"status_code": response.status_code if response else None,
"message": f"API call failed: {str(e)}",
"data": None,
}
except Exception as e:
logger.error(f"Unexpected error in API call: {str(e)}", exc_info=True)
return {
"status_code": None,
"message": f"Unexpected error: {str(e)}",
"data": None,
}
def _parse_response(self, response: requests.Response) -> Any:
"""
Parse response based on Content-Type header.
Supports: JSON, XML, plain text, binary data.
"""
content_type = response.headers.get("Content-Type", "").lower()
if not response.content:
return None
# JSON response
if "application/json" in content_type:
try:
return response.json()
except json.JSONDecodeError as e:
logger.warning(f"Failed to parse JSON response: {str(e)}")
return response.text
# XML response
elif "application/xml" in content_type or "text/xml" in content_type:
return response.text
# Plain text response
elif "text/plain" in content_type or "text/html" in content_type:
return response.text
# Binary/unknown response
else:
# Try to decode as text first, fall back to base64
try:
return response.text
except (UnicodeDecodeError, AttributeError):
import base64
return base64.b64encode(response.content).decode("utf-8")
def get_actions_metadata(self):
"""Return metadata for available actions (none for API Tool - actions are user-defined)."""
return []
def get_config_requirements(self):
"""Return configuration requirements for the tool."""
return {}

View File

@@ -2,8 +2,6 @@ from abc import ABC, abstractmethod
class Tool(ABC):
internal: bool = False
@abstractmethod
def execute_action(self, action_name: str, **kwargs):
pass

View File

@@ -1,11 +1,6 @@
import logging
import requests
from application.agents.tools.base import Tool
logger = logging.getLogger(__name__)
class BraveSearchTool(Tool):
"""
@@ -46,7 +41,7 @@ class BraveSearchTool(Tool):
"""
Performs a web search using the Brave Search API.
"""
logger.debug("Performing Brave web search for: %s", query)
print(f"Performing Brave web search for: {query}")
url = f"{self.base_url}/web/search"
@@ -73,7 +68,7 @@ class BraveSearchTool(Tool):
"X-Subscription-Token": self.token,
}
response = requests.get(url, params=params, headers=headers, timeout=100)
response = requests.get(url, params=params, headers=headers)
if response.status_code == 200:
return {
@@ -99,7 +94,7 @@ class BraveSearchTool(Tool):
"""
Performs an image search using the Brave Search API.
"""
logger.debug("Performing Brave image search for: %s", query)
print(f"Performing Brave image search for: {query}")
url = f"{self.base_url}/images/search"
@@ -118,7 +113,7 @@ class BraveSearchTool(Tool):
"X-Subscription-Token": self.token,
}
response = requests.get(url, params=params, headers=headers, timeout=100)
response = requests.get(url, params=params, headers=headers)
if response.status_code == 200:
return {
@@ -182,10 +177,6 @@ class BraveSearchTool(Tool):
return {
"token": {
"type": "string",
"label": "API Key",
"description": "Brave Search API key for authentication",
"required": True,
"secret": True,
"order": 1,
},
}

View File

@@ -28,7 +28,7 @@ class CryptoPriceTool(Tool):
returns price in USD.
"""
url = f"https://min-api.cryptocompare.com/data/price?fsym={symbol.upper()}&tsyms={currency.upper()}"
response = requests.get(url, timeout=100)
response = requests.get(url)
if response.status_code == 200:
data = response.json()
if currency.upper() in data:

View File

@@ -1,14 +1,5 @@
import logging
import time
from typing import Any, Dict, Optional
from application.agents.tools.base import Tool
logger = logging.getLogger(__name__)
MAX_RETRIES = 3
RETRY_DELAY = 2.0
DEFAULT_TIMEOUT = 15
from duckduckgo_search import DDGS
class DuckDuckGoSearchTool(Tool):
@@ -19,123 +10,71 @@ class DuckDuckGoSearchTool(Tool):
def __init__(self, config):
self.config = config
self.timeout = config.get("timeout", DEFAULT_TIMEOUT)
def _get_ddgs_client(self):
from ddgs import DDGS
return DDGS(timeout=self.timeout)
def _execute_with_retry(self, operation, operation_name: str) -> Dict[str, Any]:
last_error = None
for attempt in range(1, MAX_RETRIES + 1):
try:
results = operation()
return {
"status_code": 200,
"results": list(results) if results else [],
"message": f"{operation_name} completed successfully.",
}
except Exception as e:
last_error = e
error_str = str(e).lower()
if "ratelimit" in error_str or "429" in error_str:
if attempt < MAX_RETRIES:
delay = RETRY_DELAY * attempt
logger.warning(
f"{operation_name} rate limited, retrying in {delay}s (attempt {attempt}/{MAX_RETRIES})"
)
time.sleep(delay)
continue
logger.error(f"{operation_name} failed: {e}")
break
return {
"status_code": 500,
"results": [],
"message": f"{operation_name} failed: {str(last_error)}",
}
def execute_action(self, action_name, **kwargs):
actions = {
"ddg_web_search": self._web_search,
"ddg_image_search": self._image_search,
"ddg_news_search": self._news_search,
}
if action_name not in actions:
if action_name in actions:
return actions[action_name](**kwargs)
else:
raise ValueError(f"Unknown action: {action_name}")
return actions[action_name](**kwargs)
def _web_search(
self,
query: str,
max_results: int = 5,
region: str = "wt-wt",
safesearch: str = "moderate",
timelimit: Optional[str] = None,
) -> Dict[str, Any]:
logger.info(f"DuckDuckGo web search: {query}")
query,
max_results=5,
):
print(f"Performing DuckDuckGo web search for: {query}")
def operation():
client = self._get_ddgs_client()
return client.text(
try:
results = DDGS().text(
query,
region=region,
safesearch=safesearch,
timelimit=timelimit,
max_results=min(max_results, 20),
max_results=max_results,
)
return self._execute_with_retry(operation, "Web search")
return {
"status_code": 200,
"results": results,
"message": "Web search completed successfully.",
}
except Exception as e:
return {
"status_code": 500,
"message": f"Web search failed: {str(e)}",
}
def _image_search(
self,
query: str,
max_results: int = 5,
region: str = "wt-wt",
safesearch: str = "moderate",
timelimit: Optional[str] = None,
) -> Dict[str, Any]:
logger.info(f"DuckDuckGo image search: {query}")
query,
max_results=5,
):
print(f"Performing DuckDuckGo image search for: {query}")
def operation():
client = self._get_ddgs_client()
return client.images(
query,
region=region,
safesearch=safesearch,
timelimit=timelimit,
max_results=min(max_results, 50),
try:
results = DDGS().images(
keywords=query,
max_results=max_results,
)
return self._execute_with_retry(operation, "Image search")
def _news_search(
self,
query: str,
max_results: int = 5,
region: str = "wt-wt",
safesearch: str = "moderate",
timelimit: Optional[str] = None,
) -> Dict[str, Any]:
logger.info(f"DuckDuckGo news search: {query}")
def operation():
client = self._get_ddgs_client()
return client.news(
query,
region=region,
safesearch=safesearch,
timelimit=timelimit,
max_results=min(max_results, 20),
)
return self._execute_with_retry(operation, "News search")
return {
"status_code": 200,
"results": results,
"message": "Image search completed successfully.",
}
except Exception as e:
return {
"status_code": 500,
"message": f"Image search failed: {str(e)}",
}
def get_actions_metadata(self):
return [
{
"name": "ddg_web_search",
"description": "Search the web using DuckDuckGo. Returns titles, URLs, and snippets.",
"description": "Perform a web search using DuckDuckGo.",
"parameters": {
"type": "object",
"properties": {
@@ -145,15 +84,7 @@ class DuckDuckGoSearchTool(Tool):
},
"max_results": {
"type": "integer",
"description": "Number of results (default: 5, max: 20)",
},
"region": {
"type": "string",
"description": "Region code (default: wt-wt for worldwide, us-en for US)",
},
"timelimit": {
"type": "string",
"description": "Time filter: d (day), w (week), m (month), y (year)",
"description": "Number of results to return (default: 5)",
},
},
"required": ["query"],
@@ -161,43 +92,17 @@ class DuckDuckGoSearchTool(Tool):
},
{
"name": "ddg_image_search",
"description": "Search for images using DuckDuckGo. Returns image URLs and metadata.",
"description": "Perform an image search using DuckDuckGo.",
"parameters": {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "Image search query",
"description": "Search query",
},
"max_results": {
"type": "integer",
"description": "Number of results (default: 5, max: 50)",
},
"region": {
"type": "string",
"description": "Region code (default: wt-wt for worldwide)",
},
},
"required": ["query"],
},
},
{
"name": "ddg_news_search",
"description": "Search for news articles using DuckDuckGo. Returns recent news.",
"parameters": {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "News search query",
},
"max_results": {
"type": "integer",
"description": "Number of results (default: 5, max: 20)",
},
"timelimit": {
"type": "string",
"description": "Time filter: d (day), w (week), m (month)",
"description": "Number of results to return (default: 5, max: 50)",
},
},
"required": ["query"],

View File

@@ -1,461 +0,0 @@
import json
import logging
from typing import Dict, List, Optional
from application.agents.tools.base import Tool
from application.core.settings import settings
from application.retriever.retriever_creator import RetrieverCreator
logger = logging.getLogger(__name__)
class InternalSearchTool(Tool):
"""Wraps the ClassicRAG retriever as an LLM-callable tool.
Instead of pre-fetching docs into the prompt, the LLM decides
when and what to search. Supports multiple searches per session.
Optional capabilities (enabled when sources have directory_structure):
- path_filter on search: restrict results to a specific file/folder
- list_files action: browse the file/folder structure
"""
internal = True
def __init__(self, config: Dict):
self.config = config
self.retrieved_docs: List[Dict] = []
self._retriever = None
self._directory_structure: Optional[Dict] = None
self._dir_structure_loaded = False
def _get_retriever(self):
if self._retriever is None:
self._retriever = RetrieverCreator.create_retriever(
self.config.get("retriever_name", "classic"),
source=self.config.get("source", {}),
chat_history=[],
prompt="",
chunks=int(self.config.get("chunks", 2)),
doc_token_limit=int(self.config.get("doc_token_limit", 50000)),
model_id=self.config.get("model_id", "docsgpt-local"),
model_user_id=self.config.get("model_user_id"),
user_api_key=self.config.get("user_api_key"),
agent_id=self.config.get("agent_id"),
llm_name=self.config.get("llm_name", settings.LLM_PROVIDER),
api_key=self.config.get("api_key", settings.API_KEY),
decoded_token=self.config.get("decoded_token"),
)
return self._retriever
def _get_directory_structure(self) -> Optional[Dict]:
"""Load directory structure from Postgres for the configured sources."""
if self._dir_structure_loaded:
return self._directory_structure
self._dir_structure_loaded = True
source = self.config.get("source", {})
active_docs = source.get("active_docs", [])
if not active_docs:
return None
try:
# Per-operation session: this tool runs inside the answer
# generator hot path, so we open a short-lived read
# connection for the batch lookup and release immediately.
from application.storage.db.repositories.sources import (
SourcesRepository,
)
from application.storage.db.session import db_readonly
if isinstance(active_docs, str):
active_docs = [active_docs]
decoded_token = self.config.get("decoded_token") or {}
user_id = decoded_token.get("sub") if decoded_token else None
merged_structure = {}
with db_readonly() as conn:
repo = SourcesRepository(conn)
for doc_id in active_docs:
try:
source_doc = repo.get_any(str(doc_id), user_id) if user_id else None
if not source_doc:
continue
dir_str = source_doc.get("directory_structure")
if dir_str:
if isinstance(dir_str, str):
dir_str = json.loads(dir_str)
source_name = source_doc.get("name", doc_id)
if len(active_docs) > 1:
merged_structure[source_name] = dir_str
else:
merged_structure = dir_str
except Exception as e:
logger.debug(f"Could not load dir structure for {doc_id}: {e}")
self._directory_structure = merged_structure if merged_structure else None
except Exception as e:
logger.debug(f"Failed to load directory structures: {e}")
return self._directory_structure
def execute_action(self, action_name: str, **kwargs):
if action_name == "search":
return self._execute_search(**kwargs)
elif action_name == "list_files":
return self._execute_list_files(**kwargs)
return f"Unknown action: {action_name}"
def _execute_search(self, **kwargs) -> str:
query = kwargs.get("query", "")
path_filter = kwargs.get("path_filter", "")
if not query:
return "Error: 'query' parameter is required."
try:
retriever = self._get_retriever()
docs = retriever.search(query)
except Exception as e:
logger.error(f"Internal search failed: {e}", exc_info=True)
return "Search failed: an internal error occurred."
if not docs:
return "No documents found matching your query."
# Apply path filter if specified
if path_filter:
path_lower = path_filter.lower()
docs = [
d
for d in docs
if path_lower in d.get("source", "").lower()
or path_lower in d.get("filename", "").lower()
or path_lower in d.get("title", "").lower()
]
if not docs:
return f"No documents found matching query '{query}' in path '{path_filter}'."
# Accumulate for source tracking
for doc in docs:
if doc not in self.retrieved_docs:
self.retrieved_docs.append(doc)
# Format results for the LLM
formatted = []
for i, doc in enumerate(docs, 1):
title = doc.get("title", "Untitled")
text = doc.get("text", "")
source = doc.get("source", "Unknown")
filename = doc.get("filename", "")
header = filename or title
formatted.append(f"[{i}] {header} (source: {source})\n{text}")
return "\n\n---\n\n".join(formatted)
def _execute_list_files(self, **kwargs) -> str:
path = kwargs.get("path", "")
dir_structure = self._get_directory_structure()
if not dir_structure:
return "No file structure available for the current sources."
# Navigate to the requested path
current = dir_structure
if path:
for part in path.strip("/").split("/"):
if not part:
continue
if isinstance(current, dict) and part in current:
current = current[part]
else:
return f"Path '{path}' not found in the file structure."
# Format the structure for the LLM
return self._format_structure(current, path or "/")
def _format_structure(self, node: Dict, current_path: str) -> str:
if not isinstance(node, dict):
return f"'{current_path}' is a file, not a directory."
lines = [f"File structure at '{current_path}':\n"]
folders = []
files = []
for name, value in sorted(node.items()):
if isinstance(value, dict):
# Check if it's a file metadata dict or a folder
if "type" in value or "size_bytes" in value or "token_count" in value:
# It's a file with metadata
size = value.get("token_count", "")
ftype = value.get("type", "")
info_parts = []
if ftype:
info_parts.append(ftype)
if size:
info_parts.append(f"{size} tokens")
info = f" ({', '.join(info_parts)})" if info_parts else ""
files.append(f" {name}{info}")
else:
# It's a folder
count = self._count_files(value)
folders.append(f" {name}/ ({count} items)")
else:
files.append(f" {name}")
if folders:
lines.append("Folders:")
lines.extend(folders)
if files:
lines.append("Files:")
lines.extend(files)
if not folders and not files:
lines.append(" (empty)")
return "\n".join(lines)
def _count_files(self, node: Dict) -> int:
count = 0
for value in node.values():
if isinstance(value, dict):
if "type" in value or "size_bytes" in value or "token_count" in value:
count += 1
else:
count += self._count_files(value)
else:
count += 1
return count
def get_actions_metadata(self):
actions = [
{
"name": "search",
"description": (
"Search the user's uploaded documents and knowledge base. "
"Use this to find relevant information before answering questions. "
"You can call this multiple times with different queries."
),
"parameters": {
"properties": {
"query": {
"type": "string",
"description": "The search query. Be specific and focused.",
"filled_by_llm": True,
"required": True,
},
}
},
}
]
# Add path_filter and list_files only if directory structure exists
has_structure = self.config.get("has_directory_structure", False)
if has_structure:
actions[0]["parameters"]["properties"]["path_filter"] = {
"type": "string",
"description": (
"Optional: filter results to a specific file or folder path. "
"Use list_files first to see available paths."
),
"filled_by_llm": True,
"required": False,
}
actions.append(
{
"name": "list_files",
"description": (
"Browse the file and folder structure of the knowledge base. "
"Use this to see what files are available before searching. "
"Optionally provide a path to browse a specific folder."
),
"parameters": {
"properties": {
"path": {
"type": "string",
"description": "Optional: folder path to browse. Leave empty for root.",
"filled_by_llm": True,
"required": False,
}
}
},
}
)
return actions
def get_config_requirements(self):
return {}
# Constants for building synthetic tools_dict entries
INTERNAL_TOOL_ID = "internal"
def build_internal_tool_entry(has_directory_structure: bool = False) -> Dict:
"""Build the tools_dict entry for InternalSearchTool.
Dynamically includes list_files and path_filter based on
whether the sources have directory structure.
"""
search_params = {
"properties": {
"query": {
"type": "string",
"description": "The search query. Be specific and focused.",
"filled_by_llm": True,
"required": True,
}
}
}
actions = [
{
"name": "search",
"description": (
"Search the user's uploaded documents and knowledge base. "
"Use this to find relevant information before answering questions. "
"You can call this multiple times with different queries."
),
"active": True,
"parameters": search_params,
}
]
if has_directory_structure:
search_params["properties"]["path_filter"] = {
"type": "string",
"description": (
"Optional: filter results to a specific file or folder path. "
"Use list_files first to see available paths."
),
"filled_by_llm": True,
"required": False,
}
actions.append(
{
"name": "list_files",
"description": (
"Browse the file and folder structure of the knowledge base. "
"Use this to see what files are available before searching. "
"Optionally provide a path to browse a specific folder."
),
"active": True,
"parameters": {
"properties": {
"path": {
"type": "string",
"description": "Optional: folder path to browse. Leave empty for root.",
"filled_by_llm": True,
"required": False,
}
}
},
}
)
return {"name": "internal_search", "actions": actions}
# Keep backward compat
INTERNAL_TOOL_ENTRY = build_internal_tool_entry(has_directory_structure=False)
def sources_have_directory_structure(source: Dict) -> bool:
"""Check if any of the active sources have a ``directory_structure`` row."""
active_docs = source.get("active_docs", [])
if not active_docs:
return False
try:
# TODO(pg-cutover): SourcesRepository.get_any requires ``user_id``
# scoping, but callers in the agent build path don't always
# thread the decoded token through here. Use a direct
# short-lived SQL lookup instead of the repo until the call
# sites are updated to propagate user context.
from sqlalchemy import text as _text
from application.storage.db.session import db_readonly
if isinstance(active_docs, str):
active_docs = [active_docs]
with db_readonly() as conn:
for doc_id in active_docs:
try:
value = str(doc_id)
if len(value) == 36 and "-" in value:
row = conn.execute(
_text(
"SELECT directory_structure FROM sources "
"WHERE id = CAST(:id AS uuid)"
),
{"id": value},
).fetchone()
else:
row = conn.execute(
_text(
"SELECT directory_structure FROM sources "
"WHERE legacy_mongo_id = :lid"
),
{"lid": value},
).fetchone()
if row is not None and row[0]:
return True
except Exception:
continue
except Exception as e:
logger.debug(f"Could not check directory structure: {e}")
return False
def add_internal_search_tool(tools_dict: Dict, retriever_config: Dict) -> None:
"""Add the internal search tool to tools_dict if sources are configured.
Shared by AgenticAgent and ResearchAgent to avoid duplicate setup logic.
Mutates tools_dict in place.
"""
source = retriever_config.get("source", {})
has_sources = bool(source.get("active_docs"))
if not retriever_config or not has_sources:
return
has_dir = sources_have_directory_structure(source)
internal_entry = build_internal_tool_entry(has_directory_structure=has_dir)
internal_entry["config"] = build_internal_tool_config(
**retriever_config,
has_directory_structure=has_dir,
)
tools_dict[INTERNAL_TOOL_ID] = internal_entry
def build_internal_tool_config(
source: Dict,
retriever_name: str = "classic",
chunks: int = 2,
doc_token_limit: int = 50000,
model_id: str = "docsgpt-local",
model_user_id: Optional[str] = None,
user_api_key: Optional[str] = None,
agent_id: Optional[str] = None,
llm_name: str = None,
api_key: str = None,
decoded_token: Optional[Dict] = None,
has_directory_structure: bool = False,
) -> Dict:
"""Build the config dict for InternalSearchTool."""
return {
"source": source,
"retriever_name": retriever_name,
"chunks": chunks,
"doc_token_limit": doc_token_limit,
"model_id": model_id,
"model_user_id": model_user_id,
"user_api_key": user_api_key,
"agent_id": agent_id,
"llm_name": llm_name or settings.LLM_PROVIDER,
"api_key": api_key or settings.API_KEY,
"decoded_token": decoded_token,
"has_directory_structure": has_directory_structure,
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,548 +0,0 @@
from pathlib import Path
from typing import Any, Dict, List, Optional
import logging
import uuid
from .base import Tool
from application.storage.db.repositories.memories import MemoriesRepository
from application.storage.db.session import db_readonly, db_session
logger = logging.getLogger(__name__)
class MemoryTool(Tool):
"""Memory
Stores and retrieves information across conversations through a memory file directory.
"""
def __init__(self, tool_config: Optional[Dict[str, Any]] = None, user_id: Optional[str] = None) -> None:
"""Initialize the tool.
Args:
tool_config: Optional tool configuration. Should include:
- tool_id: Unique identifier for this memory tool instance (from user_tools._id)
This ensures each user's tool configuration has isolated memories
user_id: The authenticated user's id (should come from decoded_token["sub"]).
"""
self.user_id: Optional[str] = user_id
# Get tool_id from configuration (passed from user_tools._id in production)
# In production, tool_id is the UUID string from user_tools.id.
if tool_config and "tool_id" in tool_config:
self.tool_id = tool_config["tool_id"]
elif user_id:
# Fallback for backward compatibility or testing
self.tool_id = f"default_{user_id}"
else:
# Last resort fallback (shouldn't happen in normal use)
self.tool_id = str(uuid.uuid4())
def _pg_enabled(self) -> bool:
"""Return True if this MemoryTool's tool_id is a real ``user_tools.id``.
The ``memories`` PG table has a UUID foreign key to ``user_tools``.
The sentinel ``default_{uid}`` fallback tool_id is not a UUID and
has no row in ``user_tools``, so any storage operation would fail
the foreign-key check. After the Postgres cutover Postgres is the
only store, so for the sentinel case there is nowhere to read or
write — operations become no-ops and the tool returns an
explanatory error to the caller.
"""
tool_id = getattr(self, "tool_id", None)
if not tool_id or not isinstance(tool_id, str):
return False
if tool_id.startswith("default_"):
logger.debug(
"Skipping Postgres operation for MemoryTool with sentinel tool_id=%s",
tool_id,
)
return False
from application.storage.db.base_repository import looks_like_uuid
if not looks_like_uuid(tool_id):
logger.debug(
"Skipping Postgres operation for MemoryTool with non-UUID tool_id=%s",
tool_id,
)
return False
return True
# -----------------------------
# Action implementations
# -----------------------------
def execute_action(self, action_name: str, **kwargs: Any) -> str:
"""Execute an action by name.
Args:
action_name: One of view, create, str_replace, insert, delete, rename.
**kwargs: Parameters for the action.
Returns:
A human-readable string result.
"""
if not self.user_id:
return "Error: MemoryTool requires a valid user_id."
if not self._pg_enabled():
return (
"Error: MemoryTool is not configured with a persistent tool_id; "
"memory storage is unavailable for this session."
)
if action_name == "view":
return self._view(
kwargs.get("path", "/"),
kwargs.get("view_range")
)
if action_name == "create":
return self._create(
kwargs.get("path", ""),
kwargs.get("file_text", "")
)
if action_name == "str_replace":
return self._str_replace(
kwargs.get("path", ""),
kwargs.get("old_str", ""),
kwargs.get("new_str", "")
)
if action_name == "insert":
return self._insert(
kwargs.get("path", ""),
kwargs.get("insert_line", 1),
kwargs.get("insert_text", "")
)
if action_name == "delete":
return self._delete(kwargs.get("path", ""))
if action_name == "rename":
return self._rename(
kwargs.get("old_path", ""),
kwargs.get("new_path", "")
)
return f"Unknown action: {action_name}"
def get_actions_metadata(self) -> List[Dict[str, Any]]:
"""Return JSON metadata describing supported actions for tool schemas."""
return [
{
"name": "view",
"description": "Shows directory contents or file contents with optional line ranges.",
"parameters": {
"type": "object",
"properties": {
"path": {
"type": "string",
"description": "Path to file or directory (e.g., /notes.txt or /project/ or /)."
},
"view_range": {
"type": "array",
"items": {"type": "integer"},
"description": "Optional [start_line, end_line] to view specific lines (1-indexed)."
}
},
"required": ["path"]
},
},
{
"name": "create",
"description": "Create or overwrite a file.",
"parameters": {
"type": "object",
"properties": {
"path": {
"type": "string",
"description": "File path to create (e.g., /notes.txt or /project/task.txt)."
},
"file_text": {
"type": "string",
"description": "Content to write to the file."
}
},
"required": ["path", "file_text"]
},
},
{
"name": "str_replace",
"description": "Replace text in a file.",
"parameters": {
"type": "object",
"properties": {
"path": {
"type": "string",
"description": "File path (e.g., /notes.txt)."
},
"old_str": {
"type": "string",
"description": "String to find."
},
"new_str": {
"type": "string",
"description": "String to replace with."
}
},
"required": ["path", "old_str", "new_str"]
},
},
{
"name": "insert",
"description": "Insert text at a specific line in a file.",
"parameters": {
"type": "object",
"properties": {
"path": {
"type": "string",
"description": "File path (e.g., /notes.txt)."
},
"insert_line": {
"type": "integer",
"description": "Line number to insert at (1-indexed)."
},
"insert_text": {
"type": "string",
"description": "Text to insert."
}
},
"required": ["path", "insert_line", "insert_text"]
},
},
{
"name": "delete",
"description": "Delete a file or directory.",
"parameters": {
"type": "object",
"properties": {
"path": {
"type": "string",
"description": "Path to delete (e.g., /notes.txt or /project/)."
}
},
"required": ["path"]
},
},
{
"name": "rename",
"description": "Rename or move a file/directory.",
"parameters": {
"type": "object",
"properties": {
"old_path": {
"type": "string",
"description": "Current path (e.g., /old.txt)."
},
"new_path": {
"type": "string",
"description": "New path (e.g., /new.txt)."
}
},
"required": ["old_path", "new_path"]
},
},
]
def get_config_requirements(self) -> Dict[str, Any]:
"""Return configuration requirements."""
return {}
# -----------------------------
# Path validation
# -----------------------------
def _validate_path(self, path: str) -> Optional[str]:
"""Validate and normalize path.
Args:
path: User-provided path.
Returns:
Normalized path or None if invalid.
"""
if not path:
return None
# Remove any leading/trailing whitespace
path = path.strip()
# Preserve whether path ends with / (indicates directory)
is_directory = path.endswith("/")
# Ensure path starts with / for consistency
if not path.startswith("/"):
path = "/" + path
# Check for directory traversal patterns
if ".." in path or path.count("//") > 0:
return None
# Normalize the path
try:
# Convert to Path object and resolve to canonical form
normalized = str(Path(path).as_posix())
# Ensure it still starts with /
if not normalized.startswith("/"):
return None
# Preserve trailing slash for directories
if is_directory and not normalized.endswith("/") and normalized != "/":
normalized = normalized + "/"
return normalized
except Exception:
return None
# -----------------------------
# Internal helpers
# -----------------------------
def _view(self, path: str, view_range: Optional[List[int]] = None) -> str:
"""View directory contents or file contents."""
validated_path = self._validate_path(path)
if not validated_path:
return "Error: Invalid path."
# Check if viewing directory (ends with / or is root)
if validated_path == "/" or validated_path.endswith("/"):
return self._view_directory(validated_path)
# Otherwise view file
return self._view_file(validated_path, view_range)
def _view_directory(self, path: str) -> str:
"""List files in a directory."""
# Ensure path ends with / for proper prefix matching
search_path = path if path.endswith("/") else path + "/"
with db_readonly() as conn:
docs = MemoriesRepository(conn).list_by_prefix(
self.user_id, self.tool_id, search_path
)
if not docs:
return f"Directory: {path}\n(empty)"
# Extract filenames relative to the directory
files = []
for doc in docs:
file_path = doc["path"]
# Remove the directory prefix
if file_path.startswith(search_path):
relative = file_path[len(search_path):]
if relative:
files.append(relative)
files.sort()
file_list = "\n".join(f"- {f}" for f in files)
return f"Directory: {path}\n{file_list}"
def _view_file(self, path: str, view_range: Optional[List[int]] = None) -> str:
"""View file contents with optional line range."""
with db_readonly() as conn:
doc = MemoriesRepository(conn).get_by_path(
self.user_id, self.tool_id, path
)
if not doc or not doc.get("content"):
return f"Error: File not found: {path}"
content = str(doc["content"])
# Apply view_range if specified
if view_range and len(view_range) == 2:
lines = content.split("\n")
start, end = view_range
# Convert to 0-indexed
start_idx = max(0, start - 1)
end_idx = min(len(lines), end)
if start_idx >= len(lines):
return f"Error: Line range out of bounds. File has {len(lines)} lines."
selected_lines = lines[start_idx:end_idx]
# Add line numbers (enumerate with 1-based start)
numbered_lines = [f"{i}: {line}" for i, line in enumerate(selected_lines, start=start)]
return "\n".join(numbered_lines)
return content
def _create(self, path: str, file_text: str) -> str:
"""Create or overwrite a file."""
validated_path = self._validate_path(path)
if not validated_path:
return "Error: Invalid path."
if validated_path == "/" or validated_path.endswith("/"):
return "Error: Cannot create a file at directory path."
with db_session() as conn:
MemoriesRepository(conn).upsert(
self.user_id, self.tool_id, validated_path, file_text
)
return f"File created: {validated_path}"
def _str_replace(self, path: str, old_str: str, new_str: str) -> str:
"""Replace text in a file."""
validated_path = self._validate_path(path)
if not validated_path:
return "Error: Invalid path."
if not old_str:
return "Error: old_str is required."
with db_session() as conn:
repo = MemoriesRepository(conn)
doc = repo.get_by_path(self.user_id, self.tool_id, validated_path)
if not doc or not doc.get("content"):
return f"Error: File not found: {validated_path}"
current_content = str(doc["content"])
# Check if old_str exists (case-insensitive)
if old_str.lower() not in current_content.lower():
return f"Error: String '{old_str}' not found in file."
# Case-insensitive replace
import re as regex_module
updated_content = regex_module.sub(
regex_module.escape(old_str),
new_str,
current_content,
flags=regex_module.IGNORECASE,
)
repo.upsert(self.user_id, self.tool_id, validated_path, updated_content)
return f"File updated: {validated_path}"
def _insert(self, path: str, insert_line: int, insert_text: str) -> str:
"""Insert text at a specific line."""
validated_path = self._validate_path(path)
if not validated_path:
return "Error: Invalid path."
if not insert_text:
return "Error: insert_text is required."
with db_session() as conn:
repo = MemoriesRepository(conn)
doc = repo.get_by_path(self.user_id, self.tool_id, validated_path)
if not doc or not doc.get("content"):
return f"Error: File not found: {validated_path}"
current_content = str(doc["content"])
lines = current_content.split("\n")
# Convert to 0-indexed
index = insert_line - 1
if index < 0 or index > len(lines):
return f"Error: Invalid line number. File has {len(lines)} lines."
lines.insert(index, insert_text)
updated_content = "\n".join(lines)
repo.upsert(self.user_id, self.tool_id, validated_path, updated_content)
return f"Text inserted at line {insert_line} in {validated_path}"
def _delete(self, path: str) -> str:
"""Delete a file or directory."""
validated_path = self._validate_path(path)
if not validated_path:
return "Error: Invalid path."
if validated_path == "/":
# Delete all files for this user and tool
with db_session() as conn:
deleted = MemoriesRepository(conn).delete_all(
self.user_id, self.tool_id
)
return f"Deleted {deleted} file(s) from memory."
# Check if it's a directory (ends with /)
if validated_path.endswith("/"):
with db_session() as conn:
deleted = MemoriesRepository(conn).delete_by_prefix(
self.user_id, self.tool_id, validated_path
)
return f"Deleted directory and {deleted} file(s)."
# Try as directory first (without trailing slash)
search_path = validated_path + "/"
with db_session() as conn:
repo = MemoriesRepository(conn)
directory_deleted = repo.delete_by_prefix(
self.user_id, self.tool_id, search_path
)
if directory_deleted > 0:
return f"Deleted directory and {directory_deleted} file(s)."
# Otherwise delete a single file
file_deleted = repo.delete_by_path(
self.user_id, self.tool_id, validated_path
)
if file_deleted:
return f"Deleted: {validated_path}"
return f"Error: File not found: {validated_path}"
def _rename(self, old_path: str, new_path: str) -> str:
"""Rename or move a file/directory."""
validated_old = self._validate_path(old_path)
validated_new = self._validate_path(new_path)
if not validated_old or not validated_new:
return "Error: Invalid path."
if validated_old == "/" or validated_new == "/":
return "Error: Cannot rename root directory."
# Directory rename: do all path updates inside one transaction so
# the rename is atomic from the caller's perspective.
if validated_old.endswith("/"):
# Ensure validated_new also ends with / for proper path replacement
if not validated_new.endswith("/"):
validated_new = validated_new + "/"
with db_session() as conn:
repo = MemoriesRepository(conn)
docs = repo.list_by_prefix(
self.user_id, self.tool_id, validated_old
)
if not docs:
return f"Error: Directory not found: {validated_old}"
for doc in docs:
old_file_path = doc["path"]
new_file_path = old_file_path.replace(
validated_old, validated_new, 1
)
repo.update_path(
self.user_id, self.tool_id, old_file_path, new_file_path
)
return f"Renamed directory: {validated_old} -> {validated_new} ({len(docs)} files)"
# Single-file rename: lookup, collision check, and update in one txn.
with db_session() as conn:
repo = MemoriesRepository(conn)
doc = repo.get_by_path(self.user_id, self.tool_id, validated_old)
if not doc:
return f"Error: File not found: {validated_old}"
existing = repo.get_by_path(self.user_id, self.tool_id, validated_new)
if existing:
return f"Error: File already exists at {validated_new}"
repo.update_path(
self.user_id, self.tool_id, validated_old, validated_new
)
return f"Renamed: {validated_old} -> {validated_new}"

View File

@@ -1,258 +0,0 @@
from typing import Any, Dict, List, Optional
import uuid
from .base import Tool
from application.storage.db.repositories.notes import NotesRepository
from application.storage.db.session import db_readonly, db_session
# Stable synthetic title used in the Postgres ``notes.title`` column.
# The notes tool stores one note per (user_id, tool_id); there is no
# user-facing title. PG requires ``title`` NOT NULL, so we write a stable
# constant alongside the actual note body in ``content``.
_NOTE_TITLE = "note"
class NotesTool(Tool):
"""Notepad
Single note. Supports viewing, overwriting, string replacement.
"""
def __init__(self, tool_config: Optional[Dict[str, Any]] = None, user_id: Optional[str] = None) -> None:
"""Initialize the tool.
Args:
tool_config: Optional tool configuration. Should include:
- tool_id: Unique identifier for this notes tool instance (from user_tools._id)
This ensures each user's tool configuration has isolated notes
user_id: The authenticated user's id (should come from decoded_token["sub"]).
"""
self.user_id: Optional[str] = user_id
# Get tool_id from configuration (passed from user_tools._id in production)
if tool_config and "tool_id" in tool_config:
self.tool_id = tool_config["tool_id"]
elif user_id:
# Fallback for backward compatibility or testing
self.tool_id = f"default_{user_id}"
else:
# Last resort fallback (shouldn't happen in normal use)
self.tool_id = str(uuid.uuid4())
self._last_artifact_id: Optional[str] = None
def _pg_enabled(self) -> bool:
"""Return True only when ``tool_id`` is a real ``user_tools.id`` UUID.
``notes.tool_id`` is a UUID FK to ``user_tools``; repo queries
``CAST(:tool_id AS uuid)``. The sentinel ``default_{uid}``
fallback is neither a UUID nor a ``user_tools`` row, so any DB
operation would crash. Mirror MemoryTool's guard and no-op.
"""
tool_id = getattr(self, "tool_id", None)
if not tool_id or not isinstance(tool_id, str):
return False
if tool_id.startswith("default_"):
return False
from application.storage.db.base_repository import looks_like_uuid
return looks_like_uuid(tool_id)
# -----------------------------
# Action implementations
# -----------------------------
def execute_action(self, action_name: str, **kwargs: Any) -> str:
"""Execute an action by name.
Args:
action_name: One of view, overwrite, str_replace, insert, delete.
**kwargs: Parameters for the action.
Returns:
A human-readable string result.
"""
if not self.user_id:
return "Error: NotesTool requires a valid user_id."
if not self._pg_enabled():
return (
"Error: NotesTool is not configured with a persistent "
"tool_id; note storage is unavailable for this session."
)
self._last_artifact_id = None
if action_name == "view":
return self._get_note()
if action_name == "overwrite":
return self._overwrite_note(kwargs.get("text", ""))
if action_name == "str_replace":
return self._str_replace(kwargs.get("old_str", ""), kwargs.get("new_str", ""))
if action_name == "insert":
return self._insert(kwargs.get("line_number", 1), kwargs.get("text", ""))
if action_name == "delete":
return self._delete_note()
return f"Unknown action: {action_name}"
def get_actions_metadata(self) -> List[Dict[str, Any]]:
"""Return JSON metadata describing supported actions for tool schemas."""
return [
{
"name": "view",
"description": "Retrieve the user's note.",
"parameters": {"type": "object", "properties": {}},
},
{
"name": "overwrite",
"description": "Replace the entire note content (creates if doesn't exist).",
"parameters": {
"type": "object",
"properties": {
"text": {"type": "string", "description": "New note content."}
},
"required": ["text"],
},
},
{
"name": "str_replace",
"description": "Replace occurrences of old_str with new_str in the note.",
"parameters": {
"type": "object",
"properties": {
"old_str": {"type": "string", "description": "String to find."},
"new_str": {"type": "string", "description": "String to replace with."}
},
"required": ["old_str", "new_str"],
},
},
{
"name": "insert",
"description": "Insert text at the specified line number (1-indexed).",
"parameters": {
"type": "object",
"properties": {
"line_number": {"type": "integer", "description": "Line number to insert at (1-indexed)."},
"text": {"type": "string", "description": "Text to insert."}
},
"required": ["line_number", "text"],
},
},
{
"name": "delete",
"description": "Delete the user's note.",
"parameters": {"type": "object", "properties": {}},
},
]
def get_config_requirements(self) -> Dict[str, Any]:
"""Return configuration requirements (none for now)."""
return {}
def get_artifact_id(self, action_name: str, **kwargs: Any) -> Optional[str]:
return self._last_artifact_id
# -----------------------------
# Internal helpers (single-note)
# -----------------------------
def _fetch_note(self) -> Optional[dict]:
"""Read the note row for this (user, tool) from Postgres."""
with db_readonly() as conn:
return NotesRepository(conn).get_for_user_tool(self.user_id, self.tool_id)
def _get_note(self) -> str:
doc = self._fetch_note()
# ``content`` is the PG column; expose as ``note`` to callers via the
# textual return value. Frontends that read the artifact via the
# repo dict get ``content`` (PG-native) plus the artifact id below.
body = (doc or {}).get("content")
if not doc or not body:
return "No note found."
if doc.get("id") is not None:
self._last_artifact_id = str(doc.get("id"))
return str(body)
def _overwrite_note(self, content: str) -> str:
content = (content or "").strip()
if not content:
return "Note content required."
with db_session() as conn:
row = NotesRepository(conn).upsert(
self.user_id, self.tool_id, _NOTE_TITLE, content
)
if row and row.get("id") is not None:
self._last_artifact_id = str(row.get("id"))
return "Note saved."
def _str_replace(self, old_str: str, new_str: str) -> str:
if not old_str:
return "old_str is required."
doc = self._fetch_note()
existing = (doc or {}).get("content")
if not doc or not existing:
return "No note found."
current_note = str(existing)
# Case-insensitive search
if old_str.lower() not in current_note.lower():
return f"String '{old_str}' not found in note."
# Case-insensitive replacement
import re
updated_note = re.sub(re.escape(old_str), new_str, current_note, flags=re.IGNORECASE)
with db_session() as conn:
row = NotesRepository(conn).upsert(
self.user_id, self.tool_id, _NOTE_TITLE, updated_note
)
if row and row.get("id") is not None:
self._last_artifact_id = str(row.get("id"))
return "Note updated."
def _insert(self, line_number: int, text: str) -> str:
if not text:
return "Text is required."
doc = self._fetch_note()
existing = (doc or {}).get("content")
if not doc or not existing:
return "No note found."
current_note = str(existing)
lines = current_note.split("\n")
# Convert to 0-indexed and validate
index = line_number - 1
if index < 0 or index > len(lines):
return f"Invalid line number. Note has {len(lines)} lines."
lines.insert(index, text)
updated_note = "\n".join(lines)
with db_session() as conn:
row = NotesRepository(conn).upsert(
self.user_id, self.tool_id, _NOTE_TITLE, updated_note
)
if row and row.get("id") is not None:
self._last_artifact_id = str(row.get("id"))
return "Text inserted."
def _delete_note(self) -> str:
# Capture the id (for artifact tracking) before deleting.
existing = self._fetch_note()
if not existing:
return "No note found to delete."
with db_session() as conn:
deleted = NotesRepository(conn).delete(self.user_id, self.tool_id)
if not deleted:
return "No note found to delete."
if existing.get("id") is not None:
self._last_artifact_id = str(existing.get("id"))
return "Note deleted."

View File

@@ -71,7 +71,7 @@ class NtfyTool(Tool):
if self.token:
headers["Authorization"] = f"Basic {self.token}"
data = message.encode("utf-8")
response = requests.post(url, headers=headers, data=data, timeout=100)
response = requests.post(url, headers=headers, data=data)
return {"status_code": response.status_code, "message": "Message sent"}
def get_actions_metadata(self):
@@ -116,13 +116,12 @@ class NtfyTool(Tool):
]
def get_config_requirements(self):
"""
Specify the configuration requirements.
Returns:
dict: Dictionary describing required config parameters.
"""
return {
"token": {
"type": "string",
"label": "Access Token",
"description": "Ntfy access token for authentication",
"required": True,
"secret": True,
"order": 1,
},
"token": {"type": "string", "description": "Access token for authentication"},
}

View File

@@ -1,12 +1,6 @@
import logging
import psycopg
import psycopg2
from application.agents.tools.base import Tool
logger = logging.getLogger(__name__)
class PostgresTool(Tool):
"""
PostgreSQL Database Tool
@@ -23,25 +17,25 @@ class PostgresTool(Tool):
"postgres_execute_sql": self._execute_sql,
"postgres_get_schema": self._get_schema,
}
if action_name not in actions:
if action_name in actions:
return actions[action_name](**kwargs)
else:
raise ValueError(f"Unknown action: {action_name}")
return actions[action_name](**kwargs)
def _execute_sql(self, sql_query):
"""
Executes an SQL query against the PostgreSQL database using a connection string.
"""
conn = None
conn = None # Initialize conn to None for error handling
try:
conn = psycopg.connect(self.connection_string)
conn = psycopg2.connect(self.connection_string)
cur = conn.cursor()
cur.execute(sql_query)
conn.commit()
if sql_query.strip().lower().startswith("select"):
column_names = (
[desc[0] for desc in cur.description] if cur.description else []
)
column_names = [desc[0] for desc in cur.description] if cur.description else []
results = []
rows = cur.fetchall()
for row in rows:
@@ -49,9 +43,7 @@ class PostgresTool(Tool):
response_data = {"data": results, "column_names": column_names}
else:
row_count = cur.rowcount
response_data = {
"message": f"Query executed successfully, {row_count} rows affected."
}
response_data = {"message": f"Query executed successfully, {row_count} rows affected."}
cur.close()
return {
@@ -60,29 +52,28 @@ class PostgresTool(Tool):
"response_data": response_data,
}
except psycopg.Error as e:
except psycopg2.Error as e:
error_message = f"Database error: {e}"
logger.error("PostgreSQL execute_sql error: %s", e)
print(f"Database error: {e}")
return {
"status_code": 500,
"message": "Failed to execute SQL query.",
"error": error_message,
}
finally:
if conn:
if conn: # Ensure connection is closed even if errors occur
conn.close()
def _get_schema(self, db_name):
"""
Retrieves the schema of the PostgreSQL database using a connection string.
"""
conn = None
conn = None # Initialize conn to None for error handling
try:
conn = psycopg.connect(self.connection_string)
conn = psycopg2.connect(self.connection_string)
cur = conn.cursor()
cur.execute(
"""
cur.execute("""
SELECT
table_name,
column_name,
@@ -96,22 +87,19 @@ class PostgresTool(Tool):
ORDER BY
table_name,
ordinal_position;
"""
)
""")
schema_data = {}
for row in cur.fetchall():
table_name, column_name, data_type, column_default, is_nullable = row
if table_name not in schema_data:
schema_data[table_name] = []
schema_data[table_name].append(
{
"column_name": column_name,
"data_type": data_type,
"column_default": column_default,
"is_nullable": is_nullable,
}
)
schema_data[table_name].append({
"column_name": column_name,
"data_type": data_type,
"column_default": column_default,
"is_nullable": is_nullable
})
cur.close()
return {
@@ -120,16 +108,16 @@ class PostgresTool(Tool):
"schema": schema_data,
}
except psycopg.Error as e:
except psycopg2.Error as e:
error_message = f"Database error: {e}"
logger.error("PostgreSQL get_schema error: %s", e)
print(f"Database error: {e}")
return {
"status_code": 500,
"message": "Failed to retrieve database schema.",
"error": error_message,
}
finally:
if conn:
if conn: # Ensure connection is closed even if errors occur
conn.close()
def get_actions_metadata(self):
@@ -170,10 +158,6 @@ class PostgresTool(Tool):
return {
"token": {
"type": "string",
"label": "Connection String",
"description": "PostgreSQL database connection string",
"required": True,
"secret": True,
"order": 1,
"description": "PostgreSQL database connection string (e.g., 'postgresql://user:password@host:port/dbname')",
},
}
}

View File

@@ -1,7 +1,7 @@
import requests
from markdownify import markdownify
from application.agents.tools.base import Tool
from application.core.url_validation import validate_url, SSRFError
from urllib.parse import urlparse
class ReadWebpageTool(Tool):
"""
@@ -31,12 +31,11 @@ class ReadWebpageTool(Tool):
if not url:
return "Error: URL parameter is missing."
# Validate URL to prevent SSRF attacks
try:
url = validate_url(url)
except SSRFError as e:
return f"Error: URL validation failed - {e}"
# Ensure the URL has a scheme (if not, default to http)
parsed_url = urlparse(url)
if not parsed_url.scheme:
url = "http://" + url
try:
response = requests.get(url, timeout=10, headers={'User-Agent': 'DocsGPT-Agent/1.0'})
response.raise_for_status() # Raise an exception for HTTP errors (4xx or 5xx)

View File

@@ -1,342 +0,0 @@
"""
API Specification Parser
Parses OpenAPI 3.x and Swagger 2.0 specifications and converts them
to API Tool action definitions for use in DocsGPT.
"""
import json
import logging
import re
from typing import Any, Dict, List, Optional, Tuple
import yaml
logger = logging.getLogger(__name__)
SUPPORTED_METHODS = frozenset(
{"get", "post", "put", "delete", "patch", "head", "options"}
)
def parse_spec(spec_content: str) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]:
"""
Parse an API specification and convert operations to action definitions.
Supports OpenAPI 3.x and Swagger 2.0 formats in JSON or YAML.
Args:
spec_content: Raw specification content as string
Returns:
Tuple of (metadata dict, list of action dicts)
Raises:
ValueError: If the spec is invalid or uses an unsupported format
"""
spec = _load_spec(spec_content)
_validate_spec(spec)
is_swagger = "swagger" in spec
metadata = _extract_metadata(spec, is_swagger)
actions = _extract_actions(spec, is_swagger)
return metadata, actions
def _load_spec(content: str) -> Dict[str, Any]:
"""Parse spec content from JSON or YAML string."""
content = content.strip()
if not content:
raise ValueError("Empty specification content")
try:
if content.startswith("{"):
return json.loads(content)
return yaml.safe_load(content)
except json.JSONDecodeError as e:
raise ValueError(f"Invalid JSON format: {e.msg}")
except yaml.YAMLError as e:
raise ValueError(f"Invalid YAML format: {e}")
def _validate_spec(spec: Dict[str, Any]) -> None:
"""Validate spec version and required fields."""
if not isinstance(spec, dict):
raise ValueError("Specification must be a valid object")
openapi_version = spec.get("openapi", "")
swagger_version = spec.get("swagger", "")
if not (openapi_version.startswith("3.") or swagger_version == "2.0"):
raise ValueError(
"Unsupported specification version. Expected OpenAPI 3.x or Swagger 2.0"
)
if "paths" not in spec or not spec["paths"]:
raise ValueError("No API paths defined in the specification")
def _extract_metadata(spec: Dict[str, Any], is_swagger: bool) -> Dict[str, Any]:
"""Extract API metadata from specification."""
info = spec.get("info", {})
base_url = _get_base_url(spec, is_swagger)
return {
"title": info.get("title", "Untitled API"),
"description": (info.get("description", "") or "")[:500],
"version": info.get("version", ""),
"base_url": base_url,
}
def _get_base_url(spec: Dict[str, Any], is_swagger: bool) -> str:
"""Extract base URL from spec (handles both OpenAPI 3.x and Swagger 2.0)."""
if is_swagger:
schemes = spec.get("schemes", ["https"])
host = spec.get("host", "")
base_path = spec.get("basePath", "")
if host:
scheme = schemes[0] if schemes else "https"
return f"{scheme}://{host}{base_path}".rstrip("/")
return ""
servers = spec.get("servers", [])
if servers and isinstance(servers, list) and servers[0].get("url"):
return servers[0]["url"].rstrip("/")
return ""
def _extract_actions(spec: Dict[str, Any], is_swagger: bool) -> List[Dict[str, Any]]:
"""Extract all API operations as action definitions."""
actions = []
paths = spec.get("paths", {})
base_url = _get_base_url(spec, is_swagger)
components = spec.get("components", {})
definitions = spec.get("definitions", {})
for path, path_item in paths.items():
if not isinstance(path_item, dict):
continue
path_params = path_item.get("parameters", [])
for method in SUPPORTED_METHODS:
operation = path_item.get(method)
if not isinstance(operation, dict):
continue
try:
action = _build_action(
path=path,
method=method,
operation=operation,
path_params=path_params,
base_url=base_url,
components=components,
definitions=definitions,
is_swagger=is_swagger,
)
actions.append(action)
except Exception as e:
logger.warning(
f"Failed to parse operation {method.upper()} {path}: {e}"
)
continue
return actions
def _build_action(
path: str,
method: str,
operation: Dict[str, Any],
path_params: List[Dict],
base_url: str,
components: Dict[str, Any],
definitions: Dict[str, Any],
is_swagger: bool,
) -> Dict[str, Any]:
"""Build a single action from an API operation."""
action_name = _generate_action_name(operation, method, path)
full_url = f"{base_url}{path}" if base_url else path
all_params = path_params + operation.get("parameters", [])
query_params, headers = _categorize_parameters(all_params, components, definitions)
body, body_content_type = _extract_request_body(
operation, components, definitions, is_swagger
)
description = operation.get("summary", "") or operation.get("description", "")
return {
"name": action_name,
"url": full_url,
"method": method.upper(),
"description": (description or "")[:500],
"query_params": {"type": "object", "properties": query_params},
"headers": {"type": "object", "properties": headers},
"body": {"type": "object", "properties": body},
"body_content_type": body_content_type,
"active": True,
}
def _generate_action_name(operation: Dict[str, Any], method: str, path: str) -> str:
"""Generate a valid action name from operationId or method+path."""
if operation.get("operationId"):
name = operation["operationId"]
else:
path_slug = re.sub(r"[{}]", "", path)
path_slug = re.sub(r"[^a-zA-Z0-9]", "_", path_slug)
path_slug = re.sub(r"_+", "_", path_slug).strip("_")
name = f"{method}_{path_slug}"
name = re.sub(r"[^a-zA-Z0-9_-]", "_", name)
return name[:64]
def _categorize_parameters(
parameters: List[Dict],
components: Dict[str, Any],
definitions: Dict[str, Any],
) -> Tuple[Dict, Dict]:
"""Categorize parameters into query params and headers."""
query_params = {}
headers = {}
for param in parameters:
resolved = _resolve_ref(param, components, definitions)
if not resolved or "name" not in resolved:
continue
location = resolved.get("in", "query")
prop = _param_to_property(resolved)
if location in ("query", "path"):
query_params[resolved["name"]] = prop
elif location == "header":
headers[resolved["name"]] = prop
return query_params, headers
def _param_to_property(param: Dict) -> Dict[str, Any]:
"""Convert an API parameter to an action property definition."""
schema = param.get("schema", {})
param_type = schema.get("type", param.get("type", "string"))
mapped_type = "integer" if param_type in ("integer", "number") else "string"
return {
"type": mapped_type,
"description": (param.get("description", "") or "")[:200],
"value": "",
"filled_by_llm": param.get("required", False),
"required": param.get("required", False),
}
def _extract_request_body(
operation: Dict[str, Any],
components: Dict[str, Any],
definitions: Dict[str, Any],
is_swagger: bool,
) -> Tuple[Dict, str]:
"""Extract request body schema and content type."""
content_types = [
"application/json",
"application/x-www-form-urlencoded",
"multipart/form-data",
"text/plain",
"application/xml",
]
if is_swagger:
consumes = operation.get("consumes", [])
body_param = next(
(p for p in operation.get("parameters", []) if p.get("in") == "body"), None
)
if not body_param:
return {}, "application/json"
selected_type = consumes[0] if consumes else "application/json"
schema = body_param.get("schema", {})
else:
request_body = operation.get("requestBody", {})
if not request_body:
return {}, "application/json"
request_body = _resolve_ref(request_body, components, definitions)
content = request_body.get("content", {})
selected_type = "application/json"
schema = {}
for ct in content_types:
if ct in content:
selected_type = ct
schema = content[ct].get("schema", {})
break
if not schema and content:
first_type = next(iter(content))
selected_type = first_type
schema = content[first_type].get("schema", {})
properties = _schema_to_properties(schema, components, definitions)
return properties, selected_type
def _schema_to_properties(
schema: Dict,
components: Dict[str, Any],
definitions: Dict[str, Any],
depth: int = 0,
) -> Dict[str, Any]:
"""Convert schema to action body properties (limited depth to prevent recursion)."""
if depth > 3:
return {}
schema = _resolve_ref(schema, components, definitions)
if not schema or not isinstance(schema, dict):
return {}
properties = {}
schema_type = schema.get("type", "object")
if schema_type == "object":
required_fields = set(schema.get("required", []))
for prop_name, prop_schema in schema.get("properties", {}).items():
resolved = _resolve_ref(prop_schema, components, definitions)
if not isinstance(resolved, dict):
continue
prop_type = resolved.get("type", "string")
mapped_type = "integer" if prop_type in ("integer", "number") else "string"
properties[prop_name] = {
"type": mapped_type,
"description": (resolved.get("description", "") or "")[:200],
"value": "",
"filled_by_llm": prop_name in required_fields,
"required": prop_name in required_fields,
}
return properties
def _resolve_ref(
obj: Any,
components: Dict[str, Any],
definitions: Dict[str, Any],
) -> Optional[Dict]:
"""Resolve $ref references in the specification."""
if not isinstance(obj, dict):
return obj if isinstance(obj, dict) else None
if "$ref" not in obj:
return obj
ref_path = obj["$ref"]
if ref_path.startswith("#/components/"):
parts = ref_path.replace("#/components/", "").split("/")
return _traverse_path(components, parts)
elif ref_path.startswith("#/definitions/"):
parts = ref_path.replace("#/definitions/", "").split("/")
return _traverse_path(definitions, parts)
logger.debug(f"Unsupported ref path: {ref_path}")
return None
def _traverse_path(obj: Dict, parts: List[str]) -> Optional[Dict]:
"""Traverse a nested dictionary using path parts."""
try:
for part in parts:
obj = obj[part]
return obj if isinstance(obj, dict) else None
except (KeyError, TypeError):
return None

View File

@@ -1,11 +1,6 @@
import logging
import requests
from application.agents.tools.base import Tool
logger = logging.getLogger(__name__)
class TelegramTool(Tool):
"""
@@ -23,22 +18,24 @@ class TelegramTool(Tool):
"telegram_send_message": self._send_message,
"telegram_send_image": self._send_image,
}
if action_name not in actions:
if action_name in actions:
return actions[action_name](**kwargs)
else:
raise ValueError(f"Unknown action: {action_name}")
return actions[action_name](**kwargs)
def _send_message(self, text, chat_id):
logger.debug("Sending Telegram message to chat_id=%s", chat_id)
print(f"Sending message: {text}")
url = f"https://api.telegram.org/bot{self.token}/sendMessage"
payload = {"chat_id": chat_id, "text": text}
response = requests.post(url, data=payload, timeout=100)
response = requests.post(url, data=payload)
return {"status_code": response.status_code, "message": "Message sent"}
def _send_image(self, image_url, chat_id):
logger.debug("Sending Telegram image to chat_id=%s", chat_id)
print(f"Sending image: {image_url}")
url = f"https://api.telegram.org/bot{self.token}/sendPhoto"
payload = {"chat_id": chat_id, "photo": image_url}
response = requests.post(url, data=payload, timeout=100)
response = requests.post(url, data=payload)
return {"status_code": response.status_code, "message": "Image sent"}
def get_actions_metadata(self):
@@ -85,12 +82,5 @@ class TelegramTool(Tool):
def get_config_requirements(self):
return {
"token": {
"type": "string",
"label": "Bot Token",
"description": "Telegram bot token for authentication",
"required": True,
"secret": True,
"order": 1,
},
"token": {"type": "string", "description": "Bot token for authentication"},
}

View File

@@ -1,70 +0,0 @@
from application.agents.tools.base import Tool
THINK_TOOL_ID = "think"
THINK_TOOL_ENTRY = {
"name": "think",
"actions": [
{
"name": "reason",
"description": (
"Use this tool to think through your reasoning step by step "
"before deciding on your next action. Always reason before "
"searching or answering."
),
"active": True,
"parameters": {
"properties": {
"reasoning": {
"type": "string",
"description": "Your step-by-step reasoning and analysis",
"filled_by_llm": True,
"required": True,
}
}
},
}
],
}
class ThinkTool(Tool):
"""Pseudo-tool that captures chain-of-thought reasoning.
Returns a short acknowledgment so the LLM can continue.
The reasoning content is captured in tool_call data for transparency.
"""
internal = True
def __init__(self, config=None):
pass
def execute_action(self, action_name: str, **kwargs):
return "Continue."
def get_actions_metadata(self):
return [
{
"name": "reason",
"description": (
"Use this tool to think through your reasoning step by step "
"before deciding on your next action. Always reason before "
"searching or answering."
),
"parameters": {
"properties": {
"reasoning": {
"type": "string",
"description": "Your step-by-step reasoning and analysis",
"filled_by_llm": True,
"required": True,
}
}
},
}
]
def get_config_requirements(self):
return {}

View File

@@ -1,351 +0,0 @@
from typing import Any, Dict, List, Optional
import uuid
from .base import Tool
from application.storage.db.repositories.todos import TodosRepository
from application.storage.db.session import db_readonly, db_session
def _status_from_completed(completed: Any) -> str:
"""Translate the PG ``completed`` boolean to the legacy status string.
The frontend (and prior LLM-facing tool output) expects
``"open"`` / ``"completed"``. Keeping that contract at the tool
boundary insulates callers from the schema change.
"""
return "completed" if bool(completed) else "open"
class TodoListTool(Tool):
"""Todo List
Manages todo items for users. Supports creating, viewing, updating, and deleting todos.
"""
def __init__(self, tool_config: Optional[Dict[str, Any]] = None, user_id: Optional[str] = None) -> None:
"""Initialize the tool.
Args:
tool_config: Optional tool configuration. Should include:
- tool_id: Unique identifier for this todo list tool instance (from user_tools._id)
This ensures each user's tool configuration has isolated todos
user_id: The authenticated user's id (should come from decoded_token["sub"]).
"""
self.user_id: Optional[str] = user_id
# Get tool_id from configuration (passed from user_tools._id in production)
if tool_config and "tool_id" in tool_config:
self.tool_id = tool_config["tool_id"]
elif user_id:
# Fallback for backward compatibility or testing
self.tool_id = f"default_{user_id}"
else:
# Last resort fallback (shouldn't happen in normal use)
self.tool_id = str(uuid.uuid4())
self._last_artifact_id: Optional[str] = None
def _pg_enabled(self) -> bool:
"""Return True only when ``tool_id`` is a real ``user_tools.id`` UUID.
The ``todos`` PG table has a UUID foreign key to ``user_tools`` and
the repo queries ``CAST(:tool_id AS uuid)``. The sentinel
``default_{uid}`` fallback is neither a UUID nor a row in
``user_tools`` — binding it would crash ``invalid input syntax for
type uuid`` and even if it didn't the FK would reject it. Mirror
the MemoryTool guard and no-op in that case.
"""
tool_id = getattr(self, "tool_id", None)
if not tool_id or not isinstance(tool_id, str):
return False
if tool_id.startswith("default_"):
return False
from application.storage.db.base_repository import looks_like_uuid
return looks_like_uuid(tool_id)
# -----------------------------
# Action implementations
# -----------------------------
def execute_action(self, action_name: str, **kwargs: Any) -> str:
"""Execute an action by name.
Args:
action_name: One of list, create, get, update, complete, delete.
**kwargs: Parameters for the action.
Returns:
A human-readable string result.
"""
if not self.user_id:
return "Error: TodoListTool requires a valid user_id."
if not self._pg_enabled():
return (
"Error: TodoListTool is not configured with a persistent "
"tool_id; todo storage is unavailable for this session."
)
self._last_artifact_id = None
if action_name == "list":
return self._list()
if action_name == "create":
return self._create(kwargs.get("title", ""))
if action_name == "get":
return self._get(kwargs.get("todo_id"))
if action_name == "update":
return self._update(
kwargs.get("todo_id"),
kwargs.get("title", "")
)
if action_name == "complete":
return self._complete(kwargs.get("todo_id"))
if action_name == "delete":
return self._delete(kwargs.get("todo_id"))
return f"Unknown action: {action_name}"
def get_actions_metadata(self) -> List[Dict[str, Any]]:
"""Return JSON metadata describing supported actions for tool schemas."""
return [
{
"name": "list",
"description": "List all todos for the user.",
"parameters": {"type": "object", "properties": {}},
},
{
"name": "create",
"description": "Create a new todo item.",
"parameters": {
"type": "object",
"properties": {
"title": {
"type": "string",
"description": "Title of the todo item."
}
},
"required": ["title"],
},
},
{
"name": "get",
"description": "Get a specific todo by ID.",
"parameters": {
"type": "object",
"properties": {
"todo_id": {
"type": "integer",
"description": "The ID of the todo to retrieve."
}
},
"required": ["todo_id"],
},
},
{
"name": "update",
"description": "Update a todo's title by ID.",
"parameters": {
"type": "object",
"properties": {
"todo_id": {
"type": "integer",
"description": "The ID of the todo to update."
},
"title": {
"type": "string",
"description": "The new title for the todo."
}
},
"required": ["todo_id", "title"],
},
},
{
"name": "complete",
"description": "Mark a todo as completed.",
"parameters": {
"type": "object",
"properties": {
"todo_id": {
"type": "integer",
"description": "The ID of the todo to mark as completed."
}
},
"required": ["todo_id"],
},
},
{
"name": "delete",
"description": "Delete a specific todo by ID.",
"parameters": {
"type": "object",
"properties": {
"todo_id": {
"type": "integer",
"description": "The ID of the todo to delete."
}
},
"required": ["todo_id"],
},
},
]
def get_config_requirements(self) -> Dict[str, Any]:
"""Return configuration requirements."""
return {}
def get_artifact_id(self, action_name: str, **kwargs: Any) -> Optional[str]:
return self._last_artifact_id
# -----------------------------
# Internal helpers
# -----------------------------
def _coerce_todo_id(self, value: Optional[Any]) -> Optional[int]:
"""Convert todo identifiers to sequential integers."""
if value is None:
return None
if isinstance(value, int):
return value if value > 0 else None
if isinstance(value, str):
stripped = value.strip()
if stripped.isdigit():
numeric_value = int(stripped)
return numeric_value if numeric_value > 0 else None
return None
def _list(self) -> str:
"""List all todos for the user."""
with db_readonly() as conn:
todos = TodosRepository(conn).list_for_tool(self.user_id, self.tool_id)
if not todos:
return "No todos found."
result_lines = ["Todos:"]
for doc in todos:
todo_id = doc.get("todo_id")
title = doc.get("title", "Untitled")
status = _status_from_completed(doc.get("completed"))
line = f"[{todo_id}] {title} ({status})"
result_lines.append(line)
return "\n".join(result_lines)
def _create(self, title: str) -> str:
"""Create a new todo item.
``TodosRepository.create`` allocates the per-tool monotonic
``todo_id`` inside the same transaction (``COALESCE(MAX(todo_id),0)+1``
scoped to ``tool_id``), so we no longer need a separate read-then-
write step here.
"""
title = (title or "").strip()
if not title:
return "Error: Title is required."
with db_session() as conn:
row = TodosRepository(conn).create(self.user_id, self.tool_id, title)
todo_id = row.get("todo_id")
if row.get("id") is not None:
self._last_artifact_id = str(row.get("id"))
return f"Todo created with ID {todo_id}: {title}"
def _get(self, todo_id: Optional[Any]) -> str:
"""Get a specific todo by ID."""
parsed_todo_id = self._coerce_todo_id(todo_id)
if parsed_todo_id is None:
return "Error: todo_id must be a positive integer."
with db_readonly() as conn:
doc = TodosRepository(conn).get_by_tool_and_todo_id(
self.user_id, self.tool_id, parsed_todo_id
)
if not doc:
return f"Error: Todo with ID {parsed_todo_id} not found."
if doc.get("id") is not None:
self._last_artifact_id = str(doc.get("id"))
title = doc.get("title", "Untitled")
status = _status_from_completed(doc.get("completed"))
return f"Todo [{parsed_todo_id}]:\nTitle: {title}\nStatus: {status}"
def _update(self, todo_id: Optional[Any], title: str) -> str:
"""Update a todo's title by ID."""
parsed_todo_id = self._coerce_todo_id(todo_id)
if parsed_todo_id is None:
return "Error: todo_id must be a positive integer."
title = (title or "").strip()
if not title:
return "Error: Title is required."
with db_session() as conn:
repo = TodosRepository(conn)
existing = repo.get_by_tool_and_todo_id(
self.user_id, self.tool_id, parsed_todo_id
)
if not existing:
return f"Error: Todo with ID {parsed_todo_id} not found."
repo.update_title_by_tool_and_todo_id(
self.user_id, self.tool_id, parsed_todo_id, title
)
if existing.get("id") is not None:
self._last_artifact_id = str(existing.get("id"))
return f"Todo {parsed_todo_id} updated to: {title}"
def _complete(self, todo_id: Optional[Any]) -> str:
"""Mark a todo as completed."""
parsed_todo_id = self._coerce_todo_id(todo_id)
if parsed_todo_id is None:
return "Error: todo_id must be a positive integer."
with db_session() as conn:
repo = TodosRepository(conn)
existing = repo.get_by_tool_and_todo_id(
self.user_id, self.tool_id, parsed_todo_id
)
if not existing:
return f"Error: Todo with ID {parsed_todo_id} not found."
repo.set_completed(self.user_id, self.tool_id, parsed_todo_id, True)
if existing.get("id") is not None:
self._last_artifact_id = str(existing.get("id"))
return f"Todo {parsed_todo_id} marked as completed."
def _delete(self, todo_id: Optional[Any]) -> str:
"""Delete a specific todo by ID."""
parsed_todo_id = self._coerce_todo_id(todo_id)
if parsed_todo_id is None:
return "Error: todo_id must be a positive integer."
with db_session() as conn:
repo = TodosRepository(conn)
existing = repo.get_by_tool_and_todo_id(
self.user_id, self.tool_id, parsed_todo_id
)
if not existing:
return f"Error: Todo with ID {parsed_todo_id} not found."
repo.delete_by_tool_and_todo_id(
self.user_id, self.tool_id, parsed_todo_id
)
if existing.get("id") is not None:
self._last_artifact_id = str(existing.get("id"))
return f"Todo {parsed_todo_id} deleted."

View File

@@ -5,9 +5,8 @@ logger = logging.getLogger(__name__)
class ToolActionParser:
def __init__(self, llm_type, name_mapping=None):
def __init__(self, llm_type):
self.llm_type = llm_type
self.name_mapping = name_mapping
self.parsers = {
"OpenAILLM": self._parse_openai_llm,
"GoogleLLM": self._parse_google_llm,
@@ -17,39 +16,12 @@ class ToolActionParser:
parser = self.parsers.get(self.llm_type, self._parse_openai_llm)
return parser(call)
def _resolve_via_mapping(self, call_name):
"""Look up (tool_id, action_name) from the name mapping if available."""
if self.name_mapping and call_name in self.name_mapping:
return self.name_mapping[call_name]
return None
def _parse_openai_llm(self, call):
try:
call_args = json.loads(call.arguments)
resolved = self._resolve_via_mapping(call.name)
if resolved:
return resolved[0], resolved[1], call_args
# Fallback: legacy split on "_" for backward compatibility
tool_parts = call.name.split("_")
if len(tool_parts) < 2:
logger.warning(
f"Invalid tool name format: {call.name}. "
"Could not resolve via mapping or legacy parsing."
)
return None, None, None
tool_id = tool_parts[-1]
action_name = "_".join(tool_parts[:-1])
if not tool_id.isdigit():
logger.warning(
f"Tool ID '{tool_id}' is not numerical. This might be a hallucinated tool call."
)
except (AttributeError, TypeError, json.JSONDecodeError) as e:
tool_id = call.name.split("_")[-1]
action_name = call.name.rsplit("_", 1)[0]
except (AttributeError, TypeError) as e:
logger.error(f"Error parsing OpenAI LLM call: {e}")
return None, None, None
return tool_id, action_name, call_args
@@ -57,29 +29,8 @@ class ToolActionParser:
def _parse_google_llm(self, call):
try:
call_args = call.arguments
resolved = self._resolve_via_mapping(call.name)
if resolved:
return resolved[0], resolved[1], call_args
# Fallback: legacy split on "_" for backward compatibility
tool_parts = call.name.split("_")
if len(tool_parts) < 2:
logger.warning(
f"Invalid tool name format: {call.name}. "
"Could not resolve via mapping or legacy parsing."
)
return None, None, None
tool_id = tool_parts[-1]
action_name = "_".join(tool_parts[:-1])
if not tool_id.isdigit():
logger.warning(
f"Tool ID '{tool_id}' is not numerical. This might be a hallucinated tool call."
)
tool_id = call.name.split("_")[-1]
action_name = call.name.rsplit("_", 1)[0]
except (AttributeError, TypeError) as e:
logger.error(f"Error parsing Google LLM call: {e}")
return None, None, None

View File

@@ -19,27 +19,20 @@ class ToolManager:
continue
module = importlib.import_module(f"application.agents.tools.{name}")
for member_name, obj in inspect.getmembers(module, inspect.isclass):
if issubclass(obj, Tool) and obj is not Tool and not obj.internal:
if issubclass(obj, Tool) and obj is not Tool:
tool_config = self.config.get(name, {})
self.tools[name] = obj(tool_config)
def load_tool(self, tool_name, tool_config, user_id=None):
def load_tool(self, tool_name, tool_config):
self.config[tool_name] = tool_config
module = importlib.import_module(f"application.agents.tools.{tool_name}")
for member_name, obj in inspect.getmembers(module, inspect.isclass):
if issubclass(obj, Tool) and obj is not Tool:
if tool_name in {"mcp_tool", "notes", "memory", "todo_list"} and user_id:
return obj(tool_config, user_id)
else:
return obj(tool_config)
return obj(tool_config)
def execute_action(self, tool_name, action_name, user_id=None, **kwargs):
def execute_action(self, tool_name, action_name, **kwargs):
if tool_name not in self.tools:
raise ValueError(f"Tool '{tool_name}' not loaded")
if tool_name in {"mcp_tool", "memory", "todo_list", "notes"} and user_id:
tool_config = self.config.get(tool_name, {})
tool = self.load_tool(tool_name, tool_config, user_id)
return tool.execute_action(action_name, **kwargs)
return self.tools[tool_name].execute_action(action_name, **kwargs)
def get_all_actions_metadata(self):

View File

@@ -1,254 +0,0 @@
import logging
from datetime import datetime, timezone
from typing import Any, Dict, Generator, Optional
from application.agents.base import BaseAgent
from application.agents.workflows.schemas import (
ExecutionStatus,
Workflow,
WorkflowEdge,
WorkflowGraph,
WorkflowNode,
WorkflowRun,
)
from application.agents.workflows.workflow_engine import WorkflowEngine
from application.logging import log_activity, LogContext
from application.storage.db.base_repository import looks_like_uuid
from application.storage.db.repositories.workflow_edges import WorkflowEdgesRepository
from application.storage.db.repositories.workflow_nodes import WorkflowNodesRepository
from application.storage.db.repositories.workflow_runs import WorkflowRunsRepository
from application.storage.db.repositories.workflows import WorkflowsRepository
from application.storage.db.session import db_readonly, db_session
logger = logging.getLogger(__name__)
class WorkflowAgent(BaseAgent):
"""A specialized agent that executes predefined workflows."""
def __init__(
self,
*args,
workflow_id: Optional[str] = None,
workflow: Optional[Dict[str, Any]] = None,
workflow_owner: Optional[str] = None,
**kwargs,
):
super().__init__(*args, **kwargs)
self.workflow_id = workflow_id
self.workflow_owner = workflow_owner
self._workflow_data = workflow
self._engine: Optional[WorkflowEngine] = None
@log_activity()
def gen(
self, query: str, log_context: LogContext = None
) -> Generator[Dict[str, str], None, None]:
yield from self._gen_inner(query, log_context)
def _gen_inner(
self, query: str, log_context: LogContext
) -> Generator[Dict[str, str], None, None]:
graph = self._load_workflow_graph()
if not graph:
yield {"type": "error", "error": "Failed to load workflow configuration."}
return
self._engine = WorkflowEngine(graph, self)
yield from self._engine.execute({}, query)
self._save_workflow_run(query)
def _load_workflow_graph(self) -> Optional[WorkflowGraph]:
if self._workflow_data:
return self._parse_embedded_workflow()
if self.workflow_id:
return self._load_from_database()
return None
def _parse_embedded_workflow(self) -> Optional[WorkflowGraph]:
try:
nodes_data = self._workflow_data.get("nodes", [])
edges_data = self._workflow_data.get("edges", [])
workflow = Workflow(
name=self._workflow_data.get("name", "Embedded Workflow"),
description=self._workflow_data.get("description"),
)
nodes = []
for n in nodes_data:
node_config = n.get("data", {})
nodes.append(
WorkflowNode(
id=n["id"],
workflow_id=self.workflow_id or "embedded",
type=n["type"],
title=n.get("title", "Node"),
description=n.get("description"),
position=n.get("position", {"x": 0, "y": 0}),
config=node_config,
)
)
edges = []
for e in edges_data:
edges.append(
WorkflowEdge(
id=e["id"],
workflow_id=self.workflow_id or "embedded",
source=e.get("source") or e.get("source_id"),
target=e.get("target") or e.get("target_id"),
sourceHandle=e.get("sourceHandle") or e.get("source_handle"),
targetHandle=e.get("targetHandle") or e.get("target_handle"),
)
)
return WorkflowGraph(workflow=workflow, nodes=nodes, edges=edges)
except Exception as e:
logger.error(f"Invalid embedded workflow: {e}")
return None
def _load_from_database(self) -> Optional[WorkflowGraph]:
try:
if not self.workflow_id:
logger.error("Missing workflow ID for load")
return None
owner_id = self.workflow_owner
if not owner_id and isinstance(self.decoded_token, dict):
owner_id = self.decoded_token.get("sub")
if not owner_id:
logger.error(
f"Workflow owner not available for workflow load: {self.workflow_id}"
)
return None
with db_readonly() as conn:
wf_repo = WorkflowsRepository(conn)
if looks_like_uuid(self.workflow_id):
workflow_row = wf_repo.get(self.workflow_id, owner_id)
else:
workflow_row = wf_repo.get_by_legacy_id(self.workflow_id, owner_id)
if workflow_row is None:
logger.error(
f"Workflow {self.workflow_id} not found or inaccessible "
f"for user {owner_id}"
)
return None
pg_workflow_id = str(workflow_row["id"])
graph_version = workflow_row.get("current_graph_version", 1)
try:
graph_version = int(graph_version)
if graph_version <= 0:
graph_version = 1
except (ValueError, TypeError):
graph_version = 1
node_rows = WorkflowNodesRepository(conn).find_by_version(
pg_workflow_id, graph_version,
)
edge_rows = WorkflowEdgesRepository(conn).find_by_version(
pg_workflow_id, graph_version,
)
workflow = Workflow(
name=workflow_row.get("name"),
description=workflow_row.get("description"),
)
nodes = [
WorkflowNode(
id=n["node_id"],
workflow_id=pg_workflow_id,
type=n["node_type"],
title=n.get("title") or "Node",
description=n.get("description"),
position=n.get("position") or {"x": 0, "y": 0},
config=n.get("config") or {},
)
for n in node_rows
]
edges = [
WorkflowEdge(
id=e["edge_id"],
workflow_id=pg_workflow_id,
source=e.get("source_id"),
target=e.get("target_id"),
sourceHandle=e.get("source_handle"),
targetHandle=e.get("target_handle"),
)
for e in edge_rows
]
return WorkflowGraph(workflow=workflow, nodes=nodes, edges=edges)
except Exception as e:
logger.error(f"Failed to load workflow from database: {e}")
return None
def _save_workflow_run(self, query: str) -> None:
if not self._engine:
return
owner_id = self.workflow_owner
if not owner_id and isinstance(self.decoded_token, dict):
owner_id = self.decoded_token.get("sub")
try:
run = WorkflowRun(
workflow_id=self.workflow_id or "unknown",
user=owner_id,
status=self._determine_run_status(),
inputs={"query": query},
outputs=self._serialize_state(self._engine.state),
steps=self._engine.get_execution_summary(),
created_at=datetime.now(timezone.utc),
completed_at=datetime.now(timezone.utc),
)
if not self.workflow_id or not owner_id:
return
with db_session() as conn:
wf_repo = WorkflowsRepository(conn)
if looks_like_uuid(self.workflow_id):
workflow_row = wf_repo.get(self.workflow_id, owner_id)
else:
workflow_row = wf_repo.get_by_legacy_id(
self.workflow_id, owner_id,
)
if workflow_row is None:
return
WorkflowRunsRepository(conn).create(
str(workflow_row["id"]),
owner_id,
run.status.value,
inputs=run.inputs,
result=run.outputs,
steps=[step.model_dump(mode="json") for step in run.steps],
started_at=run.created_at,
ended_at=run.completed_at,
)
except Exception as e:
logger.error(f"Failed to save workflow run: {e}")
def _determine_run_status(self) -> ExecutionStatus:
if not self._engine or not self._engine.execution_log:
return ExecutionStatus.COMPLETED
for log in self._engine.execution_log:
if log.get("status") == ExecutionStatus.FAILED.value:
return ExecutionStatus.FAILED
return ExecutionStatus.COMPLETED
def _serialize_state(self, state: Dict[str, Any]) -> Dict[str, Any]:
serialized: Dict[str, Any] = {}
for key, value in state.items():
serialized[key] = self._serialize_state_value(value)
return serialized
def _serialize_state_value(self, value: Any) -> Any:
if isinstance(value, dict):
return {
str(dict_key): self._serialize_state_value(dict_value)
for dict_key, dict_value in value.items()
}
if isinstance(value, list):
return [self._serialize_state_value(item) for item in value]
if isinstance(value, tuple):
return [self._serialize_state_value(item) for item in value]
if isinstance(value, datetime):
return value.isoformat()
if isinstance(value, (str, int, float, bool, type(None))):
return value
return str(value)

View File

@@ -1,64 +0,0 @@
from typing import Any, Dict
import celpy
import celpy.celtypes
class CelEvaluationError(Exception):
pass
def _convert_value(value: Any) -> Any:
if isinstance(value, bool):
return celpy.celtypes.BoolType(value)
if isinstance(value, int):
return celpy.celtypes.IntType(value)
if isinstance(value, float):
return celpy.celtypes.DoubleType(value)
if isinstance(value, str):
return celpy.celtypes.StringType(value)
if isinstance(value, list):
return celpy.celtypes.ListType([_convert_value(item) for item in value])
if isinstance(value, dict):
return celpy.celtypes.MapType(
{celpy.celtypes.StringType(k): _convert_value(v) for k, v in value.items()}
)
if value is None:
return celpy.celtypes.BoolType(False)
return celpy.celtypes.StringType(str(value))
def build_activation(state: Dict[str, Any]) -> Dict[str, Any]:
return {k: _convert_value(v) for k, v in state.items()}
def evaluate_cel(expression: str, state: Dict[str, Any]) -> Any:
if not expression or not expression.strip():
raise CelEvaluationError("Empty expression")
try:
env = celpy.Environment()
ast = env.compile(expression)
program = env.program(ast)
activation = build_activation(state)
result = program.evaluate(activation)
except celpy.CELEvalError as exc:
raise CelEvaluationError(f"CEL evaluation error: {exc}") from exc
except Exception as exc:
raise CelEvaluationError(f"CEL error: {exc}") from exc
return cel_to_python(result)
def cel_to_python(value: Any) -> Any:
if isinstance(value, celpy.celtypes.BoolType):
return bool(value)
if isinstance(value, celpy.celtypes.IntType):
return int(value)
if isinstance(value, celpy.celtypes.DoubleType):
return float(value)
if isinstance(value, celpy.celtypes.StringType):
return str(value)
if isinstance(value, celpy.celtypes.ListType):
return [cel_to_python(item) for item in value]
if isinstance(value, celpy.celtypes.MapType):
return {str(k): cel_to_python(v) for k, v in value.items()}
return value

View File

@@ -1,104 +0,0 @@
"""Workflow Node Agents - defines specialized agents for workflow nodes."""
from typing import Any, Dict, List, Optional, Type
from application.agents.agentic_agent import AgenticAgent
from application.agents.base import BaseAgent
from application.agents.classic_agent import ClassicAgent
from application.agents.research_agent import ResearchAgent
from application.agents.workflows.schemas import AgentType
class ToolFilterMixin:
"""Mixin that filters fetched tools to only those specified in tool_ids."""
_allowed_tool_ids: List[str]
def _get_user_tools(self, user: str = "local") -> Dict[str, Dict[str, Any]]:
all_tools = super()._get_user_tools(user)
if not self._allowed_tool_ids:
return {}
filtered_tools = {
tool_id: tool
for tool_id, tool in all_tools.items()
if str(tool.get("_id", "")) in self._allowed_tool_ids
}
return filtered_tools
def _get_tools(self, api_key: str = None) -> Dict[str, Dict[str, Any]]:
all_tools = super()._get_tools(api_key)
if not self._allowed_tool_ids:
return {}
filtered_tools = {
tool_id: tool
for tool_id, tool in all_tools.items()
if str(tool.get("_id", "")) in self._allowed_tool_ids
}
return filtered_tools
class _WorkflowNodeMixin:
"""Common __init__ for all workflow node agents."""
def __init__(
self,
endpoint: str,
llm_name: str,
model_id: str,
api_key: str,
tool_ids: Optional[List[str]] = None,
**kwargs,
):
super().__init__(
endpoint=endpoint,
llm_name=llm_name,
model_id=model_id,
api_key=api_key,
**kwargs,
)
self._allowed_tool_ids = tool_ids or []
class WorkflowNodeClassicAgent(ToolFilterMixin, _WorkflowNodeMixin, ClassicAgent):
pass
class WorkflowNodeAgenticAgent(ToolFilterMixin, _WorkflowNodeMixin, AgenticAgent):
pass
class WorkflowNodeResearchAgent(ToolFilterMixin, _WorkflowNodeMixin, ResearchAgent):
pass
class WorkflowNodeAgentFactory:
_agents: Dict[AgentType, Type[BaseAgent]] = {
AgentType.CLASSIC: WorkflowNodeClassicAgent,
AgentType.REACT: WorkflowNodeClassicAgent, # backwards compat
AgentType.AGENTIC: WorkflowNodeAgenticAgent,
AgentType.RESEARCH: WorkflowNodeResearchAgent,
}
@classmethod
def create(
cls,
agent_type: AgentType,
endpoint: str,
llm_name: str,
model_id: str,
api_key: str,
tool_ids: Optional[List[str]] = None,
**kwargs,
) -> BaseAgent:
agent_class = cls._agents.get(agent_type)
if not agent_class:
raise ValueError(f"Unsupported agent type: {agent_type}")
return agent_class(
endpoint=endpoint,
llm_name=llm_name,
model_id=model_id,
api_key=api_key,
tool_ids=tool_ids,
**kwargs,
)

View File

@@ -1,168 +0,0 @@
from datetime import datetime, timezone
from enum import Enum
from typing import Any, Dict, List, Literal, Optional, Union
from pydantic import BaseModel, ConfigDict, Field, field_validator
class NodeType(str, Enum):
START = "start"
END = "end"
AGENT = "agent"
NOTE = "note"
STATE = "state"
CONDITION = "condition"
class AgentType(str, Enum):
CLASSIC = "classic"
REACT = "react"
AGENTIC = "agentic"
RESEARCH = "research"
class ExecutionStatus(str, Enum):
PENDING = "pending"
RUNNING = "running"
COMPLETED = "completed"
FAILED = "failed"
class Position(BaseModel):
model_config = ConfigDict(extra="forbid")
x: float = 0.0
y: float = 0.0
class AgentNodeConfig(BaseModel):
model_config = ConfigDict(extra="allow")
agent_type: AgentType = AgentType.CLASSIC
llm_name: Optional[str] = None
system_prompt: str = "You are a helpful assistant."
prompt_template: str = ""
output_variable: Optional[str] = None
stream_to_user: bool = True
tools: List[str] = Field(default_factory=list)
sources: List[str] = Field(default_factory=list)
chunks: str = "2"
retriever: str = ""
model_id: Optional[str] = None
json_schema: Optional[Dict[str, Any]] = None
class ConditionCase(BaseModel):
model_config = ConfigDict(extra="forbid", populate_by_name=True)
name: Optional[str] = None
expression: str = ""
source_handle: str = Field(..., alias="sourceHandle")
class ConditionNodeConfig(BaseModel):
model_config = ConfigDict(extra="allow")
mode: Literal["simple", "advanced"] = "simple"
cases: List[ConditionCase] = Field(default_factory=list)
class StateOperation(BaseModel):
model_config = ConfigDict(extra="forbid")
expression: str = ""
target_variable: str = ""
class WorkflowEdgeCreate(BaseModel):
model_config = ConfigDict(populate_by_name=True)
id: str
workflow_id: str
source_id: str = Field(..., alias="source")
target_id: str = Field(..., alias="target")
source_handle: Optional[str] = Field(None, alias="sourceHandle")
target_handle: Optional[str] = Field(None, alias="targetHandle")
class WorkflowEdge(WorkflowEdgeCreate):
pass
class WorkflowNodeCreate(BaseModel):
model_config = ConfigDict(extra="allow")
id: str
workflow_id: str
type: NodeType
title: str = "Node"
description: Optional[str] = None
position: Position = Field(default_factory=Position)
config: Dict[str, Any] = Field(default_factory=dict)
@field_validator("position", mode="before")
@classmethod
def parse_position(cls, v: Union[Dict[str, float], Position]) -> Position:
if isinstance(v, dict):
return Position(**v)
return v
class WorkflowNode(WorkflowNodeCreate):
pass
class WorkflowCreate(BaseModel):
model_config = ConfigDict(extra="allow")
name: str = "New Workflow"
description: Optional[str] = None
user: Optional[str] = None
class Workflow(WorkflowCreate):
id: Optional[str] = None
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
updated_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
class WorkflowGraph(BaseModel):
workflow: Workflow
nodes: List[WorkflowNode] = Field(default_factory=list)
edges: List[WorkflowEdge] = Field(default_factory=list)
def get_node_by_id(self, node_id: str) -> Optional[WorkflowNode]:
for node in self.nodes:
if node.id == node_id:
return node
return None
def get_start_node(self) -> Optional[WorkflowNode]:
for node in self.nodes:
if node.type == NodeType.START:
return node
return None
def get_outgoing_edges(self, node_id: str) -> List[WorkflowEdge]:
return [edge for edge in self.edges if edge.source_id == node_id]
class NodeExecutionLog(BaseModel):
model_config = ConfigDict(extra="forbid")
node_id: str
node_type: str
status: ExecutionStatus
started_at: datetime
completed_at: Optional[datetime] = None
error: Optional[str] = None
state_snapshot: Dict[str, Any] = Field(default_factory=dict)
class WorkflowRunCreate(BaseModel):
workflow_id: str
inputs: Dict[str, str] = Field(default_factory=dict)
class WorkflowRun(BaseModel):
model_config = ConfigDict(extra="allow")
id: Optional[str] = None
workflow_id: str
user: Optional[str] = None
status: ExecutionStatus = ExecutionStatus.PENDING
inputs: Dict[str, str] = Field(default_factory=dict)
outputs: Dict[str, Any] = Field(default_factory=dict)
steps: List[NodeExecutionLog] = Field(default_factory=list)
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
completed_at: Optional[datetime] = None

View File

@@ -1,508 +0,0 @@
import json
import logging
from datetime import datetime, timezone
from typing import Any, Dict, Generator, List, Optional, TYPE_CHECKING
from application.agents.workflows.cel_evaluator import CelEvaluationError, evaluate_cel
from application.agents.workflows.node_agent import WorkflowNodeAgentFactory
from application.agents.workflows.schemas import (
AgentNodeConfig,
AgentType,
ConditionNodeConfig,
ExecutionStatus,
NodeExecutionLog,
NodeType,
WorkflowGraph,
WorkflowNode,
)
from application.core.json_schema_utils import (
JsonSchemaValidationError,
normalize_json_schema_payload,
)
from application.error import sanitize_api_error
from application.templates.namespaces import NamespaceManager
from application.templates.template_engine import TemplateEngine, TemplateRenderError
try:
import jsonschema
except ImportError: # pragma: no cover - optional dependency in some deployments.
jsonschema = None
if TYPE_CHECKING:
from application.agents.base import BaseAgent
logger = logging.getLogger(__name__)
StateValue = Any
WorkflowState = Dict[str, StateValue]
TEMPLATE_RESERVED_NAMESPACES = {"agent", "system", "source", "tools", "passthrough"}
class WorkflowEngine:
MAX_EXECUTION_STEPS = 50
def __init__(self, graph: WorkflowGraph, agent: "BaseAgent"):
self.graph = graph
self.agent = agent
self.state: WorkflowState = {}
self.execution_log: List[Dict[str, Any]] = []
self._condition_result: Optional[str] = None
self._template_engine = TemplateEngine()
self._namespace_manager = NamespaceManager()
def execute(
self, initial_inputs: WorkflowState, query: str
) -> Generator[Dict[str, str], None, None]:
self._initialize_state(initial_inputs, query)
start_node = self.graph.get_start_node()
if not start_node:
yield {"type": "error", "error": "No start node found in workflow."}
return
current_node_id: Optional[str] = start_node.id
steps = 0
while current_node_id and steps < self.MAX_EXECUTION_STEPS:
node = self.graph.get_node_by_id(current_node_id)
if not node:
yield {"type": "error", "error": f"Node {current_node_id} not found."}
break
log_entry = self._create_log_entry(node)
yield {
"type": "workflow_step",
"node_id": node.id,
"node_type": node.type.value,
"node_title": node.title,
"status": "running",
}
try:
yield from self._execute_node(node)
log_entry["status"] = ExecutionStatus.COMPLETED.value
log_entry["completed_at"] = datetime.now(timezone.utc)
output_key = f"node_{node.id}_output"
node_output = self.state.get(output_key)
yield {
"type": "workflow_step",
"node_id": node.id,
"node_type": node.type.value,
"node_title": node.title,
"status": "completed",
"state_snapshot": dict(self.state),
"output": node_output,
}
except Exception as e:
logger.error(f"Error executing node {node.id}: {e}", exc_info=True)
log_entry["status"] = ExecutionStatus.FAILED.value
log_entry["error"] = str(e)
log_entry["completed_at"] = datetime.now(timezone.utc)
log_entry["state_snapshot"] = dict(self.state)
self.execution_log.append(log_entry)
user_friendly_error = sanitize_api_error(e)
yield {
"type": "workflow_step",
"node_id": node.id,
"node_type": node.type.value,
"node_title": node.title,
"status": "failed",
"state_snapshot": dict(self.state),
"error": user_friendly_error,
}
yield {"type": "error", "error": user_friendly_error}
break
log_entry["state_snapshot"] = dict(self.state)
self.execution_log.append(log_entry)
if node.type == NodeType.END:
break
current_node_id = self._get_next_node_id(current_node_id)
if current_node_id is None and node.type != NodeType.END:
logger.warning(
f"Branch ended at node '{node.title}' ({node.id}) without reaching an end node"
)
steps += 1
if steps >= self.MAX_EXECUTION_STEPS:
logger.warning(
f"Workflow reached max steps limit ({self.MAX_EXECUTION_STEPS})"
)
def _initialize_state(self, initial_inputs: WorkflowState, query: str) -> None:
self.state.update(initial_inputs)
self.state["query"] = query
self.state["chat_history"] = str(self.agent.chat_history)
def _create_log_entry(self, node: WorkflowNode) -> Dict[str, Any]:
return {
"node_id": node.id,
"node_type": node.type.value,
"started_at": datetime.now(timezone.utc),
"completed_at": None,
"status": ExecutionStatus.RUNNING.value,
"error": None,
"state_snapshot": {},
}
def _get_next_node_id(self, current_node_id: str) -> Optional[str]:
node = self.graph.get_node_by_id(current_node_id)
edges = self.graph.get_outgoing_edges(current_node_id)
if not edges:
return None
if node and node.type == NodeType.CONDITION and self._condition_result:
target_handle = self._condition_result
self._condition_result = None
for edge in edges:
if edge.source_handle == target_handle:
return edge.target_id
return None
return edges[0].target_id
def _execute_node(
self, node: WorkflowNode
) -> Generator[Dict[str, str], None, None]:
logger.info(f"Executing node {node.id} ({node.type.value})")
node_handlers = {
NodeType.START: self._execute_start_node,
NodeType.NOTE: self._execute_note_node,
NodeType.AGENT: self._execute_agent_node,
NodeType.STATE: self._execute_state_node,
NodeType.CONDITION: self._execute_condition_node,
NodeType.END: self._execute_end_node,
}
handler = node_handlers.get(node.type)
if handler:
yield from handler(node)
def _execute_start_node(
self, node: WorkflowNode
) -> Generator[Dict[str, str], None, None]:
yield from ()
def _execute_note_node(
self, node: WorkflowNode
) -> Generator[Dict[str, str], None, None]:
yield from ()
def _execute_agent_node(
self, node: WorkflowNode
) -> Generator[Dict[str, str], None, None]:
from application.core.model_utils import (
get_api_key_for_provider,
get_model_capabilities,
get_provider_from_model_id,
)
node_config = AgentNodeConfig(**node.config.get("config", node.config))
if node_config.sources:
self._retrieve_node_sources(node_config)
if node_config.prompt_template:
formatted_prompt = self._format_template(node_config.prompt_template)
else:
formatted_prompt = self.state.get("query", "")
node_json_schema = self._normalize_node_json_schema(
node_config.json_schema, node.title
)
node_model_id = node_config.model_id or self.agent.model_id
# Inherit BYOM scope from parent agent so owner-stored BYOM
# resolves on shared workflows.
node_user_id = getattr(self.agent, "model_user_id", None) or (
self.agent.decoded_token.get("sub")
if isinstance(self.agent.decoded_token, dict)
else None
)
node_llm_name = (
node_config.llm_name
or get_provider_from_model_id(
node_model_id or "", user_id=node_user_id
)
or self.agent.llm_name
)
node_api_key = get_api_key_for_provider(node_llm_name) or self.agent.api_key
if node_json_schema and node_model_id:
model_capabilities = get_model_capabilities(
node_model_id, user_id=node_user_id
)
if model_capabilities and not model_capabilities.get(
"supports_structured_output", False
):
raise ValueError(
f'Model "{node_model_id}" does not support structured output for node "{node.title}"'
)
factory_kwargs = {
"agent_type": node_config.agent_type,
"endpoint": self.agent.endpoint,
"llm_name": node_llm_name,
"model_id": node_model_id,
"model_user_id": getattr(self.agent, "model_user_id", None),
"api_key": node_api_key,
"tool_ids": node_config.tools,
"prompt": node_config.system_prompt,
"chat_history": self.agent.chat_history,
"decoded_token": self.agent.decoded_token,
"json_schema": node_json_schema,
}
# Agentic/research agents need retriever_config for on-demand search
if node_config.agent_type in (AgentType.AGENTIC, AgentType.RESEARCH):
factory_kwargs["retriever_config"] = {
"source": {"active_docs": node_config.sources} if node_config.sources else {},
"retriever_name": node_config.retriever or "classic",
"chunks": int(node_config.chunks) if node_config.chunks else 2,
"model_id": node_model_id,
"llm_name": node_llm_name,
"api_key": node_api_key,
"decoded_token": self.agent.decoded_token,
}
node_agent = WorkflowNodeAgentFactory.create(**factory_kwargs)
full_response_parts: List[str] = []
structured_response_parts: List[str] = []
has_structured_response = False
first_chunk = True
for event in node_agent.gen(formatted_prompt):
if "answer" in event:
chunk = str(event["answer"])
full_response_parts.append(chunk)
if event.get("structured"):
has_structured_response = True
structured_response_parts.append(chunk)
if node_config.stream_to_user:
if first_chunk and hasattr(self, "_has_streamed"):
yield {"answer": "\n\n"}
first_chunk = False
yield event
if node_config.stream_to_user:
self._has_streamed = True
full_response = "".join(full_response_parts).strip()
output_value: Any = full_response
if has_structured_response:
structured_response = "".join(structured_response_parts).strip()
response_to_parse = structured_response or full_response
parsed_success, parsed_structured = self._parse_structured_output(
response_to_parse
)
output_value = parsed_structured if parsed_success else response_to_parse
if node_json_schema:
self._validate_structured_output(node_json_schema, output_value)
elif node_json_schema:
parsed_success, parsed_structured = self._parse_structured_output(
full_response
)
if not parsed_success:
raise ValueError(
"Structured output was expected but response was not valid JSON"
)
output_value = parsed_structured
self._validate_structured_output(node_json_schema, output_value)
default_output_key = f"node_{node.id}_output"
self.state[default_output_key] = output_value
if node_config.output_variable:
self.state[node_config.output_variable] = output_value
def _execute_state_node(
self, node: WorkflowNode
) -> Generator[Dict[str, str], None, None]:
config = node.config.get("config", node.config)
for op in config.get("operations", []):
expression = op.get("expression", "")
target_variable = op.get("target_variable", "")
if expression and target_variable:
self.state[target_variable] = evaluate_cel(expression, self.state)
yield from ()
def _execute_condition_node(
self, node: WorkflowNode
) -> Generator[Dict[str, str], None, None]:
config = ConditionNodeConfig(**node.config.get("config", node.config))
matched_handle = None
for case in config.cases:
if not case.expression.strip():
continue
try:
if evaluate_cel(case.expression, self.state):
matched_handle = case.source_handle
break
except CelEvaluationError:
continue
self._condition_result = matched_handle or "else"
yield from ()
def _execute_end_node(
self, node: WorkflowNode
) -> Generator[Dict[str, str], None, None]:
config = node.config.get("config", node.config)
output_template = str(config.get("output_template", ""))
if output_template:
formatted_output = self._format_template(output_template)
yield {"answer": formatted_output}
def _parse_structured_output(self, raw_response: str) -> tuple[bool, Optional[Any]]:
normalized_response = raw_response.strip()
if not normalized_response:
return False, None
try:
return True, json.loads(normalized_response)
except json.JSONDecodeError:
logger.warning(
"Workflow agent returned structured output that was not valid JSON"
)
return False, None
def _normalize_node_json_schema(
self, schema: Optional[Dict[str, Any]], node_title: str
) -> Optional[Dict[str, Any]]:
if schema is None:
return None
try:
return normalize_json_schema_payload(schema)
except JsonSchemaValidationError as exc:
raise ValueError(
f'Invalid JSON schema for node "{node_title}": {exc}'
) from exc
def _validate_structured_output(self, schema: Dict[str, Any], output_value: Any) -> None:
if jsonschema is None:
logger.warning(
"jsonschema package is not available, skipping structured output validation"
)
return
try:
normalized_schema = normalize_json_schema_payload(schema)
except JsonSchemaValidationError as exc:
raise ValueError(f"Invalid JSON schema: {exc}") from exc
try:
jsonschema.validate(instance=output_value, schema=normalized_schema)
except jsonschema.exceptions.ValidationError as exc:
raise ValueError(f"Structured output did not match schema: {exc.message}") from exc
except jsonschema.exceptions.SchemaError as exc:
raise ValueError(f"Invalid JSON schema: {exc.message}") from exc
def _format_template(self, template: str) -> str:
context = self._build_template_context()
try:
return self._template_engine.render(template, context)
except TemplateRenderError as e:
logger.warning(
"Workflow template rendering failed, using raw template: %s", str(e)
)
return template
def _build_template_context(self) -> Dict[str, Any]:
docs, docs_together = self._get_source_template_data()
passthrough_data = (
self.state.get("passthrough")
if isinstance(self.state.get("passthrough"), dict)
else None
)
tools_data = (
self.state.get("tools") if isinstance(self.state.get("tools"), dict) else None
)
context = self._namespace_manager.build_context(
user_id=getattr(self.agent, "user", None),
request_id=getattr(self.agent, "request_id", None),
passthrough_data=passthrough_data,
docs=docs,
docs_together=docs_together,
tools_data=tools_data,
)
agent_context: Dict[str, Any] = {}
for key, value in self.state.items():
if not isinstance(key, str):
continue
normalized_key = key.strip()
if not normalized_key:
continue
agent_context[normalized_key] = value
context["agent"] = agent_context
# Keep legacy top-level variables working while namespaced variables are adopted.
for key, value in agent_context.items():
if key in TEMPLATE_RESERVED_NAMESPACES:
context[f"agent_{key}"] = value
continue
if key not in context:
context[key] = value
return context
def _get_source_template_data(self) -> tuple[Optional[List[Dict[str, Any]]], Optional[str]]:
docs = getattr(self.agent, "retrieved_docs", None)
if not isinstance(docs, list) or len(docs) == 0:
return None, None
docs_together_parts: List[str] = []
for doc in docs:
if not isinstance(doc, dict):
continue
text = doc.get("text")
if not isinstance(text, str):
continue
filename = doc.get("filename") or doc.get("title") or doc.get("source")
if isinstance(filename, str) and filename.strip():
docs_together_parts.append(f"{filename}\n{text}")
else:
docs_together_parts.append(text)
docs_together = "\n\n".join(docs_together_parts) if docs_together_parts else None
return docs, docs_together
def _retrieve_node_sources(self, node_config: AgentNodeConfig) -> None:
"""Retrieve documents from the node's sources for template resolution."""
from application.retriever.retriever_creator import RetrieverCreator
query = self.state.get("query", "")
if not query:
return
try:
retriever = RetrieverCreator.create_retriever(
node_config.retriever or "classic",
source={"active_docs": node_config.sources},
chat_history=[],
prompt="",
chunks=int(node_config.chunks) if node_config.chunks else 2,
decoded_token=self.agent.decoded_token,
)
docs = retriever.search(query)
if docs:
self.agent.retrieved_docs = docs
except Exception:
logger.exception("Failed to retrieve docs for workflow node")
def get_execution_summary(self) -> List[NodeExecutionLog]:
return [
NodeExecutionLog(
node_id=log["node_id"],
node_type=log["node_type"],
status=ExecutionStatus(log["status"]),
started_at=log["started_at"],
completed_at=log.get("completed_at"),
error=log.get("error"),
state_snapshot=log.get("state_snapshot", {}),
)
for log in self.execution_log
]

View File

@@ -1,52 +0,0 @@
# Alembic configuration for the DocsGPT user-data Postgres database.
#
# The SQLAlchemy URL is deliberately NOT set here — env.py reads it from
# ``application.core.settings.settings.POSTGRES_URI`` so the same config
# source serves the running app and migrations. To run from the project
# root::
#
# alembic -c application/alembic.ini upgrade head
[alembic]
script_location = %(here)s/alembic
prepend_sys_path = ..
version_path_separator = os
# sqlalchemy.url is intentionally left blank — env.py supplies it.
sqlalchemy.url =
[post_write_hooks]
[loggers]
keys = root,sqlalchemy,alembic
[handlers]
keys = console
[formatters]
keys = generic
[logger_root]
level = WARNING
handlers = console
qualname =
[logger_sqlalchemy]
level = WARNING
handlers =
qualname = sqlalchemy.engine
[logger_alembic]
level = INFO
handlers =
qualname = alembic
[handler_console]
class = StreamHandler
args = (sys.stderr,)
level = NOTSET
formatter = generic
[formatter_generic]
format = %(levelname)-5.5s [%(name)s] %(message)s
datefmt = %H:%M:%S

View File

@@ -1,82 +0,0 @@
"""Alembic environment for the DocsGPT user-data Postgres database.
The URL is pulled from ``application.core.settings`` rather than
``alembic.ini`` so that a single ``POSTGRES_URI`` env var drives both the
running app and ``alembic`` CLI invocations.
"""
import sys
from logging.config import fileConfig
from pathlib import Path
# Make the project root importable regardless of cwd. env.py lives at
# <repo>/application/alembic/env.py, so parents[2] is the repo root.
_PROJECT_ROOT = Path(__file__).resolve().parents[2]
if str(_PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(_PROJECT_ROOT))
from alembic import context # noqa: E402
from sqlalchemy import engine_from_config, pool # noqa: E402
from application.core.settings import settings # noqa: E402
from application.storage.db.models import metadata as target_metadata # noqa: E402
config = context.config
# Populate the runtime URL from settings.
if settings.POSTGRES_URI:
config.set_main_option("sqlalchemy.url", settings.POSTGRES_URI)
if config.config_file_name is not None:
fileConfig(config.config_file_name)
def run_migrations_offline() -> None:
"""Run migrations in 'offline' mode (emits SQL without a live DB)."""
url = config.get_main_option("sqlalchemy.url")
if not url:
raise RuntimeError(
"POSTGRES_URI is not configured. Set it in your .env to a "
"psycopg3 URI such as "
"'postgresql+psycopg://user:pass@host:5432/docsgpt'."
)
context.configure(
url=url,
target_metadata=target_metadata,
literal_binds=True,
dialect_opts={"paramstyle": "named"},
compare_type=True,
)
with context.begin_transaction():
context.run_migrations()
def run_migrations_online() -> None:
"""Run migrations in 'online' mode against a live connection."""
if not config.get_main_option("sqlalchemy.url"):
raise RuntimeError(
"POSTGRES_URI is not configured. Set it in your .env to a "
"psycopg3 URI such as "
"'postgresql+psycopg://user:pass@host:5432/docsgpt'."
)
connectable = engine_from_config(
config.get_section(config.config_ini_section, {}),
prefix="sqlalchemy.",
poolclass=pool.NullPool,
future=True,
)
with connectable.connect() as connection:
context.configure(
connection=connection,
target_metadata=target_metadata,
compare_type=True,
)
with context.begin_transaction():
context.run_migrations()
if context.is_offline_mode():
run_migrations_offline()
else:
run_migrations_online()

View File

@@ -1,26 +0,0 @@
"""${message}
Revision ID: ${up_revision}
Revises: ${down_revision | comma,n}
Create Date: ${create_date}
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
${imports if imports else ""}
# revision identifiers, used by Alembic.
revision: str = ${repr(up_revision)}
down_revision: Union[str, None] = ${repr(down_revision)}
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
def upgrade() -> None:
${upgrades if upgrades else "pass"}
def downgrade() -> None:
${downgrades if downgrades else "pass"}

View File

@@ -1,927 +0,0 @@
"""0001 initial schema — consolidated Phase-1..3 baseline.
Revision ID: 0001_initial
Revises:
Create Date: 2026-04-13
"""
from typing import Sequence, Union
from alembic import op
revision: str = "0001_initial"
down_revision: Union[str, None] = None
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ------------------------------------------------------------------
# Extensions
# ------------------------------------------------------------------
op.execute('CREATE EXTENSION IF NOT EXISTS "pgcrypto";')
op.execute('CREATE EXTENSION IF NOT EXISTS "citext";')
# ------------------------------------------------------------------
# Trigger functions
# ------------------------------------------------------------------
op.execute(
"""
CREATE FUNCTION set_updated_at() RETURNS trigger
LANGUAGE plpgsql AS $$
BEGIN
NEW.updated_at = now();
RETURN NEW;
END;
$$;
"""
)
op.execute(
"""
CREATE FUNCTION ensure_user_exists() RETURNS trigger
LANGUAGE plpgsql AS $$
BEGIN
IF NEW.user_id IS NOT NULL THEN
INSERT INTO users (user_id) VALUES (NEW.user_id)
ON CONFLICT (user_id) DO NOTHING;
END IF;
RETURN NEW;
END;
$$;
"""
)
op.execute(
"""
CREATE FUNCTION cleanup_message_attachment_refs() RETURNS trigger
LANGUAGE plpgsql AS $$
BEGIN
UPDATE conversation_messages
SET attachments = array_remove(attachments, OLD.id)
WHERE OLD.id = ANY(attachments);
RETURN OLD;
END;
$$;
"""
)
op.execute(
"""
CREATE FUNCTION cleanup_agent_extra_source_refs() RETURNS trigger
LANGUAGE plpgsql AS $$
BEGIN
UPDATE agents
SET extra_source_ids = array_remove(extra_source_ids, OLD.id)
WHERE OLD.id = ANY(extra_source_ids);
RETURN OLD;
END;
$$;
"""
)
op.execute(
"""
CREATE FUNCTION cleanup_user_agent_prefs() RETURNS trigger
LANGUAGE plpgsql AS $$
DECLARE
agent_id_text text := OLD.id::text;
BEGIN
UPDATE users
SET agent_preferences = jsonb_set(
jsonb_set(
agent_preferences,
'{pinned}',
COALESCE((
SELECT jsonb_agg(e)
FROM jsonb_array_elements(
COALESCE(agent_preferences->'pinned', '[]'::jsonb)
) e
WHERE (e #>> '{}') <> agent_id_text
), '[]'::jsonb)
),
'{shared_with_me}',
COALESCE((
SELECT jsonb_agg(e)
FROM jsonb_array_elements(
COALESCE(agent_preferences->'shared_with_me', '[]'::jsonb)
) e
WHERE (e #>> '{}') <> agent_id_text
), '[]'::jsonb)
)
WHERE agent_preferences->'pinned' @> to_jsonb(agent_id_text)
OR agent_preferences->'shared_with_me' @> to_jsonb(agent_id_text);
RETURN OLD;
END;
$$;
"""
)
op.execute(
"""
CREATE FUNCTION conversation_messages_fill_user_id() RETURNS trigger
LANGUAGE plpgsql AS $$
BEGIN
IF NEW.user_id IS NULL THEN
SELECT user_id INTO NEW.user_id
FROM conversations
WHERE id = NEW.conversation_id;
END IF;
RETURN NEW;
END;
$$;
"""
)
# ------------------------------------------------------------------
# Tables
# ------------------------------------------------------------------
op.execute(
"""
CREATE TABLE users (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
user_id TEXT NOT NULL UNIQUE,
agent_preferences JSONB NOT NULL
DEFAULT '{"pinned": [], "shared_with_me": []}'::jsonb,
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
);
"""
)
op.execute(
"""
CREATE TABLE prompts (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
user_id TEXT NOT NULL,
name TEXT NOT NULL,
content TEXT NOT NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT now(),
legacy_mongo_id TEXT
);
"""
)
op.execute(
"""
CREATE TABLE user_tools (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
user_id TEXT NOT NULL,
name TEXT NOT NULL,
custom_name TEXT,
display_name TEXT,
description TEXT,
config JSONB NOT NULL DEFAULT '{}'::jsonb,
config_requirements JSONB NOT NULL DEFAULT '{}'::jsonb,
actions JSONB NOT NULL DEFAULT '[]'::jsonb,
status BOOLEAN NOT NULL DEFAULT true,
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT now(),
legacy_mongo_id TEXT
);
"""
)
op.execute(
"""
CREATE TABLE token_usage (
id BIGSERIAL PRIMARY KEY,
user_id TEXT,
api_key TEXT,
agent_id UUID,
prompt_tokens INTEGER NOT NULL DEFAULT 0,
generated_tokens INTEGER NOT NULL DEFAULT 0,
timestamp TIMESTAMPTZ NOT NULL DEFAULT now(),
mongo_id TEXT
);
"""
)
op.execute(
"ALTER TABLE token_usage ADD CONSTRAINT token_usage_attribution_chk "
"CHECK (user_id IS NOT NULL OR api_key IS NOT NULL) NOT VALID;"
)
op.execute(
"""
CREATE TABLE user_logs (
id BIGSERIAL PRIMARY KEY,
user_id TEXT,
endpoint TEXT,
timestamp TIMESTAMPTZ NOT NULL DEFAULT now(),
data JSONB,
mongo_id TEXT
);
"""
)
op.execute(
"""
CREATE TABLE stack_logs (
id BIGSERIAL PRIMARY KEY,
activity_id TEXT NOT NULL,
endpoint TEXT,
level TEXT,
user_id TEXT,
api_key TEXT,
query TEXT,
stacks JSONB NOT NULL DEFAULT '[]'::jsonb,
timestamp TIMESTAMPTZ NOT NULL DEFAULT now(),
mongo_id TEXT
);
"""
)
op.execute(
"""
CREATE TABLE agent_folders (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
user_id TEXT NOT NULL,
name TEXT NOT NULL,
description TEXT,
parent_id UUID REFERENCES agent_folders(id) ON DELETE SET NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT now(),
legacy_mongo_id TEXT
);
"""
)
op.execute(
"""
CREATE TABLE sources (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
user_id TEXT NOT NULL,
name TEXT NOT NULL,
language TEXT,
date TIMESTAMPTZ NOT NULL DEFAULT now(),
model TEXT,
type TEXT,
metadata JSONB NOT NULL DEFAULT '{}'::jsonb,
retriever TEXT,
sync_frequency TEXT,
tokens TEXT,
file_path TEXT,
remote_data JSONB,
directory_structure JSONB,
file_name_map JSONB,
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT now(),
legacy_mongo_id TEXT
);
"""
)
op.execute(
"""
CREATE TABLE agents (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
user_id TEXT NOT NULL,
name TEXT NOT NULL,
description TEXT,
agent_type TEXT,
status TEXT NOT NULL,
key CITEXT UNIQUE,
image TEXT,
source_id UUID REFERENCES sources(id) ON DELETE SET NULL,
extra_source_ids UUID[] NOT NULL DEFAULT '{}',
chunks INTEGER,
retriever TEXT,
prompt_id UUID REFERENCES prompts(id) ON DELETE SET NULL,
tools JSONB NOT NULL DEFAULT '[]'::jsonb,
json_schema JSONB,
models JSONB,
default_model_id TEXT,
folder_id UUID REFERENCES agent_folders(id) ON DELETE SET NULL,
workflow_id UUID,
limited_token_mode BOOLEAN NOT NULL DEFAULT false,
token_limit INTEGER,
limited_request_mode BOOLEAN NOT NULL DEFAULT false,
request_limit INTEGER,
allow_system_prompt_override BOOLEAN NOT NULL DEFAULT false,
shared BOOLEAN NOT NULL DEFAULT false,
shared_token CITEXT UNIQUE,
shared_metadata JSONB,
incoming_webhook_token CITEXT UNIQUE,
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT now(),
last_used_at TIMESTAMPTZ,
legacy_mongo_id TEXT
);
"""
)
op.execute(
"ALTER TABLE token_usage ADD CONSTRAINT token_usage_agent_fk "
"FOREIGN KEY (agent_id) REFERENCES agents(id) ON DELETE SET NULL;"
)
op.execute(
"""
CREATE TABLE attachments (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
user_id TEXT NOT NULL,
filename TEXT NOT NULL,
upload_path TEXT NOT NULL,
mime_type TEXT,
size BIGINT,
content TEXT,
token_count INTEGER,
openai_file_id TEXT,
google_file_uri TEXT,
metadata JSONB,
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
legacy_mongo_id TEXT
);
"""
)
op.execute(
"""
CREATE TABLE memories (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
user_id TEXT NOT NULL,
tool_id UUID REFERENCES user_tools(id) ON DELETE CASCADE,
path TEXT NOT NULL,
content TEXT NOT NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
);
"""
)
op.execute(
"""
CREATE TABLE todos (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
user_id TEXT NOT NULL,
tool_id UUID REFERENCES user_tools(id) ON DELETE CASCADE,
todo_id INTEGER,
title TEXT NOT NULL,
completed BOOLEAN NOT NULL DEFAULT false,
metadata JSONB NOT NULL DEFAULT '{}'::jsonb,
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT now(),
legacy_mongo_id TEXT
);
"""
)
op.execute(
"""
CREATE TABLE notes (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
user_id TEXT NOT NULL,
tool_id UUID REFERENCES user_tools(id) ON DELETE CASCADE,
title TEXT NOT NULL,
content TEXT NOT NULL,
metadata JSONB NOT NULL DEFAULT '{}'::jsonb,
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT now(),
legacy_mongo_id TEXT
);
"""
)
op.execute(
"""
CREATE TABLE connector_sessions (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
user_id TEXT NOT NULL,
provider TEXT NOT NULL,
server_url TEXT,
session_token TEXT UNIQUE,
user_email TEXT,
status TEXT,
token_info JSONB,
session_data JSONB NOT NULL DEFAULT '{}'::jsonb,
expires_at TIMESTAMPTZ,
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
legacy_mongo_id TEXT
);
"""
)
op.execute(
"""
CREATE TABLE conversations (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
user_id TEXT NOT NULL,
agent_id UUID REFERENCES agents(id) ON DELETE SET NULL,
name TEXT,
api_key TEXT,
is_shared_usage BOOLEAN NOT NULL DEFAULT false,
shared_token TEXT,
date TIMESTAMPTZ NOT NULL DEFAULT now(),
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT now(),
shared_with TEXT[] NOT NULL DEFAULT '{}'::text[],
compression_metadata JSONB,
legacy_mongo_id TEXT,
CONSTRAINT conversations_api_key_nonempty_chk
CHECK (api_key IS NULL OR api_key <> '')
);
"""
)
op.execute(
"""
CREATE TABLE conversation_messages (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
conversation_id UUID NOT NULL REFERENCES conversations(id) ON DELETE CASCADE,
position INTEGER NOT NULL,
prompt TEXT,
response TEXT,
thought TEXT,
sources JSONB NOT NULL DEFAULT '[]'::jsonb,
tool_calls JSONB NOT NULL DEFAULT '[]'::jsonb,
attachments UUID[] NOT NULL DEFAULT '{}'::uuid[],
model_id TEXT,
message_metadata JSONB NOT NULL DEFAULT '{}'::jsonb,
feedback JSONB,
timestamp TIMESTAMPTZ NOT NULL DEFAULT now(),
user_id TEXT NOT NULL,
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
);
"""
)
op.execute(
"""
CREATE TABLE shared_conversations (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
conversation_id UUID NOT NULL REFERENCES conversations(id) ON DELETE CASCADE,
user_id TEXT NOT NULL,
is_promptable BOOLEAN NOT NULL DEFAULT false,
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
uuid UUID NOT NULL,
first_n_queries INTEGER NOT NULL DEFAULT 0,
api_key TEXT,
prompt_id UUID REFERENCES prompts(id) ON DELETE SET NULL,
chunks INTEGER,
CONSTRAINT shared_conversations_api_key_nonempty_chk
CHECK (api_key IS NULL OR api_key <> '')
);
"""
)
op.execute(
"""
CREATE TABLE pending_tool_state (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
conversation_id UUID NOT NULL REFERENCES conversations(id) ON DELETE CASCADE,
user_id TEXT NOT NULL,
messages JSONB NOT NULL,
pending_tool_calls JSONB NOT NULL,
tools_dict JSONB NOT NULL,
tool_schemas JSONB NOT NULL,
agent_config JSONB NOT NULL,
client_tools JSONB,
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
expires_at TIMESTAMPTZ NOT NULL
);
"""
)
op.execute(
"""
CREATE TABLE workflows (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
user_id TEXT NOT NULL,
name TEXT NOT NULL,
description TEXT,
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT now(),
current_graph_version INTEGER NOT NULL DEFAULT 1,
legacy_mongo_id TEXT
);
"""
)
# Backfill the agents.workflow_id FK now that workflows exists.
# The column was created without a FK (forward reference to a table
# that hadn't been declared yet); add the constraint here so workflow
# deletion still cascades through to agent unset.
op.execute(
"ALTER TABLE agents ADD CONSTRAINT agents_workflow_fk "
"FOREIGN KEY (workflow_id) REFERENCES workflows(id) ON DELETE SET NULL;"
)
op.execute(
"""
CREATE TABLE workflow_nodes (
id UUID DEFAULT gen_random_uuid() NOT NULL,
workflow_id UUID NOT NULL REFERENCES workflows(id) ON DELETE CASCADE,
graph_version INTEGER NOT NULL,
node_type TEXT NOT NULL,
config JSONB NOT NULL DEFAULT '{}'::jsonb,
node_id TEXT NOT NULL,
title TEXT,
description TEXT,
position JSONB NOT NULL DEFAULT '{"x": 0, "y": 0}'::jsonb,
legacy_mongo_id TEXT,
PRIMARY KEY (id),
CONSTRAINT workflow_nodes_id_wf_ver_key
UNIQUE (id, workflow_id, graph_version)
);
"""
)
op.execute(
"""
CREATE TABLE workflow_edges (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
workflow_id UUID NOT NULL REFERENCES workflows(id) ON DELETE CASCADE,
graph_version INTEGER NOT NULL,
from_node_id UUID NOT NULL,
to_node_id UUID NOT NULL,
config JSONB NOT NULL DEFAULT '{}'::jsonb,
edge_id TEXT NOT NULL,
source_handle TEXT,
target_handle TEXT,
CONSTRAINT workflow_edges_from_node_fk
FOREIGN KEY (from_node_id, workflow_id, graph_version)
REFERENCES workflow_nodes(id, workflow_id, graph_version) ON DELETE CASCADE,
CONSTRAINT workflow_edges_to_node_fk
FOREIGN KEY (to_node_id, workflow_id, graph_version)
REFERENCES workflow_nodes(id, workflow_id, graph_version) ON DELETE CASCADE
);
"""
)
op.execute(
"""
CREATE TABLE workflow_runs (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
workflow_id UUID NOT NULL REFERENCES workflows(id) ON DELETE CASCADE,
user_id TEXT NOT NULL,
status TEXT NOT NULL,
started_at TIMESTAMPTZ NOT NULL DEFAULT now(),
ended_at TIMESTAMPTZ,
result JSONB,
inputs JSONB,
steps JSONB NOT NULL DEFAULT '[]'::jsonb,
legacy_mongo_id TEXT,
CONSTRAINT workflow_runs_status_chk
CHECK (status IN ('pending', 'running', 'completed', 'failed'))
);
"""
)
# ------------------------------------------------------------------
# Indexes
# ------------------------------------------------------------------
op.execute("CREATE INDEX agent_folders_user_idx ON agent_folders (user_id);")
op.execute("CREATE INDEX agents_user_idx ON agents (user_id);")
op.execute("CREATE INDEX agents_shared_idx ON agents (shared) WHERE shared = true;")
op.execute("CREATE INDEX agents_status_idx ON agents (status);")
op.execute("CREATE INDEX agents_source_id_idx ON agents (source_id);")
op.execute("CREATE INDEX agents_prompt_id_idx ON agents (prompt_id);")
op.execute("CREATE INDEX agents_folder_id_idx ON agents (folder_id);")
op.execute(
"CREATE UNIQUE INDEX agents_legacy_mongo_id_uidx "
"ON agents (legacy_mongo_id) WHERE legacy_mongo_id IS NOT NULL;"
)
op.execute("CREATE INDEX attachments_user_idx ON attachments (user_id);")
op.execute(
"CREATE UNIQUE INDEX attachments_legacy_mongo_id_uidx "
"ON attachments (legacy_mongo_id) WHERE legacy_mongo_id IS NOT NULL;"
)
op.execute(
# MCP and OAuth connectors share the ``provider`` slot, so the
# dedup key is ``(user_id, server_url, provider)``: MCP rows
# differentiate by server_url (one per MCP server), OAuth rows
# have server_url = NULL and differentiate by provider alone.
# COALESCE lets NULL server_url participate in the constraint.
"CREATE UNIQUE INDEX connector_sessions_user_endpoint_uidx "
"ON connector_sessions (user_id, COALESCE(server_url, ''), provider);"
)
op.execute(
"CREATE INDEX connector_sessions_expiry_idx "
"ON connector_sessions (expires_at) WHERE expires_at IS NOT NULL;"
)
op.execute(
"CREATE INDEX connector_sessions_server_url_idx "
"ON connector_sessions (server_url) WHERE server_url IS NOT NULL;"
)
op.execute(
"CREATE UNIQUE INDEX connector_sessions_legacy_mongo_id_uidx "
"ON connector_sessions (legacy_mongo_id) WHERE legacy_mongo_id IS NOT NULL;"
)
op.execute(
"CREATE UNIQUE INDEX conversation_messages_conv_pos_uidx "
"ON conversation_messages (conversation_id, position);"
)
op.execute(
"CREATE INDEX conversation_messages_user_ts_idx "
"ON conversation_messages (user_id, timestamp DESC);"
)
op.execute("CREATE INDEX conversations_user_date_idx ON conversations (user_id, date DESC);")
op.execute("CREATE INDEX conversations_agent_idx ON conversations (agent_id);")
op.execute(
"CREATE UNIQUE INDEX conversations_shared_token_uidx "
"ON conversations (shared_token) WHERE shared_token IS NOT NULL;"
)
op.execute(
"CREATE INDEX conversations_api_key_date_idx "
"ON conversations (api_key, date DESC) WHERE api_key IS NOT NULL;"
)
op.execute(
"CREATE UNIQUE INDEX conversations_legacy_mongo_id_uidx "
"ON conversations (legacy_mongo_id) WHERE legacy_mongo_id IS NOT NULL;"
)
op.execute(
"CREATE UNIQUE INDEX memories_user_tool_path_uidx "
"ON memories (user_id, tool_id, path);"
)
op.execute(
"CREATE UNIQUE INDEX memories_user_path_null_tool_uidx "
"ON memories (user_id, path) WHERE tool_id IS NULL;"
)
op.execute(
"CREATE INDEX memories_path_prefix_idx "
"ON memories (user_id, tool_id, path text_pattern_ops);"
)
op.execute("CREATE INDEX memories_tool_id_idx ON memories (tool_id);")
op.execute("CREATE UNIQUE INDEX notes_user_tool_uidx ON notes (user_id, tool_id);")
op.execute("CREATE INDEX notes_tool_id_idx ON notes (tool_id);")
op.execute(
"CREATE UNIQUE INDEX notes_legacy_mongo_id_uidx "
"ON notes (legacy_mongo_id) WHERE legacy_mongo_id IS NOT NULL;"
)
op.execute(
"CREATE UNIQUE INDEX pending_tool_state_conv_user_uidx "
"ON pending_tool_state (conversation_id, user_id);"
)
op.execute(
"CREATE INDEX pending_tool_state_expires_idx ON pending_tool_state (expires_at);"
)
op.execute("CREATE INDEX prompts_user_id_idx ON prompts (user_id);")
op.execute(
"CREATE UNIQUE INDEX prompts_legacy_mongo_id_uidx "
"ON prompts (legacy_mongo_id) WHERE legacy_mongo_id IS NOT NULL;"
)
op.execute("CREATE INDEX shared_conversations_user_idx ON shared_conversations (user_id);")
op.execute("CREATE INDEX shared_conversations_conv_idx ON shared_conversations (conversation_id);")
op.execute(
"CREATE INDEX shared_conversations_prompt_id_idx ON shared_conversations (prompt_id);"
)
op.execute(
"CREATE UNIQUE INDEX shared_conversations_uuid_uidx ON shared_conversations (uuid);"
)
op.execute(
"CREATE UNIQUE INDEX shared_conversations_dedup_uidx "
"ON shared_conversations (conversation_id, user_id, is_promptable, first_n_queries, COALESCE(api_key, ''));"
)
op.execute("CREATE INDEX sources_user_idx ON sources (user_id);")
op.execute(
"CREATE UNIQUE INDEX sources_legacy_mongo_id_uidx "
"ON sources (legacy_mongo_id) WHERE legacy_mongo_id IS NOT NULL;"
)
op.execute(
"CREATE UNIQUE INDEX user_tools_legacy_mongo_id_uidx "
"ON user_tools (legacy_mongo_id) WHERE legacy_mongo_id IS NOT NULL;"
)
op.execute(
"CREATE UNIQUE INDEX agent_folders_legacy_mongo_id_uidx "
"ON agent_folders (legacy_mongo_id) WHERE legacy_mongo_id IS NOT NULL;"
)
op.execute("CREATE INDEX agent_folders_parent_idx ON agent_folders (parent_id);")
op.execute("CREATE INDEX agents_workflow_idx ON agents (workflow_id);")
op.execute('CREATE INDEX stack_logs_timestamp_idx ON stack_logs ("timestamp" DESC);')
op.execute('CREATE INDEX stack_logs_user_ts_idx ON stack_logs (user_id, "timestamp" DESC);')
op.execute('CREATE INDEX stack_logs_level_ts_idx ON stack_logs (level, "timestamp" DESC);')
op.execute("CREATE INDEX stack_logs_activity_idx ON stack_logs (activity_id);")
op.execute(
"CREATE UNIQUE INDEX stack_logs_mongo_id_uidx "
"ON stack_logs (mongo_id) WHERE mongo_id IS NOT NULL;"
)
op.execute("CREATE INDEX todos_user_tool_idx ON todos (user_id, tool_id);")
op.execute("CREATE INDEX todos_tool_id_idx ON todos (tool_id);")
op.execute(
"CREATE UNIQUE INDEX todos_legacy_mongo_id_uidx "
"ON todos (legacy_mongo_id) WHERE legacy_mongo_id IS NOT NULL;"
)
op.execute(
"CREATE UNIQUE INDEX todos_tool_todo_id_uidx "
"ON todos (tool_id, todo_id) WHERE todo_id IS NOT NULL;"
)
op.execute('CREATE INDEX token_usage_user_ts_idx ON token_usage (user_id, "timestamp" DESC);')
op.execute('CREATE INDEX token_usage_key_ts_idx ON token_usage (api_key, "timestamp" DESC);')
op.execute('CREATE INDEX token_usage_agent_ts_idx ON token_usage (agent_id, "timestamp" DESC);')
op.execute(
"CREATE UNIQUE INDEX token_usage_mongo_id_uidx "
"ON token_usage (mongo_id) WHERE mongo_id IS NOT NULL;"
)
op.execute('CREATE INDEX user_logs_user_ts_idx ON user_logs (user_id, "timestamp" DESC);')
op.execute(
"CREATE UNIQUE INDEX user_logs_mongo_id_uidx "
"ON user_logs (mongo_id) WHERE mongo_id IS NOT NULL;"
)
op.execute("CREATE INDEX user_tools_user_id_idx ON user_tools (user_id);")
op.execute("CREATE INDEX workflow_edges_from_node_idx ON workflow_edges (from_node_id);")
op.execute("CREATE INDEX workflow_edges_to_node_idx ON workflow_edges (to_node_id);")
op.execute(
"CREATE UNIQUE INDEX workflow_edges_wf_ver_eid_uidx "
"ON workflow_edges (workflow_id, graph_version, edge_id);"
)
op.execute(
"CREATE UNIQUE INDEX workflow_nodes_wf_ver_nid_uidx "
"ON workflow_nodes (workflow_id, graph_version, node_id);"
)
op.execute(
"CREATE UNIQUE INDEX workflow_nodes_legacy_mongo_id_uidx "
"ON workflow_nodes (legacy_mongo_id) WHERE legacy_mongo_id IS NOT NULL;"
)
op.execute("CREATE INDEX workflow_runs_workflow_idx ON workflow_runs (workflow_id);")
op.execute("CREATE INDEX workflow_runs_user_idx ON workflow_runs (user_id);")
op.execute(
"CREATE INDEX workflow_runs_status_started_idx "
"ON workflow_runs (status, started_at DESC);"
)
op.execute(
"CREATE UNIQUE INDEX workflow_runs_legacy_mongo_id_uidx "
"ON workflow_runs (legacy_mongo_id) WHERE legacy_mongo_id IS NOT NULL;"
)
op.execute("CREATE INDEX workflows_user_idx ON workflows (user_id);")
op.execute(
"CREATE UNIQUE INDEX workflows_legacy_mongo_id_uidx "
"ON workflows (legacy_mongo_id) WHERE legacy_mongo_id IS NOT NULL;"
)
# ------------------------------------------------------------------
# user_id foreign keys (deferrable so backfills can stage rows)
# ------------------------------------------------------------------
user_fk_tables = (
"agent_folders",
"agents",
"attachments",
"connector_sessions",
"conversation_messages",
"conversations",
"memories",
"notes",
"pending_tool_state",
"prompts",
"shared_conversations",
"sources",
"stack_logs",
"todos",
"token_usage",
"user_logs",
"user_tools",
"workflow_runs",
"workflows",
)
for table in user_fk_tables:
op.execute(
f"ALTER TABLE {table} "
f"ADD CONSTRAINT {table}_user_id_fk "
f"FOREIGN KEY (user_id) REFERENCES users(user_id) "
f"ON DELETE RESTRICT DEFERRABLE INITIALLY IMMEDIATE;"
)
# ------------------------------------------------------------------
# Triggers
# ------------------------------------------------------------------
updated_at_tables = (
"agent_folders",
"agents",
"conversation_messages",
"conversations",
"memories",
"notes",
"prompts",
"sources",
"todos",
"user_tools",
"users",
"workflows",
)
for table in updated_at_tables:
op.execute(
f"CREATE TRIGGER {table}_set_updated_at "
f"BEFORE UPDATE ON {table} "
f"FOR EACH ROW WHEN (OLD.* IS DISTINCT FROM NEW.*) "
f"EXECUTE FUNCTION set_updated_at();"
)
ensure_user_tables = (
"agent_folders",
"agents",
"attachments",
"connector_sessions",
"conversation_messages",
"conversations",
"memories",
"notes",
"pending_tool_state",
"prompts",
"shared_conversations",
"sources",
"stack_logs",
"todos",
"token_usage",
"user_logs",
"user_tools",
"workflow_runs",
"workflows",
)
for table in ensure_user_tables:
op.execute(
f"CREATE TRIGGER {table}_ensure_user "
f"BEFORE INSERT OR UPDATE OF user_id ON {table} "
f"FOR EACH ROW EXECUTE FUNCTION ensure_user_exists();"
)
op.execute(
"CREATE TRIGGER conversation_messages_fill_user "
"BEFORE INSERT ON conversation_messages "
"FOR EACH ROW EXECUTE FUNCTION conversation_messages_fill_user_id();"
)
op.execute(
"CREATE TRIGGER attachments_cleanup_message_refs "
"AFTER DELETE ON attachments "
"FOR EACH ROW EXECUTE FUNCTION cleanup_message_attachment_refs();"
)
op.execute(
"CREATE TRIGGER agents_cleanup_user_prefs "
"AFTER DELETE ON agents "
"FOR EACH ROW EXECUTE FUNCTION cleanup_user_agent_prefs();"
)
op.execute(
"CREATE TRIGGER sources_cleanup_agent_extra_refs "
"AFTER DELETE ON sources "
"FOR EACH ROW EXECUTE FUNCTION cleanup_agent_extra_source_refs();"
)
# ------------------------------------------------------------------
# Seed sentinel __system__ user (system/template sources attribute here)
# ------------------------------------------------------------------
op.execute(
"INSERT INTO users (user_id) VALUES ('__system__') "
"ON CONFLICT (user_id) DO NOTHING;"
)
def downgrade() -> None:
# Nuclear downgrade: drop everything this migration created. The
# ordering drops FK-bearing children before parents; CASCADE would
# also work but explicit ordering is easier to reason about in code
# review.
tables_in_drop_order = (
"workflow_edges",
"workflow_runs",
"workflow_nodes",
"workflows",
"pending_tool_state",
"shared_conversations",
"conversation_messages",
"conversations",
"connector_sessions",
"notes",
"todos",
"memories",
"attachments",
"agents",
"sources",
"agent_folders",
"stack_logs",
"user_logs",
"token_usage",
"user_tools",
"prompts",
"users",
)
for table in tables_in_drop_order:
op.execute(f"DROP TABLE IF EXISTS {table} CASCADE;")
for fn in (
"conversation_messages_fill_user_id",
"cleanup_user_agent_prefs",
"cleanup_agent_extra_source_refs",
"cleanup_message_attachment_refs",
"ensure_user_exists",
"set_updated_at",
):
op.execute(f"DROP FUNCTION IF EXISTS {fn}();")

View File

@@ -1,37 +0,0 @@
"""0002 app_metadata — singleton key/value table for instance-wide state.
Used by the startup version-check client to persist the anonymous
instance UUID and a one-shot "notice shown" flag. Both values are tiny
plain-text strings; this is a deliberate generic-config table rather
than dedicated columns so future one-off settings (telemetry opt-in
timestamps, feature-flag overrides, etc.) don't each need their own
migration.
Revision ID: 0002_app_metadata
Revises: 0001_initial
"""
from typing import Sequence, Union
from alembic import op
revision: str = "0002_app_metadata"
down_revision: Union[str, None] = "0001_initial"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.execute(
"""
CREATE TABLE app_metadata (
key TEXT PRIMARY KEY,
value TEXT NOT NULL
);
"""
)
def downgrade() -> None:
op.execute("DROP TABLE IF EXISTS app_metadata;")

View File

@@ -1,65 +0,0 @@
"""0003 user_custom_models — per-user OpenAI-compatible model registrations.
Revision ID: 0003_user_custom_models
Revises: 0002_app_metadata
"""
from typing import Sequence, Union
from alembic import op
revision: str = "0003_user_custom_models"
down_revision: Union[str, None] = "0002_app_metadata"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.execute(
"""
CREATE TABLE user_custom_models (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
user_id TEXT NOT NULL,
upstream_model_id TEXT NOT NULL,
display_name TEXT NOT NULL,
description TEXT NOT NULL DEFAULT '',
base_url TEXT NOT NULL,
api_key_encrypted TEXT NOT NULL,
capabilities JSONB NOT NULL DEFAULT '{}'::jsonb,
enabled BOOLEAN NOT NULL DEFAULT true,
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
);
"""
)
op.execute(
"CREATE INDEX user_custom_models_user_id_idx "
"ON user_custom_models (user_id);"
)
# Mirror the project-wide invariants set up in 0001_initial:
# * user_id FK with ON DELETE RESTRICT (deferrable),
# * ensure_user_exists() trigger so the parent users row autocreates,
# * set_updated_at() trigger.
op.execute(
"ALTER TABLE user_custom_models "
"ADD CONSTRAINT user_custom_models_user_id_fk "
"FOREIGN KEY (user_id) REFERENCES users(user_id) "
"ON DELETE RESTRICT DEFERRABLE INITIALLY IMMEDIATE;"
)
op.execute(
"CREATE TRIGGER user_custom_models_ensure_user "
"BEFORE INSERT OR UPDATE OF user_id ON user_custom_models "
"FOR EACH ROW EXECUTE FUNCTION ensure_user_exists();"
)
op.execute(
"CREATE TRIGGER user_custom_models_set_updated_at "
"BEFORE UPDATE ON user_custom_models "
"FOR EACH ROW WHEN (OLD.* IS DISTINCT FROM NEW.*) "
"EXECUTE FUNCTION set_updated_at();"
)
def downgrade() -> None:
op.execute("DROP TABLE IF EXISTS user_custom_models;")

View File

@@ -1,7 +0,0 @@
from flask_restx import Api
api = Api(
version="1.0",
title="DocsGPT API",
description="API for DocsGPT",
)

View File

@@ -1,21 +0,0 @@
from flask import Blueprint
from application.api import api
from application.api.answer.routes.answer import AnswerResource
from application.api.answer.routes.base import answer_ns
from application.api.answer.routes.search import SearchResource
from application.api.answer.routes.stream import StreamResource
answer = Blueprint("answer", __name__)
api.add_namespace(answer_ns)
def init_answer_routes():
api.add_resource(StreamResource, "/stream")
api.add_resource(AnswerResource, "/api/answer")
api.add_resource(SearchResource, "/api/search")
init_answer_routes()

View File

@@ -0,0 +1,925 @@
import asyncio
import datetime
import json
import logging
import os
import traceback
from bson.dbref import DBRef
from bson.objectid import ObjectId
from flask import Blueprint, make_response, request, Response
from flask_restx import fields, Namespace, Resource
from application.agents.agent_creator import AgentCreator
from application.core.mongo_db import MongoDB
from application.core.settings import settings
from application.error import bad_request
from application.extensions import api
from application.llm.llm_creator import LLMCreator
from application.retriever.retriever_creator import RetrieverCreator
from application.utils import check_required_fields, limit_chat_history
logger = logging.getLogger(__name__)
mongo = MongoDB.get_client()
db = mongo[settings.MONGO_DB_NAME]
conversations_collection = db["conversations"]
sources_collection = db["sources"]
prompts_collection = db["prompts"]
agents_collection = db["agents"]
user_logs_collection = db["user_logs"]
attachments_collection = db["attachments"]
answer = Blueprint("answer", __name__)
answer_ns = Namespace("answer", description="Answer related operations", path="/")
api.add_namespace(answer_ns)
gpt_model = ""
# to have some kind of default behaviour
if settings.LLM_PROVIDER == "openai":
gpt_model = "gpt-4o-mini"
elif settings.LLM_PROVIDER == "anthropic":
gpt_model = "claude-2"
elif settings.LLM_PROVIDER == "groq":
gpt_model = "llama3-8b-8192"
elif settings.LLM_PROVIDER == "novita":
gpt_model = "deepseek/deepseek-r1"
if settings.LLM_NAME: # in case there is particular model name configured
gpt_model = settings.LLM_NAME
# load the prompts
current_dir = os.path.dirname(
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
)
with open(os.path.join(current_dir, "prompts", "chat_combine_default.txt"), "r") as f:
chat_combine_template = f.read()
with open(os.path.join(current_dir, "prompts", "chat_reduce_prompt.txt"), "r") as f:
chat_reduce_template = f.read()
with open(os.path.join(current_dir, "prompts", "chat_combine_creative.txt"), "r") as f:
chat_combine_creative = f.read()
with open(os.path.join(current_dir, "prompts", "chat_combine_strict.txt"), "r") as f:
chat_combine_strict = f.read()
api_key_set = settings.API_KEY is not None
embeddings_key_set = settings.EMBEDDINGS_KEY is not None
async def async_generate(chain, question, chat_history):
result = await chain.arun({"question": question, "chat_history": chat_history})
return result
def run_async_chain(chain, question, chat_history):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
result = {}
try:
answer = loop.run_until_complete(async_generate(chain, question, chat_history))
finally:
loop.close()
result["answer"] = answer
return result
def get_agent_key(agent_id, user_id):
if not agent_id:
return None, False, None
try:
agent = agents_collection.find_one({"_id": ObjectId(agent_id)})
if agent is None:
raise Exception("Agent not found", 404)
is_owner = agent.get("user") == user_id
if is_owner:
agents_collection.update_one(
{"_id": ObjectId(agent_id)},
{"$set": {"lastUsedAt": datetime.datetime.now(datetime.timezone.utc)}},
)
return str(agent["key"]), False, None
is_shared_with_user = agent.get(
"shared_publicly", False
) or user_id in agent.get("shared_with", [])
if is_shared_with_user:
return str(agent["key"]), True, agent.get("shared_token")
raise Exception("Unauthorized access to the agent", 403)
except Exception as e:
logger.error(f"Error in get_agent_key: {str(e)}", exc_info=True)
raise
def get_data_from_api_key(api_key):
data = agents_collection.find_one({"key": api_key})
if not data:
raise Exception("Invalid API Key, please generate a new key", 401)
source = data.get("source")
if isinstance(source, DBRef):
source_doc = db.dereference(source)
data["source"] = str(source_doc["_id"])
data["retriever"] = source_doc.get("retriever", data.get("retriever"))
else:
data["source"] = {}
return data
def get_retriever(source_id: str):
doc = sources_collection.find_one({"_id": ObjectId(source_id)})
if doc is None:
raise Exception("Source document does not exist", 404)
retriever_name = None if "retriever" not in doc else doc["retriever"]
return retriever_name
def is_azure_configured():
return (
settings.OPENAI_API_BASE
and settings.OPENAI_API_VERSION
and settings.AZURE_DEPLOYMENT_NAME
)
def save_conversation(
conversation_id,
question,
response,
thought,
source_log_docs,
tool_calls,
llm,
decoded_token,
index=None,
api_key=None,
agent_id=None,
is_shared_usage=False,
shared_token=None,
attachment_ids=None,
):
current_time = datetime.datetime.now(datetime.timezone.utc)
if conversation_id is not None and index is not None:
conversations_collection.update_one(
{"_id": ObjectId(conversation_id), f"queries.{index}": {"$exists": True}},
{
"$set": {
f"queries.{index}.prompt": question,
f"queries.{index}.response": response,
f"queries.{index}.thought": thought,
f"queries.{index}.sources": source_log_docs,
f"queries.{index}.tool_calls": tool_calls,
f"queries.{index}.timestamp": current_time,
f"queries.{index}.attachments": attachment_ids,
}
},
)
##remove following queries from the array
conversations_collection.update_one(
{"_id": ObjectId(conversation_id), f"queries.{index}": {"$exists": True}},
{"$push": {"queries": {"$each": [], "$slice": index + 1}}},
)
elif conversation_id is not None and conversation_id != "None":
conversations_collection.update_one(
{"_id": ObjectId(conversation_id)},
{
"$push": {
"queries": {
"prompt": question,
"response": response,
"thought": thought,
"sources": source_log_docs,
"tool_calls": tool_calls,
"timestamp": current_time,
"attachments": attachment_ids,
}
}
},
)
else:
# create new conversation
# generate summary
messages_summary = [
{
"role": "assistant",
"content": "Summarise following conversation in no more than 3 "
"words, respond ONLY with the summary, use the same "
"language as the system",
},
{
"role": "user",
"content": "Summarise following conversation in no more than 3 words, "
"respond ONLY with the summary, use the same language as the "
"system \n\nUser: " + question + "\n\n" + "AI: " + response,
},
]
completion = llm.gen(model=gpt_model, messages=messages_summary, max_tokens=30)
conversation_data = {
"user": decoded_token.get("sub"),
"date": datetime.datetime.utcnow(),
"name": completion,
"queries": [
{
"prompt": question,
"response": response,
"thought": thought,
"sources": source_log_docs,
"tool_calls": tool_calls,
"timestamp": current_time,
"attachments": attachment_ids,
}
],
}
if api_key:
if agent_id:
conversation_data["agent_id"] = agent_id
if is_shared_usage:
conversation_data["is_shared_usage"] = is_shared_usage
conversation_data["shared_token"] = shared_token
api_key_doc = agents_collection.find_one({"key": api_key})
if api_key_doc:
conversation_data["api_key"] = api_key_doc["key"]
conversation_id = conversations_collection.insert_one(
conversation_data
).inserted_id
return conversation_id
def get_prompt(prompt_id):
if prompt_id == "default":
prompt = chat_combine_template
elif prompt_id == "creative":
prompt = chat_combine_creative
elif prompt_id == "strict":
prompt = chat_combine_strict
else:
prompt = prompts_collection.find_one({"_id": ObjectId(prompt_id)})["content"]
return prompt
def complete_stream(
question,
agent,
retriever,
conversation_id,
user_api_key,
decoded_token,
isNoneDoc=False,
index=None,
should_save_conversation=True,
attachment_ids=None,
agent_id=None,
is_shared_usage=False,
shared_token=None,
):
try:
response_full, thought, source_log_docs, tool_calls = "", "", [], []
answer = agent.gen(query=question, retriever=retriever)
for line in answer:
if "answer" in line:
response_full += str(line["answer"])
data = json.dumps({"type": "answer", "answer": line["answer"]})
yield f"data: {data}\n\n"
elif "sources" in line:
truncated_sources = []
source_log_docs = line["sources"]
for source in line["sources"]:
truncated_source = source.copy()
if "text" in truncated_source:
truncated_source["text"] = (
truncated_source["text"][:100].strip() + "..."
)
truncated_sources.append(truncated_source)
if len(truncated_sources) > 0:
data = json.dumps({"type": "source", "source": truncated_sources})
yield f"data: {data}\n\n"
elif "tool_calls" in line:
tool_calls = line["tool_calls"]
elif "thought" in line:
thought += line["thought"]
data = json.dumps({"type": "thought", "thought": line["thought"]})
yield f"data: {data}\n\n"
elif "type" in line:
data = json.dumps(line)
yield f"data: {data}\n\n"
if isNoneDoc:
for doc in source_log_docs:
doc["source"] = "None"
llm = LLMCreator.create_llm(
settings.LLM_PROVIDER,
api_key=settings.API_KEY,
user_api_key=user_api_key,
decoded_token=decoded_token,
)
if should_save_conversation:
conversation_id = save_conversation(
conversation_id,
question,
response_full,
thought,
source_log_docs,
tool_calls,
llm,
decoded_token,
index,
api_key=user_api_key,
attachment_ids=attachment_ids,
agent_id=agent_id,
is_shared_usage=is_shared_usage,
shared_token=shared_token,
)
else:
conversation_id = None
# send data.type = "end" to indicate that the stream has ended as json
data = json.dumps({"type": "id", "id": str(conversation_id)})
yield f"data: {data}\n\n"
retriever_params = retriever.get_params()
user_logs_collection.insert_one(
{
"action": "stream_answer",
"level": "info",
"user": decoded_token.get("sub"),
"api_key": user_api_key,
"question": question,
"response": response_full,
"sources": source_log_docs,
"retriever_params": retriever_params,
"attachments": attachment_ids,
"timestamp": datetime.datetime.now(datetime.timezone.utc),
}
)
data = json.dumps({"type": "end"})
yield f"data: {data}\n\n"
except Exception as e:
logger.error(f"Error in stream: {str(e)}", exc_info=True)
data = json.dumps(
{
"type": "error",
"error": "Please try again later. We apologize for any inconvenience.",
}
)
yield f"data: {data}\n\n"
return
@answer_ns.route("/stream")
class Stream(Resource):
stream_model = api.model(
"StreamModel",
{
"question": fields.String(
required=True, description="Question to be asked"
),
"history": fields.List(
fields.String, required=False, description="Chat history"
),
"conversation_id": fields.String(
required=False, description="Conversation ID"
),
"prompt_id": fields.String(
required=False, default="default", description="Prompt ID"
),
"chunks": fields.Integer(
required=False, default=2, description="Number of chunks"
),
"token_limit": fields.Integer(required=False, description="Token limit"),
"retriever": fields.String(required=False, description="Retriever type"),
"api_key": fields.String(required=False, description="API key"),
"active_docs": fields.String(
required=False, description="Active documents"
),
"isNoneDoc": fields.Boolean(
required=False, description="Flag indicating if no document is used"
),
"index": fields.Integer(
required=False, description="Index of the query to update"
),
"save_conversation": fields.Boolean(
required=False,
default=True,
description="Whether to save the conversation",
),
"attachments": fields.List(
fields.String, required=False, description="List of attachment IDs"
),
},
)
@api.expect(stream_model)
@api.doc(description="Stream a response based on the question and retriever")
def post(self):
data = request.get_json()
required_fields = ["question"]
if "index" in data:
required_fields = ["question", "conversation_id"]
missing_fields = check_required_fields(data, required_fields)
if missing_fields:
return missing_fields
save_conv = data.get("save_conversation", True)
try:
question = data["question"]
history = limit_chat_history(
json.loads(data.get("history", "[]")), gpt_model=gpt_model
)
conversation_id = data.get("conversation_id")
prompt_id = data.get("prompt_id", "default")
attachment_ids = data.get("attachments", [])
index = data.get("index", None)
chunks = int(data.get("chunks", 2))
token_limit = data.get("token_limit", settings.DEFAULT_MAX_HISTORY)
retriever_name = data.get("retriever", "classic")
agent_id = data.get("agent_id", None)
agent_type = settings.AGENT_NAME
decoded_token = getattr(request, "decoded_token", None)
user_sub = decoded_token.get("sub") if decoded_token else None
agent_key, is_shared_usage, shared_token = get_agent_key(agent_id, user_sub)
if agent_key:
data.update({"api_key": agent_key})
else:
agent_id = None
if "api_key" in data:
data_key = get_data_from_api_key(data["api_key"])
chunks = int(data_key.get("chunks", 2))
prompt_id = data_key.get("prompt_id", "default")
source = {"active_docs": data_key.get("source")}
retriever_name = data_key.get("retriever", retriever_name)
user_api_key = data["api_key"]
agent_type = data_key.get("agent_type", agent_type)
if is_shared_usage:
decoded_token = request.decoded_token
else:
decoded_token = {"sub": data_key.get("user")}
is_shared_usage = False
elif "active_docs" in data:
source = {"active_docs": data["active_docs"]}
retriever_name = get_retriever(data["active_docs"]) or retriever_name
user_api_key = None
decoded_token = request.decoded_token
else:
source = {}
user_api_key = None
decoded_token = request.decoded_token
if not decoded_token:
return make_response({"error": "Unauthorized"}, 401)
attachments = get_attachments_content(
attachment_ids, decoded_token.get("sub")
)
logger.info(
f"/stream - request_data: {data}, source: {source}, attachments: {len(attachments)}",
extra={"data": json.dumps({"request_data": data, "source": source})},
)
prompt = get_prompt(prompt_id)
if "isNoneDoc" in data and data["isNoneDoc"] is True:
chunks = 0
agent = AgentCreator.create_agent(
agent_type,
endpoint="stream",
llm_name=settings.LLM_PROVIDER,
gpt_model=gpt_model,
api_key=settings.API_KEY,
user_api_key=user_api_key,
prompt=prompt,
chat_history=history,
decoded_token=decoded_token,
attachments=attachments,
)
retriever = RetrieverCreator.create_retriever(
retriever_name,
source=source,
chat_history=history,
prompt=prompt,
chunks=chunks,
token_limit=token_limit,
gpt_model=gpt_model,
user_api_key=user_api_key,
decoded_token=decoded_token,
)
return Response(
complete_stream(
question=question,
agent=agent,
retriever=retriever,
conversation_id=conversation_id,
user_api_key=user_api_key,
decoded_token=decoded_token,
isNoneDoc=data.get("isNoneDoc"),
index=index,
should_save_conversation=save_conv,
attachment_ids=attachment_ids,
agent_id=agent_id,
is_shared_usage=is_shared_usage,
shared_token=shared_token,
),
mimetype="text/event-stream",
)
except ValueError:
message = "Malformed request body"
logger.error(f"/stream - error: {message}")
return Response(
error_stream_generate(message),
status=400,
mimetype="text/event-stream",
)
except Exception as e:
logger.error(
f"/stream - error: {str(e)} - traceback: {traceback.format_exc()}",
extra={"error": str(e), "traceback": traceback.format_exc()},
)
status_code = 400
return Response(
error_stream_generate("Unknown error occurred"),
status=status_code,
mimetype="text/event-stream",
)
def error_stream_generate(err_response):
data = json.dumps({"type": "error", "error": err_response})
yield f"data: {data}\n\n"
@answer_ns.route("/api/answer")
class Answer(Resource):
answer_model = api.model(
"AnswerModel",
{
"question": fields.String(
required=True, description="The question to answer"
),
"history": fields.List(
fields.String, required=False, description="Conversation history"
),
"conversation_id": fields.String(
required=False, description="Conversation ID"
),
"prompt_id": fields.String(
required=False, default="default", description="Prompt ID"
),
"chunks": fields.Integer(
required=False, default=2, description="Number of chunks"
),
"token_limit": fields.Integer(required=False, description="Token limit"),
"retriever": fields.String(required=False, description="Retriever type"),
"api_key": fields.String(required=False, description="API key"),
"active_docs": fields.String(
required=False, description="Active documents"
),
"isNoneDoc": fields.Boolean(
required=False, description="Flag indicating if no document is used"
),
"attachments": fields.List(
fields.String, required=False, description="List of attachment IDs"
),
},
)
@api.expect(answer_model)
@api.doc(description="Provide an answer based on the question and retriever")
def post(self):
data = request.get_json()
required_fields = ["question"]
missing_fields = check_required_fields(data, required_fields)
if missing_fields:
return missing_fields
try:
question = data["question"]
history = limit_chat_history(
json.loads(data.get("history", "[]")), gpt_model=gpt_model
)
conversation_id = data.get("conversation_id")
prompt_id = data.get("prompt_id", "default")
attachment_ids = data.get("attachments", [])
chunks = int(data.get("chunks", 2))
token_limit = data.get("token_limit", settings.DEFAULT_MAX_HISTORY)
retriever_name = data.get("retriever", "classic")
agent_type = settings.AGENT_NAME
if "api_key" in data:
data_key = get_data_from_api_key(data["api_key"])
chunks = int(data_key.get("chunks", 2))
prompt_id = data_key.get("prompt_id", "default")
source = {"active_docs": data_key.get("source")}
retriever_name = data_key.get("retriever", retriever_name)
user_api_key = data["api_key"]
agent_type = data_key.get("agent_type", agent_type)
decoded_token = {"sub": data_key.get("user")}
elif "active_docs" in data:
source = {"active_docs": data["active_docs"]}
retriever_name = get_retriever(data["active_docs"]) or retriever_name
user_api_key = None
decoded_token = request.decoded_token
else:
source = {}
user_api_key = None
decoded_token = request.decoded_token
if not decoded_token:
return make_response({"error": "Unauthorized"}, 401)
attachments = get_attachments_content(
attachment_ids, decoded_token.get("sub")
)
prompt = get_prompt(prompt_id)
logger.info(
f"/api/answer - request_data: {data}, source: {source}, attachments: {len(attachments)}",
extra={"data": json.dumps({"request_data": data, "source": source})},
)
agent = AgentCreator.create_agent(
agent_type,
endpoint="api/answer",
llm_name=settings.LLM_PROVIDER,
gpt_model=gpt_model,
api_key=settings.API_KEY,
user_api_key=user_api_key,
prompt=prompt,
chat_history=history,
decoded_token=decoded_token,
attachments=attachments,
)
retriever = RetrieverCreator.create_retriever(
retriever_name,
source=source,
chat_history=history,
prompt=prompt,
chunks=chunks,
token_limit=token_limit,
gpt_model=gpt_model,
user_api_key=user_api_key,
decoded_token=decoded_token,
)
response_full = ""
source_log_docs = []
tool_calls = []
stream_ended = False
thought = ""
for line in complete_stream(
question=question,
agent=agent,
retriever=retriever,
conversation_id=conversation_id,
user_api_key=user_api_key,
decoded_token=decoded_token,
isNoneDoc=data.get("isNoneDoc"),
index=None,
should_save_conversation=False,
attachment_ids=attachment_ids,
):
try:
event_data = line.replace("data: ", "").strip()
event = json.loads(event_data)
if event["type"] == "answer":
response_full += event["answer"]
elif event["type"] == "source":
source_log_docs = event["source"]
elif event["type"] == "tool_calls":
tool_calls = event["tool_calls"]
elif event["type"] == "thought":
thought = event["thought"]
elif event["type"] == "error":
logger.error(f"Error from stream: {event['error']}")
return bad_request(500, event["error"])
elif event["type"] == "end":
stream_ended = True
except (json.JSONDecodeError, KeyError) as e:
logger.warning(f"Error parsing stream event: {e}, line: {line}")
continue
if not stream_ended:
logger.error("Stream ended unexpectedly without an 'end' event.")
return bad_request(500, "Stream ended unexpectedly.")
if data.get("isNoneDoc"):
for doc in source_log_docs:
doc["source"] = "None"
llm = LLMCreator.create_llm(
settings.LLM_PROVIDER,
api_key=settings.API_KEY,
user_api_key=user_api_key,
decoded_token=decoded_token,
)
result = {"answer": response_full, "sources": source_log_docs}
result["conversation_id"] = str(
save_conversation(
conversation_id,
question,
response_full,
thought,
source_log_docs,
tool_calls,
llm,
decoded_token,
api_key=user_api_key,
attachment_ids=attachment_ids,
)
)
retriever_params = retriever.get_params()
user_logs_collection.insert_one(
{
"action": "api_answer",
"level": "info",
"user": decoded_token.get("sub"),
"api_key": user_api_key,
"question": question,
"response": response_full,
"sources": source_log_docs,
"retriever_params": retriever_params,
"timestamp": datetime.datetime.now(datetime.timezone.utc),
}
)
except Exception as e:
logger.error(
f"/api/answer - error: {str(e)} - traceback: {traceback.format_exc()}",
extra={"error": str(e), "traceback": traceback.format_exc()},
)
return bad_request(500, str(e))
return make_response(result, 200)
@answer_ns.route("/api/search")
class Search(Resource):
search_model = api.model(
"SearchModel",
{
"question": fields.String(
required=True, description="The question to search"
),
"chunks": fields.Integer(
required=False, default=2, description="Number of chunks"
),
"api_key": fields.String(
required=False, description="API key for authentication"
),
"active_docs": fields.String(
required=False, description="Active documents for retrieval"
),
"retriever": fields.String(required=False, description="Retriever type"),
"token_limit": fields.Integer(
required=False, description="Limit for tokens"
),
"isNoneDoc": fields.Boolean(
required=False, description="Flag indicating if no document is used"
),
},
)
@api.expect(search_model)
@api.doc(
description="Search for relevant documents based on the question and retriever"
)
def post(self):
data = request.get_json()
required_fields = ["question"]
missing_fields = check_required_fields(data, required_fields)
if missing_fields:
return missing_fields
try:
question = data["question"]
chunks = int(data.get("chunks", 2))
token_limit = data.get("token_limit", settings.DEFAULT_MAX_HISTORY)
retriever_name = data.get("retriever", "classic")
if "api_key" in data:
data_key = get_data_from_api_key(data["api_key"])
chunks = int(data_key.get("chunks", 2))
source = {"active_docs": data_key.get("source")}
user_api_key = data["api_key"]
decoded_token = {"sub": data_key.get("user")}
elif "active_docs" in data:
source = {"active_docs": data["active_docs"]}
user_api_key = None
decoded_token = request.decoded_token
else:
source = {}
user_api_key = None
decoded_token = request.decoded_token
if not decoded_token:
return make_response({"error": "Unauthorized"}, 401)
logger.info(
f"/api/answer - request_data: {data}, source: {source}",
extra={"data": json.dumps({"request_data": data, "source": source})},
)
retriever = RetrieverCreator.create_retriever(
retriever_name,
source=source,
chat_history=[],
prompt="default",
chunks=chunks,
token_limit=token_limit,
gpt_model=gpt_model,
user_api_key=user_api_key,
decoded_token=decoded_token,
)
docs = retriever.search(question)
retriever_params = retriever.get_params()
user_logs_collection.insert_one(
{
"action": "api_search",
"level": "info",
"user": decoded_token.get("sub"),
"api_key": user_api_key,
"question": question,
"sources": docs,
"retriever_params": retriever_params,
"timestamp": datetime.datetime.now(datetime.timezone.utc),
}
)
if data.get("isNoneDoc"):
for doc in docs:
doc["source"] = "None"
except Exception as e:
logger.error(
f"/api/search - error: {str(e)} - traceback: {traceback.format_exc()}",
extra={"error": str(e), "traceback": traceback.format_exc()},
)
return bad_request(500, str(e))
return make_response(docs, 200)
def get_attachments_content(attachment_ids, user):
"""
Retrieve content from attachment documents based on their IDs.
Args:
attachment_ids (list): List of attachment document IDs
user (str): User identifier to verify ownership
Returns:
list: List of dictionaries containing attachment content and metadata
"""
if not attachment_ids:
return []
attachments = []
for attachment_id in attachment_ids:
try:
attachment_doc = attachments_collection.find_one(
{"_id": ObjectId(attachment_id), "user": user}
)
if attachment_doc:
attachments.append(attachment_doc)
except Exception as e:
logger.error(
f"Error retrieving attachment {attachment_id}: {e}", exc_info=True
)
return attachments

View File

@@ -1,153 +0,0 @@
import logging
import traceback
from flask import make_response, request
from flask_restx import fields, Resource
from application.api import api
from application.api.answer.routes.base import answer_ns, BaseAnswerResource
from application.api.answer.services.stream_processor import StreamProcessor
logger = logging.getLogger(__name__)
@answer_ns.route("/api/answer")
class AnswerResource(Resource, BaseAnswerResource):
def __init__(self, *args, **kwargs):
Resource.__init__(self, *args, **kwargs)
BaseAnswerResource.__init__(self)
answer_model = answer_ns.model(
"AnswerModel",
{
"question": fields.String(
required=True, description="Question to be asked"
),
"history": fields.List(
fields.String,
required=False,
description="Conversation history (only for new conversations)",
),
"conversation_id": fields.String(
required=False,
description="Existing conversation ID (loads history)",
),
"prompt_id": fields.String(
required=False, default="default", description="Prompt ID"
),
"chunks": fields.Integer(
required=False, default=2, description="Number of chunks"
),
"retriever": fields.String(required=False, description="Retriever type"),
"api_key": fields.String(required=False, description="API key"),
"agent_id": fields.String(required=False, description="Agent ID"),
"active_docs": fields.String(
required=False, description="Active documents"
),
"isNoneDoc": fields.Boolean(
required=False, description="Flag indicating if no document is used"
),
"save_conversation": fields.Boolean(
required=False,
default=True,
description="Whether to save the conversation",
),
"model_id": fields.String(
required=False,
description="Model ID to use for this request",
),
"passthrough": fields.Raw(
required=False,
description="Dynamic parameters to inject into prompt template",
),
},
)
@api.expect(answer_model)
@api.doc(description="Provide a response based on the question and retriever")
def post(self):
data = request.get_json()
if error := self.validate_request(data):
return error
decoded_token = getattr(request, "decoded_token", None)
processor = StreamProcessor(data, decoded_token)
try:
# ---- Continuation mode ----
if data.get("tool_actions"):
(
agent,
messages,
tools_dict,
pending_tool_calls,
tool_actions,
) = processor.resume_from_tool_actions(
data["tool_actions"], data["conversation_id"]
)
if not processor.decoded_token:
return make_response({"error": "Unauthorized"}, 401)
if error := self.check_usage(processor.agent_config):
return error
stream = self.complete_stream(
question="",
agent=agent,
conversation_id=processor.conversation_id,
user_api_key=processor.agent_config.get("user_api_key"),
decoded_token=processor.decoded_token,
agent_id=processor.agent_id,
model_id=processor.model_id,
_continuation={
"messages": messages,
"tools_dict": tools_dict,
"pending_tool_calls": pending_tool_calls,
"tool_actions": tool_actions,
},
)
else:
# ---- Normal mode ----
agent = processor.build_agent(data.get("question", ""))
if not processor.decoded_token:
return make_response({"error": "Unauthorized"}, 401)
if error := self.check_usage(processor.agent_config):
return error
stream = self.complete_stream(
question=data["question"],
agent=agent,
conversation_id=processor.conversation_id,
user_api_key=processor.agent_config.get("user_api_key"),
decoded_token=processor.decoded_token,
isNoneDoc=data.get("isNoneDoc"),
index=None,
should_save_conversation=data.get("save_conversation", True),
agent_id=processor.agent_id,
is_shared_usage=processor.is_shared_usage,
shared_token=processor.shared_token,
model_id=processor.model_id,
)
stream_result = self.process_response_stream(stream)
if stream_result["error"]:
return make_response({"error": stream_result["error"]}, 400)
result = {
"conversation_id": stream_result["conversation_id"],
"answer": stream_result["answer"],
"sources": stream_result["sources"],
"tool_calls": stream_result["tool_calls"],
"thought": stream_result["thought"],
}
extra_info = stream_result.get("extra")
if extra_info:
result.update(extra_info)
except Exception as e:
logger.error(
f"/api/answer - error: {str(e)} - traceback: {traceback.format_exc()}",
extra={"error": str(e), "traceback": traceback.format_exc()},
)
return make_response({"error": "An error occurred processing your request"}, 500)
return make_response(result, 200)

View File

@@ -1,673 +0,0 @@
import datetime
import json
import logging
from typing import Any, Dict, Generator, List, Optional
from flask import jsonify, make_response, Response
from flask_restx import Namespace
from application.api.answer.services.continuation_service import ContinuationService
from application.api.answer.services.conversation_service import ConversationService
from application.core.model_utils import (
get_api_key_for_provider,
get_default_model_id,
get_provider_from_model_id,
)
from application.core.settings import settings
from application.error import sanitize_api_error
from application.llm.llm_creator import LLMCreator
from application.storage.db.repositories.agents import AgentsRepository
from application.storage.db.repositories.token_usage import TokenUsageRepository
from application.storage.db.repositories.user_logs import UserLogsRepository
from application.storage.db.session import db_readonly, db_session
from application.utils import check_required_fields
logger = logging.getLogger(__name__)
answer_ns = Namespace("answer", description="Answer related operations", path="/")
class BaseAnswerResource:
"""Shared base class for answer endpoints"""
def __init__(self):
self.default_model_id = get_default_model_id()
self.conversation_service = ConversationService()
def validate_request(
self, data: Dict[str, Any], require_conversation_id: bool = False
) -> Optional[Response]:
"""Common request validation.
Continuation requests (``tool_actions`` present) require
``conversation_id`` but not ``question``.
"""
if data.get("tool_actions"):
# Continuation mode — question is not required
if missing := check_required_fields(data, ["conversation_id"]):
return missing
return None
required_fields = ["question"]
if require_conversation_id:
required_fields.append("conversation_id")
if missing_fields := check_required_fields(data, required_fields):
return missing_fields
return None
@staticmethod
def _prepare_tool_calls_for_logging(
tool_calls: Optional[List[Dict[str, Any]]], max_chars: int = 10000
) -> List[Dict[str, Any]]:
if not tool_calls:
return []
prepared = []
for tool_call in tool_calls:
if not isinstance(tool_call, dict):
prepared.append({"result": str(tool_call)[:max_chars]})
continue
item = dict(tool_call)
for key in ("result", "result_full"):
value = item.get(key)
if isinstance(value, str) and len(value) > max_chars:
item[key] = value[:max_chars]
prepared.append(item)
return prepared
def check_usage(self, agent_config: Dict) -> Optional[Response]:
"""Check if there is a usage limit and if it is exceeded
Args:
agent_config: The config dict of agent instance
Returns:
None or Response if either of limits exceeded.
"""
api_key = agent_config.get("user_api_key")
if not api_key:
return None
with db_readonly() as conn:
agent = AgentsRepository(conn).find_by_key(api_key)
if not agent:
return make_response(
jsonify({"success": False, "message": "Invalid API key."}), 401
)
limited_token_mode_raw = agent.get("limited_token_mode", False)
limited_request_mode_raw = agent.get("limited_request_mode", False)
limited_token_mode = (
limited_token_mode_raw
if isinstance(limited_token_mode_raw, bool)
else limited_token_mode_raw == "True"
)
limited_request_mode = (
limited_request_mode_raw
if isinstance(limited_request_mode_raw, bool)
else limited_request_mode_raw == "True"
)
token_limit = int(
agent.get("token_limit") or settings.DEFAULT_AGENT_LIMITS["token_limit"]
)
request_limit = int(
agent.get("request_limit") or settings.DEFAULT_AGENT_LIMITS["request_limit"]
)
end_date = datetime.datetime.now(datetime.timezone.utc)
start_date = end_date - datetime.timedelta(hours=24)
if limited_token_mode or limited_request_mode:
with db_readonly() as conn:
token_repo = TokenUsageRepository(conn)
if limited_token_mode:
daily_token_usage = token_repo.sum_tokens_in_range(
start=start_date, end=end_date, api_key=api_key,
)
else:
daily_token_usage = 0
if limited_request_mode:
daily_request_usage = token_repo.count_in_range(
start=start_date, end=end_date, api_key=api_key,
)
else:
daily_request_usage = 0
else:
daily_token_usage = 0
daily_request_usage = 0
if not limited_token_mode and not limited_request_mode:
return None
token_exceeded = (
limited_token_mode and token_limit > 0 and daily_token_usage >= token_limit
)
request_exceeded = (
limited_request_mode
and request_limit > 0
and daily_request_usage >= request_limit
)
if token_exceeded or request_exceeded:
return make_response(
jsonify(
{
"success": False,
"message": "Exceeding usage limit, please try again later.",
}
),
429,
)
return None
def complete_stream(
self,
question: str,
agent: Any,
conversation_id: Optional[str],
user_api_key: Optional[str],
decoded_token: Dict[str, Any],
isNoneDoc: bool = False,
index: Optional[int] = None,
should_save_conversation: bool = True,
attachment_ids: Optional[List[str]] = None,
agent_id: Optional[str] = None,
is_shared_usage: bool = False,
shared_token: Optional[str] = None,
model_id: Optional[str] = None,
model_user_id: Optional[str] = None,
_continuation: Optional[Dict] = None,
) -> Generator[str, None, None]:
"""
Generator function that streams the complete conversation response.
Args:
question: The user's question
agent: The agent instance
retriever: The retriever instance
conversation_id: Existing conversation ID
user_api_key: User's API key if any
decoded_token: Decoded JWT token
isNoneDoc: Flag for document-less responses
index: Index of message to update
should_save_conversation: Whether to persist the conversation
attachment_ids: List of attachment IDs
agent_id: ID of agent used
is_shared_usage: Flag for shared agent usage
shared_token: Token for shared agent
model_id: Model ID used for the request
retrieved_docs: Pre-fetched documents for sources (optional)
Yields:
Server-sent event strings
"""
try:
response_full, thought, source_log_docs, tool_calls = "", "", [], []
is_structured = False
schema_info = None
structured_chunks = []
query_metadata = {}
paused = False
if _continuation:
gen_iter = agent.gen_continuation(
messages=_continuation["messages"],
tools_dict=_continuation["tools_dict"],
pending_tool_calls=_continuation["pending_tool_calls"],
tool_actions=_continuation["tool_actions"],
)
else:
gen_iter = agent.gen(query=question)
for line in gen_iter:
if "metadata" in line:
query_metadata.update(line["metadata"])
elif "answer" in line:
response_full += str(line["answer"])
if line.get("structured"):
is_structured = True
schema_info = line.get("schema")
structured_chunks.append(line["answer"])
else:
data = json.dumps({"type": "answer", "answer": line["answer"]})
yield f"data: {data}\n\n"
elif "sources" in line:
truncated_sources = []
source_log_docs = line["sources"]
for source in line["sources"]:
truncated_source = source.copy()
if "text" in truncated_source:
truncated_source["text"] = (
truncated_source["text"][:100].strip() + "..."
)
truncated_sources.append(truncated_source)
if truncated_sources:
data = json.dumps(
{"type": "source", "source": truncated_sources}
)
yield f"data: {data}\n\n"
elif "tool_calls" in line:
tool_calls = line["tool_calls"]
data = json.dumps({"type": "tool_calls", "tool_calls": tool_calls})
yield f"data: {data}\n\n"
elif "thought" in line:
thought += line["thought"]
data = json.dumps({"type": "thought", "thought": line["thought"]})
yield f"data: {data}\n\n"
elif "type" in line:
if line.get("type") == "tool_calls_pending":
# Save continuation state and end the stream
paused = True
data = json.dumps(line)
yield f"data: {data}\n\n"
elif line.get("type") == "error":
sanitized_error = {
"type": "error",
"error": sanitize_api_error(line.get("error", "An error occurred"))
}
data = json.dumps(sanitized_error)
yield f"data: {data}\n\n"
else:
data = json.dumps(line)
yield f"data: {data}\n\n"
if is_structured and structured_chunks:
structured_data = {
"type": "structured_answer",
"answer": response_full,
"structured": True,
"schema": schema_info,
}
data = json.dumps(structured_data)
yield f"data: {data}\n\n"
# ---- Paused: save continuation state and end stream early ----
if paused:
continuation = getattr(agent, "_pending_continuation", None)
if continuation:
# Ensure we have a conversation_id — create a partial
# conversation if this is the first turn.
if not conversation_id and should_save_conversation:
try:
# Use model-owner scope so shared-agent
# owner-BYOM resolves to its registered plugin.
provider = (
get_provider_from_model_id(
model_id,
user_id=model_user_id
or (
decoded_token.get("sub")
if decoded_token
else None
),
)
if model_id
else settings.LLM_PROVIDER
)
sys_api_key = get_api_key_for_provider(
provider or settings.LLM_PROVIDER
)
llm = LLMCreator.create_llm(
provider or settings.LLM_PROVIDER,
api_key=sys_api_key,
user_api_key=user_api_key,
decoded_token=decoded_token,
model_id=model_id,
agent_id=agent_id,
model_user_id=model_user_id,
)
conversation_id = (
self.conversation_service.save_conversation(
None,
question,
response_full,
thought,
source_log_docs,
tool_calls,
llm,
model_id or self.default_model_id,
decoded_token,
api_key=user_api_key,
agent_id=agent_id,
is_shared_usage=is_shared_usage,
shared_token=shared_token,
)
)
except Exception as e:
logger.error(
f"Failed to create conversation for continuation: {e}",
exc_info=True,
)
if conversation_id:
try:
cont_service = ContinuationService()
cont_service.save_state(
conversation_id=str(conversation_id),
user=decoded_token.get("sub", "local"),
messages=continuation["messages"],
pending_tool_calls=continuation["pending_tool_calls"],
tools_dict=continuation["tools_dict"],
tool_schemas=getattr(agent, "tools", []),
agent_config={
"model_id": model_id or self.default_model_id,
# Persist BYOM scope so resume doesn't
# fall back to caller's layer.
"model_user_id": model_user_id,
"llm_name": getattr(agent, "llm_name", settings.LLM_PROVIDER),
"api_key": getattr(agent, "api_key", None),
"user_api_key": user_api_key,
"agent_id": agent_id,
"agent_type": agent.__class__.__name__,
"prompt": getattr(agent, "prompt", ""),
"json_schema": getattr(agent, "json_schema", None),
"retriever_config": getattr(agent, "retriever_config", None),
},
client_tools=getattr(
agent.tool_executor, "client_tools", None
),
)
except Exception as e:
logger.error(
f"Failed to save continuation state: {str(e)}",
exc_info=True,
)
id_data = {"type": "id", "id": str(conversation_id)}
data = json.dumps(id_data)
yield f"data: {data}\n\n"
data = json.dumps({"type": "end"})
yield f"data: {data}\n\n"
return
if isNoneDoc:
for doc in source_log_docs:
doc["source"] = "None"
# Run under model-owner scope so title-gen LLM inside
# save_conversation uses the owner's BYOM provider/key.
provider = (
get_provider_from_model_id(
model_id,
user_id=model_user_id
or (decoded_token.get("sub") if decoded_token else None),
)
if model_id
else settings.LLM_PROVIDER
)
system_api_key = get_api_key_for_provider(provider or settings.LLM_PROVIDER)
llm = LLMCreator.create_llm(
provider or settings.LLM_PROVIDER,
api_key=system_api_key,
user_api_key=user_api_key,
decoded_token=decoded_token,
model_id=model_id,
agent_id=agent_id,
model_user_id=model_user_id,
)
if should_save_conversation:
conversation_id = self.conversation_service.save_conversation(
conversation_id,
question,
response_full,
thought,
source_log_docs,
tool_calls,
llm,
model_id or self.default_model_id,
decoded_token,
index=index,
api_key=user_api_key,
agent_id=agent_id,
is_shared_usage=is_shared_usage,
shared_token=shared_token,
attachment_ids=attachment_ids,
metadata=query_metadata if query_metadata else None,
)
# Persist compression metadata/summary if it exists and wasn't saved mid-execution
compression_meta = getattr(agent, "compression_metadata", None)
compression_saved = getattr(agent, "compression_saved", False)
if conversation_id and compression_meta and not compression_saved:
try:
self.conversation_service.update_compression_metadata(
conversation_id, compression_meta
)
self.conversation_service.append_compression_message(
conversation_id, compression_meta
)
agent.compression_saved = True
logger.info(
f"Persisted compression metadata for conversation {conversation_id}"
)
except Exception as e:
logger.error(
f"Failed to persist compression metadata: {str(e)}",
exc_info=True,
)
else:
conversation_id = None
id_data = {"type": "id", "id": str(conversation_id)}
data = json.dumps(id_data)
yield f"data: {data}\n\n"
tool_calls_for_logging = self._prepare_tool_calls_for_logging(
getattr(agent, "tool_calls", tool_calls) or tool_calls
)
log_data = {
"action": "stream_answer",
"level": "info",
"user": decoded_token.get("sub"),
"api_key": user_api_key,
"agent_id": agent_id,
"question": question,
"response": response_full,
"sources": source_log_docs,
"tool_calls": tool_calls_for_logging,
"attachments": attachment_ids,
"timestamp": datetime.datetime.now(datetime.timezone.utc),
}
if is_structured:
log_data["structured_output"] = True
if schema_info:
log_data["schema"] = schema_info
# Clean up text fields to be no longer than 10000 characters
for key, value in log_data.items():
if isinstance(value, str) and len(value) > 10000:
log_data[key] = value[:10000]
try:
with db_session() as conn:
UserLogsRepository(conn).insert(
user_id=log_data.get("user"),
endpoint="stream_answer",
data=log_data,
)
except Exception as log_err:
logger.error(
f"Failed to persist stream_answer user log: {log_err}",
exc_info=True,
)
data = json.dumps({"type": "end"})
yield f"data: {data}\n\n"
except GeneratorExit:
logger.info(f"Stream aborted by client for question: {question[:50]}... ")
# Save partial response
if should_save_conversation and response_full:
try:
if isNoneDoc:
for doc in source_log_docs:
doc["source"] = "None"
# Mirror the normal-path provider resolution so the
# partial-save title LLM uses the model-owner's BYOM
# registration (shared-agent dispatch) rather than
# the deployment default with the instance api key.
provider = (
get_provider_from_model_id(
model_id,
user_id=model_user_id
or (
decoded_token.get("sub")
if decoded_token
else None
),
)
if model_id
else settings.LLM_PROVIDER
)
sys_api_key = get_api_key_for_provider(
provider or settings.LLM_PROVIDER
)
llm = LLMCreator.create_llm(
provider or settings.LLM_PROVIDER,
api_key=sys_api_key,
user_api_key=user_api_key,
decoded_token=decoded_token,
model_id=model_id,
agent_id=agent_id,
model_user_id=model_user_id,
)
self.conversation_service.save_conversation(
conversation_id,
question,
response_full,
thought,
source_log_docs,
tool_calls,
llm,
model_id or self.default_model_id,
decoded_token,
index=index,
api_key=user_api_key,
agent_id=agent_id,
is_shared_usage=is_shared_usage,
shared_token=shared_token,
attachment_ids=attachment_ids,
metadata=query_metadata if query_metadata else None,
)
compression_meta = getattr(agent, "compression_metadata", None)
compression_saved = getattr(agent, "compression_saved", False)
if conversation_id and compression_meta and not compression_saved:
try:
self.conversation_service.update_compression_metadata(
conversation_id, compression_meta
)
self.conversation_service.append_compression_message(
conversation_id, compression_meta
)
agent.compression_saved = True
logger.info(
f"Persisted compression metadata for conversation {conversation_id} (partial stream)"
)
except Exception as e:
logger.error(
f"Failed to persist compression metadata (partial stream): {str(e)}",
exc_info=True,
)
except Exception as e:
logger.error(
f"Error saving partial response: {str(e)}", exc_info=True
)
raise
except Exception as e:
logger.error(f"Error in stream: {str(e)}", exc_info=True)
data = json.dumps(
{
"type": "error",
"error": "Please try again later. We apologize for any inconvenience.",
}
)
yield f"data: {data}\n\n"
return
def process_response_stream(self, stream) -> Dict[str, Any]:
"""Process the stream response for non-streaming endpoint.
Returns:
Dict with keys: conversation_id, answer, sources, tool_calls,
thought, error, and optional extra.
"""
conversation_id = ""
response_full = ""
source_log_docs = []
tool_calls = []
thought = ""
stream_ended = False
is_structured = False
schema_info = None
pending_tool_calls = None
for line in stream:
try:
event_data = line.replace("data: ", "").strip()
event = json.loads(event_data)
if event["type"] == "id":
conversation_id = event["id"]
elif event["type"] == "answer":
response_full += event["answer"]
elif event["type"] == "structured_answer":
response_full = event["answer"]
is_structured = True
schema_info = event.get("schema")
elif event["type"] == "source":
source_log_docs = event["source"]
elif event["type"] == "tool_calls":
tool_calls = event["tool_calls"]
elif event["type"] == "tool_calls_pending":
pending_tool_calls = event.get("data", {}).get(
"pending_tool_calls", []
)
elif event["type"] == "thought":
thought = event["thought"]
elif event["type"] == "error":
logger.error(f"Error from stream: {event['error']}")
return {
"conversation_id": None,
"answer": None,
"sources": None,
"tool_calls": None,
"thought": None,
"error": event["error"],
}
elif event["type"] == "end":
stream_ended = True
except (json.JSONDecodeError, KeyError) as e:
logger.warning(f"Error parsing stream event: {e}, line: {line}")
continue
if not stream_ended:
logger.error("Stream ended unexpectedly without an 'end' event.")
return {
"conversation_id": None,
"answer": None,
"sources": None,
"tool_calls": None,
"thought": None,
"error": "Stream ended unexpectedly",
}
result: Dict[str, Any] = {
"conversation_id": conversation_id,
"answer": response_full,
"sources": source_log_docs,
"tool_calls": tool_calls,
"thought": thought,
"error": None,
}
if pending_tool_calls is not None:
result["extra"] = {"pending_tool_calls": pending_tool_calls}
if is_structured:
result["extra"] = {"structured": True, "schema": schema_info}
return result
def error_stream_generate(self, err_response):
data = json.dumps({"type": "error", "error": err_response})
yield f"data: {data}\n\n"

View File

@@ -1,55 +0,0 @@
import logging
from flask import make_response, request
from flask_restx import fields, Resource
from application.api.answer.routes.base import answer_ns
from application.services.search_service import (
InvalidAPIKey,
SearchFailed,
search,
)
logger = logging.getLogger(__name__)
@answer_ns.route("/api/search")
class SearchResource(Resource):
"""Fast search endpoint for retrieving relevant documents."""
search_model = answer_ns.model(
"SearchModel",
{
"question": fields.String(
required=True, description="Search query"
),
"api_key": fields.String(
required=True, description="API key for authentication"
),
"chunks": fields.Integer(
required=False, default=5, description="Number of results to return"
),
},
)
@answer_ns.expect(search_model)
@answer_ns.doc(description="Search for relevant documents based on query")
def post(self):
data = request.get_json() or {}
question = data.get("question")
api_key = data.get("api_key")
chunks = data.get("chunks", 5)
if not question:
return make_response({"error": "question is required"}, 400)
if not api_key:
return make_response({"error": "api_key is required"}, 400)
try:
return make_response(search(api_key, question, chunks), 200)
except InvalidAPIKey:
return make_response({"error": "Invalid API key"}, 401)
except SearchFailed:
logger.exception("/api/search failed")
return make_response({"error": "Search failed"}, 500)

View File

@@ -1,173 +0,0 @@
import logging
import traceback
from flask import request, Response
from flask_restx import fields, Resource
from application.api import api
from application.api.answer.routes.base import answer_ns, BaseAnswerResource
from application.api.answer.services.stream_processor import StreamProcessor
logger = logging.getLogger(__name__)
@answer_ns.route("/stream")
class StreamResource(Resource, BaseAnswerResource):
def __init__(self, *args, **kwargs):
Resource.__init__(self, *args, **kwargs)
BaseAnswerResource.__init__(self)
stream_model = answer_ns.model(
"StreamModel",
{
"question": fields.String(
required=True, description="Question to be asked"
),
"history": fields.List(
fields.String,
required=False,
description="Conversation history (only for new conversations)",
),
"conversation_id": fields.String(
required=False,
description="Existing conversation ID (loads history)",
),
"prompt_id": fields.String(
required=False, default="default", description="Prompt ID"
),
"chunks": fields.Integer(
required=False, default=2, description="Number of chunks"
),
"retriever": fields.String(required=False, description="Retriever type"),
"api_key": fields.String(required=False, description="API key"),
"agent_id": fields.String(required=False, description="Agent ID"),
"active_docs": fields.String(
required=False, description="Active documents"
),
"isNoneDoc": fields.Boolean(
required=False, description="Flag indicating if no document is used"
),
"index": fields.Integer(
required=False, description="Index of the query to update"
),
"save_conversation": fields.Boolean(
required=False,
default=True,
description="Whether to save the conversation",
),
"model_id": fields.String(
required=False,
description="Model ID to use for this request",
),
"attachments": fields.List(
fields.String, required=False, description="List of attachment IDs"
),
"passthrough": fields.Raw(
required=False,
description="Dynamic parameters to inject into prompt template",
),
},
)
@api.expect(stream_model)
@api.doc(description="Stream a response based on the question and retriever")
def post(self):
data = request.get_json()
if error := self.validate_request(data, "index" in data):
return error
decoded_token = getattr(request, "decoded_token", None)
processor = StreamProcessor(data, decoded_token)
try:
# ---- Continuation mode ----
if data.get("tool_actions"):
(
agent,
messages,
tools_dict,
pending_tool_calls,
tool_actions,
) = processor.resume_from_tool_actions(
data["tool_actions"], data["conversation_id"]
)
if not processor.decoded_token:
return Response(
self.error_stream_generate("Unauthorized"),
status=401,
mimetype="text/event-stream",
)
if error := self.check_usage(processor.agent_config):
return error
return Response(
self.complete_stream(
question="",
agent=agent,
conversation_id=processor.conversation_id,
user_api_key=processor.agent_config.get("user_api_key"),
decoded_token=processor.decoded_token,
agent_id=processor.agent_id,
model_id=processor.model_id,
model_user_id=processor.model_user_id,
_continuation={
"messages": messages,
"tools_dict": tools_dict,
"pending_tool_calls": pending_tool_calls,
"tool_actions": tool_actions,
},
),
mimetype="text/event-stream",
)
# ---- Normal mode ----
agent = processor.build_agent(data["question"])
if not processor.decoded_token:
return Response(
self.error_stream_generate("Unauthorized"),
status=401,
mimetype="text/event-stream",
)
if error := self.check_usage(processor.agent_config):
return error
return Response(
self.complete_stream(
question=data["question"],
agent=agent,
conversation_id=processor.conversation_id,
user_api_key=processor.agent_config.get("user_api_key"),
decoded_token=processor.decoded_token,
isNoneDoc=data.get("isNoneDoc"),
index=data.get("index"),
should_save_conversation=data.get("save_conversation", True),
attachment_ids=data.get("attachments", []),
agent_id=processor.agent_id,
is_shared_usage=processor.is_shared_usage,
shared_token=processor.shared_token,
model_id=processor.model_id,
model_user_id=processor.model_user_id,
),
mimetype="text/event-stream",
)
except ValueError as e:
message = "Malformed request body"
logger.error(
f"/stream - error: {message} - specific error: {str(e)} - traceback: {traceback.format_exc()}",
extra={"error": str(e), "traceback": traceback.format_exc()},
)
return Response(
self.error_stream_generate(message),
status=400,
mimetype="text/event-stream",
)
except Exception as e:
logger.error(
f"/stream - error: {str(e)} - traceback: {traceback.format_exc()}",
extra={"error": str(e), "traceback": traceback.format_exc()},
)
return Response(
self.error_stream_generate("Unknown error occurred"),
status=400,
mimetype="text/event-stream",
)

View File

@@ -1,20 +0,0 @@
"""
Compression module for managing conversation context compression.
"""
from application.api.answer.services.compression.orchestrator import (
CompressionOrchestrator,
)
from application.api.answer.services.compression.service import CompressionService
from application.api.answer.services.compression.types import (
CompressionResult,
CompressionMetadata,
)
__all__ = [
"CompressionOrchestrator",
"CompressionService",
"CompressionResult",
"CompressionMetadata",
]

View File

@@ -1,249 +0,0 @@
"""Message reconstruction utilities for compression."""
import json
import logging
import uuid
from typing import Dict, List, Optional
logger = logging.getLogger(__name__)
class MessageBuilder:
"""Builds message arrays from compressed context."""
@staticmethod
def build_from_compressed_context(
system_prompt: str,
compressed_summary: Optional[str],
recent_queries: List[Dict],
include_tool_calls: bool = False,
context_type: str = "pre_request",
) -> List[Dict]:
"""
Build messages from compressed context.
Args:
system_prompt: Original system prompt
compressed_summary: Compressed summary (if any)
recent_queries: Recent uncompressed queries
include_tool_calls: Whether to include tool calls from history
context_type: Type of context ('pre_request' or 'mid_execution')
Returns:
List of message dicts ready for LLM
"""
# Append compression summary to system prompt if present
if compressed_summary:
system_prompt = MessageBuilder._append_compression_context(
system_prompt, compressed_summary, context_type
)
messages = [{"role": "system", "content": system_prompt}]
# Add recent history
for query in recent_queries:
if "prompt" in query and "response" in query:
messages.append({"role": "user", "content": query["prompt"]})
messages.append({"role": "assistant", "content": query["response"]})
# Add tool calls from history if present
if include_tool_calls and "tool_calls" in query:
for tool_call in query["tool_calls"]:
call_id = tool_call.get("call_id") or str(uuid.uuid4())
args = tool_call.get("arguments")
args_str = (
json.dumps(args)
if isinstance(args, dict)
else (args or "{}")
)
messages.append({
"role": "assistant",
"content": None,
"tool_calls": [{
"id": call_id,
"type": "function",
"function": {
"name": tool_call.get("action_name", ""),
"arguments": args_str,
},
}],
})
result = tool_call.get("result")
result_str = (
json.dumps(result)
if not isinstance(result, str)
else (result or "")
)
messages.append({
"role": "tool",
"tool_call_id": call_id,
"content": result_str,
})
# If no recent queries (everything was compressed), add a continuation user message
if len(recent_queries) == 0 and compressed_summary:
messages.append({
"role": "user",
"content": "Please continue with the remaining tasks based on the context above."
})
logger.info("Added continuation user message to maintain proper turn-taking after full compression")
return messages
@staticmethod
def _append_compression_context(
system_prompt: str, compressed_summary: str, context_type: str = "pre_request"
) -> str:
"""
Append compression context to system prompt.
Args:
system_prompt: Original system prompt
compressed_summary: Summary to append
context_type: Type of compression context
Returns:
Updated system prompt
"""
# Remove existing compression context if present
if "This session is being continued" in system_prompt or "Context window limit reached" in system_prompt:
parts = system_prompt.split("\n\n---\n\n")
system_prompt = parts[0]
# Build appropriate context message based on type
if context_type == "mid_execution":
context_message = (
"\n\n---\n\n"
"Context window limit reached during execution. "
"Previous conversation has been compressed to fit within limits. "
"The conversation is summarized below:\n\n"
f"{compressed_summary}"
)
else: # pre_request
context_message = (
"\n\n---\n\n"
"This session is being continued from a previous conversation that "
"has been compressed to fit within context limits. "
"The conversation is summarized below:\n\n"
f"{compressed_summary}"
)
return system_prompt + context_message
@staticmethod
def rebuild_messages_after_compression(
messages: List[Dict],
compressed_summary: Optional[str],
recent_queries: List[Dict],
include_current_execution: bool = False,
include_tool_calls: bool = False,
) -> Optional[List[Dict]]:
"""
Rebuild the message list after compression so tool execution can continue.
Args:
messages: Original message list
compressed_summary: Compressed summary
recent_queries: Recent uncompressed queries
include_current_execution: Whether to preserve current execution messages
include_tool_calls: Whether to include tool calls from history
Returns:
Rebuilt message list or None if failed
"""
# Find the system message
system_message = next(
(msg for msg in messages if msg.get("role") == "system"), None
)
if not system_message:
logger.warning("No system message found in messages list")
return None
# Update system message with compressed summary
if compressed_summary:
content = system_message.get("content", "")
system_message["content"] = MessageBuilder._append_compression_context(
content, compressed_summary, "mid_execution"
)
logger.info(
"Appended compression summary to system prompt (truncated): %s",
(
compressed_summary[:500] + "..."
if len(compressed_summary) > 500
else compressed_summary
),
)
rebuilt_messages = [system_message]
# Add recent history from compressed context
for query in recent_queries:
if "prompt" in query and "response" in query:
rebuilt_messages.append({"role": "user", "content": query["prompt"]})
rebuilt_messages.append(
{"role": "assistant", "content": query["response"]}
)
# Add tool calls from history if present
if include_tool_calls and "tool_calls" in query:
for tool_call in query["tool_calls"]:
call_id = tool_call.get("call_id") or str(uuid.uuid4())
args = tool_call.get("arguments")
args_str = (
json.dumps(args)
if isinstance(args, dict)
else (args or "{}")
)
rebuilt_messages.append({
"role": "assistant",
"content": None,
"tool_calls": [{
"id": call_id,
"type": "function",
"function": {
"name": tool_call.get("action_name", ""),
"arguments": args_str,
},
}],
})
result = tool_call.get("result")
result_str = (
json.dumps(result)
if not isinstance(result, str)
else (result or "")
)
rebuilt_messages.append({
"role": "tool",
"tool_call_id": call_id,
"content": result_str,
})
# If no recent queries (everything was compressed), add a continuation user message
if len(recent_queries) == 0 and compressed_summary:
rebuilt_messages.append({
"role": "user",
"content": "Please continue with the remaining tasks based on the context above."
})
logger.info("Added continuation user message to maintain proper turn-taking after full compression")
if include_current_execution:
# Preserve any messages that were added during the current execution cycle
recent_msg_count = 1 # system message
for query in recent_queries:
if "prompt" in query and "response" in query:
recent_msg_count += 2
if "tool_calls" in query:
recent_msg_count += len(query["tool_calls"]) * 2
if len(messages) > recent_msg_count:
current_execution_messages = messages[recent_msg_count:]
rebuilt_messages.extend(current_execution_messages)
logger.info(
f"Preserved {len(current_execution_messages)} messages from current execution cycle"
)
logger.info(
f"Messages rebuilt: {len(messages)}{len(rebuilt_messages)} messages. "
f"Ready to continue tool execution."
)
return rebuilt_messages

View File

@@ -1,270 +0,0 @@
"""High-level compression orchestration."""
import logging
from typing import Any, Dict, Optional
from application.api.answer.services.compression.service import CompressionService
from application.api.answer.services.compression.threshold_checker import (
CompressionThresholdChecker,
)
from application.api.answer.services.compression.types import CompressionResult
from application.api.answer.services.conversation_service import ConversationService
from application.core.model_utils import (
get_api_key_for_provider,
get_provider_from_model_id,
)
from application.core.settings import settings
from application.llm.llm_creator import LLMCreator
logger = logging.getLogger(__name__)
class CompressionOrchestrator:
"""
Facade for compression operations.
Coordinates between all compression components and provides
a simple interface for callers.
"""
def __init__(
self,
conversation_service: ConversationService,
threshold_checker: Optional[CompressionThresholdChecker] = None,
):
"""
Initialize orchestrator.
Args:
conversation_service: Service for DB operations
threshold_checker: Custom threshold checker (optional)
"""
self.conversation_service = conversation_service
self.threshold_checker = threshold_checker or CompressionThresholdChecker()
def compress_if_needed(
self,
conversation_id: str,
user_id: str,
model_id: str,
decoded_token: Dict[str, Any],
current_query_tokens: int = 500,
model_user_id: Optional[str] = None,
) -> CompressionResult:
"""
Check if compression is needed and perform it if so.
This is the main entry point for compression operations.
Args:
conversation_id: Conversation ID
user_id: Caller's user id — used for conversation access checks
model_id: Model being used for conversation
decoded_token: User's decoded JWT token
current_query_tokens: Estimated tokens for current query
model_user_id: BYOM-resolution scope (model owner); defaults
to ``user_id`` for built-in / caller-owned models.
Returns:
CompressionResult with summary and recent queries
"""
try:
# Conversation row is owned by the caller, not the model owner.
conversation = self.conversation_service.get_conversation(
conversation_id, user_id
)
if not conversation:
logger.warning(
f"Conversation {conversation_id} not found for user {user_id}"
)
return CompressionResult.failure("Conversation not found")
# Use model-owner scope so per-user BYOM context windows
# (e.g. 8k) compute the threshold against the right limit.
registry_user_id = model_user_id or user_id
if not self.threshold_checker.should_compress(
conversation,
model_id,
current_query_tokens,
user_id=registry_user_id,
):
# No compression needed, return full history
queries = conversation.get("queries", [])
return CompressionResult.success_no_compression(queries)
# Perform compression
return self._perform_compression(
conversation_id,
conversation,
model_id,
decoded_token,
user_id=user_id,
model_user_id=model_user_id,
)
except Exception as e:
logger.error(
f"Error in compress_if_needed: {str(e)}", exc_info=True
)
return CompressionResult.failure(str(e))
def _perform_compression(
self,
conversation_id: str,
conversation: Dict[str, Any],
model_id: str,
decoded_token: Dict[str, Any],
user_id: Optional[str] = None,
model_user_id: Optional[str] = None,
) -> CompressionResult:
"""
Perform the actual compression operation.
Args:
conversation_id: Conversation ID
conversation: Conversation document
model_id: Model ID for conversation
decoded_token: User token
user_id: Caller's id (for conversation reload after compression)
model_user_id: BYOM-resolution scope (model owner)
Returns:
CompressionResult
"""
try:
# Determine which model to use for compression
compression_model = (
settings.COMPRESSION_MODEL_OVERRIDE
if settings.COMPRESSION_MODEL_OVERRIDE
else model_id
)
# Use model-owner scope so provider/api_key resolves to the
# owner's BYOM record (shared-agent dispatch).
caller_user_id = user_id
if caller_user_id is None and isinstance(decoded_token, dict):
caller_user_id = decoded_token.get("sub")
registry_user_id = model_user_id or caller_user_id
provider = get_provider_from_model_id(
compression_model, user_id=registry_user_id
)
api_key = get_api_key_for_provider(provider)
compression_llm = LLMCreator.create_llm(
provider,
api_key=api_key,
user_api_key=None,
decoded_token=decoded_token,
model_id=compression_model,
agent_id=conversation.get("agent_id"),
model_user_id=registry_user_id,
)
# Create compression service with DB update capability
compression_service = CompressionService(
llm=compression_llm,
model_id=compression_model,
conversation_service=self.conversation_service,
)
# Compress all queries up to the latest
queries_count = len(conversation.get("queries", []))
compress_up_to = queries_count - 1
if compress_up_to < 0:
logger.warning("No queries to compress")
return CompressionResult.success_no_compression([])
logger.info(
f"Initiating compression for conversation {conversation_id}: "
f"compressing all {queries_count} queries (0-{compress_up_to})"
)
# Perform compression and save to DB
metadata = compression_service.compress_and_save(
conversation_id, conversation, compress_up_to
)
logger.info(
f"Compression successful - ratio: {metadata.compression_ratio:.1f}x, "
f"saved {metadata.original_token_count - metadata.compressed_token_count} tokens"
)
# Reload under caller (conversation is owned by caller).
reload_user_id = caller_user_id
if reload_user_id is None and isinstance(decoded_token, dict):
reload_user_id = decoded_token.get("sub")
conversation = self.conversation_service.get_conversation(
conversation_id, user_id=reload_user_id
)
# Get compressed context
compressed_summary, recent_queries = (
compression_service.get_compressed_context(conversation)
)
return CompressionResult.success_with_compression(
compressed_summary, recent_queries, metadata
)
except Exception as e:
logger.error(f"Error performing compression: {str(e)}", exc_info=True)
return CompressionResult.failure(str(e))
def compress_mid_execution(
self,
conversation_id: str,
user_id: str,
model_id: str,
decoded_token: Dict[str, Any],
current_conversation: Optional[Dict[str, Any]] = None,
model_user_id: Optional[str] = None,
) -> CompressionResult:
"""
Perform compression during tool execution.
Args:
conversation_id: Conversation ID
user_id: Caller's user id — used for conversation access checks
model_id: Model ID
decoded_token: User token
current_conversation: Pre-loaded conversation (optional)
model_user_id: BYOM-resolution scope (model owner). For
shared-agent dispatch this is the agent owner; defaults
to ``user_id`` so built-in / caller-owned models are
unaffected.
Returns:
CompressionResult
"""
try:
# Load conversation if not provided
if current_conversation:
conversation = current_conversation
else:
conversation = self.conversation_service.get_conversation(
conversation_id, user_id
)
if not conversation:
logger.warning(
f"Could not load conversation {conversation_id} for mid-execution compression"
)
return CompressionResult.failure("Conversation not found")
# Perform compression
return self._perform_compression(
conversation_id,
conversation,
model_id,
decoded_token,
user_id=user_id,
model_user_id=model_user_id,
)
except Exception as e:
logger.error(
f"Error in mid-execution compression: {str(e)}", exc_info=True
)
return CompressionResult.failure(str(e))

View File

@@ -1,149 +0,0 @@
"""Compression prompt building logic."""
import logging
from pathlib import Path
from typing import Any, Dict, List, Optional
logger = logging.getLogger(__name__)
class CompressionPromptBuilder:
"""Builds prompts for LLM compression calls."""
def __init__(self, version: str = "v1.0"):
"""
Initialize prompt builder.
Args:
version: Prompt template version to use
"""
self.version = version
self.system_prompt = self._load_prompt(version)
def _load_prompt(self, version: str) -> str:
"""
Load prompt template from file.
Args:
version: Version string (e.g., 'v1.0')
Returns:
Prompt template content
Raises:
FileNotFoundError: If prompt template file doesn't exist
"""
current_dir = Path(__file__).resolve().parents[4]
prompt_path = current_dir / "prompts" / "compression" / f"{version}.txt"
try:
with open(prompt_path, "r") as f:
return f.read()
except FileNotFoundError:
logger.error(f"Compression prompt template not found: {prompt_path}")
raise FileNotFoundError(
f"Compression prompt template '{version}' not found at {prompt_path}. "
f"Please ensure the template file exists."
)
def build_prompt(
self,
queries: List[Dict[str, Any]],
existing_compressions: Optional[List[Dict[str, Any]]] = None,
) -> List[Dict[str, str]]:
"""
Build messages for compression LLM call.
Args:
queries: List of query objects to compress
existing_compressions: List of previous compression points
Returns:
List of message dicts for LLM
"""
# Build conversation text
conversation_text = self._format_conversation(queries)
# Add existing compression context if present
existing_compression_context = ""
if existing_compressions and len(existing_compressions) > 0:
existing_compression_context = (
"\n\nIMPORTANT: This conversation has been compressed before. "
"Previous compression summaries:\n\n"
)
for i, comp in enumerate(existing_compressions):
existing_compression_context += (
f"--- Compression {i + 1} (up to message {comp.get('query_index', 'unknown')}) ---\n"
f"{comp.get('compressed_summary', '')}\n\n"
)
existing_compression_context += (
"Your task is to create a NEW summary that incorporates the context from "
"previous compressions AND the new messages below. The final summary should "
"be comprehensive and include all important information from both previous "
"compressions and new messages.\n\n"
)
user_prompt = (
f"{existing_compression_context}"
f"Here is the conversation to summarize:\n\n"
f"{conversation_text}"
)
messages = [
{"role": "system", "content": self.system_prompt},
{"role": "user", "content": user_prompt},
]
return messages
def _format_conversation(self, queries: List[Dict[str, Any]]) -> str:
"""
Format conversation queries into readable text for compression.
Args:
queries: List of query objects
Returns:
Formatted conversation text
"""
conversation_lines = []
for i, query in enumerate(queries):
conversation_lines.append(f"--- Message {i + 1} ---")
conversation_lines.append(f"User: {query.get('prompt', '')}")
# Add tool calls if present
tool_calls = query.get("tool_calls", [])
if tool_calls:
conversation_lines.append("\nTool Calls:")
for tc in tool_calls:
tool_name = tc.get("tool_name", "unknown")
action_name = tc.get("action_name", "unknown")
arguments = tc.get("arguments", {})
result = tc.get("result", "")
if result is None:
result = ""
status = tc.get("status", "unknown")
# Include full tool result for complete compression context
conversation_lines.append(
f" - {tool_name}.{action_name}({arguments}) "
f"[{status}] → {result}"
)
# Add agent thought if present
thought = query.get("thought", "")
if thought:
conversation_lines.append(f"\nAgent Thought: {thought}")
# Add assistant response
conversation_lines.append(f"\nAssistant: {query.get('response', '')}")
# Add sources if present
sources = query.get("sources", [])
if sources:
conversation_lines.append(f"\nSources Used: {len(sources)} documents")
conversation_lines.append("") # Empty line between messages
return "\n".join(conversation_lines)

View File

@@ -1,311 +0,0 @@
"""Core compression service with simplified responsibilities."""
import logging
import re
from datetime import datetime, timezone
from typing import Any, Dict, List, Optional
from application.api.answer.services.compression.prompt_builder import (
CompressionPromptBuilder,
)
from application.api.answer.services.compression.token_counter import TokenCounter
from application.api.answer.services.compression.types import (
CompressionMetadata,
)
from application.core.settings import settings
logger = logging.getLogger(__name__)
class CompressionService:
"""
Service for compressing conversation history.
Handles DB updates.
"""
def __init__(
self,
llm,
model_id: str,
conversation_service=None,
prompt_builder: Optional[CompressionPromptBuilder] = None,
):
"""
Initialize compression service.
Args:
llm: LLM instance to use for compression
model_id: Model ID for compression
conversation_service: Service for DB operations (optional, for DB updates)
prompt_builder: Custom prompt builder (optional)
"""
self.llm = llm
self.model_id = model_id
self.conversation_service = conversation_service
self.prompt_builder = prompt_builder or CompressionPromptBuilder(
version=settings.COMPRESSION_PROMPT_VERSION
)
def compress_conversation(
self,
conversation: Dict[str, Any],
compress_up_to_index: int,
) -> CompressionMetadata:
"""
Compress conversation history up to specified index.
Args:
conversation: Full conversation document
compress_up_to_index: Last query index to include in compression
Returns:
CompressionMetadata with compression details
Raises:
ValueError: If compress_up_to_index is invalid
"""
try:
queries = conversation.get("queries", [])
if compress_up_to_index < 0 or compress_up_to_index >= len(queries):
raise ValueError(
f"Invalid compress_up_to_index: {compress_up_to_index} "
f"(conversation has {len(queries)} queries)"
)
# Get queries to compress
queries_to_compress = queries[: compress_up_to_index + 1]
# Check if there are existing compressions
existing_compressions = conversation.get("compression_metadata", {}).get(
"compression_points", []
)
if existing_compressions:
logger.info(
f"Found {len(existing_compressions)} previous compression(s) - "
f"will incorporate into new summary"
)
# Calculate original token count
original_tokens = TokenCounter.count_query_tokens(queries_to_compress)
# Log tool call stats
self._log_tool_call_stats(queries_to_compress)
# Build compression prompt
messages = self.prompt_builder.build_prompt(
queries_to_compress, existing_compressions
)
# Call LLM to generate compression
logger.info(
f"Starting compression: {len(queries_to_compress)} queries "
f"(messages 0-{compress_up_to_index}, {original_tokens} tokens) "
f"using model {self.model_id}"
)
# See note in conversation_service.py: ``self.model_id`` is
# the registry id (UUID for BYOM); the LLM's own model_id is
# what the provider's API actually expects.
response = self.llm.gen(
model=getattr(self.llm, "model_id", None) or self.model_id,
messages=messages,
max_tokens=4000,
)
# Extract summary from response
compressed_summary = self._extract_summary(response)
# Calculate compressed token count
compressed_tokens = TokenCounter.count_message_tokens(
[{"content": compressed_summary}]
)
# Calculate compression ratio
compression_ratio = (
original_tokens / compressed_tokens if compressed_tokens > 0 else 0
)
logger.info(
f"Compression complete: {original_tokens}{compressed_tokens} tokens "
f"({compression_ratio:.1f}x compression)"
)
# Build compression metadata
compression_metadata = CompressionMetadata(
timestamp=datetime.now(timezone.utc),
query_index=compress_up_to_index,
compressed_summary=compressed_summary,
original_token_count=original_tokens,
compressed_token_count=compressed_tokens,
compression_ratio=compression_ratio,
model_used=self.model_id,
compression_prompt_version=self.prompt_builder.version,
)
return compression_metadata
except Exception as e:
logger.error(f"Error compressing conversation: {str(e)}", exc_info=True)
raise
def compress_and_save(
self,
conversation_id: str,
conversation: Dict[str, Any],
compress_up_to_index: int,
) -> CompressionMetadata:
"""
Compress conversation and save to database.
Args:
conversation_id: Conversation ID
conversation: Full conversation document
compress_up_to_index: Last query index to include
Returns:
CompressionMetadata
Raises:
ValueError: If conversation_service not provided or invalid index
"""
if not self.conversation_service:
raise ValueError(
"conversation_service required for compress_and_save operation"
)
# Perform compression
metadata = self.compress_conversation(conversation, compress_up_to_index)
# Save to database
self.conversation_service.update_compression_metadata(
conversation_id, metadata.to_dict()
)
logger.info(f"Compression metadata saved to database for {conversation_id}")
return metadata
def get_compressed_context(
self, conversation: Dict[str, Any]
) -> tuple[Optional[str], List[Dict[str, Any]]]:
"""
Get compressed summary + recent uncompressed messages.
Args:
conversation: Full conversation document
Returns:
(compressed_summary, recent_messages)
"""
try:
compression_metadata = conversation.get("compression_metadata", {})
if not compression_metadata.get("is_compressed"):
logger.debug("No compression metadata found - using full history")
queries = conversation.get("queries", [])
if queries is None:
logger.error("Conversation queries is None - returning empty list")
return None, []
return None, queries
compression_points = compression_metadata.get("compression_points", [])
if not compression_points:
logger.debug("No compression points found - using full history")
queries = conversation.get("queries", [])
if queries is None:
logger.error("Conversation queries is None - returning empty list")
return None, []
return None, queries
# Get the most recent compression point
latest_compression = compression_points[-1]
compressed_summary = latest_compression.get("compressed_summary")
last_compressed_index = latest_compression.get("query_index")
compressed_tokens = latest_compression.get("compressed_token_count", 0)
original_tokens = latest_compression.get("original_token_count", 0)
# Get only messages after compression point
queries = conversation.get("queries", [])
total_queries = len(queries)
recent_queries = queries[last_compressed_index + 1 :]
logger.info(
f"Using compressed context: summary ({compressed_tokens} tokens, "
f"compressed from {original_tokens}) + {len(recent_queries)} recent messages "
f"(messages {last_compressed_index + 1}-{total_queries - 1})"
)
return compressed_summary, recent_queries
except Exception as e:
logger.error(
f"Error getting compressed context: {str(e)}", exc_info=True
)
queries = conversation.get("queries", [])
if queries is None:
return None, []
return None, queries
def _extract_summary(self, llm_response: str) -> str:
"""
Extract clean summary from LLM response.
Args:
llm_response: Raw LLM response
Returns:
Cleaned summary text
"""
try:
# Try to extract content within <summary> tags
summary_match = re.search(
r"<summary>(.*?)</summary>", llm_response, re.DOTALL
)
if summary_match:
summary = summary_match.group(1).strip()
else:
# If no summary tags, remove analysis tags and use the rest
summary = re.sub(
r"<analysis>.*?</analysis>", "", llm_response, flags=re.DOTALL
).strip()
return summary
except Exception as e:
logger.warning(f"Error extracting summary: {str(e)}, using full response")
return llm_response
def _log_tool_call_stats(self, queries: List[Dict[str, Any]]) -> None:
"""Log statistics about tool calls in queries."""
total_tool_calls = 0
total_tool_result_chars = 0
tool_call_breakdown = {}
for q in queries:
for tc in q.get("tool_calls", []):
total_tool_calls += 1
tool_name = tc.get("tool_name", "unknown")
action_name = tc.get("action_name", "unknown")
key = f"{tool_name}.{action_name}"
tool_call_breakdown[key] = tool_call_breakdown.get(key, 0) + 1
# Track total tool result size
result = tc.get("result", "")
if result:
total_tool_result_chars += len(str(result))
if total_tool_calls > 0:
tool_breakdown_str = ", ".join(
f"{tool}({count})"
for tool, count in sorted(tool_call_breakdown.items())
)
tool_result_kb = total_tool_result_chars / 1024
logger.info(
f"Tool call breakdown: {tool_breakdown_str} "
f"(total result size: {tool_result_kb:.1f} KB, {total_tool_result_chars:,} chars)"
)

View File

@@ -1,110 +0,0 @@
"""Compression threshold checking logic."""
import logging
from typing import Any, Dict
from application.core.model_utils import get_token_limit
from application.core.settings import settings
from application.api.answer.services.compression.token_counter import TokenCounter
logger = logging.getLogger(__name__)
class CompressionThresholdChecker:
"""Determines if compression is needed based on token thresholds."""
def __init__(self, threshold_percentage: float = None):
"""
Initialize threshold checker.
Args:
threshold_percentage: Percentage of context to use as threshold
(defaults to settings.COMPRESSION_THRESHOLD_PERCENTAGE)
"""
self.threshold_percentage = (
threshold_percentage or settings.COMPRESSION_THRESHOLD_PERCENTAGE
)
def should_compress(
self,
conversation: Dict[str, Any],
model_id: str,
current_query_tokens: int = 500,
user_id: str | None = None,
) -> bool:
"""
Determine if compression is needed.
Args:
conversation: Full conversation document
model_id: Target model for this request
current_query_tokens: Estimated tokens for current query
user_id: Owner — needed so per-user BYOM custom-model UUIDs
resolve when looking up the context window.
Returns:
True if tokens >= threshold% of context window
"""
try:
# Calculate total tokens in conversation
total_tokens = TokenCounter.count_conversation_tokens(conversation)
total_tokens += current_query_tokens
# Get context window limit for model
context_limit = get_token_limit(model_id, user_id=user_id)
# Calculate threshold
threshold = int(context_limit * self.threshold_percentage)
compression_needed = total_tokens >= threshold
percentage_used = (total_tokens / context_limit) * 100
if compression_needed:
logger.warning(
f"COMPRESSION TRIGGERED: {total_tokens} tokens / {context_limit} limit "
f"({percentage_used:.1f}% used, threshold: {self.threshold_percentage * 100:.0f}%)"
)
else:
logger.info(
f"Compression check: {total_tokens}/{context_limit} tokens "
f"({percentage_used:.1f}% used, threshold: {self.threshold_percentage * 100:.0f}%) - No compression needed"
)
return compression_needed
except Exception as e:
logger.error(f"Error checking compression need: {str(e)}", exc_info=True)
return False
def check_message_tokens(
self, messages: list, model_id: str, user_id: str | None = None
) -> bool:
"""
Check if message list exceeds threshold.
Args:
messages: List of message dicts
model_id: Target model
user_id: Owner — needed so per-user BYOM custom-model UUIDs
resolve when looking up the context window.
Returns:
True if at or above threshold
"""
try:
current_tokens = TokenCounter.count_message_tokens(messages)
context_limit = get_token_limit(model_id, user_id=user_id)
threshold = int(context_limit * self.threshold_percentage)
if current_tokens >= threshold:
logger.warning(
f"Message context limit approaching: {current_tokens}/{context_limit} tokens "
f"({(current_tokens/context_limit)*100:.1f}%)"
)
return True
return False
except Exception as e:
logger.error(f"Error checking message tokens: {str(e)}", exc_info=True)
return False

View File

@@ -1,133 +0,0 @@
"""Token counting utilities for compression."""
import logging
from typing import Any, Dict, List
from application.utils import num_tokens_from_string
from application.core.settings import settings
logger = logging.getLogger(__name__)
class TokenCounter:
"""Centralized token counting for conversations and messages."""
# Per-image token estimate. Provider tokenizers vary widely
# (Gemini ~258, GPT-4o 85-1500, Claude ~1500) and the actual cost
# depends on resolution/detail we can't see here. Errs slightly high
# so the threshold check stays conservative.
_IMAGE_PART_TOKEN_ESTIMATE = 1500
@staticmethod
def count_message_tokens(messages: List[Dict]) -> int:
"""
Calculate total tokens in a list of messages.
Args:
messages: List of message dicts with 'content' field
Returns:
Total token count
"""
total_tokens = 0
for message in messages:
content = message.get("content", "")
if isinstance(content, str):
total_tokens += num_tokens_from_string(content)
elif isinstance(content, list):
# Handle structured content (tool calls, image parts, etc.)
for item in content:
if isinstance(item, dict):
total_tokens += TokenCounter._count_content_part(item)
return total_tokens
@staticmethod
def _count_content_part(item: Dict) -> int:
# Image/file attachments are billed by the provider per image,
# not proportional to the inline bytes/base64 string.
# ``str(item)`` on a 1MB image inflates the count by ~10000x,
# which trips spurious compression and overflows downstream
# input limits.
item_type = item.get("type")
if "files" in item:
files = item.get("files")
count = len(files) if isinstance(files, list) and files else 1
return TokenCounter._IMAGE_PART_TOKEN_ESTIMATE * count
if "image_url" in item or item_type in {
"image",
"image_url",
"input_image",
"file",
}:
return TokenCounter._IMAGE_PART_TOKEN_ESTIMATE
return num_tokens_from_string(str(item))
@staticmethod
def count_query_tokens(
queries: List[Dict[str, Any]], include_tool_calls: bool = True
) -> int:
"""
Count tokens across multiple query objects.
Args:
queries: List of query objects from conversation
include_tool_calls: Whether to count tool call tokens
Returns:
Total token count
"""
total_tokens = 0
for query in queries:
# Count prompt and response tokens
if "prompt" in query:
total_tokens += num_tokens_from_string(query["prompt"])
if "response" in query:
total_tokens += num_tokens_from_string(query["response"])
if "thought" in query:
total_tokens += num_tokens_from_string(query.get("thought", ""))
# Count tool call tokens
if include_tool_calls and "tool_calls" in query:
for tool_call in query["tool_calls"]:
tool_call_string = (
f"Tool: {tool_call.get('tool_name')} | "
f"Action: {tool_call.get('action_name')} | "
f"Args: {tool_call.get('arguments')} | "
f"Response: {tool_call.get('result')}"
)
total_tokens += num_tokens_from_string(tool_call_string)
return total_tokens
@staticmethod
def count_conversation_tokens(
conversation: Dict[str, Any], include_system_prompt: bool = False
) -> int:
"""
Calculate total tokens in a conversation.
Args:
conversation: Conversation document
include_system_prompt: Whether to include system prompt in count
Returns:
Total token count
"""
try:
queries = conversation.get("queries", [])
total_tokens = TokenCounter.count_query_tokens(queries)
# Add system prompt tokens if requested
if include_system_prompt:
# Rough estimate for system prompt
total_tokens += settings.RESERVED_TOKENS.get("system_prompt", 500)
return total_tokens
except Exception as e:
logger.error(f"Error calculating conversation tokens: {str(e)}")
return 0

View File

@@ -1,83 +0,0 @@
"""Type definitions for compression module."""
from dataclasses import dataclass, field
from datetime import datetime
from typing import Any, Dict, List, Optional
@dataclass
class CompressionMetadata:
"""Metadata about a compression operation."""
timestamp: datetime
query_index: int
compressed_summary: str
original_token_count: int
compressed_token_count: int
compression_ratio: float
model_used: str
compression_prompt_version: str
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary for DB storage."""
return {
"timestamp": self.timestamp,
"query_index": self.query_index,
"compressed_summary": self.compressed_summary,
"original_token_count": self.original_token_count,
"compressed_token_count": self.compressed_token_count,
"compression_ratio": self.compression_ratio,
"model_used": self.model_used,
"compression_prompt_version": self.compression_prompt_version,
}
@dataclass
class CompressionResult:
"""Result of a compression operation."""
success: bool
compressed_summary: Optional[str] = None
recent_queries: List[Dict[str, Any]] = field(default_factory=list)
metadata: Optional[CompressionMetadata] = None
error: Optional[str] = None
compression_performed: bool = False
@classmethod
def success_with_compression(
cls, summary: str, queries: List[Dict], metadata: CompressionMetadata
) -> "CompressionResult":
"""Create a successful result with compression."""
return cls(
success=True,
compressed_summary=summary,
recent_queries=queries,
metadata=metadata,
compression_performed=True,
)
@classmethod
def success_no_compression(cls, queries: List[Dict]) -> "CompressionResult":
"""Create a successful result without compression needed."""
return cls(
success=True,
recent_queries=queries,
compression_performed=False,
)
@classmethod
def failure(cls, error: str) -> "CompressionResult":
"""Create a failure result."""
return cls(success=False, error=error, compression_performed=False)
def as_history(self) -> List[Dict[str, str]]:
"""
Convert recent queries to history format.
Returns:
List of prompt/response dicts
"""
return [
{"prompt": q["prompt"], "response": q["response"]}
for q in self.recent_queries
]

View File

@@ -1,157 +0,0 @@
"""Service for saving and restoring tool-call continuation state.
When a stream pauses (tool needs approval or client-side execution),
the full execution state is persisted to Postgres so the client can
resume later by sending tool_actions.
"""
import logging
from typing import Any, Dict, List, Optional
from uuid import UUID
from application.storage.db.base_repository import looks_like_uuid
from application.storage.db.repositories.conversations import ConversationsRepository
from application.storage.db.repositories.pending_tool_state import (
PendingToolStateRepository,
)
from application.storage.db.session import db_readonly, db_session
logger = logging.getLogger(__name__)
# TTL for pending states — auto-cleaned after this period
PENDING_STATE_TTL_SECONDS = 30 * 60 # 30 minutes
def _make_serializable(obj: Any) -> Any:
"""Recursively coerce non-JSON values into JSON-safe forms.
Handles ``uuid.UUID`` (from PG columns), ``bytes``, and recurses into
dicts/lists. Post-Mongo-cutover the ObjectId branch is gone — none of
our writers produce them anymore.
"""
if isinstance(obj, UUID):
return str(obj)
if isinstance(obj, dict):
return {str(k): _make_serializable(v) for k, v in obj.items()}
if isinstance(obj, list):
return [_make_serializable(v) for v in obj]
if isinstance(obj, bytes):
return obj.decode("utf-8", errors="replace")
return obj
class ContinuationService:
"""Manages pending tool-call state in Postgres."""
def __init__(self):
# No-op constructor retained for call-site compatibility. State
# lives in Postgres now; each operation opens its own short-lived
# session rather than holding a connection on the service.
pass
def save_state(
self,
conversation_id: str,
user: str,
messages: List[Dict],
pending_tool_calls: List[Dict],
tools_dict: Dict,
tool_schemas: List[Dict],
agent_config: Dict,
client_tools: Optional[List[Dict]] = None,
) -> str:
"""Save execution state for later continuation.
``conversation_id`` may be a Postgres UUID or the legacy Mongo
``ObjectId`` string — the latter is resolved via
``conversations.legacy_mongo_id`` to find the matching row.
Args:
conversation_id: The conversation this state belongs to.
user: Owner user ID.
messages: Full messages array at the pause point.
pending_tool_calls: Tool calls awaiting client action.
tools_dict: Serializable tools configuration dict.
tool_schemas: LLM-formatted tool schemas (agent.tools).
agent_config: Config needed to recreate the agent on resume.
client_tools: Client-provided tool schemas for client-side execution.
Returns:
The string ID (conversation_id as provided) of the saved state.
"""
with db_session() as conn:
conv = ConversationsRepository(conn).get_by_legacy_id(conversation_id)
if conv is not None:
pg_conv_id = conv["id"]
elif looks_like_uuid(conversation_id):
pg_conv_id = conversation_id
else:
# Unresolvable legacy ObjectId — downstream ``CAST AS uuid``
# would raise and poison the save. Surface the mismatch so
# the caller can decide (the stream loop in routes/base.py
# already wraps this in try/except).
raise ValueError(
f"Cannot save continuation state: conversation_id "
f"{conversation_id!r} is neither a PG UUID nor a "
f"backfilled legacy Mongo id."
)
PendingToolStateRepository(conn).save_state(
pg_conv_id,
user,
messages=_make_serializable(messages),
pending_tool_calls=_make_serializable(pending_tool_calls),
tools_dict=_make_serializable(tools_dict),
tool_schemas=_make_serializable(tool_schemas),
agent_config=_make_serializable(agent_config),
client_tools=_make_serializable(client_tools) if client_tools else None,
)
logger.info(
f"Saved continuation state for conversation {conversation_id} "
f"with {len(pending_tool_calls)} pending tool call(s)"
)
return conversation_id
def load_state(
self, conversation_id: str, user: str
) -> Optional[Dict[str, Any]]:
"""Load pending continuation state.
Returns:
The state dict, or None if no pending state exists.
"""
with db_readonly() as conn:
conv = ConversationsRepository(conn).get_by_legacy_id(conversation_id)
if conv is not None:
pg_conv_id = conv["id"]
elif looks_like_uuid(conversation_id):
pg_conv_id = conversation_id
else:
# Unresolvable legacy ObjectId → no state can exist for it.
return None
doc = PendingToolStateRepository(conn).load_state(pg_conv_id, user)
if not doc:
return None
return doc
def delete_state(self, conversation_id: str, user: str) -> bool:
"""Delete pending state after successful resumption.
Returns:
True if a row was deleted.
"""
with db_session() as conn:
conv = ConversationsRepository(conn).get_by_legacy_id(conversation_id)
if conv is not None:
pg_conv_id = conv["id"]
elif looks_like_uuid(conversation_id):
pg_conv_id = conversation_id
else:
# Unresolvable legacy ObjectId → nothing to delete.
return False
deleted = PendingToolStateRepository(conn).delete_state(pg_conv_id, user)
if deleted:
logger.info(
f"Deleted continuation state for conversation {conversation_id}"
)
return deleted

View File

@@ -1,286 +0,0 @@
"""Conversation persistence service backed by Postgres.
Handles create / append / update / compression for conversations during
the answer-streaming path. Connections are opened per-operation rather
than held for the duration of a stream.
"""
import logging
from datetime import datetime, timezone
from typing import Any, Dict, List, Optional
from sqlalchemy import text as sql_text
from application.core.settings import settings
from application.storage.db.base_repository import looks_like_uuid
from application.storage.db.repositories.agents import AgentsRepository
from application.storage.db.repositories.conversations import ConversationsRepository
from application.storage.db.session import db_readonly, db_session
logger = logging.getLogger(__name__)
class ConversationService:
def get_conversation(
self, conversation_id: str, user_id: str
) -> Optional[Dict[str, Any]]:
"""Retrieve a conversation with owner-or-shared access control.
Returns a dict in the legacy Mongo shape — ``queries`` is a list
of message dicts (prompt/response/...) — for compatibility with
the streaming pipeline that consumes this shape.
"""
if not conversation_id or not user_id:
return None
try:
with db_readonly() as conn:
repo = ConversationsRepository(conn)
conv = repo.get_any(conversation_id, user_id)
if conv is None:
logger.warning(
f"Conversation not found or unauthorized - ID: {conversation_id}, User: {user_id}"
)
return None
messages = repo.get_messages(str(conv["id"]))
conv["queries"] = messages
conv["_id"] = str(conv["id"])
return conv
except Exception as e:
logger.error(f"Error fetching conversation: {str(e)}", exc_info=True)
return None
def save_conversation(
self,
conversation_id: Optional[str],
question: str,
response: str,
thought: str,
sources: List[Dict[str, Any]],
tool_calls: List[Dict[str, Any]],
llm: Any,
model_id: str,
decoded_token: Dict[str, Any],
index: Optional[int] = None,
api_key: Optional[str] = None,
agent_id: Optional[str] = None,
is_shared_usage: bool = False,
shared_token: Optional[str] = None,
attachment_ids: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
) -> str:
"""Save or update a conversation in Postgres.
Returns the string conversation id (PG UUID as string, or the
caller-provided id if it was already a UUID).
"""
if decoded_token is None:
raise ValueError("Invalid or missing authentication token")
user_id = decoded_token.get("sub")
if not user_id:
raise ValueError("User ID not found in token")
current_time = datetime.now(timezone.utc)
# Trim huge inline source text to a reasonable max before persist.
for source in sources:
if "text" in source and isinstance(source["text"], str):
source["text"] = source["text"][:1000]
message_payload = {
"prompt": question,
"response": response,
"thought": thought,
"sources": sources,
"tool_calls": tool_calls,
"attachments": attachment_ids,
"model_id": model_id,
"timestamp": current_time,
}
if metadata:
message_payload["metadata"] = metadata
if conversation_id is not None and index is not None:
with db_session() as conn:
repo = ConversationsRepository(conn)
conv = repo.get_any(conversation_id, user_id)
if conv is None:
raise ValueError("Conversation not found or unauthorized")
conv_pg_id = str(conv["id"])
repo.update_message_at(conv_pg_id, index, message_payload)
repo.truncate_after(conv_pg_id, index)
return conversation_id
elif conversation_id:
with db_session() as conn:
repo = ConversationsRepository(conn)
conv = repo.get_any(conversation_id, user_id)
if conv is None:
raise ValueError("Conversation not found or unauthorized")
conv_pg_id = str(conv["id"])
# append_message expects 'metadata' key either way; normalise.
append_payload = dict(message_payload)
append_payload.setdefault("metadata", metadata or {})
repo.append_message(conv_pg_id, append_payload)
return conversation_id
else:
messages_summary = [
{
"role": "system",
"content": "You are a helpful assistant that creates concise conversation titles. "
"Summarize conversations in 3 words or less using the same language as the user.",
},
{
"role": "user",
"content": "Summarise following conversation in no more than 3 words, "
"respond ONLY with the summary, use the same language as the "
"user query \n\nUser: " + question + "\n\n" + "AI: " + response,
},
]
# ``model_id`` here is the registry id (a UUID for BYOM
# records). The LLM's own ``model_id`` is the upstream name
# LLMCreator resolved at construction time — that's what
# the provider's API expects. Built-ins are unaffected.
completion = llm.gen(
model=getattr(llm, "model_id", None) or model_id,
messages=messages_summary,
max_tokens=500,
)
if not completion or not completion.strip():
completion = question[:50] if question else "New Conversation"
resolved_api_key: Optional[str] = None
resolved_agent_id: Optional[str] = None
if api_key:
with db_readonly() as conn:
agent = AgentsRepository(conn).find_by_key(api_key)
if agent:
resolved_api_key = agent.get("key")
if agent_id:
resolved_agent_id = agent_id
with db_session() as conn:
repo = ConversationsRepository(conn)
conv = repo.create(
user_id,
completion,
agent_id=resolved_agent_id,
api_key=resolved_api_key,
is_shared_usage=bool(resolved_agent_id and is_shared_usage),
shared_token=(
shared_token
if (resolved_agent_id and is_shared_usage)
else None
),
)
conv_pg_id = str(conv["id"])
append_payload = dict(message_payload)
append_payload.setdefault("metadata", metadata or {})
repo.append_message(conv_pg_id, append_payload)
return conv_pg_id
def update_compression_metadata(
self, conversation_id: str, compression_metadata: Dict[str, Any]
) -> None:
"""Persist compression flags and append a compression point.
Mirrors the Mongo-era ``$set`` + ``$push $slice`` on
``compression_metadata`` but goes through the PG repo API.
"""
try:
with db_session() as conn:
repo = ConversationsRepository(conn)
# conversation_id here comes from the streaming pipeline
# which has already resolved it; accept either UUID or
# legacy id for safety.
conv = repo.get_by_legacy_id(conversation_id)
conv_pg_id = (
str(conv["id"]) if conv is not None else conversation_id
)
repo.set_compression_flags(
conv_pg_id,
is_compressed=True,
last_compression_at=compression_metadata.get("timestamp"),
)
repo.append_compression_point(
conv_pg_id,
compression_metadata,
max_points=settings.COMPRESSION_MAX_HISTORY_POINTS,
)
logger.info(
f"Updated compression metadata for conversation {conversation_id}"
)
except Exception as e:
logger.error(
f"Error updating compression metadata: {str(e)}", exc_info=True
)
raise
def append_compression_message(
self, conversation_id: str, compression_metadata: Dict[str, Any]
) -> None:
"""Append a synthetic compression summary message to the conversation."""
try:
summary = compression_metadata.get("compressed_summary", "")
if not summary:
return
timestamp = compression_metadata.get(
"timestamp", datetime.now(timezone.utc)
)
with db_session() as conn:
repo = ConversationsRepository(conn)
conv = repo.get_by_legacy_id(conversation_id)
conv_pg_id = (
str(conv["id"]) if conv is not None else conversation_id
)
repo.append_message(conv_pg_id, {
"prompt": "[Context Compression Summary]",
"response": summary,
"thought": "",
"sources": [],
"tool_calls": [],
"attachments": [],
"model_id": compression_metadata.get("model_used"),
"timestamp": timestamp,
})
logger.info(
f"Appended compression summary to conversation {conversation_id}"
)
except Exception as e:
logger.error(
f"Error appending compression summary: {str(e)}", exc_info=True
)
def get_compression_metadata(
self, conversation_id: str
) -> Optional[Dict[str, Any]]:
"""Fetch the stored compression metadata JSONB blob for a conversation."""
try:
with db_readonly() as conn:
repo = ConversationsRepository(conn)
conv = repo.get_by_legacy_id(conversation_id)
if conv is None:
# Fallback to UUID lookup without user scoping — the
# caller already holds an authenticated conversation
# id from the streaming path. Gate on id shape so a
# non-UUID (legacy ObjectId that wasn't backfilled)
# doesn't reach CAST — the cast raises and spams the
# logs with a stack trace on every call.
if not looks_like_uuid(conversation_id):
return None
result = conn.execute(
sql_text(
"SELECT compression_metadata FROM conversations "
"WHERE id = CAST(:id AS uuid)"
),
{"id": conversation_id},
)
row = result.fetchone()
return row[0] if row is not None else None
return conv.get("compression_metadata") if conv else None
except Exception as e:
logger.error(
f"Error getting compression metadata: {str(e)}", exc_info=True
)
return None

View File

@@ -1,97 +0,0 @@
import logging
from typing import Any, Dict, Optional
from application.templates.namespaces import NamespaceManager
from application.templates.template_engine import TemplateEngine, TemplateRenderError
logger = logging.getLogger(__name__)
class PromptRenderer:
"""Service for rendering prompts with dynamic context using namespaces"""
def __init__(self):
self.template_engine = TemplateEngine()
self.namespace_manager = NamespaceManager()
def render_prompt(
self,
prompt_content: str,
user_id: Optional[str] = None,
request_id: Optional[str] = None,
passthrough_data: Optional[Dict[str, Any]] = None,
docs: Optional[list] = None,
docs_together: Optional[str] = None,
tools_data: Optional[Dict[str, Any]] = None,
**kwargs,
) -> str:
"""
Render prompt with full context from all namespaces.
Args:
prompt_content: Raw prompt template string
user_id: Current user identifier
request_id: Unique request identifier
passthrough_data: Parameters from web request
docs: RAG retrieved documents
docs_together: Concatenated document content
tools_data: Pre-fetched tool results organized by tool name
**kwargs: Additional parameters for namespace builders
Returns:
Rendered prompt string with all variables substituted
Raises:
TemplateRenderError: If template rendering fails
"""
if not prompt_content:
return ""
uses_template = self._uses_template_syntax(prompt_content)
if not uses_template:
return self._apply_legacy_substitutions(prompt_content, docs_together)
try:
context = self.namespace_manager.build_context(
user_id=user_id,
request_id=request_id,
passthrough_data=passthrough_data,
docs=docs,
docs_together=docs_together,
tools_data=tools_data,
**kwargs,
)
return self.template_engine.render(prompt_content, context)
except TemplateRenderError:
raise
except Exception as e:
error_msg = f"Prompt rendering failed: {str(e)}"
logger.error(error_msg)
raise TemplateRenderError(error_msg) from e
def _uses_template_syntax(self, prompt_content: str) -> bool:
"""Check if prompt uses Jinja2 template syntax"""
return "{{" in prompt_content and "}}" in prompt_content
def _apply_legacy_substitutions(
self, prompt_content: str, docs_together: Optional[str] = None
) -> str:
"""
Apply backward-compatible substitutions for old prompt format.
Handles legacy {summaries} and {query} placeholders during transition period.
"""
if docs_together:
prompt_content = prompt_content.replace("{summaries}", docs_together)
return prompt_content
def validate_template(self, prompt_content: str) -> bool:
"""Validate prompt template syntax"""
return self.template_engine.validate_template(prompt_content)
def extract_variables(self, prompt_content: str) -> set[str]:
"""Extract all variable names from prompt template"""
return self.template_engine.extract_variables(prompt_content)

File diff suppressed because it is too large Load Diff

View File

@@ -1,551 +0,0 @@
import base64
import html
import json
import uuid
from urllib.parse import urlencode
from flask import (
Blueprint,
current_app,
jsonify,
make_response,
request
)
from flask_restx import fields, Namespace, Resource
from application.api import api
from application.api.user.tasks import (
ingest_connector_task,
)
from application.parser.connectors.connector_creator import ConnectorCreator
from application.storage.db.repositories.connector_sessions import (
ConnectorSessionsRepository,
)
from application.storage.db.repositories.sources import SourcesRepository
from application.storage.db.session import db_readonly, db_session
connector = Blueprint("connector", __name__)
connectors_ns = Namespace("connectors", description="Connector operations", path="/")
api.add_namespace(connectors_ns)
# Fixed callback status path to prevent open redirect
CALLBACK_STATUS_PATH = "/api/connectors/callback-status"
def build_callback_redirect(params: dict) -> str:
"""Build a safe redirect URL to the callback status page.
Uses a fixed path and properly URL-encodes all parameters
to prevent URL injection and open redirect vulnerabilities.
"""
return f"{CALLBACK_STATUS_PATH}?{urlencode(params)}"
@connectors_ns.route("/api/connectors/auth")
class ConnectorAuth(Resource):
@api.doc(description="Get connector OAuth authorization URL", params={"provider": "Connector provider (e.g., google_drive)"})
def get(self):
try:
provider = request.args.get('provider') or request.args.get('source')
if not provider:
return make_response(jsonify({"success": False, "error": "Missing provider"}), 400)
if not ConnectorCreator.is_supported(provider):
return make_response(jsonify({"success": False, "error": f"Unsupported provider: {provider}"}), 400)
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False, "error": "Unauthorized"}), 401)
user_id = decoded_token.get('sub')
with db_session() as conn:
session_row = ConnectorSessionsRepository(conn).upsert(
user_id, provider, status="pending",
)
session_pg_id = str(session_row["id"])
state_dict = {
"provider": provider,
"object_id": session_pg_id,
}
state = base64.urlsafe_b64encode(json.dumps(state_dict).encode()).decode()
auth = ConnectorCreator.create_auth(provider)
authorization_url = auth.get_authorization_url(state=state)
return make_response(jsonify({
"success": True,
"authorization_url": authorization_url,
"state": state
}), 200)
except Exception as e:
current_app.logger.error(f"Error generating connector auth URL: {e}", exc_info=True)
return make_response(jsonify({"success": False, "error": "Failed to generate authorization URL"}), 500)
@connectors_ns.route("/api/connectors/callback")
class ConnectorsCallback(Resource):
@api.doc(description="Handle OAuth callback for external connectors")
def get(self):
"""Handle OAuth callback for external connectors"""
try:
from application.parser.connectors.connector_creator import ConnectorCreator
from flask import request, redirect
authorization_code = request.args.get('code')
state = request.args.get('state')
error = request.args.get('error')
state_dict = json.loads(base64.urlsafe_b64decode(state.encode()).decode())
provider = state_dict.get("provider")
state_object_id = state_dict.get("object_id")
# Validate provider
if not provider or not isinstance(provider, str) or not ConnectorCreator.is_supported(provider):
return redirect(build_callback_redirect({
"status": "error",
"message": "Invalid provider"
}))
if error:
if error == "access_denied":
return redirect(build_callback_redirect({
"status": "cancelled",
"message": "Authentication was cancelled. You can try again if you'd like to connect your account.",
"provider": provider
}))
else:
current_app.logger.warning(f"OAuth error in callback: {error}")
return redirect(build_callback_redirect({
"status": "error",
"message": "Authentication failed. Please try again and make sure to grant all requested permissions.",
"provider": provider
}))
if not authorization_code:
return redirect(build_callback_redirect({
"status": "error",
"message": "Authentication failed. Please try again and make sure to grant all requested permissions.",
"provider": provider
}))
try:
auth = ConnectorCreator.create_auth(provider)
token_info = auth.exchange_code_for_tokens(authorization_code)
session_token = str(uuid.uuid4())
try:
if provider == "google_drive":
credentials = auth.create_credentials_from_token_info(token_info)
service = auth.build_drive_service(credentials)
user_info = service.about().get(fields="user").execute()
user_email = user_info.get('user', {}).get('emailAddress', 'Connected User')
else:
user_email = token_info.get('user_info', {}).get('email', 'Connected User')
except Exception as e:
current_app.logger.warning(f"Could not get user info: {e}")
user_email = 'Connected User'
sanitized_token_info = auth.sanitize_token_info(token_info)
# ``object_id`` in the OAuth state is the PG session row
# UUID (new flow) or a legacy Mongo ObjectId (pre-cutover
# issued state). Try UUID update first; fall back to
# legacy id path.
patch = {
"session_token": session_token,
"token_info": sanitized_token_info,
"user_email": user_email,
"status": "authorized",
}
with db_session() as conn:
repo = ConnectorSessionsRepository(conn)
if state_object_id:
value = str(state_object_id)
updated = False
if len(value) == 36 and "-" in value:
updated = repo.update(value, patch)
if not updated:
repo.update_by_legacy_id(value, patch)
# Redirect to success page with session token and user email
return redirect(build_callback_redirect({
"status": "success",
"message": "Authentication successful",
"provider": provider,
"session_token": session_token,
"user_email": user_email
}))
except Exception as e:
current_app.logger.error(f"Error exchanging code for tokens: {str(e)}", exc_info=True)
return redirect(build_callback_redirect({
"status": "error",
"message": "Authentication failed. Please try again and make sure to grant all requested permissions.",
"provider": provider
}))
except Exception as e:
current_app.logger.error(f"Error handling connector callback: {e}")
return redirect(build_callback_redirect({
"status": "error",
"message": "Authentication failed. Please try again and make sure to grant all requested permissions."
}))
@connectors_ns.route("/api/connectors/files")
class ConnectorFiles(Resource):
@api.expect(api.model("ConnectorFilesModel", {
"provider": fields.String(required=True),
"session_token": fields.String(required=True),
"folder_id": fields.String(required=False),
"limit": fields.Integer(required=False),
"page_token": fields.String(required=False),
"search_query": fields.String(required=False),
}))
@api.doc(description="List files from a connector provider (supports pagination and search)")
def post(self):
try:
data = request.get_json()
provider = data.get('provider')
session_token = data.get('session_token')
limit = data.get('limit', 10)
if not provider or not session_token:
return make_response(jsonify({"success": False, "error": "provider and session_token are required"}), 400)
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False, "error": "Unauthorized"}), 401)
user = decoded_token.get('sub')
with db_readonly() as conn:
session = ConnectorSessionsRepository(conn).get_by_session_token(
session_token,
)
if not session or session.get("user_id") != user:
return make_response(jsonify({"success": False, "error": "Invalid or unauthorized session"}), 401)
loader = ConnectorCreator.create_connector(provider, session_token)
generic_keys = {'provider', 'session_token'}
input_config = {
k: v for k, v in data.items() if k not in generic_keys
}
input_config['list_only'] = True
documents = loader.load_data(input_config)
files = []
for doc in documents[:limit]:
metadata = doc.extra_info
modified_time = metadata.get('modified_time')
if modified_time:
date_part = modified_time.split('T')[0]
time_part = modified_time.split('T')[1].split('.')[0].split('Z')[0]
formatted_time = f"{date_part} {time_part}"
else:
formatted_time = None
files.append({
'id': doc.doc_id,
'name': metadata.get('file_name', 'Unknown File'),
'type': metadata.get('mime_type', 'unknown'),
'size': metadata.get('size', None),
'modifiedTime': formatted_time,
'isFolder': metadata.get('is_folder', False)
})
next_token = getattr(loader, 'next_page_token', None)
has_more = bool(next_token)
return make_response(jsonify({
"success": True,
"files": files,
"total": len(files),
"next_page_token": next_token,
"has_more": has_more
}), 200)
except Exception as e:
current_app.logger.error(f"Error loading connector files: {e}", exc_info=True)
return make_response(jsonify({"success": False, "error": "Failed to load files"}), 500)
@connectors_ns.route("/api/connectors/validate-session")
class ConnectorValidateSession(Resource):
@api.expect(api.model("ConnectorValidateSessionModel", {"provider": fields.String(required=True), "session_token": fields.String(required=True)}))
@api.doc(description="Validate connector session token and return user info and access token")
def post(self):
try:
data = request.get_json()
provider = data.get('provider')
session_token = data.get('session_token')
if not provider or not session_token:
return make_response(jsonify({"success": False, "error": "provider and session_token are required"}), 400)
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False, "error": "Unauthorized"}), 401)
user = decoded_token.get('sub')
with db_readonly() as conn:
session = ConnectorSessionsRepository(conn).get_by_session_token(
session_token,
)
if not session or session.get("user_id") != user or not session.get("token_info"):
return make_response(jsonify({"success": False, "error": "Invalid or expired session"}), 401)
token_info = session["token_info"]
auth = ConnectorCreator.create_auth(provider)
is_expired = auth.is_token_expired(token_info)
if is_expired and token_info.get('refresh_token'):
try:
refreshed_token_info = auth.refresh_access_token(token_info.get('refresh_token'))
sanitized_token_info = auth.sanitize_token_info(refreshed_token_info)
with db_session() as conn:
repo = ConnectorSessionsRepository(conn)
row = repo.get_by_session_token(session_token)
if row:
repo.update(str(row["id"]), {"token_info": sanitized_token_info})
token_info = sanitized_token_info
is_expired = False
except Exception as refresh_error:
current_app.logger.error(f"Failed to refresh token: {refresh_error}")
if is_expired:
return make_response(jsonify({
"success": False,
"expired": True,
"error": "Session token has expired. Please reconnect."
}), 401)
_base_fields = {"access_token", "refresh_token", "token_uri", "expiry"}
provider_extras = {k: v for k, v in token_info.items() if k not in _base_fields}
response_data = {
"success": True,
"expired": False,
"user_email": session.get('user_email', 'Connected User'),
"access_token": token_info.get('access_token'),
**provider_extras,
}
return make_response(jsonify(response_data), 200)
except Exception as e:
current_app.logger.error(f"Error validating connector session: {e}", exc_info=True)
return make_response(jsonify({"success": False, "error": "Failed to validate session"}), 500)
@connectors_ns.route("/api/connectors/disconnect")
class ConnectorDisconnect(Resource):
@api.expect(api.model("ConnectorDisconnectModel", {"provider": fields.String(required=True), "session_token": fields.String(required=False)}))
@api.doc(description="Disconnect a connector session")
def post(self):
try:
data = request.get_json()
provider = data.get('provider')
session_token = data.get('session_token')
if not provider:
return make_response(jsonify({"success": False, "error": "provider is required"}), 400)
if session_token:
with db_session() as conn:
ConnectorSessionsRepository(conn).delete_by_session_token(
session_token,
)
return make_response(jsonify({"success": True}), 200)
except Exception as e:
current_app.logger.error(f"Error disconnecting connector session: {e}", exc_info=True)
return make_response(jsonify({"success": False, "error": "Failed to disconnect session"}), 500)
@connectors_ns.route("/api/connectors/sync")
class ConnectorSync(Resource):
@api.expect(
api.model(
"ConnectorSyncModel",
{
"source_id": fields.String(required=True, description="Source ID to sync"),
"session_token": fields.String(required=True, description="Authentication token")
},
)
)
@api.doc(description="Sync connector source to check for modifications")
def post(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
try:
data = request.get_json()
source_id = data.get('source_id')
session_token = data.get('session_token')
if not all([source_id, session_token]):
return make_response(
jsonify({
"success": False,
"error": "source_id and session_token are required"
}),
400
)
user_id = decoded_token.get('sub')
with db_readonly() as conn:
source = SourcesRepository(conn).get_any(source_id, user_id)
if not source:
return make_response(
jsonify({
"success": False,
"error": "Source not found"
}),
404
)
# ``get_any`` already scopes by ``user_id``; an extra guard
# here would be dead code.
remote_data = source.get('remote_data') or {}
if isinstance(remote_data, str):
try:
remote_data = json.loads(remote_data)
except json.JSONDecodeError:
current_app.logger.error(f"Invalid remote_data format for source {source_id}")
remote_data = {}
source_type = remote_data.get('provider')
if not source_type:
return make_response(
jsonify({
"success": False,
"error": "Source provider not found in remote_data"
}),
400
)
# Extract configuration from remote_data
file_ids = remote_data.get('file_ids', [])
folder_ids = remote_data.get('folder_ids', [])
recursive = remote_data.get('recursive', True)
# Start the sync task
task = ingest_connector_task.delay(
job_name=source.get('name'),
user=decoded_token.get('sub'),
source_type=source_type,
session_token=session_token,
file_ids=file_ids,
folder_ids=folder_ids,
recursive=recursive,
retriever=source.get('retriever', 'classic'),
operation_mode="sync",
doc_id=str(source.get('id') or source_id),
sync_frequency=source.get('sync_frequency', 'never')
)
return make_response(
jsonify({
"success": True,
"task_id": task.id
}),
200
)
except Exception as err:
current_app.logger.error(
f"Error syncing connector source: {err}",
exc_info=True
)
return make_response(
jsonify({
"success": False,
"error": "Failed to sync connector source"
}),
400
)
@connectors_ns.route("/api/connectors/callback-status")
class ConnectorCallbackStatus(Resource):
@api.doc(description="Return HTML page with connector authentication status")
def get(self):
"""Return HTML page with connector authentication status"""
try:
# Validate and sanitize status to a known value
status_raw = request.args.get('status', 'error')
status = status_raw if status_raw in ('success', 'error', 'cancelled') else 'error'
# Escape all user-controlled values for HTML context
message = html.escape(request.args.get('message', ''))
provider_raw = request.args.get('provider', 'connector')
provider = html.escape(provider_raw.replace('_', ' ').title())
session_token = request.args.get('session_token', '')
user_email = html.escape(request.args.get('user_email', ''))
def safe_js_string(value: str) -> str:
"""Safely encode a string for embedding in inline JavaScript."""
js_encoded = json.dumps(value)
return js_encoded.replace('</', '<\\/').replace('<!--', '<\\!--')
js_status = safe_js_string(status)
js_session_token = safe_js_string(session_token)
js_user_email = safe_js_string(user_email)
js_provider_type = safe_js_string(provider_raw)
html_content = f"""
<!DOCTYPE html>
<html>
<head>
<title>{provider} Authentication</title>
<style>
body {{ font-family: Arial, sans-serif; text-align: center; padding: 40px; }}
.container {{ max-width: 600px; margin: 0 auto; }}
.success {{ color: #4CAF50; }}
.error {{ color: #F44336; }}
.cancelled {{ color: #FF9800; }}
</style>
<script>
window.onload = function() {{
const status = {js_status};
const sessionToken = {js_session_token};
const userEmail = {js_user_email};
const providerType = {js_provider_type};
if (status === "success" && window.opener) {{
window.opener.postMessage({{
type: providerType + '_auth_success',
session_token: sessionToken,
user_email: userEmail
}}, '*');
setTimeout(() => window.close(), 3000);
}} else if (status === "cancelled" || status === "error") {{
setTimeout(() => window.close(), 3000);
}}
}};
</script>
</head>
<body>
<div class="container">
<h2>{provider} Authentication</h2>
<div class="{status}">
<p>{message}</p>
{f'<p>Connected as: {user_email}</p>' if status == 'success' else ''}
</div>
<p><small>You can close this window. {f"Your {provider} is now connected and ready to use." if status == 'success' else "Feel free to close this window."}</small></p>
</div>
</body>
</html>
"""
return make_response(html_content, 200, {'Content-Type': 'text/html'})
except Exception as e:
current_app.logger.error(f"Error rendering callback status page: {e}")
return make_response("Authentication error occurred", 500, {'Content-Type': 'text/html'})

View File

@@ -1,18 +1,19 @@
import os
import datetime
import json
from flask import Blueprint, request, send_from_directory, jsonify
from flask import Blueprint, request, send_from_directory
from werkzeug.utils import secure_filename
from bson.objectid import ObjectId
import logging
from application.core.mongo_db import MongoDB
from application.core.settings import settings
from application.storage.db.base_repository import looks_like_uuid
from application.storage.db.repositories.sources import SourcesRepository
from application.storage.db.session import db_session
from application.storage.storage_creator import StorageCreator
logger = logging.getLogger(__name__)
mongo = MongoDB.get_client()
db = mongo[settings.MONGO_DB_NAME]
conversations_collection = db["conversations"]
sources_collection = db["sources"]
current_dir = os.path.dirname(
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
@@ -22,24 +23,6 @@ current_dir = os.path.dirname(
internal = Blueprint("internal", __name__)
@internal.before_request
def verify_internal_key():
"""Verify INTERNAL_KEY for all internal endpoint requests.
Deny by default: if INTERNAL_KEY is not configured, reject all requests.
"""
if not settings.INTERNAL_KEY:
logger.warning(
f"Internal API request rejected from {request.remote_addr}: "
"INTERNAL_KEY is not configured"
)
return jsonify({"error": "Unauthorized", "message": "Internal API is not configured"}), 401
internal_key = request.headers.get("X-Internal-Key")
if not internal_key or internal_key != settings.INTERNAL_KEY:
logger.warning(f"Unauthorized internal API access attempt from {request.remote_addr}")
return jsonify({"error": "Unauthorized", "message": "Invalid or missing internal key"}), 401
@internal.route("/api/download", methods=["get"])
def download_file():
user = secure_filename(request.args.get("user"))
@@ -54,41 +37,22 @@ def upload_index_files():
"""Upload two files(index.faiss, index.pkl) to the user's folder."""
if "user" not in request.form:
return {"status": "no user"}
user = request.form["user"]
user = request.form["user"]
if "name" not in request.form:
return {"status": "no name"}
job_name = request.form["name"]
tokens = request.form["tokens"]
retriever = request.form["retriever"]
source_id = request.form["id"]
id = request.form["id"]
type = request.form["type"]
remote_data = request.form["remote_data"] if "remote_data" in request.form else None
sync_frequency = request.form["sync_frequency"] if "sync_frequency" in request.form else None
file_path = request.form.get("file_path")
directory_structure = request.form.get("directory_structure")
file_name_map = request.form.get("file_name_map")
if directory_structure:
try:
directory_structure = json.loads(directory_structure)
except Exception:
logger.error("Error parsing directory_structure")
directory_structure = {}
else:
directory_structure = {}
if file_name_map:
try:
file_name_map = json.loads(file_name_map)
except Exception:
logger.error("Error parsing file_name_map")
file_name_map = None
else:
file_name_map = None
original_file_path = request.form.get("original_file_path")
storage = StorageCreator.get_storage()
index_base_path = f"indexes/{source_id}"
index_base_path = f"indexes/{id}"
if settings.VECTOR_STORE == "faiss":
if "file_faiss" not in request.files:
logger.error("No file_faiss part")
@@ -102,55 +66,46 @@ def upload_index_files():
file_pkl = request.files["file_pkl"]
if file_pkl.filename == "":
return {"status": "no file name"}
# Save index files to storage
faiss_storage_path = f"{index_base_path}/index.faiss"
pkl_storage_path = f"{index_base_path}/index.pkl"
storage.save_file(file_faiss, faiss_storage_path)
storage.save_file(file_pkl, pkl_storage_path)
storage.save_file(file_faiss, f"{index_base_path}/index.faiss")
storage.save_file(file_pkl, f"{index_base_path}/index.pkl")
now = datetime.datetime.now(datetime.timezone.utc)
update_fields = {
"name": job_name,
"type": type,
"language": job_name,
"date": now,
"model": settings.EMBEDDINGS_NAME,
"tokens": tokens,
"retriever": retriever,
"remote_data": remote_data,
"sync_frequency": sync_frequency,
"file_path": file_path,
"directory_structure": directory_structure,
}
if file_name_map is not None:
update_fields["file_name_map"] = file_name_map
with db_session() as conn:
repo = SourcesRepository(conn)
existing = None
if looks_like_uuid(source_id):
existing = repo.get(source_id, user)
if existing is None:
existing = repo.get_by_legacy_id(source_id, user)
if existing is not None:
repo.update(str(existing["id"]), user, update_fields)
else:
repo.create(
job_name,
source_id=source_id if looks_like_uuid(source_id) else None,
user_id=user,
type=type,
tokens=tokens,
retriever=retriever,
remote_data=remote_data,
sync_frequency=sync_frequency,
file_path=file_path,
directory_structure=directory_structure,
file_name_map=file_name_map,
language=job_name,
model=settings.EMBEDDINGS_NAME,
date=now,
legacy_mongo_id=None if looks_like_uuid(source_id) else str(source_id),
)
existing_entry = sources_collection.find_one({"_id": ObjectId(id)})
if existing_entry:
sources_collection.update_one(
{"_id": ObjectId(id)},
{
"$set": {
"user": user,
"name": job_name,
"language": job_name,
"date": datetime.datetime.now(),
"model": settings.EMBEDDINGS_NAME,
"type": type,
"tokens": tokens,
"retriever": retriever,
"remote_data": remote_data,
"sync_frequency": sync_frequency,
"file_path": original_file_path,
}
},
)
else:
sources_collection.insert_one(
{
"_id": ObjectId(id),
"user": user,
"name": job_name,
"language": job_name,
"date": datetime.datetime.now(),
"model": settings.EMBEDDINGS_NAME,
"type": type,
"tokens": tokens,
"retriever": retriever,
"remote_data": remote_data,
"sync_frequency": sync_frequency,
"file_path": original_file_path,
}
)
return {"status": "ok"}

View File

@@ -1,5 +0,0 @@
"""User API module - provides all user-related API endpoints"""
from .routes import user
__all__ = ["user"]

View File

@@ -1,8 +0,0 @@
"""Agents module."""
from .routes import agents_ns
from .sharing import agents_sharing_ns
from .webhooks import agents_webhooks_ns
from .folders import agents_folders_ns
__all__ = ["agents_ns", "agents_sharing_ns", "agents_webhooks_ns", "agents_folders_ns"]

View File

@@ -1,366 +0,0 @@
"""
Agent folders management routes.
Provides virtual folder organization for agents (Google Drive-like structure).
"""
from flask import current_app, jsonify, make_response, request
from flask_restx import Namespace, Resource, fields
from sqlalchemy import text as _sql_text
from application.api import api
from application.storage.db.base_repository import looks_like_uuid
from application.storage.db.repositories.agent_folders import AgentFoldersRepository
from application.storage.db.repositories.agents import AgentsRepository
from application.storage.db.session import db_readonly, db_session
agents_folders_ns = Namespace(
"agents_folders", description="Agent folder management", path="/api/agents/folders"
)
def _resolve_folder_id(repo: AgentFoldersRepository, folder_id: str, user: str):
"""Resolve a folder id that may be either a UUID or legacy Mongo ObjectId."""
if not folder_id:
return None
if looks_like_uuid(folder_id):
row = repo.get(folder_id, user)
if row is not None:
return row
return repo.get_by_legacy_id(folder_id, user)
def _folder_error_response(message: str, err: Exception):
current_app.logger.error(f"{message}: {err}", exc_info=True)
return make_response(jsonify({"success": False, "message": message}), 400)
def _serialize_folder(f: dict) -> dict:
created_at = f.get("created_at")
updated_at = f.get("updated_at")
return {
"id": str(f["id"]),
"name": f.get("name"),
"parent_id": str(f["parent_id"]) if f.get("parent_id") else None,
"created_at": created_at.isoformat() if hasattr(created_at, "isoformat") else created_at,
"updated_at": updated_at.isoformat() if hasattr(updated_at, "isoformat") else updated_at,
}
@agents_folders_ns.route("/")
class AgentFolders(Resource):
@api.doc(description="Get all folders for the user")
def get(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
try:
with db_readonly() as conn:
folders = AgentFoldersRepository(conn).list_for_user(user)
result = [_serialize_folder(f) for f in folders]
return make_response(jsonify({"folders": result}), 200)
except Exception as err:
return _folder_error_response("Failed to fetch folders", err)
@api.doc(description="Create a new folder")
@api.expect(
api.model(
"CreateFolder",
{
"name": fields.String(required=True, description="Folder name"),
"parent_id": fields.String(required=False, description="Parent folder ID"),
},
)
)
def post(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
data = request.get_json()
if not data or not data.get("name"):
return make_response(jsonify({"success": False, "message": "Folder name is required"}), 400)
parent_id_input = data.get("parent_id")
description = data.get("description")
try:
with db_session() as conn:
repo = AgentFoldersRepository(conn)
pg_parent_id = None
if parent_id_input:
parent = _resolve_folder_id(repo, parent_id_input, user)
if not parent:
return make_response(
jsonify({"success": False, "message": "Parent folder not found"}),
404,
)
pg_parent_id = str(parent["id"])
folder = repo.create(
user, data["name"],
description=description,
parent_id=pg_parent_id,
)
return make_response(
jsonify(
{
"id": str(folder["id"]),
"name": folder["name"],
"parent_id": pg_parent_id,
}
),
201,
)
except Exception as err:
return _folder_error_response("Failed to create folder", err)
@agents_folders_ns.route("/<string:folder_id>")
class AgentFolder(Resource):
@api.doc(description="Get a specific folder with its agents")
def get(self, folder_id):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
try:
with db_readonly() as conn:
folders_repo = AgentFoldersRepository(conn)
folder = _resolve_folder_id(folders_repo, folder_id, user)
if not folder:
return make_response(
jsonify({"success": False, "message": "Folder not found"}),
404,
)
pg_folder_id = str(folder["id"])
agents_rows = conn.execute(
_sql_text(
"SELECT id, name, description FROM agents "
"WHERE user_id = :user_id AND folder_id = CAST(:fid AS uuid) "
"ORDER BY created_at DESC"
),
{"user_id": user, "fid": pg_folder_id},
).fetchall()
agents_list = [
{
"id": str(row._mapping["id"]),
"name": row._mapping["name"],
"description": row._mapping.get("description", "") or "",
}
for row in agents_rows
]
subfolders = folders_repo.list_children(pg_folder_id, user)
subfolders_list = [
{"id": str(sf["id"]), "name": sf["name"]}
for sf in subfolders
]
return make_response(
jsonify(
{
"id": pg_folder_id,
"name": folder["name"],
"parent_id": (
str(folder["parent_id"]) if folder.get("parent_id") else None
),
"agents": agents_list,
"subfolders": subfolders_list,
}
),
200,
)
except Exception as err:
return _folder_error_response("Failed to fetch folder", err)
@api.doc(description="Update a folder")
def put(self, folder_id):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
data = request.get_json()
if not data:
return make_response(jsonify({"success": False, "message": "No data provided"}), 400)
try:
with db_session() as conn:
repo = AgentFoldersRepository(conn)
folder = _resolve_folder_id(repo, folder_id, user)
if not folder:
return make_response(
jsonify({"success": False, "message": "Folder not found"}),
404,
)
pg_folder_id = str(folder["id"])
update_fields: dict = {}
if "name" in data:
update_fields["name"] = data["name"]
if "description" in data:
update_fields["description"] = data["description"]
if "parent_id" in data:
parent_input = data.get("parent_id")
if parent_input:
if parent_input == folder_id or parent_input == pg_folder_id:
return make_response(
jsonify(
{
"success": False,
"message": "Cannot set folder as its own parent",
}
),
400,
)
parent = _resolve_folder_id(repo, parent_input, user)
if not parent:
return make_response(
jsonify({"success": False, "message": "Parent folder not found"}),
404,
)
if str(parent["id"]) == pg_folder_id:
return make_response(
jsonify(
{
"success": False,
"message": "Cannot set folder as its own parent",
}
),
400,
)
update_fields["parent_id"] = str(parent["id"])
else:
update_fields["parent_id"] = None
if update_fields:
repo.update(pg_folder_id, user, update_fields)
return make_response(jsonify({"success": True}), 200)
except Exception as err:
return _folder_error_response("Failed to update folder", err)
@api.doc(description="Delete a folder")
def delete(self, folder_id):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
try:
with db_session() as conn:
repo = AgentFoldersRepository(conn)
folder = _resolve_folder_id(repo, folder_id, user)
if not folder:
return make_response(
jsonify({"success": False, "message": "Folder not found"}),
404,
)
pg_folder_id = str(folder["id"])
# Clear folder assignments from agents; self-FK
# ``ON DELETE SET NULL`` handles child folders.
AgentsRepository(conn).clear_folder_for_all(pg_folder_id, user)
deleted = repo.delete(pg_folder_id, user)
if not deleted:
return make_response(
jsonify({"success": False, "message": "Folder not found"}),
404,
)
return make_response(jsonify({"success": True}), 200)
except Exception as err:
return _folder_error_response("Failed to delete folder", err)
@agents_folders_ns.route("/move_agent")
class MoveAgentToFolder(Resource):
@api.doc(description="Move an agent to a folder or remove from folder")
@api.expect(
api.model(
"MoveAgent",
{
"agent_id": fields.String(required=True, description="Agent ID to move"),
"folder_id": fields.String(required=False, description="Target folder ID (null to remove from folder)"),
},
)
)
def post(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
data = request.get_json()
if not data or not data.get("agent_id"):
return make_response(jsonify({"success": False, "message": "Agent ID is required"}), 400)
agent_id_input = data["agent_id"]
folder_id_input = data.get("folder_id")
try:
with db_session() as conn:
agents_repo = AgentsRepository(conn)
agent = agents_repo.get_any(agent_id_input, user)
if not agent:
return make_response(
jsonify({"success": False, "message": "Agent not found"}),
404,
)
pg_folder_id = None
if folder_id_input:
folders_repo = AgentFoldersRepository(conn)
folder = _resolve_folder_id(folders_repo, folder_id_input, user)
if not folder:
return make_response(
jsonify({"success": False, "message": "Folder not found"}),
404,
)
pg_folder_id = str(folder["id"])
agents_repo.set_folder(str(agent["id"]), user, pg_folder_id)
return make_response(jsonify({"success": True}), 200)
except Exception as err:
return _folder_error_response("Failed to move agent", err)
@agents_folders_ns.route("/bulk_move")
class BulkMoveAgents(Resource):
@api.doc(description="Move multiple agents to a folder")
@api.expect(
api.model(
"BulkMoveAgents",
{
"agent_ids": fields.List(fields.String, required=True, description="List of agent IDs"),
"folder_id": fields.String(required=False, description="Target folder ID"),
},
)
)
def post(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
data = request.get_json()
if not data or not data.get("agent_ids"):
return make_response(jsonify({"success": False, "message": "Agent IDs are required"}), 400)
agent_ids = data["agent_ids"]
folder_id_input = data.get("folder_id")
try:
with db_session() as conn:
agents_repo = AgentsRepository(conn)
pg_folder_id = None
if folder_id_input:
folders_repo = AgentFoldersRepository(conn)
folder = _resolve_folder_id(folders_repo, folder_id_input, user)
if not folder:
return make_response(
jsonify({"success": False, "message": "Folder not found"}),
404,
)
pg_folder_id = str(folder["id"])
for agent_id_input in agent_ids:
agent = agents_repo.get_any(agent_id_input, user)
if agent is not None:
agents_repo.set_folder(str(agent["id"]), user, pg_folder_id)
return make_response(jsonify({"success": True}), 200)
except Exception as err:
return _folder_error_response("Failed to move agents", err)

File diff suppressed because it is too large Load Diff

View File

@@ -1,274 +0,0 @@
"""Agent management sharing functionality."""
import datetime
import secrets
from flask import current_app, jsonify, make_response, request
from flask_restx import fields, Namespace, Resource
from sqlalchemy import text as _sql_text
from application.api import api
from application.core.settings import settings
from application.api.user.base import resolve_tool_details
from application.storage.db.base_repository import looks_like_uuid
from application.storage.db.repositories.agents import AgentsRepository
from application.storage.db.repositories.users import UsersRepository
from application.storage.db.session import db_readonly, db_session
from application.utils import generate_image_url
agents_sharing_ns = Namespace(
"agents", description="Agent management operations", path="/api"
)
def _serialize_agent_basic(agent: dict) -> dict:
"""Shape a PG agent row into the API response dict."""
source_id = agent.get("source_id")
return {
"id": str(agent["id"]),
"user": agent.get("user_id", ""),
"name": agent.get("name", ""),
"image": (
generate_image_url(agent["image"]) if agent.get("image") else ""
),
"description": agent.get("description", ""),
"source": str(source_id) if source_id else "",
"chunks": str(agent["chunks"]) if agent.get("chunks") is not None else "0",
"retriever": agent.get("retriever", "classic") or "classic",
"prompt_id": str(agent["prompt_id"]) if agent.get("prompt_id") else "default",
"tools": agent.get("tools", []) or [],
"tool_details": resolve_tool_details(agent.get("tools", []) or []),
"agent_type": agent.get("agent_type", "") or "",
"status": agent.get("status", "") or "",
"json_schema": agent.get("json_schema"),
"limited_token_mode": agent.get("limited_token_mode", False),
"token_limit": agent.get("token_limit") or settings.DEFAULT_AGENT_LIMITS["token_limit"],
"limited_request_mode": agent.get("limited_request_mode", False),
"request_limit": agent.get("request_limit") or settings.DEFAULT_AGENT_LIMITS["request_limit"],
"created_at": agent.get("created_at", ""),
"updated_at": agent.get("updated_at", ""),
"shared": bool(agent.get("shared", False)),
"shared_token": agent.get("shared_token", "") or "",
"shared_metadata": agent.get("shared_metadata", {}) or {},
}
@agents_sharing_ns.route("/shared_agent")
class SharedAgent(Resource):
@api.doc(
params={
"token": "Shared token of the agent",
},
description="Get a shared agent by token or ID",
)
def get(self):
shared_token = request.args.get("token")
if not shared_token:
return make_response(
jsonify({"success": False, "message": "Token or ID is required"}), 400
)
try:
with db_readonly() as conn:
shared_agent = AgentsRepository(conn).find_by_shared_token(
shared_token,
)
if not shared_agent:
return make_response(
jsonify({"success": False, "message": "Shared agent not found"}),
404,
)
agent_id = str(shared_agent["id"])
data = _serialize_agent_basic(shared_agent)
if data["tools"]:
enriched_tools = []
for detail in data["tool_details"]:
enriched_tools.append(detail.get("name", ""))
data["tools"] = enriched_tools
decoded_token = getattr(request, "decoded_token", None)
if decoded_token:
user_id = decoded_token.get("sub")
owner_id = shared_agent.get("user_id")
if user_id != owner_id:
with db_session() as conn:
users_repo = UsersRepository(conn)
users_repo.upsert(user_id)
users_repo.add_shared(user_id, agent_id)
return make_response(jsonify(data), 200)
except Exception as err:
current_app.logger.error(f"Error retrieving shared agent: {err}")
return make_response(jsonify({"success": False}), 400)
@agents_sharing_ns.route("/shared_agents")
class SharedAgents(Resource):
@api.doc(description="Get shared agents explicitly shared with the user")
def get(self):
try:
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user_id = decoded_token.get("sub")
with db_session() as conn:
users_repo = UsersRepository(conn)
user_doc = users_repo.upsert(user_id)
shared_with_ids = (
user_doc.get("agent_preferences", {}).get("shared_with_me", [])
if isinstance(user_doc.get("agent_preferences"), dict)
else []
)
# Keep only UUID-shaped ids; ObjectId leftovers are stripped below.
uuid_ids = [sid for sid in shared_with_ids if looks_like_uuid(sid)]
non_uuid_ids = [sid for sid in shared_with_ids if not looks_like_uuid(sid)]
if uuid_ids:
result = conn.execute(
_sql_text(
"SELECT * FROM agents "
"WHERE id = ANY(CAST(:ids AS uuid[])) "
"AND shared = true"
),
{"ids": uuid_ids},
)
shared_agents = [dict(row._mapping) for row in result.fetchall()]
else:
shared_agents = []
found_ids_set = {str(agent["id"]) for agent in shared_agents}
stale_ids = [sid for sid in uuid_ids if sid not in found_ids_set]
stale_ids.extend(non_uuid_ids)
if stale_ids:
users_repo.remove_shared_bulk(user_id, stale_ids)
pinned_ids = set(
user_doc.get("agent_preferences", {}).get("pinned", [])
if isinstance(user_doc.get("agent_preferences"), dict)
else []
)
list_shared_agents = []
for agent in shared_agents:
agent_id_str = str(agent["id"])
list_shared_agents.append(
{
"id": agent_id_str,
"name": agent.get("name", ""),
"description": agent.get("description", ""),
"image": (
generate_image_url(agent["image"]) if agent.get("image") else ""
),
"tools": agent.get("tools", []) or [],
"tool_details": resolve_tool_details(
agent.get("tools", []) or []
),
"agent_type": agent.get("agent_type", "") or "",
"status": agent.get("status", "") or "",
"json_schema": agent.get("json_schema"),
"limited_token_mode": agent.get("limited_token_mode", False),
"token_limit": agent.get("token_limit") or settings.DEFAULT_AGENT_LIMITS["token_limit"],
"limited_request_mode": agent.get("limited_request_mode", False),
"request_limit": agent.get("request_limit") or settings.DEFAULT_AGENT_LIMITS["request_limit"],
"created_at": agent.get("created_at", ""),
"updated_at": agent.get("updated_at", ""),
"pinned": agent_id_str in pinned_ids,
"shared": bool(agent.get("shared", False)),
"shared_token": agent.get("shared_token", "") or "",
"shared_metadata": agent.get("shared_metadata", {}) or {},
}
)
return make_response(jsonify(list_shared_agents), 200)
except Exception as err:
current_app.logger.error(f"Error retrieving shared agents: {err}")
return make_response(jsonify({"success": False}), 400)
@agents_sharing_ns.route("/share_agent")
class ShareAgent(Resource):
@api.expect(
api.model(
"ShareAgentModel",
{
"id": fields.String(required=True, description="ID of the agent"),
"shared": fields.Boolean(
required=True, description="Share or unshare the agent"
),
"username": fields.String(
required=False, description="Name of the user"
),
},
)
)
@api.doc(description="Share or unshare an agent")
def put(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
data = request.get_json()
if not data:
return make_response(
jsonify({"success": False, "message": "Missing JSON body"}), 400
)
agent_id = data.get("id")
shared = data.get("shared")
username = data.get("username", "")
if not agent_id:
return make_response(
jsonify({"success": False, "message": "ID is required"}), 400
)
if shared is None:
return make_response(
jsonify(
{
"success": False,
"message": "Shared parameter is required and must be true or false",
}
),
400,
)
shared_token = None
try:
with db_session() as conn:
repo = AgentsRepository(conn)
agent = repo.get_any(agent_id, user)
if not agent:
return make_response(
jsonify({"success": False, "message": "Agent not found"}), 404
)
if shared:
shared_metadata = {
"shared_by": username,
"shared_at": datetime.datetime.now(
datetime.timezone.utc
).isoformat(),
}
shared_token = secrets.token_urlsafe(32)
repo.update(
str(agent["id"]), user,
{
"shared": True,
"shared_token": shared_token,
"shared_metadata": shared_metadata,
},
)
else:
repo.update(
str(agent["id"]), user,
{
"shared": False,
"shared_token": None,
"shared_metadata": None,
},
)
except Exception as err:
current_app.logger.error(f"Error sharing/unsharing agent: {err}", exc_info=True)
return make_response(jsonify({"success": False, "error": "Failed to update agent sharing status"}), 400)
return make_response(
jsonify({"success": True, "shared_token": shared_token}), 200
)

View File

@@ -1,120 +0,0 @@
"""Agent management webhook handlers."""
import secrets
from flask import current_app, jsonify, make_response, request
from flask_restx import Namespace, Resource
from application.api import api
from application.api.user.base import require_agent
from application.api.user.tasks import process_agent_webhook
from application.core.settings import settings
from application.storage.db.repositories.agents import AgentsRepository
from application.storage.db.session import db_readonly, db_session
agents_webhooks_ns = Namespace(
"agents", description="Agent management operations", path="/api"
)
@agents_webhooks_ns.route("/agent_webhook")
class AgentWebhook(Resource):
@api.doc(
params={"id": "ID of the agent"},
description="Generate webhook URL for the agent",
)
def get(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
agent_id = request.args.get("id")
if not agent_id:
return make_response(
jsonify({"success": False, "message": "ID is required"}), 400
)
try:
with db_readonly() as conn:
agent = AgentsRepository(conn).get_any(agent_id, user)
if not agent:
return make_response(
jsonify({"success": False, "message": "Agent not found"}), 404
)
webhook_token = agent.get("incoming_webhook_token")
if not webhook_token:
webhook_token = secrets.token_urlsafe(32)
with db_session() as conn:
AgentsRepository(conn).update(
str(agent["id"]), user,
{"incoming_webhook_token": webhook_token},
)
base_url = settings.API_URL.rstrip("/")
full_webhook_url = f"{base_url}/api/webhooks/agents/{webhook_token}"
except Exception as err:
current_app.logger.error(
f"Error generating webhook URL: {err}", exc_info=True
)
return make_response(
jsonify({"success": False, "message": "Error generating webhook URL"}),
400,
)
return make_response(
jsonify({"success": True, "webhook_url": full_webhook_url}), 200
)
@agents_webhooks_ns.route("/webhooks/agents/<string:webhook_token>")
class AgentWebhookListener(Resource):
method_decorators = [require_agent]
def _enqueue_webhook_task(self, agent_id_str, payload, source_method):
if not payload:
current_app.logger.warning(
f"Webhook ({source_method}) received for agent {agent_id_str} with empty payload."
)
current_app.logger.info(
f"Incoming {source_method} webhook for agent {agent_id_str}. Enqueuing task with payload: {payload}"
)
try:
task = process_agent_webhook.delay(
agent_id=agent_id_str,
payload=payload,
)
current_app.logger.info(
f"Task {task.id} enqueued for agent {agent_id_str} ({source_method})."
)
return make_response(jsonify({"success": True, "task_id": task.id}), 200)
except Exception as err:
current_app.logger.error(
f"Error enqueuing webhook task ({source_method}) for agent {agent_id_str}: {err}",
exc_info=True,
)
return make_response(
jsonify({"success": False, "message": "Error processing webhook"}), 500
)
@api.doc(
description="Webhook listener for agent events (POST). Expects JSON payload, which is used to trigger processing.",
)
def post(self, webhook_token, agent, agent_id_str):
payload = request.get_json()
if payload is None:
return make_response(
jsonify(
{
"success": False,
"message": "Invalid or missing JSON data in request body",
}
),
400,
)
return self._enqueue_webhook_task(agent_id_str, payload, source_method="POST")
@api.doc(
description="Webhook listener for agent events (GET). Uses URL query parameters as payload to trigger processing.",
)
def get(self, webhook_token, agent, agent_id_str):
payload = request.args.to_dict(flat=True)
return self._enqueue_webhook_task(agent_id_str, payload, source_method="GET")

View File

@@ -1,5 +0,0 @@
"""Analytics module."""
from .routes import analytics_ns
__all__ = ["analytics_ns"]

View File

@@ -1,437 +0,0 @@
"""Analytics and reporting routes."""
import datetime
from flask import current_app, jsonify, make_response, request
from flask_restx import fields, Namespace, Resource
from sqlalchemy import text as _sql_text
from application.api import api
from application.api.user.base import (
generate_date_range,
generate_hourly_range,
generate_minute_range,
)
from application.storage.db.repositories.agents import AgentsRepository
from application.storage.db.repositories.token_usage import TokenUsageRepository
from application.storage.db.repositories.user_logs import UserLogsRepository
from application.storage.db.session import db_readonly
analytics_ns = Namespace(
"analytics", description="Analytics and reporting operations", path="/api"
)
_FILTER_BUCKETS = {
"last_hour": ("minute", "%Y-%m-%d %H:%M:00", "YYYY-MM-DD HH24:MI:00"),
"last_24_hour": ("hour", "%Y-%m-%d %H:00", "YYYY-MM-DD HH24:00"),
"last_7_days": ("day", "%Y-%m-%d", "YYYY-MM-DD"),
"last_15_days": ("day", "%Y-%m-%d", "YYYY-MM-DD"),
"last_30_days": ("day", "%Y-%m-%d", "YYYY-MM-DD"),
}
def _range_for_filter(filter_option: str):
"""Return ``(start_date, end_date, bucket_unit, pg_fmt)`` for the filter.
Returns ``None`` on invalid filter.
"""
if filter_option not in _FILTER_BUCKETS:
return None
end_date = datetime.datetime.now(datetime.timezone.utc)
bucket_unit, _py_fmt, pg_fmt = _FILTER_BUCKETS[filter_option]
if filter_option == "last_hour":
start_date = end_date - datetime.timedelta(hours=1)
elif filter_option == "last_24_hour":
start_date = end_date - datetime.timedelta(hours=24)
else:
days = {
"last_7_days": 6,
"last_15_days": 14,
"last_30_days": 29,
}[filter_option]
start_date = end_date - datetime.timedelta(days=days)
start_date = start_date.replace(hour=0, minute=0, second=0, microsecond=0)
end_date = end_date.replace(
hour=23, minute=59, second=59, microsecond=999999
)
return start_date, end_date, bucket_unit, pg_fmt
def _intervals_for_filter(filter_option, start_date, end_date):
if filter_option == "last_hour":
return generate_minute_range(start_date, end_date)
if filter_option == "last_24_hour":
return generate_hourly_range(start_date, end_date)
return generate_date_range(start_date, end_date)
def _resolve_api_key(conn, api_key_id, user_id):
"""Look up the ``agents.key`` value for a given agent id.
Scoped by ``user_id`` so an authenticated caller can't probe another
user's agents. Accepts either UUID or legacy Mongo ObjectId shape.
"""
if not api_key_id:
return None
agent = AgentsRepository(conn).get_any(api_key_id, user_id)
return (agent or {}).get("key") if agent else None
@analytics_ns.route("/get_message_analytics")
class GetMessageAnalytics(Resource):
get_message_analytics_model = api.model(
"GetMessageAnalyticsModel",
{
"api_key_id": fields.String(required=False, description="API Key ID"),
"filter_option": fields.String(
required=False,
description="Filter option for analytics",
default="last_30_days",
enum=list(_FILTER_BUCKETS.keys()),
),
},
)
@api.expect(get_message_analytics_model)
@api.doc(description="Get message analytics based on filter option")
def post(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
data = request.get_json() or {}
api_key_id = data.get("api_key_id")
filter_option = data.get("filter_option", "last_30_days")
window = _range_for_filter(filter_option)
if window is None:
return make_response(
jsonify({"success": False, "message": "Invalid option"}), 400
)
start_date, end_date, _bucket_unit, pg_fmt = window
try:
with db_readonly() as conn:
api_key = _resolve_api_key(conn, api_key_id, user)
# Count messages per bucket, filtered by the conversation's
# owner (user_id) and optionally the agent api_key. The
# ``user_id`` filter is always applied post-cutover to
# prevent cross-tenant leakage on admin dashboards.
clauses = [
"c.user_id = :user_id",
"m.timestamp >= :start",
"m.timestamp <= :end",
]
params: dict = {
"user_id": user,
"start": start_date,
"end": end_date,
"fmt": pg_fmt,
}
if api_key:
clauses.append("c.api_key = :api_key")
params["api_key"] = api_key
where = " AND ".join(clauses)
sql = (
"SELECT to_char(m.timestamp AT TIME ZONE 'UTC', :fmt) AS bucket, "
"COUNT(*) AS count "
"FROM conversation_messages m "
"JOIN conversations c ON c.id = m.conversation_id "
f"WHERE {where} "
"GROUP BY bucket ORDER BY bucket ASC"
)
rows = conn.execute(_sql_text(sql), params).fetchall()
intervals = _intervals_for_filter(filter_option, start_date, end_date)
daily_messages = {interval: 0 for interval in intervals}
for row in rows:
daily_messages[row._mapping["bucket"]] = int(row._mapping["count"])
except Exception as err:
current_app.logger.error(
f"Error getting message analytics: {err}", exc_info=True
)
return make_response(jsonify({"success": False}), 400)
return make_response(
jsonify({"success": True, "messages": daily_messages}), 200
)
@analytics_ns.route("/get_token_analytics")
class GetTokenAnalytics(Resource):
get_token_analytics_model = api.model(
"GetTokenAnalyticsModel",
{
"api_key_id": fields.String(required=False, description="API Key ID"),
"filter_option": fields.String(
required=False,
description="Filter option for analytics",
default="last_30_days",
enum=list(_FILTER_BUCKETS.keys()),
),
},
)
@api.expect(get_token_analytics_model)
@api.doc(description="Get token analytics data")
def post(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
data = request.get_json() or {}
api_key_id = data.get("api_key_id")
filter_option = data.get("filter_option", "last_30_days")
window = _range_for_filter(filter_option)
if window is None:
return make_response(
jsonify({"success": False, "message": "Invalid option"}), 400
)
start_date, end_date, bucket_unit, _pg_fmt = window
try:
with db_readonly() as conn:
api_key = _resolve_api_key(conn, api_key_id, user)
# ``bucketed_totals`` applies user_id / api_key filters
# directly — no need to reshape a Mongo pipeline.
rows = TokenUsageRepository(conn).bucketed_totals(
bucket_unit=bucket_unit,
user_id=user,
api_key=api_key,
timestamp_gte=start_date,
timestamp_lt=end_date,
)
intervals = _intervals_for_filter(filter_option, start_date, end_date)
daily_token_usage = {interval: 0 for interval in intervals}
for entry in rows:
daily_token_usage[entry["bucket"]] = int(
entry["prompt_tokens"] + entry["generated_tokens"]
)
except Exception as err:
current_app.logger.error(
f"Error getting token analytics: {err}", exc_info=True
)
return make_response(jsonify({"success": False}), 400)
return make_response(
jsonify({"success": True, "token_usage": daily_token_usage}), 200
)
@analytics_ns.route("/get_feedback_analytics")
class GetFeedbackAnalytics(Resource):
get_feedback_analytics_model = api.model(
"GetFeedbackAnalyticsModel",
{
"api_key_id": fields.String(required=False, description="API Key ID"),
"filter_option": fields.String(
required=False,
description="Filter option for analytics",
default="last_30_days",
enum=list(_FILTER_BUCKETS.keys()),
),
},
)
@api.expect(get_feedback_analytics_model)
@api.doc(description="Get feedback analytics data")
def post(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
data = request.get_json() or {}
api_key_id = data.get("api_key_id")
filter_option = data.get("filter_option", "last_30_days")
window = _range_for_filter(filter_option)
if window is None:
return make_response(
jsonify({"success": False, "message": "Invalid option"}), 400
)
start_date, end_date, _bucket_unit, pg_fmt = window
try:
with db_readonly() as conn:
api_key = _resolve_api_key(conn, api_key_id, user)
# Feedback lives inside the ``conversation_messages.feedback``
# JSONB as ``{"text": "like"|"dislike", "timestamp": "..."}``.
# There is no scalar ``feedback_timestamp`` column — extract
# the timestamp from the JSONB and cast it to timestamptz for
# the range filter + bucket grouping.
clauses = [
"c.user_id = :user_id",
"m.feedback IS NOT NULL",
"(m.feedback->>'timestamp')::timestamptz >= :start",
"(m.feedback->>'timestamp')::timestamptz <= :end",
]
params: dict = {
"user_id": user,
"start": start_date,
"end": end_date,
"fmt": pg_fmt,
}
if api_key:
clauses.append("c.api_key = :api_key")
params["api_key"] = api_key
where = " AND ".join(clauses)
sql = (
"SELECT to_char("
"(m.feedback->>'timestamp')::timestamptz AT TIME ZONE 'UTC', :fmt"
") AS bucket, "
"SUM(CASE WHEN m.feedback->>'text' = 'like' THEN 1 ELSE 0 END) AS positive, "
"SUM(CASE WHEN m.feedback->>'text' = 'dislike' THEN 1 ELSE 0 END) AS negative "
"FROM conversation_messages m "
"JOIN conversations c ON c.id = m.conversation_id "
f"WHERE {where} "
"GROUP BY bucket ORDER BY bucket ASC"
)
rows = conn.execute(_sql_text(sql), params).fetchall()
intervals = _intervals_for_filter(filter_option, start_date, end_date)
daily_feedback = {
interval: {"positive": 0, "negative": 0} for interval in intervals
}
for row in rows:
bucket = row._mapping["bucket"]
daily_feedback[bucket] = {
"positive": int(row._mapping["positive"] or 0),
"negative": int(row._mapping["negative"] or 0),
}
except Exception as err:
current_app.logger.error(
f"Error getting feedback analytics: {err}", exc_info=True
)
return make_response(jsonify({"success": False}), 400)
return make_response(
jsonify({"success": True, "feedback": daily_feedback}), 200
)
@analytics_ns.route("/get_user_logs")
class GetUserLogs(Resource):
get_user_logs_model = api.model(
"GetUserLogsModel",
{
"page": fields.Integer(
required=False,
description="Page number for pagination",
default=1,
),
"api_key_id": fields.String(required=False, description="API Key ID"),
"page_size": fields.Integer(
required=False,
description="Number of logs per page",
default=10,
),
},
)
@api.expect(get_user_logs_model)
@api.doc(description="Get user logs with pagination")
def post(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
data = request.get_json() or {}
page = int(data.get("page", 1))
api_key_id = data.get("api_key_id")
page_size = int(data.get("page_size", 10))
try:
with db_readonly() as conn:
api_key = _resolve_api_key(conn, api_key_id, user)
logs_repo = UserLogsRepository(conn)
if api_key:
# ``find_by_api_key`` filters on ``data->>'api_key'``
# — the PG shape of the legacy top-level ``api_key``
# filter. Paginate client-side using offset/limit.
all_rows = logs_repo.find_by_api_key(api_key)
offset = (page - 1) * page_size
window = all_rows[offset: offset + page_size + 1]
items = window
else:
items, has_more_flag = logs_repo.list_paginated(
user_id=user,
page=page,
page_size=page_size,
)
# list_paginated already trims to page_size and
# returns has_more separately.
results = [
{
"id": str(item.get("id") or item.get("_id")),
"action": (item.get("data") or {}).get("action"),
"level": (item.get("data") or {}).get("level"),
"user": item.get("user_id"),
"question": (item.get("data") or {}).get("question"),
"sources": (item.get("data") or {}).get("sources"),
"retriever_params": (item.get("data") or {}).get(
"retriever_params"
),
"timestamp": (
item["timestamp"].isoformat()
if hasattr(item.get("timestamp"), "isoformat")
else item.get("timestamp")
),
}
for item in items
]
return make_response(
jsonify(
{
"success": True,
"logs": results,
"page": page,
"page_size": page_size,
"has_more": has_more_flag,
}
),
200,
)
has_more = len(items) > page_size
items = items[:page_size]
results = [
{
"id": str(item.get("id") or item.get("_id")),
"action": (item.get("data") or {}).get("action"),
"level": (item.get("data") or {}).get("level"),
"user": item.get("user_id"),
"question": (item.get("data") or {}).get("question"),
"sources": (item.get("data") or {}).get("sources"),
"retriever_params": (item.get("data") or {}).get(
"retriever_params"
),
"timestamp": (
item["timestamp"].isoformat()
if hasattr(item.get("timestamp"), "isoformat")
else item.get("timestamp")
),
}
for item in items
]
except Exception as err:
current_app.logger.error(
f"Error getting user logs: {err}", exc_info=True
)
return make_response(jsonify({"success": False}), 400)
return make_response(
jsonify(
{
"success": True,
"logs": results,
"page": page,
"page_size": page_size,
"has_more": has_more,
}
),
200,
)

View File

@@ -1,5 +0,0 @@
"""Attachments module."""
from .routes import attachments_ns
__all__ = ["attachments_ns"]

View File

@@ -1,680 +0,0 @@
"""File attachments and media routes."""
import os
import tempfile
from pathlib import Path
import uuid
from flask import current_app, jsonify, make_response, request
from flask_restx import fields, Namespace, Resource
from application.api import api
from application.cache import get_redis_instance
from application.core.settings import settings
from application.storage.db.repositories.agents import AgentsRepository
from application.storage.db.session import db_readonly
from application.stt.constants import (
SUPPORTED_AUDIO_EXTENSIONS,
SUPPORTED_AUDIO_MIME_TYPES,
)
from application.stt.upload_limits import (
AudioFileTooLargeError,
build_stt_file_size_limit_message,
enforce_audio_file_size_limit,
is_audio_filename,
)
from application.stt.live_session import (
apply_live_stt_hypothesis,
create_live_stt_session,
delete_live_stt_session,
finalize_live_stt_session,
get_live_stt_transcript_text,
load_live_stt_session,
save_live_stt_session,
)
from application.stt.stt_creator import STTCreator
from application.tts.tts_creator import TTSCreator
from application.utils import safe_filename
attachments_ns = Namespace(
"attachments", description="File attachments and media operations", path="/api"
)
def _resolve_authenticated_user():
decoded_token = getattr(request, "decoded_token", None)
api_key = request.form.get("api_key") or request.args.get("api_key")
if decoded_token:
return safe_filename(decoded_token.get("sub"))
if api_key:
with db_readonly() as conn:
agent = AgentsRepository(conn).find_by_key(api_key)
if not agent:
return make_response(
jsonify({"success": False, "message": "Invalid API key"}), 401
)
return safe_filename(agent.get("user_id"))
return None
def _get_uploaded_file_size(file) -> int:
try:
current_position = file.stream.tell()
file.stream.seek(0, os.SEEK_END)
size_bytes = file.stream.tell()
file.stream.seek(current_position)
return size_bytes
except Exception:
return 0
def _is_supported_audio_mimetype(mimetype: str) -> bool:
if not mimetype:
return True
normalized = mimetype.split(";")[0].strip().lower()
return normalized.startswith("audio/") or normalized in SUPPORTED_AUDIO_MIME_TYPES
def _enforce_uploaded_audio_size_limit(file, filename: str) -> None:
if not is_audio_filename(filename):
return
size_bytes = _get_uploaded_file_size(file)
if size_bytes:
enforce_audio_file_size_limit(size_bytes)
def _get_store_attachment_user_error(exc: Exception) -> str:
if isinstance(exc, AudioFileTooLargeError):
return build_stt_file_size_limit_message()
return "Failed to process file"
def _require_live_stt_redis():
redis_client = get_redis_instance()
if redis_client:
return redis_client
return make_response(
jsonify({"success": False, "message": "Live transcription is unavailable"}),
503,
)
def _parse_bool_form_value(value: str | None) -> bool:
if value is None:
return False
return value.strip().lower() in {"1", "true", "yes", "on"}
@attachments_ns.route("/store_attachment")
class StoreAttachment(Resource):
@api.expect(
api.model(
"AttachmentModel",
{
"file": fields.Raw(required=True, description="File(s) to upload"),
"api_key": fields.String(
required=False, description="API key (optional)"
),
},
)
)
@api.doc(
description="Stores one or multiple attachments without vectorization or training. Supports user or API key authentication."
)
def post(self):
auth_user = _resolve_authenticated_user()
if hasattr(auth_user, "status_code"):
return auth_user
files = request.files.getlist("file")
if not files:
single_file = request.files.get("file")
if single_file:
files = [single_file]
if not files or all(f.filename == "" for f in files):
return make_response(
jsonify({"status": "error", "message": "Missing file(s)"}),
400,
)
user = auth_user
if not user:
return make_response(
jsonify({"success": False, "message": "Authentication required"}), 401
)
try:
from application.api.user.tasks import store_attachment
from application.api.user.base import storage
tasks = []
errors = []
original_file_count = len(files)
for idx, file in enumerate(files):
try:
attachment_id = uuid.uuid4()
original_filename = safe_filename(os.path.basename(file.filename))
_enforce_uploaded_audio_size_limit(file, original_filename)
relative_path = f"{settings.UPLOAD_FOLDER}/{user}/attachments/{str(attachment_id)}/{original_filename}"
metadata = storage.save_file(file, relative_path)
file_info = {
"filename": original_filename,
"attachment_id": str(attachment_id),
"path": relative_path,
"metadata": metadata,
}
task = store_attachment.delay(file_info, user)
tasks.append({
"task_id": task.id,
"filename": original_filename,
"attachment_id": str(attachment_id),
"upload_index": idx,
})
except Exception as file_err:
current_app.logger.error(f"Error processing file {idx} ({file.filename}): {file_err}", exc_info=True)
errors.append({
"upload_index": idx,
"filename": file.filename,
"error": _get_store_attachment_user_error(file_err),
})
if not tasks:
if errors and all(
error.get("error") == build_stt_file_size_limit_message()
for error in errors
):
return make_response(
jsonify(
{
"success": False,
"message": build_stt_file_size_limit_message(),
"errors": errors,
}
),
413,
)
return make_response(
jsonify({"status": "error", "message": "No valid files to upload"}),
400,
)
if original_file_count == 1 and len(tasks) == 1:
current_app.logger.info("Returning single task_id response")
return make_response(
jsonify(
{
"success": True,
"task_id": tasks[0]["task_id"],
"message": "File uploaded successfully. Processing started.",
}
),
200,
)
else:
response_data = {
"success": True,
"tasks": tasks,
"message": f"{len(tasks)} file(s) uploaded successfully. Processing started.",
}
if errors:
response_data["errors"] = errors
response_data["message"] += f" {len(errors)} file(s) failed."
return make_response(
jsonify(response_data),
200,
)
except Exception as err:
current_app.logger.error(f"Error storing attachment: {err}", exc_info=True)
return make_response(jsonify({"success": False, "error": "Failed to store attachment"}), 400)
@attachments_ns.route("/stt")
class SpeechToText(Resource):
@api.expect(
api.model(
"SpeechToTextModel",
{
"file": fields.Raw(required=True, description="Audio file"),
"language": fields.String(
required=False, description="Optional transcription language hint"
),
},
)
)
@api.doc(description="Transcribe an uploaded audio file")
def post(self):
auth_user = _resolve_authenticated_user()
if hasattr(auth_user, "status_code"):
return auth_user
if not auth_user:
return make_response(
jsonify({"success": False, "message": "Authentication required"}),
401,
)
file = request.files.get("file")
if not file or file.filename == "":
return make_response(
jsonify({"success": False, "message": "Missing file"}),
400,
)
filename = safe_filename(os.path.basename(file.filename))
suffix = Path(filename).suffix.lower()
if suffix not in SUPPORTED_AUDIO_EXTENSIONS:
return make_response(
jsonify({"success": False, "message": "Unsupported audio format"}),
400,
)
if not _is_supported_audio_mimetype(file.mimetype or ""):
return make_response(
jsonify({"success": False, "message": "Unsupported audio MIME type"}),
400,
)
try:
_enforce_uploaded_audio_size_limit(file, filename)
except AudioFileTooLargeError:
return make_response(
jsonify(
{
"success": False,
"message": build_stt_file_size_limit_message(),
}
),
413,
)
temp_path = None
try:
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as temp_file:
file.save(temp_file.name)
temp_path = Path(temp_file.name)
stt_instance = STTCreator.create_stt(settings.STT_PROVIDER)
transcript = stt_instance.transcribe(
temp_path,
language=request.form.get("language") or settings.STT_LANGUAGE,
timestamps=settings.STT_ENABLE_TIMESTAMPS,
diarize=settings.STT_ENABLE_DIARIZATION,
)
return make_response(jsonify({"success": True, **transcript}), 200)
except Exception as err:
current_app.logger.error(f"Error transcribing audio: {err}", exc_info=True)
return make_response(
jsonify({"success": False, "message": "Failed to transcribe audio"}),
400,
)
finally:
if temp_path and temp_path.exists():
temp_path.unlink()
@attachments_ns.route("/stt/live/start")
class LiveSpeechToTextStart(Resource):
@api.doc(description="Start a live speech-to-text session")
def post(self):
auth_user = _resolve_authenticated_user()
if hasattr(auth_user, "status_code"):
return auth_user
if not auth_user:
return make_response(
jsonify({"success": False, "message": "Authentication required"}),
401,
)
redis_client = _require_live_stt_redis()
if hasattr(redis_client, "status_code"):
return redis_client
payload = request.get_json(silent=True) or {}
session_state = create_live_stt_session(
user=auth_user,
language=payload.get("language") or settings.STT_LANGUAGE,
)
save_live_stt_session(redis_client, session_state)
return make_response(
jsonify(
{
"success": True,
"session_id": session_state["session_id"],
"language": session_state.get("language"),
"committed_text": "",
"mutable_text": "",
"previous_hypothesis": "",
"latest_hypothesis": "",
"finalized_text": "",
"pending_text": "",
"transcript_text": "",
}
),
200,
)
@attachments_ns.route("/stt/live/chunk")
class LiveSpeechToTextChunk(Resource):
@api.expect(
api.model(
"LiveSpeechToTextChunkModel",
{
"session_id": fields.String(
required=True, description="Live transcription session ID"
),
"chunk_index": fields.Integer(
required=True, description="Sequential chunk index"
),
"is_silence": fields.Boolean(
required=False,
description="Whether the latest capture window was mostly silence",
),
"file": fields.Raw(required=True, description="Audio chunk"),
},
)
)
@api.doc(description="Transcribe a chunk for a live speech-to-text session")
def post(self):
auth_user = _resolve_authenticated_user()
if hasattr(auth_user, "status_code"):
return auth_user
if not auth_user:
return make_response(
jsonify({"success": False, "message": "Authentication required"}),
401,
)
redis_client = _require_live_stt_redis()
if hasattr(redis_client, "status_code"):
return redis_client
session_id = request.form.get("session_id", "").strip()
if not session_id:
return make_response(
jsonify({"success": False, "message": "Missing session_id"}),
400,
)
session_state = load_live_stt_session(redis_client, session_id)
if not session_state:
return make_response(
jsonify(
{
"success": False,
"message": "Live transcription session not found",
}
),
404,
)
if safe_filename(str(session_state.get("user", ""))) != auth_user:
return make_response(
jsonify({"success": False, "message": "Forbidden"}),
403,
)
chunk_index_raw = request.form.get("chunk_index", "").strip()
if chunk_index_raw == "":
return make_response(
jsonify({"success": False, "message": "Missing chunk_index"}),
400,
)
try:
chunk_index = int(chunk_index_raw)
except ValueError:
return make_response(
jsonify({"success": False, "message": "Invalid chunk_index"}),
400,
)
is_silence = _parse_bool_form_value(request.form.get("is_silence"))
file = request.files.get("file")
if not file or file.filename == "":
return make_response(
jsonify({"success": False, "message": "Missing file"}),
400,
)
filename = safe_filename(os.path.basename(file.filename))
suffix = Path(filename).suffix.lower()
if suffix not in SUPPORTED_AUDIO_EXTENSIONS:
return make_response(
jsonify({"success": False, "message": "Unsupported audio format"}),
400,
)
if not _is_supported_audio_mimetype(file.mimetype or ""):
return make_response(
jsonify({"success": False, "message": "Unsupported audio MIME type"}),
400,
)
try:
_enforce_uploaded_audio_size_limit(file, filename)
except AudioFileTooLargeError:
return make_response(
jsonify(
{
"success": False,
"message": build_stt_file_size_limit_message(),
}
),
413,
)
temp_path = None
try:
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as temp_file:
file.save(temp_file.name)
temp_path = Path(temp_file.name)
session_language = session_state.get("language") or settings.STT_LANGUAGE
stt_instance = STTCreator.create_stt(settings.STT_PROVIDER)
transcript = stt_instance.transcribe(
temp_path,
language=session_language,
timestamps=False,
diarize=False,
)
if not session_state.get("language") and transcript.get("language"):
session_state["language"] = transcript["language"]
try:
apply_live_stt_hypothesis(
session_state,
str(transcript.get("text", "")),
chunk_index,
is_silence=is_silence,
)
except ValueError:
current_app.logger.warning(
"Invalid live transcription chunk",
exc_info=True,
)
return make_response(
jsonify(
{
"success": False,
"message": "Invalid live transcription chunk",
}
),
409,
)
save_live_stt_session(redis_client, session_state)
return make_response(
jsonify(
{
"success": True,
"session_id": session_id,
"chunk_index": chunk_index,
"chunk_text": transcript.get("text", ""),
"is_silence": is_silence,
"language": session_state.get("language"),
"committed_text": session_state.get("committed_text", ""),
"mutable_text": session_state.get("mutable_text", ""),
"previous_hypothesis": session_state.get(
"previous_hypothesis", ""
),
"latest_hypothesis": session_state.get(
"latest_hypothesis", ""
),
"finalized_text": session_state.get("committed_text", ""),
"pending_text": session_state.get("mutable_text", ""),
"transcript_text": get_live_stt_transcript_text(session_state),
}
),
200,
)
except Exception as err:
current_app.logger.error(
f"Error transcribing live audio chunk: {err}", exc_info=True
)
return make_response(
jsonify({"success": False, "message": "Failed to transcribe audio"}),
400,
)
finally:
if temp_path and temp_path.exists():
temp_path.unlink()
@attachments_ns.route("/stt/live/finish")
class LiveSpeechToTextFinish(Resource):
@api.doc(description="Finish a live speech-to-text session")
def post(self):
auth_user = _resolve_authenticated_user()
if hasattr(auth_user, "status_code"):
return auth_user
if not auth_user:
return make_response(
jsonify({"success": False, "message": "Authentication required"}),
401,
)
redis_client = _require_live_stt_redis()
if hasattr(redis_client, "status_code"):
return redis_client
payload = request.get_json(silent=True) or {}
session_id = str(payload.get("session_id", "")).strip()
if not session_id:
return make_response(
jsonify({"success": False, "message": "Missing session_id"}),
400,
)
session_state = load_live_stt_session(redis_client, session_id)
if not session_state:
return make_response(
jsonify(
{
"success": False,
"message": "Live transcription session not found",
}
),
404,
)
if safe_filename(str(session_state.get("user", ""))) != auth_user:
return make_response(
jsonify({"success": False, "message": "Forbidden"}),
403,
)
final_text = finalize_live_stt_session(session_state)
delete_live_stt_session(redis_client, session_id)
return make_response(
jsonify(
{
"success": True,
"session_id": session_id,
"language": session_state.get("language"),
"text": final_text,
}
),
200,
)
@attachments_ns.route("/images/<path:image_path>")
class ServeImage(Resource):
@api.doc(description="Serve an image from storage")
def get(self, image_path):
if ".." in image_path or image_path.startswith("/") or "\x00" in image_path:
return make_response(
jsonify({"success": False, "message": "Invalid image path"}), 400
)
try:
from application.api.user.base import storage
file_obj = storage.get_file(image_path)
extension = image_path.split(".")[-1].lower()
content_type = f"image/{extension}"
if extension == "jpg":
content_type = "image/jpeg"
response = make_response(file_obj.read())
response.headers.set("Content-Type", content_type)
response.headers.set("Cache-Control", "max-age=86400")
return response
except FileNotFoundError:
return make_response(
jsonify({"success": False, "message": "Image not found"}), 404
)
except ValueError:
return make_response(
jsonify({"success": False, "message": "Invalid image path"}), 400
)
except Exception as e:
current_app.logger.error(f"Error serving image: {e}")
return make_response(
jsonify({"success": False, "message": "Error retrieving image"}), 500
)
@attachments_ns.route("/tts")
class TextToSpeech(Resource):
tts_model = api.model(
"TextToSpeechModel",
{
"text": fields.String(
required=True, description="Text to be synthesized as audio"
),
},
)
@api.expect(tts_model)
@api.doc(description="Synthesize audio speech from text")
def post(self):
data = request.get_json()
text = data["text"]
try:
tts_instance = TTSCreator.create_tts(settings.TTS_PROVIDER)
audio_base64, detected_language = tts_instance.text_to_speech(text)
return make_response(
jsonify(
{
"success": True,
"audio_base64": audio_base64,
"lang": detected_language,
}
),
200,
)
except Exception as err:
current_app.logger.error(f"Error synthesizing audio: {err}", exc_info=True)
return make_response(jsonify({"success": False}), 400)

View File

@@ -1,236 +0,0 @@
"""
Shared utilities, database connections, and helper functions for user API routes.
"""
import datetime
import os
import uuid
from functools import wraps
from typing import Optional, Tuple
from flask import current_app, jsonify, make_response, Response
from werkzeug.utils import secure_filename
from sqlalchemy import text as _sql_text
from application.core.settings import settings
from application.storage.db.base_repository import looks_like_uuid, row_to_dict
from application.storage.db.repositories.users import UsersRepository
from application.storage.db.session import db_readonly, db_session
from application.storage.storage_creator import StorageCreator
from application.vectorstore.vector_creator import VectorCreator
storage = StorageCreator.get_storage()
current_dir = os.path.dirname(
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
)
def generate_minute_range(start_date, end_date):
"""Generate a dictionary with minute-level time ranges."""
return {
(start_date + datetime.timedelta(minutes=i)).strftime("%Y-%m-%d %H:%M:00"): 0
for i in range(int((end_date - start_date).total_seconds() // 60) + 1)
}
def generate_hourly_range(start_date, end_date):
"""Generate a dictionary with hourly time ranges."""
return {
(start_date + datetime.timedelta(hours=i)).strftime("%Y-%m-%d %H:00"): 0
for i in range(int((end_date - start_date).total_seconds() // 3600) + 1)
}
def generate_date_range(start_date, end_date):
"""Generate a dictionary with daily date ranges."""
return {
(start_date + datetime.timedelta(days=i)).strftime("%Y-%m-%d"): 0
for i in range((end_date - start_date).days + 1)
}
def ensure_user_doc(user_id):
"""
Ensure a Postgres ``users`` row exists for ``user_id``.
Returns the row as a dict with the shape legacy callers expect — in
particular ``user_id`` and ``agent_preferences`` (with ``pinned`` and
``shared_with_me`` list keys always present).
Args:
user_id: The user ID to ensure
Returns:
The user document as a dict.
"""
with db_session() as conn:
user_doc = UsersRepository(conn).upsert(user_id)
prefs = user_doc.get("agent_preferences") or {}
if not isinstance(prefs, dict):
prefs = {}
prefs.setdefault("pinned", [])
prefs.setdefault("shared_with_me", [])
user_doc["agent_preferences"] = prefs
return user_doc
def resolve_tool_details(tool_ids):
"""
Resolve tool IDs to their display details.
Accepts either Postgres UUIDs or legacy Mongo ObjectId strings (mixed
lists are supported — each id is looked up via ``get_any``, which
resolves to whichever column matches). Unknown ids are silently
skipped.
Args:
tool_ids: List of tool IDs (UUIDs or legacy Mongo ObjectId strings).
Returns:
List of tool details with ``id``, ``name``, and ``display_name``.
"""
if not tool_ids:
return []
uuid_ids: list[str] = []
legacy_ids: list[str] = []
for tid in tool_ids:
if not tid:
continue
tid_str = str(tid)
if looks_like_uuid(tid_str):
uuid_ids.append(tid_str)
else:
legacy_ids.append(tid_str)
if not uuid_ids and not legacy_ids:
return []
rows: list[dict] = []
with db_readonly() as conn:
if uuid_ids:
result = conn.execute(
_sql_text(
"SELECT * FROM user_tools "
"WHERE id = ANY(CAST(:ids AS uuid[]))"
),
{"ids": uuid_ids},
)
rows.extend(row_to_dict(r) for r in result.fetchall())
if legacy_ids:
result = conn.execute(
_sql_text(
"SELECT * FROM user_tools "
"WHERE legacy_mongo_id = ANY(:ids)"
),
{"ids": legacy_ids},
)
rows.extend(row_to_dict(r) for r in result.fetchall())
return [
{
"id": str(tool.get("id") or tool.get("legacy_mongo_id") or ""),
"name": tool.get("name", "") or "",
"display_name": (
tool.get("custom_name")
or tool.get("display_name")
or tool.get("name", "")
or ""
),
}
for tool in rows
]
def get_vector_store(source_id):
"""
Get the Vector Store for a given source ID.
Args:
source_id (str): source id of the document
Returns:
Vector store instance
"""
store = VectorCreator.create_vectorstore(
settings.VECTOR_STORE,
source_id=source_id,
embeddings_key=os.getenv("EMBEDDINGS_KEY"),
)
return store
def handle_image_upload(
request, existing_url: str, user: str, storage, base_path: str = "attachments/"
) -> Tuple[str, Optional[Response]]:
"""
Handle image file upload from request.
Args:
request: Flask request object
existing_url: Existing image URL (fallback)
user: User ID
storage: Storage instance
base_path: Base path for upload
Returns:
Tuple of (image_url, error_response)
"""
image_url = existing_url
if "image" in request.files:
file = request.files["image"]
if file.filename != "":
filename = secure_filename(file.filename)
upload_path = f"{settings.UPLOAD_FOLDER.rstrip('/')}/{user}/{base_path.rstrip('/')}/{uuid.uuid4()}_{filename}"
try:
storage.save_file(file, upload_path, storage_class="STANDARD")
image_url = upload_path
except Exception as e:
current_app.logger.error(f"Error uploading image: {e}")
return None, make_response(
jsonify({"success": False, "message": "Image upload failed"}),
400,
)
return image_url, None
def require_agent(func):
"""
Decorator to require valid agent webhook token.
Args:
func: Function to decorate
Returns:
Wrapped function
"""
@wraps(func)
def wrapper(*args, **kwargs):
from application.storage.db.repositories.agents import AgentsRepository
webhook_token = kwargs.get("webhook_token")
if not webhook_token:
return make_response(
jsonify({"success": False, "message": "Webhook token missing"}), 400
)
with db_readonly() as conn:
agent = AgentsRepository(conn).find_by_webhook_token(webhook_token)
if not agent:
current_app.logger.warning(
f"Webhook attempt with invalid token: {webhook_token}"
)
return make_response(
jsonify({"success": False, "message": "Agent not found"}), 404
)
kwargs["agent"] = agent
kwargs["agent_id_str"] = str(agent["id"])
return func(*args, **kwargs)
return wrapper

View File

@@ -1,5 +0,0 @@
"""Conversation management module."""
from .routes import conversations_ns
__all__ = ["conversations_ns"]

View File

@@ -1,303 +0,0 @@
"""Conversation management routes."""
import datetime
from flask import current_app, jsonify, make_response, request
from flask_restx import fields, Namespace, Resource
from application.api import api
from application.storage.db.repositories.attachments import AttachmentsRepository
from application.storage.db.repositories.conversations import ConversationsRepository
from application.storage.db.session import db_readonly, db_session
from application.utils import check_required_fields
conversations_ns = Namespace(
"conversations", description="Conversation management operations", path="/api"
)
@conversations_ns.route("/delete_conversation")
class DeleteConversation(Resource):
@api.doc(
description="Deletes a conversation by ID",
params={"id": "The ID of the conversation to delete"},
)
def post(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
conversation_id = request.args.get("id")
if not conversation_id:
return make_response(
jsonify({"success": False, "message": "ID is required"}), 400
)
user_id = decoded_token["sub"]
try:
with db_session() as conn:
repo = ConversationsRepository(conn)
conv = repo.get_any(conversation_id, user_id)
if conv is not None:
repo.delete(str(conv["id"]), user_id)
except Exception as err:
current_app.logger.error(
f"Error deleting conversation: {err}", exc_info=True
)
return make_response(jsonify({"success": False}), 400)
return make_response(jsonify({"success": True}), 200)
@conversations_ns.route("/delete_all_conversations")
class DeleteAllConversations(Resource):
@api.doc(
description="Deletes all conversations for a specific user",
)
def get(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user_id = decoded_token.get("sub")
try:
with db_session() as conn:
ConversationsRepository(conn).delete_all_for_user(user_id)
except Exception as err:
current_app.logger.error(
f"Error deleting all conversations: {err}", exc_info=True
)
return make_response(jsonify({"success": False}), 400)
return make_response(jsonify({"success": True}), 200)
@conversations_ns.route("/get_conversations")
class GetConversations(Resource):
@api.doc(
description="Retrieve a list of the latest 30 conversations (excluding API key conversations)",
)
def get(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user_id = decoded_token.get("sub")
try:
with db_readonly() as conn:
conversations = ConversationsRepository(conn).list_for_user(
user_id, limit=30
)
list_conversations = [
{
"id": str(conversation["id"]),
"name": conversation["name"],
"agent_id": (
str(conversation["agent_id"])
if conversation.get("agent_id")
else None
),
"is_shared_usage": conversation.get("is_shared_usage", False),
"shared_token": conversation.get("shared_token", None),
}
for conversation in conversations
]
except Exception as err:
current_app.logger.error(
f"Error retrieving conversations: {err}", exc_info=True
)
return make_response(jsonify({"success": False}), 400)
return make_response(jsonify(list_conversations), 200)
@conversations_ns.route("/get_single_conversation")
class GetSingleConversation(Resource):
@api.doc(
description="Retrieve a single conversation by ID",
params={"id": "The conversation ID"},
)
def get(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
conversation_id = request.args.get("id")
if not conversation_id:
return make_response(
jsonify({"success": False, "message": "ID is required"}), 400
)
user_id = decoded_token.get("sub")
try:
with db_readonly() as conn:
repo = ConversationsRepository(conn)
conversation = repo.get_any(conversation_id, user_id)
if not conversation:
return make_response(jsonify({"status": "not found"}), 404)
conv_pg_id = str(conversation["id"])
messages = repo.get_messages(conv_pg_id)
# Resolve attachment details (id, fileName) for each message.
attachments_repo = AttachmentsRepository(conn)
queries = []
for msg in messages:
query = {
"prompt": msg.get("prompt"),
"response": msg.get("response"),
"thought": msg.get("thought"),
"sources": msg.get("sources") or [],
"tool_calls": msg.get("tool_calls") or [],
"timestamp": msg.get("timestamp"),
"model_id": msg.get("model_id"),
}
if msg.get("metadata"):
query["metadata"] = msg["metadata"]
# Feedback on conversation_messages is a JSONB blob with
# shape {"text": <str>, "timestamp": <iso>}. The legacy
# frontend consumed a flat scalar feedback string, so
# unwrap the ``text`` field for compat.
feedback = msg.get("feedback")
if feedback is not None:
if isinstance(feedback, dict):
query["feedback"] = feedback.get("text")
if feedback.get("timestamp"):
query["feedback_timestamp"] = feedback["timestamp"]
else:
query["feedback"] = feedback
attachments = msg.get("attachments") or []
if attachments:
attachment_details = []
for attachment_id in attachments:
try:
att = attachments_repo.get_any(
str(attachment_id), user_id
)
if att:
attachment_details.append(
{
"id": str(att["id"]),
"fileName": att.get(
"filename", "Unknown file"
),
}
)
except Exception as e:
current_app.logger.error(
f"Error retrieving attachment {attachment_id}: {e}",
exc_info=True,
)
query["attachments"] = attachment_details
queries.append(query)
except Exception as err:
current_app.logger.error(
f"Error retrieving conversation: {err}", exc_info=True
)
return make_response(jsonify({"success": False}), 400)
data = {
"queries": queries,
"agent_id": (
str(conversation["agent_id"]) if conversation.get("agent_id") else None
),
"is_shared_usage": conversation.get("is_shared_usage", False),
"shared_token": conversation.get("shared_token", None),
}
return make_response(jsonify(data), 200)
@conversations_ns.route("/update_conversation_name")
class UpdateConversationName(Resource):
@api.expect(
api.model(
"UpdateConversationModel",
{
"id": fields.String(required=True, description="Conversation ID"),
"name": fields.String(
required=True, description="New name of the conversation"
),
},
)
)
@api.doc(
description="Updates the name of a conversation",
)
def post(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
data = request.get_json()
required_fields = ["id", "name"]
missing_fields = check_required_fields(data, required_fields)
if missing_fields:
return missing_fields
user_id = decoded_token.get("sub")
try:
with db_session() as conn:
repo = ConversationsRepository(conn)
conv = repo.get_any(data["id"], user_id)
if conv is not None:
repo.rename(str(conv["id"]), user_id, data["name"])
except Exception as err:
current_app.logger.error(
f"Error updating conversation name: {err}", exc_info=True
)
return make_response(jsonify({"success": False}), 400)
return make_response(jsonify({"success": True}), 200)
@conversations_ns.route("/feedback")
class SubmitFeedback(Resource):
@api.expect(
api.model(
"FeedbackModel",
{
"question": fields.String(
required=False, description="The user question"
),
"answer": fields.String(required=False, description="The AI answer"),
"feedback": fields.String(required=True, description="User feedback"),
"question_index": fields.Integer(
required=True,
description="The question number in that particular conversation",
),
"conversation_id": fields.String(
required=True, description="id of the particular conversation"
),
"api_key": fields.String(description="Optional API key"),
},
)
)
@api.doc(
description="Submit feedback for a conversation",
)
def post(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
data = request.get_json()
required_fields = ["feedback", "conversation_id", "question_index"]
missing_fields = check_required_fields(data, required_fields)
if missing_fields:
return missing_fields
user_id = decoded_token.get("sub")
feedback_value = data["feedback"]
question_index = int(data["question_index"])
# Normalize string feedback to lowercase so analytics queries
# (which match 'like'/'dislike') count rows correctly. Tolerate
# legacy uppercase clients on ingest. Non-string values pass through.
if isinstance(feedback_value, str):
feedback_value = feedback_value.lower()
feedback_payload = (
None
if feedback_value is None
else {
"text": feedback_value,
"timestamp": datetime.datetime.now(
datetime.timezone.utc
).isoformat(),
}
)
try:
with db_session() as conn:
repo = ConversationsRepository(conn)
conv = repo.get_any(data["conversation_id"], user_id)
if conv is None:
return make_response(
jsonify({"success": False, "message": "Not found"}), 404
)
repo.set_feedback(str(conv["id"]), question_index, feedback_payload)
except Exception as err:
current_app.logger.error(f"Error submitting feedback: {err}", exc_info=True)
return make_response(jsonify({"success": False}), 400)
return make_response(jsonify({"success": True}), 200)

View File

@@ -1,3 +0,0 @@
from .routes import models_ns
__all__ = ["models_ns"]

View File

@@ -1,521 +0,0 @@
"""Model routes.
- ``GET /api/models`` — list available models for the current user.
Combines the built-in catalog with the user's BYOM records.
- ``GET/POST/PATCH/DELETE /api/user/models[/<id>]`` — CRUD for the
user's own OpenAI-compatible model registrations (BYOM).
- ``POST /api/user/models/<id>/test`` — sanity-check the upstream
endpoint with a tiny request.
Every BYOM endpoint is user-scoped at the repository layer
(every query filters on ``user_id`` from ``request.decoded_token``).
"""
from __future__ import annotations
import logging
import requests
from flask import current_app, jsonify, make_response, request
from flask_restx import Namespace, Resource
from application.api import api
from application.core.model_registry import ModelRegistry
from application.security.safe_url import (
UnsafeUserUrlError,
pinned_post,
validate_user_base_url,
)
from application.storage.db.repositories.user_custom_models import (
UserCustomModelsRepository,
)
from application.storage.db.session import db_readonly, db_session
from application.utils import check_required_fields
logger = logging.getLogger(__name__)
models_ns = Namespace("models", description="Available models", path="/api")
_CONTEXT_WINDOW_MIN = 1_000
_CONTEXT_WINDOW_MAX = 10_000_000
def _user_id_or_401():
decoded_token = request.decoded_token
if not decoded_token:
return None, make_response(jsonify({"success": False}), 401)
user_id = decoded_token.get("sub")
if not user_id:
return None, make_response(jsonify({"success": False}), 401)
return user_id, None
def _normalize_capabilities(raw) -> dict:
"""Coerce + bound the user-supplied capabilities payload."""
raw = raw or {}
out = {}
if "supports_tools" in raw:
out["supports_tools"] = bool(raw["supports_tools"])
if "supports_structured_output" in raw:
out["supports_structured_output"] = bool(raw["supports_structured_output"])
if "supports_streaming" in raw:
out["supports_streaming"] = bool(raw["supports_streaming"])
if "attachments" in raw:
atts = raw["attachments"] or []
if not isinstance(atts, list):
raise ValueError("'capabilities.attachments' must be a list")
coerced = [str(a) for a in atts]
# Reject unknown aliases at the API boundary so bad payloads
# never reach the registry layer (where lenient expansion just
# drops them). Raw MIME types (containing ``/``) pass through
# unchanged for parity with the built-in YAML schema.
from application.core.model_yaml import builtin_attachment_aliases
aliases = builtin_attachment_aliases()
for entry in coerced:
if "/" in entry:
continue
if entry not in aliases:
valid = ", ".join(sorted(aliases.keys())) or "<none defined>"
raise ValueError(
f"unknown attachment alias '{entry}' in "
f"'capabilities.attachments'. Valid aliases: {valid}, "
f"or use a raw MIME type like 'image/png'."
)
out["attachments"] = coerced
if "context_window" in raw:
try:
cw = int(raw["context_window"])
except (TypeError, ValueError):
raise ValueError("'capabilities.context_window' must be an integer")
if not (_CONTEXT_WINDOW_MIN <= cw <= _CONTEXT_WINDOW_MAX):
raise ValueError(
f"'capabilities.context_window' must be between "
f"{_CONTEXT_WINDOW_MIN} and {_CONTEXT_WINDOW_MAX}"
)
out["context_window"] = cw
return out
def _row_to_response(row: dict) -> dict:
"""Wire-format projection — never includes the API key."""
return {
"id": str(row["id"]),
"upstream_model_id": row["upstream_model_id"],
"display_name": row["display_name"],
"description": row.get("description") or "",
"base_url": row["base_url"],
"capabilities": row.get("capabilities") or {},
"enabled": bool(row.get("enabled", True)),
"source": "user",
}
@models_ns.route("/models")
class ModelsListResource(Resource):
def get(self):
"""Get list of available models with their capabilities.
When the request is authenticated, the response includes the
user's own BYOM registrations alongside the built-in catalog.
"""
try:
user_id = None
decoded_token = getattr(request, "decoded_token", None)
if decoded_token:
user_id = decoded_token.get("sub")
registry = ModelRegistry.get_instance()
models = registry.get_enabled_models(user_id=user_id)
response = {
"models": [model.to_dict() for model in models],
"default_model_id": registry.default_model_id,
"count": len(models),
}
except Exception as err:
current_app.logger.error(f"Error fetching models: {err}", exc_info=True)
return make_response(jsonify({"success": False}), 500)
return make_response(jsonify(response), 200)
@models_ns.route("/user/models")
class UserModelsCollectionResource(Resource):
@api.doc(description="List the current user's BYOM custom models")
def get(self):
user_id, err = _user_id_or_401()
if err:
return err
try:
with db_readonly() as conn:
rows = UserCustomModelsRepository(conn).list_for_user(user_id)
return make_response(
jsonify({"models": [_row_to_response(r) for r in rows]}), 200
)
except Exception as e:
current_app.logger.error(
f"Error listing user custom models: {e}", exc_info=True
)
return make_response(jsonify({"success": False}), 500)
@api.doc(description="Register a new BYOM custom model")
def post(self):
user_id, err = _user_id_or_401()
if err:
return err
data = request.get_json() or {}
missing = check_required_fields(
data,
["upstream_model_id", "display_name", "base_url", "api_key"],
)
if missing:
return missing
# SECURITY: reject blank api_key — would leak instance API key
# to the user-supplied base_url via LLMCreator fallback.
for required_nonblank in (
"upstream_model_id",
"display_name",
"base_url",
"api_key",
):
value = data.get(required_nonblank)
if not isinstance(value, str) or not value.strip():
return make_response(
jsonify(
{
"success": False,
"error": f"'{required_nonblank}' must be a non-empty string",
}
),
400,
)
# SSRF guard at create time. Re-runs at dispatch time (LLMCreator)
# as defense in depth against DNS rebinding and pre-guard rows.
try:
validate_user_base_url(data["base_url"])
except UnsafeUserUrlError as e:
return make_response(
jsonify({"success": False, "error": str(e)}), 400
)
try:
capabilities = _normalize_capabilities(data.get("capabilities"))
except ValueError as e:
return make_response(
jsonify({"success": False, "error": str(e)}), 400
)
try:
with db_session() as conn:
row = UserCustomModelsRepository(conn).create(
user_id=user_id,
upstream_model_id=data["upstream_model_id"],
display_name=data["display_name"],
description=data.get("description") or "",
base_url=data["base_url"],
api_key_plaintext=data["api_key"],
capabilities=capabilities,
enabled=bool(data.get("enabled", True)),
)
except Exception as e:
current_app.logger.error(
f"Error creating user custom model: {e}", exc_info=True
)
return make_response(jsonify({"success": False}), 500)
ModelRegistry.invalidate_user(user_id)
return make_response(jsonify(_row_to_response(row)), 201)
@models_ns.route("/user/models/<string:model_id>")
class UserModelResource(Resource):
@api.doc(description="Get one BYOM custom model")
def get(self, model_id):
user_id, err = _user_id_or_401()
if err:
return err
try:
with db_readonly() as conn:
row = UserCustomModelsRepository(conn).get(model_id, user_id)
except Exception as e:
current_app.logger.error(
f"Error fetching user custom model: {e}", exc_info=True
)
return make_response(jsonify({"success": False}), 500)
if row is None:
return make_response(jsonify({"success": False}), 404)
return make_response(jsonify(_row_to_response(row)), 200)
@api.doc(description="Update a BYOM custom model (partial)")
def patch(self, model_id):
user_id, err = _user_id_or_401()
if err:
return err
data = request.get_json() or {}
# Reject present-but-blank values for fields where blank doesn't
# mean "no change". (The api_key special case — blank means "keep
# existing" — is handled below.)
for required_nonblank in (
"upstream_model_id",
"display_name",
"base_url",
):
if required_nonblank in data:
value = data[required_nonblank]
if not isinstance(value, str) or not value.strip():
return make_response(
jsonify(
{
"success": False,
"error": f"'{required_nonblank}' cannot be blank",
}
),
400,
)
if "base_url" in data and data["base_url"]:
try:
validate_user_base_url(data["base_url"])
except UnsafeUserUrlError as e:
return make_response(
jsonify({"success": False, "error": str(e)}), 400
)
update_fields: dict = {}
for k in (
"upstream_model_id",
"display_name",
"description",
"base_url",
"enabled",
):
if k in data:
update_fields[k] = data[k]
if "capabilities" in data:
try:
update_fields["capabilities"] = _normalize_capabilities(
data["capabilities"]
)
except ValueError as e:
return make_response(
jsonify({"success": False, "error": str(e)}), 400
)
# PATCH semantics: blank/missing api_key → keep the existing
# ciphertext; non-empty api_key → re-encrypt and replace.
if data.get("api_key"):
update_fields["api_key_plaintext"] = data["api_key"]
if not update_fields:
return make_response(
jsonify({"success": False, "error": "no updatable fields"}), 400
)
try:
with db_session() as conn:
ok = UserCustomModelsRepository(conn).update(
model_id, user_id, update_fields
)
except Exception as e:
current_app.logger.error(
f"Error updating user custom model: {e}", exc_info=True
)
return make_response(jsonify({"success": False}), 500)
if not ok:
return make_response(jsonify({"success": False}), 404)
ModelRegistry.invalidate_user(user_id)
with db_readonly() as conn:
row = UserCustomModelsRepository(conn).get(model_id, user_id)
return make_response(jsonify(_row_to_response(row)), 200)
@api.doc(description="Delete a BYOM custom model")
def delete(self, model_id):
user_id, err = _user_id_or_401()
if err:
return err
try:
with db_session() as conn:
ok = UserCustomModelsRepository(conn).delete(model_id, user_id)
except Exception as e:
current_app.logger.error(
f"Error deleting user custom model: {e}", exc_info=True
)
return make_response(jsonify({"success": False}), 500)
if not ok:
return make_response(jsonify({"success": False}), 404)
ModelRegistry.invalidate_user(user_id)
return make_response(jsonify({"success": True}), 200)
def _run_connection_test(
base_url: str, api_key: str, upstream_model_id: str
):
"""Send a 1-token chat-completion to verify a BYOM endpoint.
Returns ``(body, http_status)``. Upstream errors return 200 with
``ok=False`` so the UI can render inline errors; only local SSRF
rejection returns 400.
"""
url = base_url.rstrip("/") + "/chat/completions"
payload = {
"model": upstream_model_id,
"messages": [{"role": "user", "content": "hi"}],
"max_tokens": 1,
"stream": False,
}
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
}
try:
# pinned_post closes the DNS-rebinding window. Redirects off
# because 3xx could bounce to an internal address (the SSRF
# guard only validates the supplied URL).
resp = pinned_post(
url,
json=payload,
headers=headers,
timeout=5,
allow_redirects=False,
)
except UnsafeUserUrlError as e:
return {"ok": False, "error": str(e)}, 400
except requests.RequestException as e:
return {"ok": False, "error": f"connection error: {e}"}, 200
if 300 <= resp.status_code < 400:
return (
{
"ok": False,
"error": (
f"upstream returned HTTP {resp.status_code} "
"redirect; refusing to follow"
),
},
200,
)
if resp.status_code >= 400:
# Cap and only reflect JSON to avoid body-exfil via non-API responses.
content_type = (resp.headers.get("Content-Type") or "").lower()
if "application/json" in content_type:
text = (resp.text or "")[:500]
error_msg = f"upstream returned HTTP {resp.status_code}: {text}"
else:
error_msg = f"upstream returned HTTP {resp.status_code}"
return {"ok": False, "error": error_msg}, 200
return {"ok": True}, 200
@models_ns.route("/user/models/test")
class UserModelTestPayloadResource(Resource):
@api.doc(
description=(
"Test an arbitrary BYOM payload (display_name / model id / "
"base_url / api_key) without saving. Used by the UI's 'Test "
"connection' button so the user can validate before they "
"Save. Same SSRF guard, same 1-token request, same 5s "
"timeout as the by-id variant."
)
)
def post(self):
user_id, err = _user_id_or_401()
if err:
return err
data = request.get_json() or {}
missing = check_required_fields(
data, ["base_url", "api_key", "upstream_model_id"]
)
if missing:
return missing
body, status = _run_connection_test(
data["base_url"], data["api_key"], data["upstream_model_id"]
)
return make_response(jsonify(body), status)
@models_ns.route("/user/models/<string:model_id>/test")
class UserModelTestResource(Resource):
@api.doc(
description=(
"Test a saved BYOM record. Defaults to the stored "
"base_url / upstream_model_id / encrypted api_key, but "
"any of those can be overridden via the request body so "
"the UI can test in-flight edits before saving. Used by "
"the 'Test connection' button in edit mode."
)
)
def post(self, model_id):
user_id, err = _user_id_or_401()
if err:
return err
data = request.get_json() or {}
# Per-field overrides; blank/missing falls back to stored value.
override_base_url = (data.get("base_url") or "").strip() or None
override_upstream_model_id = (
data.get("upstream_model_id") or ""
).strip() or None
override_api_key = (data.get("api_key") or "").strip() or None
try:
with db_readonly() as conn:
repo = UserCustomModelsRepository(conn)
row = repo.get(model_id, user_id)
if row is None:
return make_response(jsonify({"success": False}), 404)
stored_api_key = (
repo._decrypt_api_key(
row.get("api_key_encrypted", ""), user_id
)
if not override_api_key
else None
)
except Exception as e:
current_app.logger.error(
f"Error loading user custom model for test: {e}", exc_info=True
)
return make_response(
jsonify({"ok": False, "error": "internal error loading model"}),
500,
)
api_key = override_api_key or stored_api_key
if not api_key:
return make_response(
jsonify(
{
"ok": False,
"error": (
"Stored API key could not be decrypted. The "
"encryption secret may have rotated. Re-save "
"the model with the API key to recover."
),
}
),
400,
)
base_url = override_base_url or row["base_url"]
upstream_model_id = (
override_upstream_model_id or row["upstream_model_id"]
)
body, status = _run_connection_test(
base_url, api_key, upstream_model_id
)
return make_response(jsonify(body), status)

View File

@@ -1,5 +0,0 @@
"""Prompts module."""
from .routes import prompts_ns
__all__ = ["prompts_ns"]

View File

@@ -1,202 +0,0 @@
"""Prompt management routes."""
import os
from flask import current_app, jsonify, make_response, request
from flask_restx import fields, Namespace, Resource
from application.api import api
from application.api.user.base import current_dir
from application.storage.db.repositories.prompts import PromptsRepository
from application.storage.db.session import db_readonly, db_session
from application.utils import check_required_fields
prompts_ns = Namespace(
"prompts", description="Prompt management operations", path="/api"
)
@prompts_ns.route("/create_prompt")
class CreatePrompt(Resource):
create_prompt_model = api.model(
"CreatePromptModel",
{
"content": fields.String(
required=True, description="Content of the prompt"
),
"name": fields.String(required=True, description="Name of the prompt"),
},
)
@api.expect(create_prompt_model)
@api.doc(description="Create a new prompt")
def post(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
data = request.get_json()
required_fields = ["content", "name"]
missing_fields = check_required_fields(data, required_fields)
if missing_fields:
return missing_fields
user = decoded_token.get("sub")
try:
with db_session() as conn:
prompt = PromptsRepository(conn).create(user, data["name"], data["content"])
new_id = str(prompt["id"])
except Exception as err:
current_app.logger.error(f"Error creating prompt: {err}", exc_info=True)
return make_response(jsonify({"success": False}), 400)
return make_response(jsonify({"id": new_id}), 200)
@prompts_ns.route("/get_prompts")
class GetPrompts(Resource):
@api.doc(description="Get all prompts for the user")
def get(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
try:
with db_readonly() as conn:
prompts = PromptsRepository(conn).list_for_user(user)
list_prompts = [
{"id": "default", "name": "default", "type": "public"},
{"id": "creative", "name": "creative", "type": "public"},
{"id": "strict", "name": "strict", "type": "public"},
]
for prompt in prompts:
list_prompts.append(
{
"id": str(prompt["id"]),
"name": prompt["name"],
"type": "private",
}
)
except Exception as err:
current_app.logger.error(f"Error retrieving prompts: {err}", exc_info=True)
return make_response(jsonify({"success": False}), 400)
return make_response(jsonify(list_prompts), 200)
@prompts_ns.route("/get_single_prompt")
class GetSinglePrompt(Resource):
@api.doc(params={"id": "ID of the prompt"}, description="Get a single prompt by ID")
def get(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
prompt_id = request.args.get("id")
if not prompt_id:
return make_response(
jsonify({"success": False, "message": "ID is required"}), 400
)
try:
if prompt_id == "default":
with open(
os.path.join(current_dir, "prompts", "chat_combine_default.txt"),
"r",
) as f:
chat_combine_template = f.read()
return make_response(jsonify({"content": chat_combine_template}), 200)
elif prompt_id == "creative":
with open(
os.path.join(current_dir, "prompts", "chat_combine_creative.txt"),
"r",
) as f:
chat_reduce_creative = f.read()
return make_response(jsonify({"content": chat_reduce_creative}), 200)
elif prompt_id == "strict":
with open(
os.path.join(current_dir, "prompts", "chat_combine_strict.txt"), "r"
) as f:
chat_reduce_strict = f.read()
return make_response(jsonify({"content": chat_reduce_strict}), 200)
with db_readonly() as conn:
prompt = PromptsRepository(conn).get_any(prompt_id, user)
if not prompt:
return make_response(
jsonify({"success": False, "message": "Prompt not found"}), 404
)
except Exception as err:
current_app.logger.error(f"Error retrieving prompt: {err}", exc_info=True)
return make_response(jsonify({"success": False}), 400)
return make_response(jsonify({"content": prompt["content"]}), 200)
@prompts_ns.route("/delete_prompt")
class DeletePrompt(Resource):
delete_prompt_model = api.model(
"DeletePromptModel",
{"id": fields.String(required=True, description="Prompt ID to delete")},
)
@api.expect(delete_prompt_model)
@api.doc(description="Delete a prompt by ID")
def post(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
data = request.get_json()
required_fields = ["id"]
missing_fields = check_required_fields(data, required_fields)
if missing_fields:
return missing_fields
try:
with db_session() as conn:
repo = PromptsRepository(conn)
prompt = repo.get_any(data["id"], user)
if not prompt:
return make_response(
jsonify({"success": False, "message": "Prompt not found"}),
404,
)
repo.delete(str(prompt["id"]), user)
except Exception as err:
current_app.logger.error(f"Error deleting prompt: {err}", exc_info=True)
return make_response(jsonify({"success": False}), 400)
return make_response(jsonify({"success": True}), 200)
@prompts_ns.route("/update_prompt")
class UpdatePrompt(Resource):
update_prompt_model = api.model(
"UpdatePromptModel",
{
"id": fields.String(required=True, description="Prompt ID to update"),
"name": fields.String(required=True, description="New name of the prompt"),
"content": fields.String(
required=True, description="New content of the prompt"
),
},
)
@api.expect(update_prompt_model)
@api.doc(description="Update an existing prompt")
def post(self):
decoded_token = request.decoded_token
if not decoded_token:
return make_response(jsonify({"success": False}), 401)
user = decoded_token.get("sub")
data = request.get_json()
required_fields = ["id", "name", "content"]
missing_fields = check_required_fields(data, required_fields)
if missing_fields:
return missing_fields
try:
with db_session() as conn:
repo = PromptsRepository(conn)
prompt = repo.get_any(data["id"], user)
if not prompt:
return make_response(
jsonify({"success": False, "message": "Prompt not found"}),
404,
)
repo.update(str(prompt["id"]), user, data["name"], data["content"])
except Exception as err:
current_app.logger.error(f"Error updating prompt: {err}", exc_info=True)
return make_response(jsonify({"success": False}), 400)
return make_response(jsonify({"success": True}), 200)

Some files were not shown because too many files have changed in this diff Show More