mirror of
https://github.com/arc53/DocsGPT.git
synced 2026-05-07 06:30:03 +00:00
Compare commits
1 Commits
feat-model
...
chore/bump
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ec800eaf80 |
@@ -3,14 +3,6 @@ LLM_NAME=docsgpt
|
||||
VITE_API_STREAMING=true
|
||||
INTERNAL_KEY=<internal key for worker-to-backend authentication>
|
||||
|
||||
# Provider-specific API keys (optional - use these to enable multiple providers)
|
||||
# OPENAI_API_KEY=<your-openai-api-key>
|
||||
# ANTHROPIC_API_KEY=<your-anthropic-api-key>
|
||||
# GOOGLE_API_KEY=<your-google-api-key>
|
||||
# GROQ_API_KEY=<your-groq-api-key>
|
||||
# NOVITA_API_KEY=<your-novita-api-key>
|
||||
# OPEN_ROUTER_API_KEY=<your-openrouter-api-key>
|
||||
|
||||
# Remote Embeddings (Optional - for using a remote embeddings API instead of local SentenceTransformer)
|
||||
# When set, the app will use the remote API and won't load SentenceTransformer (saves RAM)
|
||||
EMBEDDINGS_BASE_URL=
|
||||
@@ -34,6 +26,3 @@ MICROSOFT_TENANT_ID=your-azure-ad-tenant-id
|
||||
#or "https://login.microsoftonline.com/contoso.onmicrosoft.com".
|
||||
#Alternatively, use "https://login.microsoftonline.com/common" for multi-tenant app.
|
||||
MICROSOFT_AUTHORITY=https://{tenantId}.ciamlogin.com/{tenantId}
|
||||
|
||||
|
||||
# POSTGRES_URI=postgresql://docsgpt:docsgpt@localhost:5432/docsgpt
|
||||
|
||||
99
.github/INCIDENT_RESPONSE.md
vendored
99
.github/INCIDENT_RESPONSE.md
vendored
@@ -1,99 +0,0 @@
|
||||
# DocsGPT Incident Response Plan (IRP)
|
||||
|
||||
This playbook describes how maintainers respond to confirmed or suspected security incidents.
|
||||
|
||||
- Vulnerability reporting: [`SECURITY.md`](../SECURITY.md)
|
||||
- Non-security bugs/features: [`CONTRIBUTING.md`](../CONTRIBUTING.md)
|
||||
|
||||
## Severity
|
||||
|
||||
| Severity | Definition | Typical examples |
|
||||
|---|---|---|
|
||||
| **Critical** | Active exploitation, supply-chain compromise, or confirmed data breach requiring immediate user action. | Compromised release artifact/image; remote execution. |
|
||||
| **High** | Serious undisclosed vulnerability with no practical workaround, or CVSS >= 7.0. | key leakage; prompt injection enabling cross-tenant access. |
|
||||
| **Medium** | Material impact but constrained by preconditions/scope, or a practical workaround exists. | Auth-required exploit; dependency CVE with limited reachability. |
|
||||
| **Low** | Defense-in-depth or narrow availability impact with no confirmed data exposure. | Missing rate limiting; hardening gap without exploit evidence. |
|
||||
|
||||
|
||||
## Response workflow
|
||||
|
||||
### 1) Triage (target: initial response within 48 hours)
|
||||
|
||||
1. Acknowledge report.
|
||||
2. Validate on latest release and `main`.
|
||||
3. Confirm in-scope security issue vs. hardening item (per `SECURITY.md`).
|
||||
4. Assign severity and open a **draft GitHub Security Advisory (GHSA)** (no public issue).
|
||||
5. Determine whether root cause is DocsGPT code or upstream dependency/provider.
|
||||
|
||||
### 2) Investigation
|
||||
|
||||
1. Identify affected components, versions, and deployment scope (self-hosted, cloud, or both).
|
||||
2. For AI issues, explicitly evaluate prompt injection, document isolation, and output leakage.
|
||||
3. Request a CVE through GHSA for **Medium+** issues.
|
||||
|
||||
### 3) Containment, fix, and disclosure
|
||||
|
||||
1. Implement and test fix in private security workflow (GHSA private fork/branch).
|
||||
2. Merge fix to `main`, cut patched release, and verify published artifacts/images.
|
||||
3. Patch managed cloud deployment (`app.docsgpt.cloud`) and other deployments as soon as validated.
|
||||
4. Publish GHSA with CVE (if assigned), affected/fixed versions, CVSS, mitigations, and upgrade guidance.
|
||||
5. **Critical/High:** coordinate disclosure timing with reporter (goal: <= 90 days) and publish a notice.
|
||||
6. **Medium/Low:** include in next scheduled release unless risk requires immediate out-of-band patching.
|
||||
|
||||
### 4) Post-incident
|
||||
|
||||
1. Monitor support channels (GitHub/Discord) for regressions or exploitation reports.
|
||||
2. Run a short retrospective (root cause, detection, response gaps, prevention work).
|
||||
3. Track follow-up hardening actions with owners/dates.
|
||||
4. Update this IRP and related runbooks as needed.
|
||||
|
||||
## Scenario playbooks
|
||||
|
||||
### Supply-chain compromise
|
||||
|
||||
1. Freeze releases and investigate blast radius.
|
||||
2. Rotate credentials in order: Docker Hub -> GitHub tokens -> LLM provider keys -> DB credentials -> `JWT_SECRET_KEY` -> `ENCRYPTION_SECRET_KEY` -> `INTERNAL_KEY`.
|
||||
3. Replace compromised artifacts/tags with clean releases and revoke/remove bad tags where possible.
|
||||
4. Publish advisory with exact affected versions and required user actions.
|
||||
|
||||
### Data exposure
|
||||
|
||||
1. Determine scope (users, documents, keys, logs, time window).
|
||||
2. Disable affected path or hotfix immediately for managed cloud.
|
||||
3. Notify affected users with concrete remediation steps (for example, rotate keys).
|
||||
4. Continue through standard fix/disclosure workflow.
|
||||
|
||||
### Critical regression with security impact
|
||||
|
||||
1. Identify introducing change (`git bisect` if needed).
|
||||
2. Publish workaround within 24 hours (for example, pin to known-good version).
|
||||
3. Ship patch release with regression test and close incident with public summary.
|
||||
|
||||
## AI-specific guidance
|
||||
|
||||
Treat confirmed AI-specific abuse as security incidents:
|
||||
|
||||
- Prompt injection causing sensitive data exfiltration (from tools that don't belong to the agent) -> **High**
|
||||
- Cross-tenant retrieval/isolation failure -> **High**
|
||||
- API key disclosure in output -> **High**
|
||||
|
||||
## Secret rotation quick reference
|
||||
|
||||
| Secret | Standard rotation action |
|
||||
|---|---|
|
||||
| Docker Hub credentials | Revoke/replace in Docker Hub; update CI/CD secrets |
|
||||
| GitHub tokens/PATs | Revoke/replace in GitHub; update automation secrets |
|
||||
| LLM provider API keys | Rotate in provider console; update runtime/deploy secrets |
|
||||
| Database credentials | Rotate in DB platform; redeploy with new secrets |
|
||||
| `JWT_SECRET_KEY` | Rotate and redeploy (invalidates all active user sessions/tokens) |
|
||||
| `ENCRYPTION_SECRET_KEY` | Rotate and redeploy (re-encrypt stored data if possible; existing encrypted data may become inaccessible) |
|
||||
| `INTERNAL_KEY` | Rotate and redeploy (invalidates worker-to-backend authentication) |
|
||||
|
||||
## Maintenance
|
||||
|
||||
Review this document:
|
||||
|
||||
- after every **Critical/High** incident, and
|
||||
- at least annually.
|
||||
|
||||
Changes should be proposed via pull request to `main`.
|
||||
144
.github/THREAT_MODEL.md
vendored
144
.github/THREAT_MODEL.md
vendored
@@ -1,144 +0,0 @@
|
||||
# DocsGPT Public Threat Model
|
||||
|
||||
**Classification:** Public
|
||||
**Last updated:** 2026-04-15
|
||||
**Applies to:** Open-source and self-hosted DocsGPT deployments
|
||||
|
||||
## 1) Overview
|
||||
|
||||
DocsGPT ingests content (files/URLs/connectors), indexes it, and answers queries via LLM-backed APIs and optional tools.
|
||||
|
||||
Core components:
|
||||
- Backend API (`application/`)
|
||||
- Workers/ingestion (`application/worker.py` and related modules)
|
||||
- Datastores (MongoDB/Redis/vector stores)
|
||||
- Frontend (`frontend/`)
|
||||
- Optional extensions/integrations (`extensions/`)
|
||||
|
||||
## 2) Scope and assumptions
|
||||
|
||||
In scope:
|
||||
- Application-level threats in this repository.
|
||||
- Local and internet-exposed self-hosted deployments.
|
||||
|
||||
Assumptions:
|
||||
- Internet-facing instances enable auth and use strong secrets.
|
||||
- Datastores/internal services are not publicly exposed.
|
||||
|
||||
Out of scope:
|
||||
- Cloud hardware/provider compromise.
|
||||
- Security guarantees of external LLM vendors.
|
||||
- Full security audits of third-party systems targeted by tools (external DBs/MCP servers/code-exec APIs).
|
||||
|
||||
## 3) Security objectives
|
||||
|
||||
- Protect document/conversation confidentiality.
|
||||
- Preserve integrity of prompts, agents, tools, and indexed data.
|
||||
- Maintain API/worker availability.
|
||||
- Enforce tenant isolation in authenticated deployments.
|
||||
|
||||
## 4) Assets
|
||||
|
||||
- Documents, attachments, chunks/embeddings, summaries.
|
||||
- Conversations, agents, workflows, prompt templates.
|
||||
- Secrets (JWT secret, `INTERNAL_KEY`, provider/API/OAuth credentials).
|
||||
- Operational capacity (worker throughput, queue depth, model quota/cost).
|
||||
|
||||
## 5) Trust boundaries and untrusted input
|
||||
|
||||
Trust boundaries:
|
||||
- Internet ↔ Frontend
|
||||
- Frontend ↔ Backend API
|
||||
- Backend ↔ Workers/internal APIs
|
||||
- Backend/workers ↔ Datastores
|
||||
- Backend ↔ External LLM/connectors/remote URLs
|
||||
|
||||
Untrusted input includes API payloads, file uploads, remote URLs, OAuth/webhook data, retrieved content, and LLM/tool arguments.
|
||||
|
||||
## 6) Main attack surfaces
|
||||
|
||||
1. Auth/authz paths and sharing tokens.
|
||||
2. File upload + parsing pipeline.
|
||||
3. Remote URL fetching and connectors (SSRF risk).
|
||||
4. Agent/tool execution from LLM output.
|
||||
5. Template/workflow rendering.
|
||||
6. Frontend rendering + token storage.
|
||||
7. Internal service endpoints (`INTERNAL_KEY`).
|
||||
8. High-impact integrations (SQL tool, generic API tool, remote MCP tools).
|
||||
|
||||
## 7) Key threats and expected mitigations
|
||||
|
||||
### A. Auth/authz misconfiguration
|
||||
- Threat: weak/no auth or leaked tokens leads to broad data access.
|
||||
- Mitigations: require auth for public deployments, short-lived tokens, rotation/revocation, least-privilege sharing.
|
||||
|
||||
### B. Untrusted file ingestion
|
||||
- Threat: malicious files/archives trigger traversal, parser exploits, or resource exhaustion.
|
||||
- Mitigations: strict path checks, archive safeguards, file limits, patched parser dependencies.
|
||||
|
||||
### C. SSRF/outbound abuse
|
||||
- Threat: URL loaders/tools access private/internal/metadata endpoints.
|
||||
- Mitigations: validate URLs + redirects, block private/link-local ranges, apply egress controls/allowlists.
|
||||
|
||||
### D. Prompt injection + tool abuse
|
||||
- Threat: retrieved text manipulates model behavior and causes unsafe tool calls.
|
||||
- Threat: never rely on the model to "choose correctly" under adversarial input.
|
||||
- Mitigations: treat retrieved/model output as untrusted, enforce tool policies, only expose tools explicitly assigned by the user/admin to that agent, separate system instructions from retrieved content, audit tool calls.
|
||||
|
||||
### E. Dangerous tool capability chaining (SQL/API/MCP)
|
||||
- Threat: write-capable SQL credentials allow destructive queries.
|
||||
- Threat: API tool can trigger side effects (infra/payment/webhook/code-exec endpoints).
|
||||
- Threat: remote MCP tools may expose privileged operations.
|
||||
- Mitigations: read-only-by-default credentials, destination allowlists, explicit approval for write/exec actions, per-tool policy enforcement + logging.
|
||||
|
||||
### F. Frontend/XSS + token theft
|
||||
- Threat: XSS can steal local tokens and call APIs.
|
||||
- Mitigations: reduce unsafe rendering paths, strong CSP, scoped short-lived credentials.
|
||||
|
||||
### G. Internal endpoint exposure
|
||||
- Threat: weak/unset `INTERNAL_KEY` enables internal API abuse.
|
||||
- Mitigations: fail closed, require strong random keys, keep internal APIs private.
|
||||
|
||||
### H. DoS and cost abuse
|
||||
- Threat: request floods, large ingestion jobs, expensive prompts/crawls.
|
||||
- Mitigations: rate limits, quotas, timeouts, queue backpressure, usage budgets.
|
||||
|
||||
## 8) Example attacker stories
|
||||
|
||||
- Internet-exposed deployment runs with weak/no auth and receives unauthorized data access/abuse.
|
||||
- Intranet deployment intentionally using weak/no auth is vulnerable to insider misuse and lateral-movement abuse.
|
||||
- Crafted archive attempts path traversal during extraction.
|
||||
- Malicious URL/redirect chain targets internal services.
|
||||
- Poisoned document causes data exfiltration through tool calls.
|
||||
- Over-privileged SQL/API/MCP tool performs destructive side effects.
|
||||
|
||||
## 9) Severity calibration
|
||||
|
||||
- **Critical:** unauthenticated public data access; prompt-injection-driven exfiltration; SSRF to sensitive internal endpoints.
|
||||
- **High:** cross-tenant leakage, persistent token compromise, over-privileged destructive tools.
|
||||
- **Medium:** DoS/cost amplification and non-critical information disclosure.
|
||||
- **Low:** minor hardening gaps with limited impact.
|
||||
|
||||
## 10) Baseline controls for public deployments
|
||||
|
||||
1. Enforce authentication and secure defaults.
|
||||
2. Set/rotate strong secrets (`JWT`, `INTERNAL_KEY`, encryption keys).
|
||||
3. Restrict CORS and front API with a hardened proxy.
|
||||
4. Add rate limiting/quotas for answer/upload/crawl/token endpoints.
|
||||
5. Enforce URL+redirect SSRF protections and egress restrictions.
|
||||
6. Apply upload/archive/parsing hardening.
|
||||
7. Require least-privilege tool credentials and auditable tool execution.
|
||||
8. Monitor auth failures, tool anomalies, ingestion spikes, and cost anomalies.
|
||||
9. Keep dependencies/images patched and scanned.
|
||||
10. Validate multi-tenant isolation with explicit tests.
|
||||
|
||||
## 11) Maintenance
|
||||
|
||||
Review this model after major auth, ingestion, connector, tool, or workflow changes.
|
||||
|
||||
## References
|
||||
|
||||
- [OWASP Top 10 for LLM Applications](https://owasp.org/www-project-top-10-for-large-language-model-applications/)
|
||||
- [OWASP ASVS](https://owasp.org/www-project-application-security-verification-standard/)
|
||||
- [STRIDE overview](https://learn.microsoft.com/azure/security/develop/threat-modeling-tool-threats)
|
||||
- [DocsGPT SECURITY.md](../SECURITY.md)
|
||||
@@ -1,80 +1,46 @@
|
||||
Agentic
|
||||
Anthropic's
|
||||
api
|
||||
APIs
|
||||
Atlassian
|
||||
automations
|
||||
autoescaping
|
||||
Autoescaping
|
||||
backfill
|
||||
backfills
|
||||
bool
|
||||
boolean
|
||||
brave_web_search
|
||||
chatbot
|
||||
Ollama
|
||||
Qdrant
|
||||
Milvus
|
||||
Chatwoot
|
||||
config
|
||||
configs
|
||||
CSVs
|
||||
dev
|
||||
diarization
|
||||
Docling
|
||||
docsgpt
|
||||
docstrings
|
||||
Entra
|
||||
env
|
||||
enqueues
|
||||
EOL
|
||||
ESLint
|
||||
feedbacks
|
||||
Figma
|
||||
GPUs
|
||||
Nextra
|
||||
VSCode
|
||||
npm
|
||||
LLMs
|
||||
APIs
|
||||
Groq
|
||||
hardcode
|
||||
hardcoding
|
||||
Idempotency
|
||||
SGLang
|
||||
LMDeploy
|
||||
OAuth
|
||||
Vite
|
||||
LLM
|
||||
JSONPath
|
||||
UIs
|
||||
configs
|
||||
uncomment
|
||||
qdrant
|
||||
vectorstore
|
||||
docsgpt
|
||||
llm
|
||||
GPUs
|
||||
kubectl
|
||||
Lightsail
|
||||
llama_cpp
|
||||
llm
|
||||
LLM
|
||||
LLMs
|
||||
LMDeploy
|
||||
Milvus
|
||||
Mixtral
|
||||
namespace
|
||||
namespaces
|
||||
needs_auth
|
||||
Nextra
|
||||
Novita
|
||||
npm
|
||||
OAuth
|
||||
Ollama
|
||||
opencode
|
||||
parsable
|
||||
passthrough
|
||||
PDFs
|
||||
pgvector
|
||||
Postgres
|
||||
enqueues
|
||||
chatbot
|
||||
VSCode's
|
||||
Shareability
|
||||
feedbacks
|
||||
automations
|
||||
Premade
|
||||
Pydantic
|
||||
pytest
|
||||
Qdrant
|
||||
qdrant
|
||||
Signup
|
||||
Repo
|
||||
repo
|
||||
Sanitization
|
||||
SDKs
|
||||
SGLang
|
||||
Shareability
|
||||
Signup
|
||||
Supabase
|
||||
UIs
|
||||
uncomment
|
||||
env
|
||||
URl
|
||||
vectorstore
|
||||
Vite
|
||||
VSCode
|
||||
VSCode's
|
||||
widget's
|
||||
agentic
|
||||
llama_cpp
|
||||
parsable
|
||||
SDKs
|
||||
boolean
|
||||
bool
|
||||
hardcode
|
||||
EOL
|
||||
|
||||
22
.github/workflows/vale.yml
vendored
22
.github/workflows/vale.yml
vendored
@@ -11,6 +11,7 @@ on:
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
pull-requests: write
|
||||
|
||||
jobs:
|
||||
vale:
|
||||
@@ -19,16 +20,11 @@ jobs:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Install Vale
|
||||
run: |
|
||||
curl -fsSL -o vale.tar.gz \
|
||||
https://github.com/errata-ai/vale/releases/download/v3.0.5/vale_3.0.5_Linux_64-bit.tar.gz
|
||||
tar -xzf vale.tar.gz
|
||||
sudo mv vale /usr/local/bin/vale
|
||||
vale --version
|
||||
|
||||
- name: Sync Vale packages
|
||||
run: vale sync
|
||||
|
||||
- name: Run Vale
|
||||
run: vale --minAlertLevel=error docs
|
||||
- name: Vale linter
|
||||
uses: errata-ai/vale-action@v2
|
||||
with:
|
||||
files: docs
|
||||
fail_on_error: false
|
||||
version: 3.0.5
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
25
.github/workflows/zizmor.yml
vendored
25
.github/workflows/zizmor.yml
vendored
@@ -1,25 +0,0 @@
|
||||
name: GitHub Actions Security Analysis
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: ["master"]
|
||||
pull_request:
|
||||
branches: ["**"]
|
||||
|
||||
permissions: {}
|
||||
|
||||
jobs:
|
||||
zizmor:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
permissions:
|
||||
security-events: write # Required for upload-sarif (used by zizmor-action) to upload SARIF files.
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Run zizmor 🌈
|
||||
uses: zizmorcore/zizmor-action@71321a20a9ded102f6e9ce5718a2fcec2c4f70d8 # v0.5.2
|
||||
11
.gitignore
vendored
11
.gitignore
vendored
@@ -108,8 +108,6 @@ celerybeat.pid
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
# Machine-specific Claude Code guidance (see CLAUDE.md preamble)
|
||||
CLAUDE.md
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
@@ -183,14 +181,5 @@ application/vectors/
|
||||
|
||||
node_modules/
|
||||
.vscode/settings.json
|
||||
.vscode/sftp.json
|
||||
/models/
|
||||
model/
|
||||
|
||||
# E2E test artifacts
|
||||
.e2e-tmp/
|
||||
/tmp/docsgpt-e2e/
|
||||
tests/e2e/node_modules/
|
||||
tests/e2e/playwright-report/
|
||||
tests/e2e/test-results/
|
||||
tests/e2e/.e2e-last-run.json
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
MinAlertLevel = warning
|
||||
StylesPath = .github/styles
|
||||
Vocab = DocsGPT
|
||||
|
||||
[*.{md,mdx}]
|
||||
BasedOnStyles = DocsGPT
|
||||
|
||||
|
||||
28
AGENTS.md
28
AGENTS.md
@@ -10,15 +10,9 @@
|
||||
For feature work, do **not** assume the environment needs to be recreated.
|
||||
|
||||
- Check whether the user already has a Python virtual environment such as `venv/` or `.venv/`.
|
||||
- Check whether Postgres is already running and reachable via `POSTGRES_URI` (the canonical user-data store).
|
||||
- Check whether MongoDB is already running.
|
||||
- Check whether Redis is already running.
|
||||
- Reuse what is already working. Do not stop or recreate Postgres, Redis, or the Python environment unless the task is environment setup or troubleshooting.
|
||||
|
||||
> MongoDB is **not** required for the default install. It is only needed if
|
||||
> the user opts into the Mongo vector-store backend (`VECTOR_STORE=mongodb`)
|
||||
> or is running the one-shot `scripts/db/backfill.py` to migrate existing
|
||||
> user data from the legacy Mongo-based install. In those cases, `pymongo`
|
||||
> is available as an optional extra, not a core dependency.
|
||||
- Reuse what is already working. Do not stop or recreate MongoDB, Redis, or the Python environment unless the task is environment setup or troubleshooting.
|
||||
|
||||
## Normal local development commands
|
||||
|
||||
@@ -37,22 +31,6 @@ Run the Flask API (if needed):
|
||||
flask --app application/app.py run --host=0.0.0.0 --port=7091
|
||||
```
|
||||
|
||||
That's the fast inner-loop option — quick startup, the Werkzeug interactive
|
||||
debugger still works, and it hot-reloads on source changes. It serves the
|
||||
Flask routes only (`/api/*`, `/stream`, etc.).
|
||||
|
||||
If you need to exercise the full ASGI stack — the `/mcp` FastMCP endpoint,
|
||||
or to match the production runtime exactly — run the ASGI composition under
|
||||
uvicorn instead:
|
||||
|
||||
```bash
|
||||
uvicorn application.asgi:asgi_app --host 0.0.0.0 --port 7091 --reload
|
||||
```
|
||||
|
||||
Production uses `gunicorn -k uvicorn_worker.UvicornWorker` against the same
|
||||
`application.asgi:asgi_app` target; see `application/Dockerfile` for the
|
||||
full flag set.
|
||||
|
||||
Run the Celery worker in a separate terminal (if needed):
|
||||
|
||||
```bash
|
||||
@@ -115,7 +93,7 @@ vale .
|
||||
- `frontend/`: Vite + React + TypeScript application.
|
||||
- `frontend/src/`: main UI code, including `components`, `conversation`, `hooks`, `locale`, `settings`, `upload`, and Redux store wiring in `store.ts`.
|
||||
- `docs/`: separate documentation site built with Next.js/Nextra.
|
||||
- `extensions/`: integrations and widgets — currently the Chatwoot webhook bridge and the React widget (published to npm as `docsgpt`). The Discord bot, Slack bot, and Chrome extension have been moved to their own repos under `arc53/`.
|
||||
- `extensions/`: integrations and widgets such as Chatwoot, Chrome, Discord, React widget, Slack bot, and web widget.
|
||||
- `deployment/`: Docker Compose variants and Kubernetes manifests.
|
||||
|
||||
## Coding rules
|
||||
|
||||
@@ -29,7 +29,7 @@
|
||||
|
||||
<div align="center">
|
||||
<br>
|
||||
<img src="https://d3dg1063dc54p9.cloudfront.net/videos/demo-26.gif" alt="video-example-of-docs-gpt" width="800" height="480">
|
||||
<img src="https://d3dg1063dc54p9.cloudfront.net/videos/demov7.gif" alt="video-example-of-docs-gpt" width="800" height="450">
|
||||
</div>
|
||||
<h3 align="left">
|
||||
<strong>Key Features:</strong>
|
||||
|
||||
18
SECURITY.md
18
SECURITY.md
@@ -2,21 +2,13 @@
|
||||
|
||||
## Supported Versions
|
||||
|
||||
Security patches target the latest release and the `main` branch. We recommend always running the most recent version.
|
||||
Supported Versions:
|
||||
|
||||
Currently, we support security patches by committing changes and bumping the version published on Github.
|
||||
|
||||
## Reporting a Vulnerability
|
||||
|
||||
Preferred method: use GitHub's private vulnerability reporting flow:
|
||||
https://github.com/arc53/DocsGPT/security
|
||||
Found a vulnerability? Please email us:
|
||||
|
||||
Then click **Report a vulnerability**.
|
||||
|
||||
|
||||
Alternatively, email us at: security@arc53.com
|
||||
|
||||
We aim to acknowledge reports within 48 hours.
|
||||
|
||||
## Incident Handling
|
||||
|
||||
For the public incident response process, see [`INCIDENT_RESPONSE.md`](./.github/INCIDENT_RESPONSE.md). If you believe an active exploit is occurring, include **URGENT** in your report subject line.
|
||||
security@arc53.com
|
||||
|
||||
|
||||
@@ -88,15 +88,5 @@ EXPOSE 7091
|
||||
# Switch to non-root user
|
||||
USER appuser
|
||||
|
||||
CMD ["gunicorn", \
|
||||
"-w", "1", \
|
||||
"-k", "uvicorn_worker.UvicornWorker", \
|
||||
"--bind", "0.0.0.0:7091", \
|
||||
"--timeout", "180", \
|
||||
"--graceful-timeout", "120", \
|
||||
"--keep-alive", "5", \
|
||||
"--worker-tmp-dir", "/dev/shm", \
|
||||
"--max-requests", "1000", \
|
||||
"--max-requests-jitter", "100", \
|
||||
"--config", "application/gunicorn_conf.py", \
|
||||
"application.asgi:asgi_app"]
|
||||
# Start Gunicorn
|
||||
CMD ["gunicorn", "-w", "1", "--timeout", "120", "--bind", "0.0.0.0:7091", "--preload", "application.wsgi:app"]
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, Generator, List, Optional
|
||||
from typing import Dict, Generator, List, Optional
|
||||
|
||||
from application.agents.tool_executor import ToolExecutor
|
||||
from application.core.json_schema_utils import (
|
||||
@@ -10,7 +9,6 @@ from application.core.json_schema_utils import (
|
||||
normalize_json_schema_payload,
|
||||
)
|
||||
from application.core.settings import settings
|
||||
from application.llm.handlers.base import ToolCall
|
||||
from application.llm.handlers.handler_creator import LLMHandlerCreator
|
||||
from application.llm.llm_creator import LLMCreator
|
||||
from application.logging import build_stack_data, log_activity, LogContext
|
||||
@@ -115,153 +113,6 @@ class BaseAgent(ABC):
|
||||
) -> Generator[Dict, None, None]:
|
||||
pass
|
||||
|
||||
def gen_continuation(
|
||||
self,
|
||||
messages: List[Dict],
|
||||
tools_dict: Dict,
|
||||
pending_tool_calls: List[Dict],
|
||||
tool_actions: List[Dict],
|
||||
) -> Generator[Dict, None, None]:
|
||||
"""Resume generation after tool actions are resolved.
|
||||
|
||||
Processes the client-provided *tool_actions* (approvals, denials,
|
||||
or client-side results), appends the resulting messages, then
|
||||
hands back to the LLM to continue the conversation.
|
||||
|
||||
Args:
|
||||
messages: The saved messages array from the pause point.
|
||||
tools_dict: The saved tools dictionary.
|
||||
pending_tool_calls: The pending tool call descriptors from the pause.
|
||||
tool_actions: Client-provided actions resolving the pending calls.
|
||||
"""
|
||||
self._prepare_tools(tools_dict)
|
||||
|
||||
actions_by_id = {a["call_id"]: a for a in tool_actions}
|
||||
|
||||
# Build a single assistant message containing all tool calls so
|
||||
# the message history matches the format LLM providers expect
|
||||
# (one assistant message with N tool_calls, followed by N tool results).
|
||||
tc_objects: List[Dict[str, Any]] = []
|
||||
for pending in pending_tool_calls:
|
||||
call_id = pending["call_id"]
|
||||
args = pending["arguments"]
|
||||
args_str = (
|
||||
json.dumps(args) if isinstance(args, dict) else (args or "{}")
|
||||
)
|
||||
tc_obj: Dict[str, Any] = {
|
||||
"id": call_id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": pending["name"],
|
||||
"arguments": args_str,
|
||||
},
|
||||
}
|
||||
if pending.get("thought_signature"):
|
||||
tc_obj["thought_signature"] = pending["thought_signature"]
|
||||
tc_objects.append(tc_obj)
|
||||
|
||||
messages.append({
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": tc_objects,
|
||||
})
|
||||
|
||||
# Now process each pending call and append tool result messages
|
||||
for pending in pending_tool_calls:
|
||||
call_id = pending["call_id"]
|
||||
args = pending["arguments"]
|
||||
action = actions_by_id.get(call_id)
|
||||
if not action:
|
||||
action = {
|
||||
"call_id": call_id,
|
||||
"decision": "denied",
|
||||
"comment": "No response provided",
|
||||
}
|
||||
|
||||
if action.get("decision") == "approved":
|
||||
# Execute the tool server-side
|
||||
tc = ToolCall(
|
||||
id=call_id,
|
||||
name=pending["name"],
|
||||
arguments=(
|
||||
json.dumps(args) if isinstance(args, dict) else args
|
||||
),
|
||||
)
|
||||
tool_gen = self._execute_tool_action(tools_dict, tc)
|
||||
tool_response = None
|
||||
while True:
|
||||
try:
|
||||
event = next(tool_gen)
|
||||
yield event
|
||||
except StopIteration as e:
|
||||
tool_response, _ = e.value
|
||||
break
|
||||
messages.append(
|
||||
self.llm_handler.create_tool_message(tc, tool_response)
|
||||
)
|
||||
|
||||
elif action.get("decision") == "denied":
|
||||
comment = action.get("comment", "")
|
||||
denial = (
|
||||
f"Tool execution denied by user. Reason: {comment}"
|
||||
if comment
|
||||
else "Tool execution denied by user."
|
||||
)
|
||||
tc = ToolCall(
|
||||
id=call_id, name=pending["name"], arguments=args
|
||||
)
|
||||
messages.append(
|
||||
self.llm_handler.create_tool_message(tc, denial)
|
||||
)
|
||||
yield {
|
||||
"type": "tool_call",
|
||||
"data": {
|
||||
"tool_name": pending.get("tool_name", "unknown"),
|
||||
"call_id": call_id,
|
||||
"action_name": pending.get("llm_name", pending["name"]),
|
||||
"arguments": args,
|
||||
"status": "denied",
|
||||
},
|
||||
}
|
||||
|
||||
elif "result" in action:
|
||||
result = action["result"]
|
||||
result_str = (
|
||||
json.dumps(result)
|
||||
if not isinstance(result, str)
|
||||
else result
|
||||
)
|
||||
tc = ToolCall(
|
||||
id=call_id, name=pending["name"], arguments=args
|
||||
)
|
||||
messages.append(
|
||||
self.llm_handler.create_tool_message(tc, result_str)
|
||||
)
|
||||
yield {
|
||||
"type": "tool_call",
|
||||
"data": {
|
||||
"tool_name": pending.get("tool_name", "unknown"),
|
||||
"call_id": call_id,
|
||||
"action_name": pending.get("llm_name", pending["name"]),
|
||||
"arguments": args,
|
||||
"result": (
|
||||
result_str[:50] + "..."
|
||||
if len(result_str) > 50
|
||||
else result_str
|
||||
),
|
||||
"status": "completed",
|
||||
},
|
||||
}
|
||||
|
||||
# Resume the LLM loop with the updated messages
|
||||
llm_response = self._llm_gen(messages)
|
||||
yield from self._handle_response(
|
||||
llm_response, tools_dict, messages, None
|
||||
)
|
||||
|
||||
yield {"sources": self.retrieved_docs}
|
||||
yield {"tool_calls": self._get_truncated_tool_calls()}
|
||||
|
||||
# ---- Tool delegation (thin wrappers around ToolExecutor) ----
|
||||
|
||||
@property
|
||||
@@ -416,35 +267,28 @@ class BaseAgent(ABC):
|
||||
if "tool_calls" in i:
|
||||
for tool_call in i["tool_calls"]:
|
||||
call_id = tool_call.get("call_id") or str(uuid.uuid4())
|
||||
args = tool_call.get("arguments")
|
||||
args_str = (
|
||||
json.dumps(args)
|
||||
if isinstance(args, dict)
|
||||
else (args or "{}")
|
||||
|
||||
function_call_dict = {
|
||||
"function_call": {
|
||||
"name": tool_call.get("action_name"),
|
||||
"args": tool_call.get("arguments"),
|
||||
"call_id": call_id,
|
||||
}
|
||||
}
|
||||
function_response_dict = {
|
||||
"function_response": {
|
||||
"name": tool_call.get("action_name"),
|
||||
"response": {"result": tool_call.get("result")},
|
||||
"call_id": call_id,
|
||||
}
|
||||
}
|
||||
|
||||
messages.append(
|
||||
{"role": "assistant", "content": [function_call_dict]}
|
||||
)
|
||||
messages.append({
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [{
|
||||
"id": call_id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool_call.get("action_name", ""),
|
||||
"arguments": args_str,
|
||||
},
|
||||
}],
|
||||
})
|
||||
result = tool_call.get("result")
|
||||
result_str = (
|
||||
json.dumps(result)
|
||||
if not isinstance(result, str)
|
||||
else (result or "")
|
||||
messages.append(
|
||||
{"role": "tool", "content": [function_response_dict]}
|
||||
)
|
||||
messages.append({
|
||||
"role": "tool",
|
||||
"tool_call_id": call_id,
|
||||
"content": result_str,
|
||||
})
|
||||
messages.append({"role": "user", "content": query})
|
||||
return messages
|
||||
|
||||
|
||||
@@ -593,22 +593,16 @@ class ResearchAgent(BaseAgent):
|
||||
)
|
||||
result = result_str
|
||||
|
||||
import json as _json
|
||||
|
||||
args_str = (
|
||||
_json.dumps(call.arguments)
|
||||
if isinstance(call.arguments, dict)
|
||||
else call.arguments
|
||||
function_call_content = {
|
||||
"function_call": {
|
||||
"name": call.name,
|
||||
"args": call.arguments,
|
||||
"call_id": call_id,
|
||||
}
|
||||
}
|
||||
messages.append(
|
||||
{"role": "assistant", "content": [function_call_content]}
|
||||
)
|
||||
messages.append({
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [{
|
||||
"id": call_id,
|
||||
"type": "function",
|
||||
"function": {"name": call.name, "arguments": args_str},
|
||||
}],
|
||||
})
|
||||
tool_message = self.llm_handler.create_tool_message(call, result)
|
||||
messages.append(tool_message)
|
||||
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
import logging
|
||||
import uuid
|
||||
from collections import Counter
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from bson.objectid import ObjectId
|
||||
|
||||
from application.agents.tools.tool_action_parser import ToolActionParser
|
||||
from application.agents.tools.tool_manager import ToolManager
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.core.settings import settings
|
||||
from application.security.encryption import decrypt_credentials
|
||||
from application.storage.db.repositories.agents import AgentsRepository
|
||||
from application.storage.db.repositories.user_tools import UserToolsRepository
|
||||
from application.storage.db.session import db_readonly
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -31,166 +31,63 @@ class ToolExecutor:
|
||||
self.tool_calls: List[Dict] = []
|
||||
self._loaded_tools: Dict[str, object] = {}
|
||||
self.conversation_id: Optional[str] = None
|
||||
self.client_tools: Optional[List[Dict]] = None
|
||||
self._name_to_tool: Dict[str, Tuple[str, str]] = {}
|
||||
self._tool_to_name: Dict[Tuple[str, str], str] = {}
|
||||
|
||||
def get_tools(self) -> Dict[str, Dict]:
|
||||
"""Load tool configs from DB based on user context.
|
||||
|
||||
If *client_tools* have been set on this executor, they are
|
||||
automatically merged into the returned dict.
|
||||
"""
|
||||
"""Load tool configs from DB based on user context."""
|
||||
if self.user_api_key:
|
||||
tools = self._get_tools_by_api_key(self.user_api_key)
|
||||
else:
|
||||
tools = self._get_user_tools(self.user or "local")
|
||||
if self.client_tools:
|
||||
self.merge_client_tools(tools, self.client_tools)
|
||||
return tools
|
||||
return self._get_tools_by_api_key(self.user_api_key)
|
||||
return self._get_user_tools(self.user or "local")
|
||||
|
||||
def _get_tools_by_api_key(self, api_key: str) -> Dict[str, Dict]:
|
||||
# Per-operation session: the answer pipeline spans a long-lived
|
||||
# generator; wrapping it in a single connection would pin a PG
|
||||
# conn for the whole stream. Open, fetch, close.
|
||||
with db_readonly() as conn:
|
||||
agent_data = AgentsRepository(conn).find_by_key(api_key)
|
||||
tool_ids = agent_data.get("tools", []) if agent_data else []
|
||||
if not tool_ids:
|
||||
return {}
|
||||
tools_repo = UserToolsRepository(conn)
|
||||
tools: List[Dict] = []
|
||||
owner = (agent_data.get("user_id") or agent_data.get("user")) if agent_data else None
|
||||
for tid in tool_ids:
|
||||
row = None
|
||||
if owner:
|
||||
row = tools_repo.get_any(str(tid), owner)
|
||||
if row is not None:
|
||||
tools.append(row)
|
||||
return {str(tool["id"]): tool for tool in tools} if tools else {}
|
||||
mongo = MongoDB.get_client()
|
||||
db = mongo[settings.MONGO_DB_NAME]
|
||||
agents_collection = db["agents"]
|
||||
tools_collection = db["user_tools"]
|
||||
|
||||
agent_data = agents_collection.find_one({"key": api_key})
|
||||
tool_ids = agent_data.get("tools", []) if agent_data else []
|
||||
|
||||
tools = (
|
||||
tools_collection.find(
|
||||
{"_id": {"$in": [ObjectId(tool_id) for tool_id in tool_ids]}}
|
||||
)
|
||||
if tool_ids
|
||||
else []
|
||||
)
|
||||
tools = list(tools)
|
||||
return {str(tool["_id"]): tool for tool in tools} if tools else {}
|
||||
|
||||
def _get_user_tools(self, user: str = "local") -> Dict[str, Dict]:
|
||||
with db_readonly() as conn:
|
||||
user_tools = UserToolsRepository(conn).list_active_for_user(user)
|
||||
mongo = MongoDB.get_client()
|
||||
db = mongo[settings.MONGO_DB_NAME]
|
||||
user_tools_collection = db["user_tools"]
|
||||
user_tools = user_tools_collection.find({"user": user, "status": True})
|
||||
user_tools = list(user_tools)
|
||||
return {str(i): tool for i, tool in enumerate(user_tools)}
|
||||
|
||||
def merge_client_tools(
|
||||
self, tools_dict: Dict, client_tools: List[Dict]
|
||||
) -> Dict:
|
||||
"""Merge client-provided tool definitions into tools_dict.
|
||||
|
||||
Client tools use the standard function-calling format::
|
||||
|
||||
[{"type": "function", "function": {"name": "get_weather",
|
||||
"description": "...", "parameters": {...}}}]
|
||||
|
||||
They are stored in *tools_dict* with ``client_side: True`` so that
|
||||
:meth:`check_pause` returns a pause signal instead of trying to
|
||||
execute them server-side.
|
||||
|
||||
Args:
|
||||
tools_dict: The mutable server tools dict (will be modified in place).
|
||||
client_tools: List of tool definitions in function-calling format.
|
||||
|
||||
Returns:
|
||||
The updated *tools_dict* (same reference, for convenience).
|
||||
"""
|
||||
for i, ct in enumerate(client_tools):
|
||||
func = ct.get("function", ct) # tolerate bare {"name":..} too
|
||||
name = func.get("name", f"clienttool{i}")
|
||||
tool_id = f"ct{i}"
|
||||
|
||||
tools_dict[tool_id] = {
|
||||
"name": name,
|
||||
"client_side": True,
|
||||
"actions": [
|
||||
{
|
||||
"name": name,
|
||||
"description": func.get("description", ""),
|
||||
"active": True,
|
||||
"parameters": func.get("parameters", {}),
|
||||
}
|
||||
],
|
||||
}
|
||||
return tools_dict
|
||||
|
||||
def prepare_tools_for_llm(self, tools_dict: Dict) -> List[Dict]:
|
||||
"""Convert tool configs to LLM function schemas.
|
||||
|
||||
Action names are kept clean for the LLM:
|
||||
- Unique action names appear as-is (e.g. ``get_weather``).
|
||||
- Duplicate action names get numbered suffixes (e.g. ``search_1``,
|
||||
``search_2``).
|
||||
|
||||
A reverse mapping is stored in ``_name_to_tool`` so that tool calls
|
||||
can be routed back to the correct ``(tool_id, action_name)`` without
|
||||
brittle string splitting.
|
||||
"""
|
||||
# Pass 1: collect entries and count action name occurrences
|
||||
entries: List[Tuple[str, str, Dict, bool]] = [] # (tool_id, action_name, action, is_client)
|
||||
name_counts: Counter = Counter()
|
||||
|
||||
for tool_id, tool in tools_dict.items():
|
||||
is_api = tool["name"] == "api_tool"
|
||||
is_client = tool.get("client_side", False)
|
||||
|
||||
if is_api and "actions" not in tool.get("config", {}):
|
||||
continue
|
||||
if not is_api and "actions" not in tool:
|
||||
continue
|
||||
|
||||
actions = (
|
||||
tool["config"]["actions"].values()
|
||||
if is_api
|
||||
else tool["actions"]
|
||||
)
|
||||
|
||||
for action in actions:
|
||||
if not action.get("active", True):
|
||||
continue
|
||||
entries.append((tool_id, action["name"], action, is_client))
|
||||
name_counts[action["name"]] += 1
|
||||
|
||||
# Pass 2: assign LLM-visible names and build mappings
|
||||
self._name_to_tool = {}
|
||||
self._tool_to_name = {}
|
||||
collision_counters: Dict[str, int] = {}
|
||||
all_llm_names: set = set()
|
||||
|
||||
result = []
|
||||
for tool_id, action_name, action, is_client in entries:
|
||||
if name_counts[action_name] == 1:
|
||||
llm_name = action_name
|
||||
else:
|
||||
counter = collision_counters.get(action_name, 1)
|
||||
candidate = f"{action_name}_{counter}"
|
||||
# Skip if candidate collides with a unique action name
|
||||
while candidate in all_llm_names or (
|
||||
candidate in name_counts and name_counts[candidate] == 1
|
||||
):
|
||||
counter += 1
|
||||
candidate = f"{action_name}_{counter}"
|
||||
collision_counters[action_name] = counter + 1
|
||||
llm_name = candidate
|
||||
|
||||
all_llm_names.add(llm_name)
|
||||
self._name_to_tool[llm_name] = (tool_id, action_name)
|
||||
self._tool_to_name[(tool_id, action_name)] = llm_name
|
||||
|
||||
if is_client:
|
||||
params = action.get("parameters", {})
|
||||
else:
|
||||
params = self._build_tool_parameters(action)
|
||||
|
||||
result.append({
|
||||
"""Convert tool configs to LLM function schemas."""
|
||||
return [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": llm_name,
|
||||
"description": action.get("description", ""),
|
||||
"parameters": params,
|
||||
"name": f"{action['name']}_{tool_id}",
|
||||
"description": action["description"],
|
||||
"parameters": self._build_tool_parameters(action),
|
||||
},
|
||||
})
|
||||
return result
|
||||
}
|
||||
for tool_id, tool in tools_dict.items()
|
||||
if (
|
||||
(tool["name"] == "api_tool" and "actions" in tool.get("config", {}))
|
||||
or (tool["name"] != "api_tool" and "actions" in tool)
|
||||
)
|
||||
for action in (
|
||||
tool["config"]["actions"].values()
|
||||
if tool["name"] == "api_tool"
|
||||
else tool["actions"]
|
||||
)
|
||||
if action.get("active", True)
|
||||
]
|
||||
|
||||
def _build_tool_parameters(self, action: Dict) -> Dict:
|
||||
params = {"type": "object", "properties": {}, "required": []}
|
||||
@@ -207,81 +104,23 @@ class ToolExecutor:
|
||||
params["required"].append(k)
|
||||
return params
|
||||
|
||||
def check_pause(
|
||||
self, tools_dict: Dict, call, llm_class_name: str
|
||||
) -> Optional[Dict]:
|
||||
"""Check if a tool call requires pausing for approval or client execution.
|
||||
|
||||
Returns a dict describing the pending action if pause is needed, None otherwise.
|
||||
"""
|
||||
parser = ToolActionParser(llm_class_name, name_mapping=self._name_to_tool)
|
||||
tool_id, action_name, call_args = parser.parse_args(call)
|
||||
call_id = getattr(call, "id", None) or str(uuid.uuid4())
|
||||
llm_name = getattr(call, "name", "")
|
||||
|
||||
if tool_id is None or action_name is None or tool_id not in tools_dict:
|
||||
return None # Will be handled as error by execute()
|
||||
|
||||
tool_data = tools_dict[tool_id]
|
||||
|
||||
# Client-side tools
|
||||
if tool_data.get("client_side"):
|
||||
return {
|
||||
"call_id": call_id,
|
||||
"name": llm_name,
|
||||
"tool_name": tool_data.get("name", "unknown"),
|
||||
"tool_id": tool_id,
|
||||
"action_name": action_name,
|
||||
"llm_name": llm_name,
|
||||
"arguments": call_args if isinstance(call_args, dict) else {},
|
||||
"pause_type": "requires_client_execution",
|
||||
"thought_signature": getattr(call, "thought_signature", None),
|
||||
}
|
||||
|
||||
# Approval required
|
||||
if tool_data["name"] == "api_tool":
|
||||
action_data = tool_data.get("config", {}).get("actions", {}).get(
|
||||
action_name, {}
|
||||
)
|
||||
else:
|
||||
action_data = next(
|
||||
(a for a in tool_data.get("actions", []) if a["name"] == action_name),
|
||||
{},
|
||||
)
|
||||
|
||||
if action_data.get("require_approval"):
|
||||
return {
|
||||
"call_id": call_id,
|
||||
"name": llm_name,
|
||||
"tool_name": tool_data.get("name", "unknown"),
|
||||
"tool_id": tool_id,
|
||||
"action_name": action_name,
|
||||
"llm_name": llm_name,
|
||||
"arguments": call_args if isinstance(call_args, dict) else {},
|
||||
"pause_type": "awaiting_approval",
|
||||
"thought_signature": getattr(call, "thought_signature", None),
|
||||
}
|
||||
|
||||
return None
|
||||
|
||||
def execute(self, tools_dict: Dict, call, llm_class_name: str):
|
||||
"""Execute a tool call. Yields status events, returns (result, call_id)."""
|
||||
parser = ToolActionParser(llm_class_name, name_mapping=self._name_to_tool)
|
||||
parser = ToolActionParser(llm_class_name)
|
||||
tool_id, action_name, call_args = parser.parse_args(call)
|
||||
llm_name = getattr(call, "name", "unknown")
|
||||
|
||||
call_id = getattr(call, "id", None) or str(uuid.uuid4())
|
||||
|
||||
if tool_id is None or action_name is None:
|
||||
error_message = f"Error: Failed to parse LLM tool call. Tool name: {llm_name}"
|
||||
error_message = f"Error: Failed to parse LLM tool call. Tool name: {getattr(call, 'name', 'unknown')}"
|
||||
logger.error(error_message)
|
||||
|
||||
tool_call_data = {
|
||||
"tool_name": "unknown",
|
||||
"call_id": call_id,
|
||||
"action_name": llm_name,
|
||||
"action_name": getattr(call, "name", "unknown"),
|
||||
"arguments": call_args or {},
|
||||
"result": f"Failed to parse tool call. Invalid tool name format: {llm_name}",
|
||||
"result": f"Failed to parse tool call. Invalid tool name format: {getattr(call, 'name', 'unknown')}",
|
||||
}
|
||||
yield {"type": "tool_call", "data": {**tool_call_data, "status": "error"}}
|
||||
self.tool_calls.append(tool_call_data)
|
||||
@@ -294,7 +133,7 @@ class ToolExecutor:
|
||||
tool_call_data = {
|
||||
"tool_name": "unknown",
|
||||
"call_id": call_id,
|
||||
"action_name": llm_name,
|
||||
"action_name": f"{action_name}_{tool_id}",
|
||||
"arguments": call_args,
|
||||
"result": f"Tool with ID {tool_id} not found. Available tools: {list(tools_dict.keys())}",
|
||||
}
|
||||
@@ -305,7 +144,7 @@ class ToolExecutor:
|
||||
tool_call_data = {
|
||||
"tool_name": tools_dict[tool_id]["name"],
|
||||
"call_id": call_id,
|
||||
"action_name": llm_name,
|
||||
"action_name": f"{action_name}_{tool_id}",
|
||||
"arguments": call_args,
|
||||
}
|
||||
yield {"type": "tool_call", "data": {**tool_call_data, "status": "pending"}}
|
||||
@@ -346,21 +185,7 @@ class ToolExecutor:
|
||||
target_dict[param] = value
|
||||
|
||||
# Load tool (with caching)
|
||||
tool = self._get_or_load_tool(
|
||||
tool_data, tool_id, action_name,
|
||||
headers=headers, query_params=query_params,
|
||||
)
|
||||
|
||||
if tool is None:
|
||||
error_message = (
|
||||
f"Failed to load tool '{tool_data.get('name')}' (tool_id key={tool_id}): "
|
||||
"missing 'id' on tool row."
|
||||
)
|
||||
logger.error(error_message)
|
||||
tool_call_data["result"] = error_message
|
||||
yield {"type": "tool_call", "data": {**tool_call_data, "status": "error"}}
|
||||
self.tool_calls.append(tool_call_data)
|
||||
return error_message, call_id
|
||||
tool = self._get_or_load_tool(tool_data, tool_id, action_name)
|
||||
|
||||
resolved_arguments = (
|
||||
{"query_params": query_params, "headers": headers, "body": body}
|
||||
@@ -413,10 +238,7 @@ class ToolExecutor:
|
||||
|
||||
return result, call_id
|
||||
|
||||
def _get_or_load_tool(
|
||||
self, tool_data: Dict, tool_id: str, action_name: str,
|
||||
headers: Optional[Dict] = None, query_params: Optional[Dict] = None,
|
||||
):
|
||||
def _get_or_load_tool(self, tool_data: Dict, tool_id: str, action_name: str):
|
||||
"""Load a tool, using cache when possible."""
|
||||
cache_key = f"{tool_data['name']}:{tool_id}:{self.user or ''}"
|
||||
if cache_key in self._loaded_tools:
|
||||
@@ -429,8 +251,8 @@ class ToolExecutor:
|
||||
tool_config = {
|
||||
"url": action_config["url"],
|
||||
"method": action_config["method"],
|
||||
"headers": headers or {},
|
||||
"query_params": query_params or {},
|
||||
"headers": {},
|
||||
"query_params": {},
|
||||
}
|
||||
if "body_content_type" in action_config:
|
||||
tool_config["body_content_type"] = action_config.get(
|
||||
@@ -448,16 +270,7 @@ class ToolExecutor:
|
||||
tool_config.update(decrypted)
|
||||
tool_config["auth_credentials"] = decrypted
|
||||
tool_config.pop("encrypted_credentials", None)
|
||||
row_id = tool_data.get("id")
|
||||
if not row_id:
|
||||
logger.error(
|
||||
"Tool data missing 'id' for tool name=%s (enumerate-key tool_id=%s); "
|
||||
"skipping load to avoid binding a non-UUID downstream.",
|
||||
tool_data.get("name"),
|
||||
tool_id,
|
||||
)
|
||||
return None
|
||||
tool_config["tool_id"] = str(row_id)
|
||||
tool_config["tool_id"] = str(tool_data.get("_id", tool_id))
|
||||
if self.conversation_id:
|
||||
tool_config["conversation_id"] = self.conversation_id
|
||||
if tool_data["name"] == "mcp_tool":
|
||||
|
||||
@@ -2,8 +2,6 @@ from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class Tool(ABC):
|
||||
internal: bool = False
|
||||
|
||||
@abstractmethod
|
||||
def execute_action(self, action_name: str, **kwargs):
|
||||
pass
|
||||
|
||||
@@ -73,7 +73,7 @@ class BraveSearchTool(Tool):
|
||||
"X-Subscription-Token": self.token,
|
||||
}
|
||||
|
||||
response = requests.get(url, params=params, headers=headers, timeout=100)
|
||||
response = requests.get(url, params=params, headers=headers)
|
||||
|
||||
if response.status_code == 200:
|
||||
return {
|
||||
@@ -118,7 +118,7 @@ class BraveSearchTool(Tool):
|
||||
"X-Subscription-Token": self.token,
|
||||
}
|
||||
|
||||
response = requests.get(url, params=params, headers=headers, timeout=100)
|
||||
response = requests.get(url, params=params, headers=headers)
|
||||
|
||||
if response.status_code == 200:
|
||||
return {
|
||||
|
||||
@@ -28,7 +28,7 @@ class CryptoPriceTool(Tool):
|
||||
returns price in USD.
|
||||
"""
|
||||
url = f"https://min-api.cryptocompare.com/data/price?fsym={symbol.upper()}&tsyms={currency.upper()}"
|
||||
response = requests.get(url, timeout=100)
|
||||
response = requests.get(url)
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
if currency.upper() in data:
|
||||
|
||||
@@ -20,8 +20,6 @@ class InternalSearchTool(Tool):
|
||||
- list_files action: browse the file/folder structure
|
||||
"""
|
||||
|
||||
internal = True
|
||||
|
||||
def __init__(self, config: Dict):
|
||||
self.config = config
|
||||
self.retrieved_docs: List[Dict] = []
|
||||
@@ -48,7 +46,7 @@ class InternalSearchTool(Tool):
|
||||
return self._retriever
|
||||
|
||||
def _get_directory_structure(self) -> Optional[Dict]:
|
||||
"""Load directory structure from Postgres for the configured sources."""
|
||||
"""Load directory structure from MongoDB for the configured sources."""
|
||||
if self._dir_structure_loaded:
|
||||
return self._directory_structure
|
||||
|
||||
@@ -59,39 +57,35 @@ class InternalSearchTool(Tool):
|
||||
return None
|
||||
|
||||
try:
|
||||
# Per-operation session: this tool runs inside the answer
|
||||
# generator hot path, so we open a short-lived read
|
||||
# connection for the batch lookup and release immediately.
|
||||
from application.storage.db.repositories.sources import (
|
||||
SourcesRepository,
|
||||
)
|
||||
from application.storage.db.session import db_readonly
|
||||
from bson.objectid import ObjectId
|
||||
from application.core.mongo_db import MongoDB
|
||||
|
||||
mongo = MongoDB.get_client()
|
||||
db = mongo[settings.MONGO_DB_NAME]
|
||||
sources_collection = db["sources"]
|
||||
|
||||
if isinstance(active_docs, str):
|
||||
active_docs = [active_docs]
|
||||
|
||||
decoded_token = self.config.get("decoded_token") or {}
|
||||
user_id = decoded_token.get("sub") if decoded_token else None
|
||||
|
||||
merged_structure = {}
|
||||
with db_readonly() as conn:
|
||||
repo = SourcesRepository(conn)
|
||||
for doc_id in active_docs:
|
||||
try:
|
||||
source_doc = repo.get_any(str(doc_id), user_id) if user_id else None
|
||||
if not source_doc:
|
||||
continue
|
||||
dir_str = source_doc.get("directory_structure")
|
||||
if dir_str:
|
||||
if isinstance(dir_str, str):
|
||||
dir_str = json.loads(dir_str)
|
||||
source_name = source_doc.get("name", doc_id)
|
||||
if len(active_docs) > 1:
|
||||
merged_structure[source_name] = dir_str
|
||||
else:
|
||||
merged_structure = dir_str
|
||||
except Exception as e:
|
||||
logger.debug(f"Could not load dir structure for {doc_id}: {e}")
|
||||
for doc_id in active_docs:
|
||||
try:
|
||||
source_doc = sources_collection.find_one(
|
||||
{"_id": ObjectId(doc_id)}
|
||||
)
|
||||
if not source_doc:
|
||||
continue
|
||||
dir_str = source_doc.get("directory_structure")
|
||||
if dir_str:
|
||||
if isinstance(dir_str, str):
|
||||
dir_str = json.loads(dir_str)
|
||||
source_name = source_doc.get("name", doc_id)
|
||||
if len(active_docs) > 1:
|
||||
merged_structure[source_name] = dir_str
|
||||
else:
|
||||
merged_structure = dir_str
|
||||
except Exception as e:
|
||||
logger.debug(f"Could not load dir structure for {doc_id}: {e}")
|
||||
|
||||
self._directory_structure = merged_structure if merged_structure else None
|
||||
except Exception as e:
|
||||
@@ -361,48 +355,32 @@ INTERNAL_TOOL_ENTRY = build_internal_tool_entry(has_directory_structure=False)
|
||||
|
||||
|
||||
def sources_have_directory_structure(source: Dict) -> bool:
|
||||
"""Check if any of the active sources have a ``directory_structure`` row."""
|
||||
"""Check if any of the active sources have directory_structure in MongoDB."""
|
||||
active_docs = source.get("active_docs", [])
|
||||
if not active_docs:
|
||||
return False
|
||||
|
||||
try:
|
||||
# TODO(pg-cutover): SourcesRepository.get_any requires ``user_id``
|
||||
# scoping, but callers in the agent build path don't always
|
||||
# thread the decoded token through here. Use a direct
|
||||
# short-lived SQL lookup instead of the repo until the call
|
||||
# sites are updated to propagate user context.
|
||||
from sqlalchemy import text as _text
|
||||
from bson.objectid import ObjectId
|
||||
from application.core.mongo_db import MongoDB
|
||||
|
||||
from application.storage.db.session import db_readonly
|
||||
mongo = MongoDB.get_client()
|
||||
db = mongo[settings.MONGO_DB_NAME]
|
||||
sources_collection = db["sources"]
|
||||
|
||||
if isinstance(active_docs, str):
|
||||
active_docs = [active_docs]
|
||||
|
||||
with db_readonly() as conn:
|
||||
for doc_id in active_docs:
|
||||
try:
|
||||
value = str(doc_id)
|
||||
if len(value) == 36 and "-" in value:
|
||||
row = conn.execute(
|
||||
_text(
|
||||
"SELECT directory_structure FROM sources "
|
||||
"WHERE id = CAST(:id AS uuid)"
|
||||
),
|
||||
{"id": value},
|
||||
).fetchone()
|
||||
else:
|
||||
row = conn.execute(
|
||||
_text(
|
||||
"SELECT directory_structure FROM sources "
|
||||
"WHERE legacy_mongo_id = :lid"
|
||||
),
|
||||
{"lid": value},
|
||||
).fetchone()
|
||||
if row is not None and row[0]:
|
||||
return True
|
||||
except Exception:
|
||||
continue
|
||||
for doc_id in active_docs:
|
||||
try:
|
||||
source_doc = sources_collection.find_one(
|
||||
{"_id": ObjectId(doc_id)},
|
||||
{"directory_structure": 1},
|
||||
)
|
||||
if source_doc and source_doc.get("directory_structure"):
|
||||
return True
|
||||
except Exception:
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.debug(f"Could not check directory structure: {e}")
|
||||
|
||||
|
||||
@@ -22,12 +22,15 @@ from redis import Redis
|
||||
from application.agents.tools.base import Tool
|
||||
from application.api.user.tasks import mcp_oauth_status_task, mcp_oauth_task
|
||||
from application.cache import get_redis_instance
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.core.settings import settings
|
||||
from application.core.url_validation import SSRFError, validate_url
|
||||
from application.security.encryption import decrypt_credentials
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
mongo = MongoDB.get_client()
|
||||
db = mongo[settings.MONGO_DB_NAME]
|
||||
|
||||
_mcp_clients_cache = {}
|
||||
|
||||
|
||||
@@ -58,8 +61,7 @@ class MCPTool(Tool):
|
||||
"""
|
||||
self.config = config
|
||||
self.user_id = user_id
|
||||
raw_url = config.get("server_url", "")
|
||||
self.server_url = self._validate_server_url(raw_url) if raw_url else ""
|
||||
self.server_url = config.get("server_url", "")
|
||||
self.transport_type = config.get("transport_type", "auto")
|
||||
self.auth_type = config.get("auth_type", "none")
|
||||
self.timeout = config.get("timeout", 30)
|
||||
@@ -85,18 +87,6 @@ class MCPTool(Tool):
|
||||
if self.server_url and self.auth_type != "oauth":
|
||||
self._setup_client()
|
||||
|
||||
@staticmethod
|
||||
def _validate_server_url(server_url: str) -> str:
|
||||
"""Validate server_url to prevent SSRF to internal networks.
|
||||
|
||||
Raises:
|
||||
ValueError: If the URL points to a private/internal address.
|
||||
"""
|
||||
try:
|
||||
return validate_url(server_url)
|
||||
except SSRFError as exc:
|
||||
raise ValueError(f"Invalid MCP server URL: {exc}") from exc
|
||||
|
||||
def _resolve_redirect_uri(self, configured_redirect_uri: Optional[str]) -> str:
|
||||
if configured_redirect_uri:
|
||||
return configured_redirect_uri.rstrip("/")
|
||||
@@ -118,9 +108,8 @@ class MCPTool(Tool):
|
||||
auth_key = ""
|
||||
if self.auth_type == "oauth":
|
||||
scopes_str = ",".join(self.oauth_scopes) if self.oauth_scopes else "none"
|
||||
oauth_identity = self.user_id or self.oauth_task_id or "anonymous"
|
||||
auth_key = (
|
||||
f"oauth:{oauth_identity}:{self.oauth_client_name}:{scopes_str}:{self.redirect_uri}"
|
||||
f"oauth:{self.oauth_client_name}:{scopes_str}:{self.redirect_uri}"
|
||||
)
|
||||
elif self.auth_type in ["bearer"]:
|
||||
token = self.auth_credentials.get(
|
||||
@@ -157,6 +146,7 @@ class MCPTool(Tool):
|
||||
scopes=self.oauth_scopes,
|
||||
redis_client=redis_client,
|
||||
redirect_uri=self.redirect_uri,
|
||||
db=db,
|
||||
user_id=self.user_id,
|
||||
)
|
||||
else:
|
||||
@@ -166,6 +156,7 @@ class MCPTool(Tool):
|
||||
redis_client=redis_client,
|
||||
redirect_uri=self.redirect_uri,
|
||||
task_id=self.oauth_task_id,
|
||||
db=db,
|
||||
user_id=self.user_id,
|
||||
)
|
||||
elif self.auth_type == "bearer":
|
||||
@@ -485,7 +476,7 @@ class MCPTool(Tool):
|
||||
|
||||
def _test_oauth_connection(self) -> Dict:
|
||||
storage = DBTokenStorage(
|
||||
server_url=self.server_url, user_id=self.user_id,
|
||||
server_url=self.server_url, user_id=self.user_id, db_client=db
|
||||
)
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
@@ -677,6 +668,7 @@ class DocsGPTOAuth(OAuthClientProvider):
|
||||
scopes: str | list[str] | None = None,
|
||||
client_name: str = "DocsGPT-MCP",
|
||||
user_id=None,
|
||||
db=None,
|
||||
additional_client_metadata: dict[str, Any] | None = None,
|
||||
skip_redirect_validation: bool = False,
|
||||
):
|
||||
@@ -685,6 +677,7 @@ class DocsGPTOAuth(OAuthClientProvider):
|
||||
self.redis_prefix = redis_prefix
|
||||
self.task_id = task_id
|
||||
self.user_id = user_id
|
||||
self.db = db
|
||||
|
||||
parsed_url = urlparse(mcp_url)
|
||||
self.server_base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
|
||||
@@ -703,6 +696,7 @@ class DocsGPTOAuth(OAuthClientProvider):
|
||||
storage = DBTokenStorage(
|
||||
server_url=self.server_base_url,
|
||||
user_id=self.user_id,
|
||||
db_client=self.db,
|
||||
expected_redirect_uri=None if skip_redirect_validation else redirect_uri,
|
||||
)
|
||||
|
||||
@@ -844,95 +838,54 @@ class DBTokenStorage(TokenStorage):
|
||||
self,
|
||||
server_url: str,
|
||||
user_id: str,
|
||||
db_client,
|
||||
expected_redirect_uri: Optional[str] = None,
|
||||
):
|
||||
self.server_url = server_url
|
||||
self.user_id = user_id
|
||||
self.db_client = db_client
|
||||
self.expected_redirect_uri = expected_redirect_uri
|
||||
self.collection = db_client["connector_sessions"]
|
||||
|
||||
@staticmethod
|
||||
def get_base_url(url: str) -> str:
|
||||
parsed = urlparse(url)
|
||||
return f"{parsed.scheme}://{parsed.netloc}"
|
||||
|
||||
def _pg_provider(self) -> str:
|
||||
return f"mcp:{self.get_base_url(self.server_url)}"
|
||||
|
||||
def _fetch_session_data(self) -> dict:
|
||||
"""Read the JSONB ``session_data`` blob for this MCP server row."""
|
||||
from application.storage.db.repositories.connector_sessions import (
|
||||
ConnectorSessionsRepository,
|
||||
)
|
||||
from application.storage.db.session import db_readonly
|
||||
|
||||
base_url = self.get_base_url(self.server_url)
|
||||
with db_readonly() as conn:
|
||||
row = ConnectorSessionsRepository(conn).get_by_user_and_server_url(
|
||||
self.user_id, base_url,
|
||||
)
|
||||
if not row:
|
||||
return {}
|
||||
data = row.get("session_data") or {}
|
||||
if isinstance(data, str):
|
||||
try:
|
||||
data = json.loads(data)
|
||||
except ValueError:
|
||||
return {}
|
||||
return data if isinstance(data, dict) else {}
|
||||
def get_db_key(self) -> dict:
|
||||
return {
|
||||
"server_url": self.get_base_url(self.server_url),
|
||||
"user_id": self.user_id,
|
||||
}
|
||||
|
||||
async def get_tokens(self) -> OAuthToken | None:
|
||||
data = await asyncio.to_thread(self._fetch_session_data)
|
||||
if not data or "tokens" not in data:
|
||||
doc = await asyncio.to_thread(self.collection.find_one, self.get_db_key())
|
||||
if not doc or "tokens" not in doc:
|
||||
return None
|
||||
try:
|
||||
return OAuthToken.model_validate(data["tokens"])
|
||||
return OAuthToken.model_validate(doc["tokens"])
|
||||
except ValidationError as e:
|
||||
logger.error("Could not load tokens: %s", e)
|
||||
return None
|
||||
|
||||
def _merge(self, patch: dict) -> None:
|
||||
"""Shallow-merge ``patch`` into this row's ``session_data``.
|
||||
|
||||
Threads ``server_url`` through to the repository so it lands in
|
||||
the scalar column — ``get_by_user_and_server_url`` needs that to
|
||||
resolve the row (``NULL = 'https://...'`` is UNKNOWN in SQL).
|
||||
"""
|
||||
from application.storage.db.repositories.connector_sessions import (
|
||||
ConnectorSessionsRepository,
|
||||
)
|
||||
from application.storage.db.session import db_session
|
||||
|
||||
base_url = self.get_base_url(self.server_url)
|
||||
with db_session() as conn:
|
||||
ConnectorSessionsRepository(conn).merge_session_data(
|
||||
self.user_id, self._pg_provider(), base_url, patch,
|
||||
)
|
||||
|
||||
def _delete(self) -> None:
|
||||
from application.storage.db.repositories.connector_sessions import (
|
||||
ConnectorSessionsRepository,
|
||||
)
|
||||
from application.storage.db.session import db_session
|
||||
|
||||
with db_session() as conn:
|
||||
ConnectorSessionsRepository(conn).delete(
|
||||
self.user_id, self._pg_provider(),
|
||||
)
|
||||
|
||||
async def set_tokens(self, tokens: OAuthToken) -> None:
|
||||
base_url = self.get_base_url(self.server_url)
|
||||
token_dump = tokens.model_dump()
|
||||
await asyncio.to_thread(self._merge, {"tokens": token_dump})
|
||||
logger.info("Saved tokens for %s", base_url)
|
||||
await asyncio.to_thread(
|
||||
self.collection.update_one,
|
||||
self.get_db_key(),
|
||||
{"$set": {"tokens": tokens.model_dump()}},
|
||||
True,
|
||||
)
|
||||
logger.info("Saved tokens for %s", self.get_base_url(self.server_url))
|
||||
|
||||
async def get_client_info(self) -> OAuthClientInformationFull | None:
|
||||
data = await asyncio.to_thread(self._fetch_session_data)
|
||||
base_url = self.get_base_url(self.server_url)
|
||||
if not data or "client_info" not in data:
|
||||
logger.debug("No client_info in DB for %s", base_url)
|
||||
doc = await asyncio.to_thread(self.collection.find_one, self.get_db_key())
|
||||
if not doc or "client_info" not in doc:
|
||||
logger.debug(
|
||||
"No client_info in DB for %s", self.get_base_url(self.server_url)
|
||||
)
|
||||
return None
|
||||
try:
|
||||
client_info = OAuthClientInformationFull.model_validate(data["client_info"])
|
||||
client_info = OAuthClientInformationFull.model_validate(doc["client_info"])
|
||||
if self.expected_redirect_uri:
|
||||
stored_uris = [
|
||||
str(uri).rstrip("/") for uri in client_info.redirect_uris
|
||||
@@ -941,16 +894,14 @@ class DBTokenStorage(TokenStorage):
|
||||
if expected_uri not in stored_uris:
|
||||
logger.warning(
|
||||
"Redirect URI mismatch for %s: expected=%s stored=%s — clearing.",
|
||||
base_url,
|
||||
self.get_base_url(self.server_url),
|
||||
expected_uri,
|
||||
stored_uris,
|
||||
)
|
||||
# Drop ``tokens`` and ``client_info`` from the JSONB
|
||||
# blob via merge_session_data's ``None``-drops-key
|
||||
# semantics — preserves the row + any other keys.
|
||||
await asyncio.to_thread(
|
||||
self._merge,
|
||||
{"tokens": None, "client_info": None},
|
||||
self.collection.update_one,
|
||||
self.get_db_key(),
|
||||
{"$unset": {"client_info": "", "tokens": ""}},
|
||||
)
|
||||
return None
|
||||
return client_info
|
||||
@@ -965,37 +916,22 @@ class DBTokenStorage(TokenStorage):
|
||||
|
||||
async def set_client_info(self, client_info: OAuthClientInformationFull) -> None:
|
||||
serialized_info = self._serialize_client_info(client_info.model_dump())
|
||||
base_url = self.get_base_url(self.server_url)
|
||||
await asyncio.to_thread(
|
||||
self._merge, {"client_info": serialized_info},
|
||||
self.collection.update_one,
|
||||
self.get_db_key(),
|
||||
{"$set": {"client_info": serialized_info}},
|
||||
True,
|
||||
)
|
||||
logger.info("Saved client info for %s", base_url)
|
||||
logger.info("Saved client info for %s", self.get_base_url(self.server_url))
|
||||
|
||||
async def clear(self) -> None:
|
||||
await asyncio.to_thread(self._delete)
|
||||
await asyncio.to_thread(self.collection.delete_one, self.get_db_key())
|
||||
logger.info("Cleared OAuth cache for %s", self.get_base_url(self.server_url))
|
||||
|
||||
@classmethod
|
||||
async def clear_all(cls, db_client=None) -> None:
|
||||
"""Delete every MCP-tagged connector session row.
|
||||
|
||||
``db_client`` retained for call-site compatibility but unused —
|
||||
storage is Postgres-only now.
|
||||
"""
|
||||
from sqlalchemy import text
|
||||
|
||||
from application.storage.db.session import db_session
|
||||
|
||||
def _delete_all() -> None:
|
||||
with db_session() as conn:
|
||||
conn.execute(
|
||||
text(
|
||||
"DELETE FROM connector_sessions "
|
||||
"WHERE provider LIKE 'mcp:%'"
|
||||
)
|
||||
)
|
||||
|
||||
await asyncio.to_thread(_delete_all)
|
||||
async def clear_all(cls, db_client) -> None:
|
||||
collection = db_client["connector_sessions"]
|
||||
await asyncio.to_thread(collection.delete_many, {})
|
||||
logger.info("Cleared all OAuth client cache data.")
|
||||
|
||||
|
||||
|
||||
@@ -1,14 +1,12 @@
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
import logging
|
||||
import re
|
||||
import uuid
|
||||
|
||||
from .base import Tool
|
||||
from application.storage.db.repositories.memories import MemoriesRepository
|
||||
from application.storage.db.session import db_readonly, db_session
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.core.settings import settings
|
||||
|
||||
|
||||
class MemoryTool(Tool):
|
||||
@@ -29,7 +27,7 @@ class MemoryTool(Tool):
|
||||
self.user_id: Optional[str] = user_id
|
||||
|
||||
# Get tool_id from configuration (passed from user_tools._id in production)
|
||||
# In production, tool_id is the UUID string from user_tools.id.
|
||||
# In production, tool_id is the MongoDB ObjectId string from user_tools collection
|
||||
if tool_config and "tool_id" in tool_config:
|
||||
self.tool_id = tool_config["tool_id"]
|
||||
elif user_id:
|
||||
@@ -39,35 +37,8 @@ class MemoryTool(Tool):
|
||||
# Last resort fallback (shouldn't happen in normal use)
|
||||
self.tool_id = str(uuid.uuid4())
|
||||
|
||||
def _pg_enabled(self) -> bool:
|
||||
"""Return True if this MemoryTool's tool_id is a real ``user_tools.id``.
|
||||
|
||||
The ``memories`` PG table has a UUID foreign key to ``user_tools``.
|
||||
The sentinel ``default_{uid}`` fallback tool_id is not a UUID and
|
||||
has no row in ``user_tools``, so any storage operation would fail
|
||||
the foreign-key check. After the Postgres cutover Postgres is the
|
||||
only store, so for the sentinel case there is nowhere to read or
|
||||
write — operations become no-ops and the tool returns an
|
||||
explanatory error to the caller.
|
||||
"""
|
||||
tool_id = getattr(self, "tool_id", None)
|
||||
if not tool_id or not isinstance(tool_id, str):
|
||||
return False
|
||||
if tool_id.startswith("default_"):
|
||||
logger.debug(
|
||||
"Skipping Postgres operation for MemoryTool with sentinel tool_id=%s",
|
||||
tool_id,
|
||||
)
|
||||
return False
|
||||
from application.storage.db.base_repository import looks_like_uuid
|
||||
|
||||
if not looks_like_uuid(tool_id):
|
||||
logger.debug(
|
||||
"Skipping Postgres operation for MemoryTool with non-UUID tool_id=%s",
|
||||
tool_id,
|
||||
)
|
||||
return False
|
||||
return True
|
||||
db = MongoDB.get_client()[settings.MONGO_DB_NAME]
|
||||
self.collection = db["memories"]
|
||||
|
||||
# -----------------------------
|
||||
# Action implementations
|
||||
@@ -85,12 +56,6 @@ class MemoryTool(Tool):
|
||||
if not self.user_id:
|
||||
return "Error: MemoryTool requires a valid user_id."
|
||||
|
||||
if not self._pg_enabled():
|
||||
return (
|
||||
"Error: MemoryTool is not configured with a persistent tool_id; "
|
||||
"memory storage is unavailable for this session."
|
||||
)
|
||||
|
||||
if action_name == "view":
|
||||
return self._view(
|
||||
kwargs.get("path", "/"),
|
||||
@@ -317,10 +282,14 @@ class MemoryTool(Tool):
|
||||
# Ensure path ends with / for proper prefix matching
|
||||
search_path = path if path.endswith("/") else path + "/"
|
||||
|
||||
with db_readonly() as conn:
|
||||
docs = MemoriesRepository(conn).list_by_prefix(
|
||||
self.user_id, self.tool_id, search_path
|
||||
)
|
||||
# Find all files that start with this directory path
|
||||
query = {
|
||||
"user_id": self.user_id,
|
||||
"tool_id": self.tool_id,
|
||||
"path": {"$regex": f"^{re.escape(search_path)}"}
|
||||
}
|
||||
|
||||
docs = list(self.collection.find(query, {"path": 1}))
|
||||
|
||||
if not docs:
|
||||
return f"Directory: {path}\n(empty)"
|
||||
@@ -341,10 +310,7 @@ class MemoryTool(Tool):
|
||||
|
||||
def _view_file(self, path: str, view_range: Optional[List[int]] = None) -> str:
|
||||
"""View file contents with optional line range."""
|
||||
with db_readonly() as conn:
|
||||
doc = MemoriesRepository(conn).get_by_path(
|
||||
self.user_id, self.tool_id, path
|
||||
)
|
||||
doc = self.collection.find_one({"user_id": self.user_id, "tool_id": self.tool_id, "path": path})
|
||||
|
||||
if not doc or not doc.get("content"):
|
||||
return f"Error: File not found: {path}"
|
||||
@@ -378,10 +344,16 @@ class MemoryTool(Tool):
|
||||
if validated_path == "/" or validated_path.endswith("/"):
|
||||
return "Error: Cannot create a file at directory path."
|
||||
|
||||
with db_session() as conn:
|
||||
MemoriesRepository(conn).upsert(
|
||||
self.user_id, self.tool_id, validated_path, file_text
|
||||
)
|
||||
self.collection.update_one(
|
||||
{"user_id": self.user_id, "tool_id": self.tool_id, "path": validated_path},
|
||||
{
|
||||
"$set": {
|
||||
"content": file_text,
|
||||
"updated_at": datetime.now()
|
||||
}
|
||||
},
|
||||
upsert=True
|
||||
)
|
||||
|
||||
return f"File created: {validated_path}"
|
||||
|
||||
@@ -394,29 +366,30 @@ class MemoryTool(Tool):
|
||||
if not old_str:
|
||||
return "Error: old_str is required."
|
||||
|
||||
with db_session() as conn:
|
||||
repo = MemoriesRepository(conn)
|
||||
doc = repo.get_by_path(self.user_id, self.tool_id, validated_path)
|
||||
doc = self.collection.find_one({"user_id": self.user_id, "tool_id": self.tool_id, "path": validated_path})
|
||||
|
||||
if not doc or not doc.get("content"):
|
||||
return f"Error: File not found: {validated_path}"
|
||||
if not doc or not doc.get("content"):
|
||||
return f"Error: File not found: {validated_path}"
|
||||
|
||||
current_content = str(doc["content"])
|
||||
current_content = str(doc["content"])
|
||||
|
||||
# Check if old_str exists (case-insensitive)
|
||||
if old_str.lower() not in current_content.lower():
|
||||
return f"Error: String '{old_str}' not found in file."
|
||||
# Check if old_str exists (case-insensitive)
|
||||
if old_str.lower() not in current_content.lower():
|
||||
return f"Error: String '{old_str}' not found in file."
|
||||
|
||||
# Case-insensitive replace
|
||||
import re as regex_module
|
||||
updated_content = regex_module.sub(
|
||||
regex_module.escape(old_str),
|
||||
new_str,
|
||||
current_content,
|
||||
flags=regex_module.IGNORECASE,
|
||||
)
|
||||
# Replace the string (case-insensitive)
|
||||
import re as regex_module
|
||||
updated_content = regex_module.sub(regex_module.escape(old_str), new_str, current_content, flags=regex_module.IGNORECASE)
|
||||
|
||||
repo.upsert(self.user_id, self.tool_id, validated_path, updated_content)
|
||||
self.collection.update_one(
|
||||
{"user_id": self.user_id, "tool_id": self.tool_id, "path": validated_path},
|
||||
{
|
||||
"$set": {
|
||||
"content": updated_content,
|
||||
"updated_at": datetime.now()
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
return f"File updated: {validated_path}"
|
||||
|
||||
@@ -429,25 +402,31 @@ class MemoryTool(Tool):
|
||||
if not insert_text:
|
||||
return "Error: insert_text is required."
|
||||
|
||||
with db_session() as conn:
|
||||
repo = MemoriesRepository(conn)
|
||||
doc = repo.get_by_path(self.user_id, self.tool_id, validated_path)
|
||||
doc = self.collection.find_one({"user_id": self.user_id, "tool_id": self.tool_id, "path": validated_path})
|
||||
|
||||
if not doc or not doc.get("content"):
|
||||
return f"Error: File not found: {validated_path}"
|
||||
if not doc or not doc.get("content"):
|
||||
return f"Error: File not found: {validated_path}"
|
||||
|
||||
current_content = str(doc["content"])
|
||||
lines = current_content.split("\n")
|
||||
current_content = str(doc["content"])
|
||||
lines = current_content.split("\n")
|
||||
|
||||
# Convert to 0-indexed
|
||||
index = insert_line - 1
|
||||
if index < 0 or index > len(lines):
|
||||
return f"Error: Invalid line number. File has {len(lines)} lines."
|
||||
# Convert to 0-indexed
|
||||
index = insert_line - 1
|
||||
if index < 0 or index > len(lines):
|
||||
return f"Error: Invalid line number. File has {len(lines)} lines."
|
||||
|
||||
lines.insert(index, insert_text)
|
||||
updated_content = "\n".join(lines)
|
||||
lines.insert(index, insert_text)
|
||||
updated_content = "\n".join(lines)
|
||||
|
||||
repo.upsert(self.user_id, self.tool_id, validated_path, updated_content)
|
||||
self.collection.update_one(
|
||||
{"user_id": self.user_id, "tool_id": self.tool_id, "path": validated_path},
|
||||
{
|
||||
"$set": {
|
||||
"content": updated_content,
|
||||
"updated_at": datetime.now()
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
return f"Text inserted at line {insert_line} in {validated_path}"
|
||||
|
||||
@@ -459,36 +438,39 @@ class MemoryTool(Tool):
|
||||
|
||||
if validated_path == "/":
|
||||
# Delete all files for this user and tool
|
||||
with db_session() as conn:
|
||||
deleted = MemoriesRepository(conn).delete_all(
|
||||
self.user_id, self.tool_id
|
||||
)
|
||||
return f"Deleted {deleted} file(s) from memory."
|
||||
result = self.collection.delete_many({"user_id": self.user_id, "tool_id": self.tool_id})
|
||||
return f"Deleted {result.deleted_count} file(s) from memory."
|
||||
|
||||
# Check if it's a directory (ends with /)
|
||||
if validated_path.endswith("/"):
|
||||
with db_session() as conn:
|
||||
deleted = MemoriesRepository(conn).delete_by_prefix(
|
||||
self.user_id, self.tool_id, validated_path
|
||||
)
|
||||
return f"Deleted directory and {deleted} file(s)."
|
||||
# Delete all files in directory
|
||||
result = self.collection.delete_many({
|
||||
"user_id": self.user_id,
|
||||
"tool_id": self.tool_id,
|
||||
"path": {"$regex": f"^{re.escape(validated_path)}"}
|
||||
})
|
||||
return f"Deleted directory and {result.deleted_count} file(s)."
|
||||
|
||||
# Try as directory first (without trailing slash)
|
||||
# Try to delete as directory first (without trailing slash)
|
||||
# Check if any files start with this path + /
|
||||
search_path = validated_path + "/"
|
||||
with db_session() as conn:
|
||||
repo = MemoriesRepository(conn)
|
||||
directory_deleted = repo.delete_by_prefix(
|
||||
self.user_id, self.tool_id, search_path
|
||||
)
|
||||
if directory_deleted > 0:
|
||||
return f"Deleted directory and {directory_deleted} file(s)."
|
||||
directory_result = self.collection.delete_many({
|
||||
"user_id": self.user_id,
|
||||
"tool_id": self.tool_id,
|
||||
"path": {"$regex": f"^{re.escape(search_path)}"}
|
||||
})
|
||||
|
||||
# Otherwise delete a single file
|
||||
file_deleted = repo.delete_by_path(
|
||||
self.user_id, self.tool_id, validated_path
|
||||
)
|
||||
if directory_result.deleted_count > 0:
|
||||
return f"Deleted directory and {directory_result.deleted_count} file(s)."
|
||||
|
||||
if file_deleted:
|
||||
# Delete single file
|
||||
result = self.collection.delete_one({
|
||||
"user_id": self.user_id,
|
||||
"tool_id": self.tool_id,
|
||||
"path": validated_path
|
||||
})
|
||||
|
||||
if result.deleted_count:
|
||||
return f"Deleted: {validated_path}"
|
||||
return f"Error: File not found: {validated_path}"
|
||||
|
||||
@@ -503,46 +485,62 @@ class MemoryTool(Tool):
|
||||
if validated_old == "/" or validated_new == "/":
|
||||
return "Error: Cannot rename root directory."
|
||||
|
||||
# Directory rename: do all path updates inside one transaction so
|
||||
# the rename is atomic from the caller's perspective.
|
||||
# Check if renaming a directory
|
||||
if validated_old.endswith("/"):
|
||||
# Ensure validated_new also ends with / for proper path replacement
|
||||
if not validated_new.endswith("/"):
|
||||
validated_new = validated_new + "/"
|
||||
|
||||
with db_session() as conn:
|
||||
repo = MemoriesRepository(conn)
|
||||
docs = repo.list_by_prefix(
|
||||
self.user_id, self.tool_id, validated_old
|
||||
# Find all files in the old directory
|
||||
docs = list(self.collection.find({
|
||||
"user_id": self.user_id,
|
||||
"tool_id": self.tool_id,
|
||||
"path": {"$regex": f"^{re.escape(validated_old)}"}
|
||||
}))
|
||||
|
||||
if not docs:
|
||||
return f"Error: Directory not found: {validated_old}"
|
||||
|
||||
# Update paths for all files
|
||||
for doc in docs:
|
||||
old_file_path = doc["path"]
|
||||
new_file_path = old_file_path.replace(validated_old, validated_new, 1)
|
||||
|
||||
self.collection.update_one(
|
||||
{"_id": doc["_id"]},
|
||||
{"$set": {"path": new_file_path, "updated_at": datetime.now()}}
|
||||
)
|
||||
|
||||
if not docs:
|
||||
return f"Error: Directory not found: {validated_old}"
|
||||
|
||||
for doc in docs:
|
||||
old_file_path = doc["path"]
|
||||
new_file_path = old_file_path.replace(
|
||||
validated_old, validated_new, 1
|
||||
)
|
||||
repo.update_path(
|
||||
self.user_id, self.tool_id, old_file_path, new_file_path
|
||||
)
|
||||
|
||||
return f"Renamed directory: {validated_old} -> {validated_new} ({len(docs)} files)"
|
||||
|
||||
# Single-file rename: lookup, collision check, and update in one txn.
|
||||
with db_session() as conn:
|
||||
repo = MemoriesRepository(conn)
|
||||
doc = repo.get_by_path(self.user_id, self.tool_id, validated_old)
|
||||
if not doc:
|
||||
return f"Error: File not found: {validated_old}"
|
||||
# Rename single file
|
||||
doc = self.collection.find_one({
|
||||
"user_id": self.user_id,
|
||||
"tool_id": self.tool_id,
|
||||
"path": validated_old
|
||||
})
|
||||
|
||||
existing = repo.get_by_path(self.user_id, self.tool_id, validated_new)
|
||||
if existing:
|
||||
return f"Error: File already exists at {validated_new}"
|
||||
if not doc:
|
||||
return f"Error: File not found: {validated_old}"
|
||||
|
||||
repo.update_path(
|
||||
self.user_id, self.tool_id, validated_old, validated_new
|
||||
)
|
||||
# Check if new path already exists
|
||||
existing = self.collection.find_one({
|
||||
"user_id": self.user_id,
|
||||
"tool_id": self.tool_id,
|
||||
"path": validated_new
|
||||
})
|
||||
|
||||
if existing:
|
||||
return f"Error: File already exists at {validated_new}"
|
||||
|
||||
# Delete the old document and create a new one with the new path
|
||||
self.collection.delete_one({"user_id": self.user_id, "tool_id": self.tool_id, "path": validated_old})
|
||||
self.collection.insert_one({
|
||||
"user_id": self.user_id,
|
||||
"tool_id": self.tool_id,
|
||||
"path": validated_new,
|
||||
"content": doc.get("content", ""),
|
||||
"updated_at": datetime.now()
|
||||
})
|
||||
|
||||
return f"Renamed: {validated_old} -> {validated_new}"
|
||||
|
||||
@@ -1,16 +1,10 @@
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
import uuid
|
||||
|
||||
from .base import Tool
|
||||
from application.storage.db.repositories.notes import NotesRepository
|
||||
from application.storage.db.session import db_readonly, db_session
|
||||
|
||||
|
||||
# Stable synthetic title used in the Postgres ``notes.title`` column.
|
||||
# The notes tool stores one note per (user_id, tool_id); there is no
|
||||
# user-facing title. PG requires ``title`` NOT NULL, so we write a stable
|
||||
# constant alongside the actual note body in ``content``.
|
||||
_NOTE_TITLE = "note"
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.core.settings import settings
|
||||
|
||||
|
||||
class NotesTool(Tool):
|
||||
@@ -31,6 +25,7 @@ class NotesTool(Tool):
|
||||
self.user_id: Optional[str] = user_id
|
||||
|
||||
# Get tool_id from configuration (passed from user_tools._id in production)
|
||||
# In production, tool_id is the MongoDB ObjectId string from user_tools collection
|
||||
if tool_config and "tool_id" in tool_config:
|
||||
self.tool_id = tool_config["tool_id"]
|
||||
elif user_id:
|
||||
@@ -40,25 +35,11 @@ class NotesTool(Tool):
|
||||
# Last resort fallback (shouldn't happen in normal use)
|
||||
self.tool_id = str(uuid.uuid4())
|
||||
|
||||
db = MongoDB.get_client()[settings.MONGO_DB_NAME]
|
||||
self.collection = db["notes"]
|
||||
|
||||
self._last_artifact_id: Optional[str] = None
|
||||
|
||||
def _pg_enabled(self) -> bool:
|
||||
"""Return True only when ``tool_id`` is a real ``user_tools.id`` UUID.
|
||||
|
||||
``notes.tool_id`` is a UUID FK to ``user_tools``; repo queries
|
||||
``CAST(:tool_id AS uuid)``. The sentinel ``default_{uid}``
|
||||
fallback is neither a UUID nor a ``user_tools`` row, so any DB
|
||||
operation would crash. Mirror MemoryTool's guard and no-op.
|
||||
"""
|
||||
tool_id = getattr(self, "tool_id", None)
|
||||
if not tool_id or not isinstance(tool_id, str):
|
||||
return False
|
||||
if tool_id.startswith("default_"):
|
||||
return False
|
||||
from application.storage.db.base_repository import looks_like_uuid
|
||||
|
||||
return looks_like_uuid(tool_id)
|
||||
|
||||
# -----------------------------
|
||||
# Action implementations
|
||||
# -----------------------------
|
||||
@@ -73,13 +54,7 @@ class NotesTool(Tool):
|
||||
A human-readable string result.
|
||||
"""
|
||||
if not self.user_id:
|
||||
return "Error: NotesTool requires a valid user_id."
|
||||
|
||||
if not self._pg_enabled():
|
||||
return (
|
||||
"Error: NotesTool is not configured with a persistent "
|
||||
"tool_id; note storage is unavailable for this session."
|
||||
)
|
||||
return "Error: NotesTool requires a valid user_id."
|
||||
|
||||
self._last_artifact_id = None
|
||||
|
||||
@@ -160,45 +135,37 @@ class NotesTool(Tool):
|
||||
# -----------------------------
|
||||
# Internal helpers (single-note)
|
||||
# -----------------------------
|
||||
def _fetch_note(self) -> Optional[dict]:
|
||||
"""Read the note row for this (user, tool) from Postgres."""
|
||||
with db_readonly() as conn:
|
||||
return NotesRepository(conn).get_for_user_tool(self.user_id, self.tool_id)
|
||||
|
||||
def _get_note(self) -> str:
|
||||
doc = self._fetch_note()
|
||||
# ``content`` is the PG column; expose as ``note`` to callers via the
|
||||
# textual return value. Frontends that read the artifact via the
|
||||
# repo dict get ``content`` (PG-native) plus the artifact id below.
|
||||
body = (doc or {}).get("content")
|
||||
if not doc or not body:
|
||||
doc = self.collection.find_one({"user_id": self.user_id, "tool_id": self.tool_id})
|
||||
if not doc or not doc.get("note"):
|
||||
return "No note found."
|
||||
if doc.get("id") is not None:
|
||||
self._last_artifact_id = str(doc.get("id"))
|
||||
return str(body)
|
||||
if doc.get("_id") is not None:
|
||||
self._last_artifact_id = str(doc.get("_id"))
|
||||
return str(doc["note"])
|
||||
|
||||
def _overwrite_note(self, content: str) -> str:
|
||||
content = (content or "").strip()
|
||||
if not content:
|
||||
return "Note content required."
|
||||
with db_session() as conn:
|
||||
row = NotesRepository(conn).upsert(
|
||||
self.user_id, self.tool_id, _NOTE_TITLE, content
|
||||
)
|
||||
if row and row.get("id") is not None:
|
||||
self._last_artifact_id = str(row.get("id"))
|
||||
result = self.collection.find_one_and_update(
|
||||
{"user_id": self.user_id, "tool_id": self.tool_id},
|
||||
{"$set": {"note": content, "updated_at": datetime.utcnow()}},
|
||||
upsert=True,
|
||||
return_document=True,
|
||||
)
|
||||
if result and result.get("_id") is not None:
|
||||
self._last_artifact_id = str(result.get("_id"))
|
||||
return "Note saved."
|
||||
|
||||
def _str_replace(self, old_str: str, new_str: str) -> str:
|
||||
if not old_str:
|
||||
return "old_str is required."
|
||||
|
||||
doc = self._fetch_note()
|
||||
existing = (doc or {}).get("content")
|
||||
if not doc or not existing:
|
||||
doc = self.collection.find_one({"user_id": self.user_id, "tool_id": self.tool_id})
|
||||
if not doc or not doc.get("note"):
|
||||
return "No note found."
|
||||
|
||||
current_note = str(existing)
|
||||
current_note = str(doc["note"])
|
||||
|
||||
# Case-insensitive search
|
||||
if old_str.lower() not in current_note.lower():
|
||||
@@ -208,24 +175,24 @@ class NotesTool(Tool):
|
||||
import re
|
||||
updated_note = re.sub(re.escape(old_str), new_str, current_note, flags=re.IGNORECASE)
|
||||
|
||||
with db_session() as conn:
|
||||
row = NotesRepository(conn).upsert(
|
||||
self.user_id, self.tool_id, _NOTE_TITLE, updated_note
|
||||
)
|
||||
if row and row.get("id") is not None:
|
||||
self._last_artifact_id = str(row.get("id"))
|
||||
result = self.collection.find_one_and_update(
|
||||
{"user_id": self.user_id, "tool_id": self.tool_id},
|
||||
{"$set": {"note": updated_note, "updated_at": datetime.utcnow()}},
|
||||
return_document=True,
|
||||
)
|
||||
if result and result.get("_id") is not None:
|
||||
self._last_artifact_id = str(result.get("_id"))
|
||||
return "Note updated."
|
||||
|
||||
def _insert(self, line_number: int, text: str) -> str:
|
||||
if not text:
|
||||
return "Text is required."
|
||||
|
||||
doc = self._fetch_note()
|
||||
existing = (doc or {}).get("content")
|
||||
if not doc or not existing:
|
||||
doc = self.collection.find_one({"user_id": self.user_id, "tool_id": self.tool_id})
|
||||
if not doc or not doc.get("note"):
|
||||
return "No note found."
|
||||
|
||||
current_note = str(existing)
|
||||
current_note = str(doc["note"])
|
||||
lines = current_note.split("\n")
|
||||
|
||||
# Convert to 0-indexed and validate
|
||||
@@ -236,23 +203,21 @@ class NotesTool(Tool):
|
||||
lines.insert(index, text)
|
||||
updated_note = "\n".join(lines)
|
||||
|
||||
with db_session() as conn:
|
||||
row = NotesRepository(conn).upsert(
|
||||
self.user_id, self.tool_id, _NOTE_TITLE, updated_note
|
||||
)
|
||||
if row and row.get("id") is not None:
|
||||
self._last_artifact_id = str(row.get("id"))
|
||||
result = self.collection.find_one_and_update(
|
||||
{"user_id": self.user_id, "tool_id": self.tool_id},
|
||||
{"$set": {"note": updated_note, "updated_at": datetime.utcnow()}},
|
||||
return_document=True,
|
||||
)
|
||||
if result and result.get("_id") is not None:
|
||||
self._last_artifact_id = str(result.get("_id"))
|
||||
return "Text inserted."
|
||||
|
||||
def _delete_note(self) -> str:
|
||||
# Capture the id (for artifact tracking) before deleting.
|
||||
existing = self._fetch_note()
|
||||
if not existing:
|
||||
doc = self.collection.find_one_and_delete(
|
||||
{"user_id": self.user_id, "tool_id": self.tool_id}
|
||||
)
|
||||
if not doc:
|
||||
return "No note found to delete."
|
||||
with db_session() as conn:
|
||||
deleted = NotesRepository(conn).delete(self.user_id, self.tool_id)
|
||||
if not deleted:
|
||||
return "No note found to delete."
|
||||
if existing.get("id") is not None:
|
||||
self._last_artifact_id = str(existing.get("id"))
|
||||
if doc.get("_id") is not None:
|
||||
self._last_artifact_id = str(doc.get("_id"))
|
||||
return "Note deleted."
|
||||
|
||||
@@ -71,7 +71,7 @@ class NtfyTool(Tool):
|
||||
if self.token:
|
||||
headers["Authorization"] = f"Basic {self.token}"
|
||||
data = message.encode("utf-8")
|
||||
response = requests.post(url, headers=headers, data=data, timeout=100)
|
||||
response = requests.post(url, headers=headers, data=data)
|
||||
return {"status_code": response.status_code, "message": "Message sent"}
|
||||
|
||||
def get_actions_metadata(self):
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import logging
|
||||
|
||||
import psycopg
|
||||
import psycopg2
|
||||
|
||||
from application.agents.tools.base import Tool
|
||||
|
||||
@@ -33,7 +33,7 @@ class PostgresTool(Tool):
|
||||
"""
|
||||
conn = None
|
||||
try:
|
||||
conn = psycopg.connect(self.connection_string)
|
||||
conn = psycopg2.connect(self.connection_string)
|
||||
cur = conn.cursor()
|
||||
cur.execute(sql_query)
|
||||
conn.commit()
|
||||
@@ -60,7 +60,7 @@ class PostgresTool(Tool):
|
||||
"response_data": response_data,
|
||||
}
|
||||
|
||||
except psycopg.Error as e:
|
||||
except psycopg2.Error as e:
|
||||
error_message = f"Database error: {e}"
|
||||
logger.error("PostgreSQL execute_sql error: %s", e)
|
||||
return {
|
||||
@@ -78,7 +78,7 @@ class PostgresTool(Tool):
|
||||
"""
|
||||
conn = None
|
||||
try:
|
||||
conn = psycopg.connect(self.connection_string)
|
||||
conn = psycopg2.connect(self.connection_string)
|
||||
cur = conn.cursor()
|
||||
|
||||
cur.execute(
|
||||
@@ -120,7 +120,7 @@ class PostgresTool(Tool):
|
||||
"schema": schema_data,
|
||||
}
|
||||
|
||||
except psycopg.Error as e:
|
||||
except psycopg2.Error as e:
|
||||
error_message = f"Database error: {e}"
|
||||
logger.error("PostgreSQL get_schema error: %s", e)
|
||||
return {
|
||||
|
||||
@@ -31,14 +31,14 @@ class TelegramTool(Tool):
|
||||
logger.debug("Sending Telegram message to chat_id=%s", chat_id)
|
||||
url = f"https://api.telegram.org/bot{self.token}/sendMessage"
|
||||
payload = {"chat_id": chat_id, "text": text}
|
||||
response = requests.post(url, data=payload, timeout=100)
|
||||
response = requests.post(url, data=payload)
|
||||
return {"status_code": response.status_code, "message": "Message sent"}
|
||||
|
||||
def _send_image(self, image_url, chat_id):
|
||||
logger.debug("Sending Telegram image to chat_id=%s", chat_id)
|
||||
url = f"https://api.telegram.org/bot{self.token}/sendPhoto"
|
||||
payload = {"chat_id": chat_id, "photo": image_url}
|
||||
response = requests.post(url, data=payload, timeout=100)
|
||||
response = requests.post(url, data=payload)
|
||||
return {"status_code": response.status_code, "message": "Image sent"}
|
||||
|
||||
def get_actions_metadata(self):
|
||||
|
||||
@@ -36,8 +36,6 @@ class ThinkTool(Tool):
|
||||
The reasoning content is captured in tool_call data for transparency.
|
||||
"""
|
||||
|
||||
internal = True
|
||||
|
||||
def __init__(self, config=None):
|
||||
pass
|
||||
|
||||
|
||||
@@ -1,19 +1,10 @@
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
import uuid
|
||||
|
||||
from .base import Tool
|
||||
from application.storage.db.repositories.todos import TodosRepository
|
||||
from application.storage.db.session import db_readonly, db_session
|
||||
|
||||
|
||||
def _status_from_completed(completed: Any) -> str:
|
||||
"""Translate the PG ``completed`` boolean to the legacy status string.
|
||||
|
||||
The frontend (and prior LLM-facing tool output) expects
|
||||
``"open"`` / ``"completed"``. Keeping that contract at the tool
|
||||
boundary insulates callers from the schema change.
|
||||
"""
|
||||
return "completed" if bool(completed) else "open"
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.core.settings import settings
|
||||
|
||||
|
||||
class TodoListTool(Tool):
|
||||
@@ -34,6 +25,7 @@ class TodoListTool(Tool):
|
||||
self.user_id: Optional[str] = user_id
|
||||
|
||||
# Get tool_id from configuration (passed from user_tools._id in production)
|
||||
# In production, tool_id is the MongoDB ObjectId string from user_tools collection
|
||||
if tool_config and "tool_id" in tool_config:
|
||||
self.tool_id = tool_config["tool_id"]
|
||||
elif user_id:
|
||||
@@ -43,27 +35,11 @@ class TodoListTool(Tool):
|
||||
# Last resort fallback (shouldn't happen in normal use)
|
||||
self.tool_id = str(uuid.uuid4())
|
||||
|
||||
db = MongoDB.get_client()[settings.MONGO_DB_NAME]
|
||||
self.collection = db["todos"]
|
||||
|
||||
self._last_artifact_id: Optional[str] = None
|
||||
|
||||
def _pg_enabled(self) -> bool:
|
||||
"""Return True only when ``tool_id`` is a real ``user_tools.id`` UUID.
|
||||
|
||||
The ``todos`` PG table has a UUID foreign key to ``user_tools`` and
|
||||
the repo queries ``CAST(:tool_id AS uuid)``. The sentinel
|
||||
``default_{uid}`` fallback is neither a UUID nor a row in
|
||||
``user_tools`` — binding it would crash ``invalid input syntax for
|
||||
type uuid`` and even if it didn't the FK would reject it. Mirror
|
||||
the MemoryTool guard and no-op in that case.
|
||||
"""
|
||||
tool_id = getattr(self, "tool_id", None)
|
||||
if not tool_id or not isinstance(tool_id, str):
|
||||
return False
|
||||
if tool_id.startswith("default_"):
|
||||
return False
|
||||
from application.storage.db.base_repository import looks_like_uuid
|
||||
|
||||
return looks_like_uuid(tool_id)
|
||||
|
||||
# -----------------------------
|
||||
# Action implementations
|
||||
# -----------------------------
|
||||
@@ -80,12 +56,6 @@ class TodoListTool(Tool):
|
||||
if not self.user_id:
|
||||
return "Error: TodoListTool requires a valid user_id."
|
||||
|
||||
if not self._pg_enabled():
|
||||
return (
|
||||
"Error: TodoListTool is not configured with a persistent "
|
||||
"tool_id; todo storage is unavailable for this session."
|
||||
)
|
||||
|
||||
self._last_artifact_id = None
|
||||
|
||||
if action_name == "list":
|
||||
@@ -221,10 +191,28 @@ class TodoListTool(Tool):
|
||||
|
||||
return None
|
||||
|
||||
def _get_next_todo_id(self) -> int:
|
||||
"""Get the next sequential todo_id for this user and tool.
|
||||
|
||||
Returns a simple integer (1, 2, 3, ...) scoped to this user/tool.
|
||||
With 5-10 todos max, scanning is negligible.
|
||||
"""
|
||||
query = {"user_id": self.user_id, "tool_id": self.tool_id}
|
||||
todos = list(self.collection.find(query, {"todo_id": 1}))
|
||||
|
||||
# Find the maximum todo_id
|
||||
max_id = 0
|
||||
for todo in todos:
|
||||
todo_id = self._coerce_todo_id(todo.get("todo_id"))
|
||||
if todo_id is not None:
|
||||
max_id = max(max_id, todo_id)
|
||||
|
||||
return max_id + 1
|
||||
|
||||
def _list(self) -> str:
|
||||
"""List all todos for the user."""
|
||||
with db_readonly() as conn:
|
||||
todos = TodosRepository(conn).list_for_tool(self.user_id, self.tool_id)
|
||||
query = {"user_id": self.user_id, "tool_id": self.tool_id}
|
||||
todos = list(self.collection.find(query))
|
||||
|
||||
if not todos:
|
||||
return "No todos found."
|
||||
@@ -233,7 +221,7 @@ class TodoListTool(Tool):
|
||||
for doc in todos:
|
||||
todo_id = doc.get("todo_id")
|
||||
title = doc.get("title", "Untitled")
|
||||
status = _status_from_completed(doc.get("completed"))
|
||||
status = doc.get("status", "open")
|
||||
|
||||
line = f"[{todo_id}] {title} ({status})"
|
||||
result_lines.append(line)
|
||||
@@ -241,23 +229,27 @@ class TodoListTool(Tool):
|
||||
return "\n".join(result_lines)
|
||||
|
||||
def _create(self, title: str) -> str:
|
||||
"""Create a new todo item.
|
||||
|
||||
``TodosRepository.create`` allocates the per-tool monotonic
|
||||
``todo_id`` inside the same transaction (``COALESCE(MAX(todo_id),0)+1``
|
||||
scoped to ``tool_id``), so we no longer need a separate read-then-
|
||||
write step here.
|
||||
"""
|
||||
"""Create a new todo item."""
|
||||
title = (title or "").strip()
|
||||
if not title:
|
||||
return "Error: Title is required."
|
||||
|
||||
with db_session() as conn:
|
||||
row = TodosRepository(conn).create(self.user_id, self.tool_id, title)
|
||||
now = datetime.now()
|
||||
todo_id = self._get_next_todo_id()
|
||||
|
||||
todo_id = row.get("todo_id")
|
||||
if row.get("id") is not None:
|
||||
self._last_artifact_id = str(row.get("id"))
|
||||
doc = {
|
||||
"todo_id": todo_id,
|
||||
"user_id": self.user_id,
|
||||
"tool_id": self.tool_id,
|
||||
"title": title,
|
||||
"status": "open",
|
||||
"created_at": now,
|
||||
"updated_at": now,
|
||||
}
|
||||
insert_result = self.collection.insert_one(doc)
|
||||
inserted_id = getattr(insert_result, "inserted_id", None) or doc.get("_id")
|
||||
if inserted_id is not None:
|
||||
self._last_artifact_id = str(inserted_id)
|
||||
return f"Todo created with ID {todo_id}: {title}"
|
||||
|
||||
def _get(self, todo_id: Optional[Any]) -> str:
|
||||
@@ -266,21 +258,21 @@ class TodoListTool(Tool):
|
||||
if parsed_todo_id is None:
|
||||
return "Error: todo_id must be a positive integer."
|
||||
|
||||
with db_readonly() as conn:
|
||||
doc = TodosRepository(conn).get_by_tool_and_todo_id(
|
||||
self.user_id, self.tool_id, parsed_todo_id
|
||||
)
|
||||
query = {"user_id": self.user_id, "tool_id": self.tool_id, "todo_id": parsed_todo_id}
|
||||
doc = self.collection.find_one(query)
|
||||
|
||||
if not doc:
|
||||
return f"Error: Todo with ID {parsed_todo_id} not found."
|
||||
|
||||
if doc.get("id") is not None:
|
||||
self._last_artifact_id = str(doc.get("id"))
|
||||
if doc.get("_id") is not None:
|
||||
self._last_artifact_id = str(doc.get("_id"))
|
||||
|
||||
title = doc.get("title", "Untitled")
|
||||
status = _status_from_completed(doc.get("completed"))
|
||||
status = doc.get("status", "open")
|
||||
|
||||
return f"Todo [{parsed_todo_id}]:\nTitle: {title}\nStatus: {status}"
|
||||
result = f"Todo [{parsed_todo_id}]:\nTitle: {title}\nStatus: {status}"
|
||||
|
||||
return result
|
||||
|
||||
def _update(self, todo_id: Optional[Any], title: str) -> str:
|
||||
"""Update a todo's title by ID."""
|
||||
@@ -292,19 +284,16 @@ class TodoListTool(Tool):
|
||||
if not title:
|
||||
return "Error: Title is required."
|
||||
|
||||
with db_session() as conn:
|
||||
repo = TodosRepository(conn)
|
||||
existing = repo.get_by_tool_and_todo_id(
|
||||
self.user_id, self.tool_id, parsed_todo_id
|
||||
)
|
||||
if not existing:
|
||||
return f"Error: Todo with ID {parsed_todo_id} not found."
|
||||
repo.update_title_by_tool_and_todo_id(
|
||||
self.user_id, self.tool_id, parsed_todo_id, title
|
||||
)
|
||||
query = {"user_id": self.user_id, "tool_id": self.tool_id, "todo_id": parsed_todo_id}
|
||||
doc = self.collection.find_one_and_update(
|
||||
query,
|
||||
{"$set": {"title": title, "updated_at": datetime.now()}},
|
||||
)
|
||||
if not doc:
|
||||
return f"Error: Todo with ID {parsed_todo_id} not found."
|
||||
|
||||
if existing.get("id") is not None:
|
||||
self._last_artifact_id = str(existing.get("id"))
|
||||
if doc.get("_id") is not None:
|
||||
self._last_artifact_id = str(doc.get("_id"))
|
||||
|
||||
return f"Todo {parsed_todo_id} updated to: {title}"
|
||||
|
||||
@@ -314,17 +303,16 @@ class TodoListTool(Tool):
|
||||
if parsed_todo_id is None:
|
||||
return "Error: todo_id must be a positive integer."
|
||||
|
||||
with db_session() as conn:
|
||||
repo = TodosRepository(conn)
|
||||
existing = repo.get_by_tool_and_todo_id(
|
||||
self.user_id, self.tool_id, parsed_todo_id
|
||||
)
|
||||
if not existing:
|
||||
return f"Error: Todo with ID {parsed_todo_id} not found."
|
||||
repo.set_completed(self.user_id, self.tool_id, parsed_todo_id, True)
|
||||
query = {"user_id": self.user_id, "tool_id": self.tool_id, "todo_id": parsed_todo_id}
|
||||
doc = self.collection.find_one_and_update(
|
||||
query,
|
||||
{"$set": {"status": "completed", "updated_at": datetime.now()}},
|
||||
)
|
||||
if not doc:
|
||||
return f"Error: Todo with ID {parsed_todo_id} not found."
|
||||
|
||||
if existing.get("id") is not None:
|
||||
self._last_artifact_id = str(existing.get("id"))
|
||||
if doc.get("_id") is not None:
|
||||
self._last_artifact_id = str(doc.get("_id"))
|
||||
|
||||
return f"Todo {parsed_todo_id} marked as completed."
|
||||
|
||||
@@ -334,18 +322,12 @@ class TodoListTool(Tool):
|
||||
if parsed_todo_id is None:
|
||||
return "Error: todo_id must be a positive integer."
|
||||
|
||||
with db_session() as conn:
|
||||
repo = TodosRepository(conn)
|
||||
existing = repo.get_by_tool_and_todo_id(
|
||||
self.user_id, self.tool_id, parsed_todo_id
|
||||
)
|
||||
if not existing:
|
||||
return f"Error: Todo with ID {parsed_todo_id} not found."
|
||||
repo.delete_by_tool_and_todo_id(
|
||||
self.user_id, self.tool_id, parsed_todo_id
|
||||
)
|
||||
query = {"user_id": self.user_id, "tool_id": self.tool_id, "todo_id": parsed_todo_id}
|
||||
doc = self.collection.find_one_and_delete(query)
|
||||
if not doc:
|
||||
return f"Error: Todo with ID {parsed_todo_id} not found."
|
||||
|
||||
if existing.get("id") is not None:
|
||||
self._last_artifact_id = str(existing.get("id"))
|
||||
if doc.get("_id") is not None:
|
||||
self._last_artifact_id = str(doc.get("_id"))
|
||||
|
||||
return f"Todo {parsed_todo_id} deleted."
|
||||
|
||||
@@ -5,9 +5,8 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ToolActionParser:
|
||||
def __init__(self, llm_type, name_mapping=None):
|
||||
def __init__(self, llm_type):
|
||||
self.llm_type = llm_type
|
||||
self.name_mapping = name_mapping
|
||||
self.parsers = {
|
||||
"OpenAILLM": self._parse_openai_llm,
|
||||
"GoogleLLM": self._parse_google_llm,
|
||||
@@ -17,33 +16,22 @@ class ToolActionParser:
|
||||
parser = self.parsers.get(self.llm_type, self._parse_openai_llm)
|
||||
return parser(call)
|
||||
|
||||
def _resolve_via_mapping(self, call_name):
|
||||
"""Look up (tool_id, action_name) from the name mapping if available."""
|
||||
if self.name_mapping and call_name in self.name_mapping:
|
||||
return self.name_mapping[call_name]
|
||||
return None
|
||||
|
||||
def _parse_openai_llm(self, call):
|
||||
try:
|
||||
call_args = json.loads(call.arguments)
|
||||
|
||||
resolved = self._resolve_via_mapping(call.name)
|
||||
if resolved:
|
||||
return resolved[0], resolved[1], call_args
|
||||
|
||||
# Fallback: legacy split on "_" for backward compatibility
|
||||
tool_parts = call.name.split("_")
|
||||
|
||||
# If the tool name doesn't contain an underscore, it's likely a hallucinated tool
|
||||
if len(tool_parts) < 2:
|
||||
logger.warning(
|
||||
f"Invalid tool name format: {call.name}. "
|
||||
"Could not resolve via mapping or legacy parsing."
|
||||
f"Invalid tool name format: {call.name}. Expected format: action_name_tool_id"
|
||||
)
|
||||
return None, None, None
|
||||
|
||||
tool_id = tool_parts[-1]
|
||||
action_name = "_".join(tool_parts[:-1])
|
||||
|
||||
# Validate that tool_id looks like a numerical ID
|
||||
if not tool_id.isdigit():
|
||||
logger.warning(
|
||||
f"Tool ID '{tool_id}' is not numerical. This might be a hallucinated tool call."
|
||||
@@ -57,24 +45,19 @@ class ToolActionParser:
|
||||
def _parse_google_llm(self, call):
|
||||
try:
|
||||
call_args = call.arguments
|
||||
|
||||
resolved = self._resolve_via_mapping(call.name)
|
||||
if resolved:
|
||||
return resolved[0], resolved[1], call_args
|
||||
|
||||
# Fallback: legacy split on "_" for backward compatibility
|
||||
tool_parts = call.name.split("_")
|
||||
|
||||
# If the tool name doesn't contain an underscore, it's likely a hallucinated tool
|
||||
if len(tool_parts) < 2:
|
||||
logger.warning(
|
||||
f"Invalid tool name format: {call.name}. "
|
||||
"Could not resolve via mapping or legacy parsing."
|
||||
f"Invalid tool name format: {call.name}. Expected format: action_name_tool_id"
|
||||
)
|
||||
return None, None, None
|
||||
|
||||
tool_id = tool_parts[-1]
|
||||
action_name = "_".join(tool_parts[:-1])
|
||||
|
||||
# Validate that tool_id looks like a numerical ID
|
||||
if not tool_id.isdigit():
|
||||
logger.warning(
|
||||
f"Tool ID '{tool_id}' is not numerical. This might be a hallucinated tool call."
|
||||
|
||||
@@ -19,7 +19,7 @@ class ToolManager:
|
||||
continue
|
||||
module = importlib.import_module(f"application.agents.tools.{name}")
|
||||
for member_name, obj in inspect.getmembers(module, inspect.isclass):
|
||||
if issubclass(obj, Tool) and obj is not Tool and not obj.internal:
|
||||
if issubclass(obj, Tool) and obj is not Tool:
|
||||
tool_config = self.config.get(name, {})
|
||||
self.tools[name] = obj(tool_config)
|
||||
|
||||
|
||||
@@ -12,13 +12,9 @@ from application.agents.workflows.schemas import (
|
||||
WorkflowRun,
|
||||
)
|
||||
from application.agents.workflows.workflow_engine import WorkflowEngine
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.core.settings import settings
|
||||
from application.logging import log_activity, LogContext
|
||||
from application.storage.db.base_repository import looks_like_uuid
|
||||
from application.storage.db.repositories.workflow_edges import WorkflowEdgesRepository
|
||||
from application.storage.db.repositories.workflow_nodes import WorkflowNodesRepository
|
||||
from application.storage.db.repositories.workflow_runs import WorkflowRunsRepository
|
||||
from application.storage.db.repositories.workflows import WorkflowsRepository
|
||||
from application.storage.db.session import db_readonly, db_session
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -107,8 +103,10 @@ class WorkflowAgent(BaseAgent):
|
||||
|
||||
def _load_from_database(self) -> Optional[WorkflowGraph]:
|
||||
try:
|
||||
if not self.workflow_id:
|
||||
logger.error("Missing workflow ID for load")
|
||||
from bson.objectid import ObjectId
|
||||
|
||||
if not self.workflow_id or not ObjectId.is_valid(self.workflow_id):
|
||||
logger.error(f"Invalid workflow ID: {self.workflow_id}")
|
||||
return None
|
||||
owner_id = self.workflow_owner
|
||||
if not owner_id and isinstance(self.decoded_token, dict):
|
||||
@@ -119,61 +117,61 @@ class WorkflowAgent(BaseAgent):
|
||||
)
|
||||
return None
|
||||
|
||||
with db_readonly() as conn:
|
||||
wf_repo = WorkflowsRepository(conn)
|
||||
if looks_like_uuid(self.workflow_id):
|
||||
workflow_row = wf_repo.get(self.workflow_id, owner_id)
|
||||
else:
|
||||
workflow_row = wf_repo.get_by_legacy_id(self.workflow_id, owner_id)
|
||||
if workflow_row is None:
|
||||
logger.error(
|
||||
f"Workflow {self.workflow_id} not found or inaccessible "
|
||||
f"for user {owner_id}"
|
||||
)
|
||||
return None
|
||||
pg_workflow_id = str(workflow_row["id"])
|
||||
graph_version = workflow_row.get("current_graph_version", 1)
|
||||
try:
|
||||
graph_version = int(graph_version)
|
||||
if graph_version <= 0:
|
||||
graph_version = 1
|
||||
except (ValueError, TypeError):
|
||||
graph_version = 1
|
||||
mongo = MongoDB.get_client()
|
||||
db = mongo[settings.MONGO_DB_NAME]
|
||||
|
||||
node_rows = WorkflowNodesRepository(conn).find_by_version(
|
||||
pg_workflow_id, graph_version,
|
||||
)
|
||||
edge_rows = WorkflowEdgesRepository(conn).find_by_version(
|
||||
pg_workflow_id, graph_version,
|
||||
)
|
||||
workflows_coll = db["workflows"]
|
||||
workflow_nodes_coll = db["workflow_nodes"]
|
||||
workflow_edges_coll = db["workflow_edges"]
|
||||
|
||||
workflow = Workflow(
|
||||
name=workflow_row.get("name"),
|
||||
description=workflow_row.get("description"),
|
||||
workflow_doc = workflows_coll.find_one(
|
||||
{"_id": ObjectId(self.workflow_id), "user": owner_id}
|
||||
)
|
||||
nodes = [
|
||||
WorkflowNode(
|
||||
id=n["node_id"],
|
||||
workflow_id=pg_workflow_id,
|
||||
type=n["node_type"],
|
||||
title=n.get("title") or "Node",
|
||||
description=n.get("description"),
|
||||
position=n.get("position") or {"x": 0, "y": 0},
|
||||
config=n.get("config") or {},
|
||||
if not workflow_doc:
|
||||
logger.error(
|
||||
f"Workflow {self.workflow_id} not found or inaccessible for user {owner_id}"
|
||||
)
|
||||
for n in node_rows
|
||||
]
|
||||
edges = [
|
||||
WorkflowEdge(
|
||||
id=e["edge_id"],
|
||||
workflow_id=pg_workflow_id,
|
||||
source=e.get("source_id"),
|
||||
target=e.get("target_id"),
|
||||
sourceHandle=e.get("source_handle"),
|
||||
targetHandle=e.get("target_handle"),
|
||||
return None
|
||||
workflow = Workflow(**workflow_doc)
|
||||
graph_version = workflow_doc.get("current_graph_version", 1)
|
||||
try:
|
||||
graph_version = int(graph_version)
|
||||
if graph_version <= 0:
|
||||
graph_version = 1
|
||||
except (ValueError, TypeError):
|
||||
graph_version = 1
|
||||
|
||||
nodes_docs = list(
|
||||
workflow_nodes_coll.find(
|
||||
{"workflow_id": self.workflow_id, "graph_version": graph_version}
|
||||
)
|
||||
for e in edge_rows
|
||||
]
|
||||
)
|
||||
if not nodes_docs and graph_version == 1:
|
||||
nodes_docs = list(
|
||||
workflow_nodes_coll.find(
|
||||
{
|
||||
"workflow_id": self.workflow_id,
|
||||
"graph_version": {"$exists": False},
|
||||
}
|
||||
)
|
||||
)
|
||||
nodes = [WorkflowNode(**doc) for doc in nodes_docs]
|
||||
|
||||
edges_docs = list(
|
||||
workflow_edges_coll.find(
|
||||
{"workflow_id": self.workflow_id, "graph_version": graph_version}
|
||||
)
|
||||
)
|
||||
if not edges_docs and graph_version == 1:
|
||||
edges_docs = list(
|
||||
workflow_edges_coll.find(
|
||||
{
|
||||
"workflow_id": self.workflow_id,
|
||||
"graph_version": {"$exists": False},
|
||||
}
|
||||
)
|
||||
)
|
||||
edges = [WorkflowEdge(**doc) for doc in edges_docs]
|
||||
|
||||
return WorkflowGraph(workflow=workflow, nodes=nodes, edges=edges)
|
||||
except Exception as e:
|
||||
@@ -183,13 +181,13 @@ class WorkflowAgent(BaseAgent):
|
||||
def _save_workflow_run(self, query: str) -> None:
|
||||
if not self._engine:
|
||||
return
|
||||
owner_id = self.workflow_owner
|
||||
if not owner_id and isinstance(self.decoded_token, dict):
|
||||
owner_id = self.decoded_token.get("sub")
|
||||
try:
|
||||
mongo = MongoDB.get_client()
|
||||
db = mongo[settings.MONGO_DB_NAME]
|
||||
workflow_runs_coll = db["workflow_runs"]
|
||||
|
||||
run = WorkflowRun(
|
||||
workflow_id=self.workflow_id or "unknown",
|
||||
user=owner_id,
|
||||
status=self._determine_run_status(),
|
||||
inputs={"query": query},
|
||||
outputs=self._serialize_state(self._engine.state),
|
||||
@@ -198,28 +196,7 @@ class WorkflowAgent(BaseAgent):
|
||||
completed_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
if not self.workflow_id or not owner_id:
|
||||
return
|
||||
with db_session() as conn:
|
||||
wf_repo = WorkflowsRepository(conn)
|
||||
if looks_like_uuid(self.workflow_id):
|
||||
workflow_row = wf_repo.get(self.workflow_id, owner_id)
|
||||
else:
|
||||
workflow_row = wf_repo.get_by_legacy_id(
|
||||
self.workflow_id, owner_id,
|
||||
)
|
||||
if workflow_row is None:
|
||||
return
|
||||
WorkflowRunsRepository(conn).create(
|
||||
str(workflow_row["id"]),
|
||||
owner_id,
|
||||
run.status.value,
|
||||
inputs=run.inputs,
|
||||
result=run.outputs,
|
||||
steps=[step.model_dump(mode="json") for step in run.steps],
|
||||
started_at=run.created_at,
|
||||
ended_at=run.completed_at,
|
||||
)
|
||||
workflow_runs_coll.insert_one(run.to_mongo_doc())
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save workflow run: {e}")
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ from datetime import datetime, timezone
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Literal, Optional, Union
|
||||
|
||||
from bson import ObjectId
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
|
||||
@@ -80,7 +81,24 @@ class WorkflowEdgeCreate(BaseModel):
|
||||
|
||||
|
||||
class WorkflowEdge(WorkflowEdgeCreate):
|
||||
pass
|
||||
mongo_id: Optional[str] = Field(None, alias="_id")
|
||||
|
||||
@field_validator("mongo_id", mode="before")
|
||||
@classmethod
|
||||
def convert_objectid(cls, v: Any) -> Optional[str]:
|
||||
if isinstance(v, ObjectId):
|
||||
return str(v)
|
||||
return v
|
||||
|
||||
def to_mongo_doc(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"id": self.id,
|
||||
"workflow_id": self.workflow_id,
|
||||
"source_id": self.source_id,
|
||||
"target_id": self.target_id,
|
||||
"source_handle": self.source_handle,
|
||||
"target_handle": self.target_handle,
|
||||
}
|
||||
|
||||
|
||||
class WorkflowNodeCreate(BaseModel):
|
||||
@@ -102,7 +120,25 @@ class WorkflowNodeCreate(BaseModel):
|
||||
|
||||
|
||||
class WorkflowNode(WorkflowNodeCreate):
|
||||
pass
|
||||
mongo_id: Optional[str] = Field(None, alias="_id")
|
||||
|
||||
@field_validator("mongo_id", mode="before")
|
||||
@classmethod
|
||||
def convert_objectid(cls, v: Any) -> Optional[str]:
|
||||
if isinstance(v, ObjectId):
|
||||
return str(v)
|
||||
return v
|
||||
|
||||
def to_mongo_doc(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"id": self.id,
|
||||
"workflow_id": self.workflow_id,
|
||||
"type": self.type.value,
|
||||
"title": self.title,
|
||||
"description": self.description,
|
||||
"position": self.position.model_dump(),
|
||||
"config": self.config,
|
||||
}
|
||||
|
||||
|
||||
class WorkflowCreate(BaseModel):
|
||||
@@ -113,10 +149,26 @@ class WorkflowCreate(BaseModel):
|
||||
|
||||
|
||||
class Workflow(WorkflowCreate):
|
||||
id: Optional[str] = None
|
||||
id: Optional[str] = Field(None, alias="_id")
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
updated_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
|
||||
@field_validator("id", mode="before")
|
||||
@classmethod
|
||||
def convert_objectid(cls, v: Any) -> Optional[str]:
|
||||
if isinstance(v, ObjectId):
|
||||
return str(v)
|
||||
return v
|
||||
|
||||
def to_mongo_doc(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"name": self.name,
|
||||
"description": self.description,
|
||||
"user": self.user,
|
||||
"created_at": self.created_at,
|
||||
"updated_at": self.updated_at,
|
||||
}
|
||||
|
||||
|
||||
class WorkflowGraph(BaseModel):
|
||||
workflow: Workflow
|
||||
@@ -157,12 +209,29 @@ class WorkflowRunCreate(BaseModel):
|
||||
|
||||
class WorkflowRun(BaseModel):
|
||||
model_config = ConfigDict(extra="allow")
|
||||
id: Optional[str] = None
|
||||
id: Optional[str] = Field(None, alias="_id")
|
||||
workflow_id: str
|
||||
user: Optional[str] = None
|
||||
status: ExecutionStatus = ExecutionStatus.PENDING
|
||||
inputs: Dict[str, str] = Field(default_factory=dict)
|
||||
outputs: Dict[str, Any] = Field(default_factory=dict)
|
||||
steps: List[NodeExecutionLog] = Field(default_factory=list)
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
completed_at: Optional[datetime] = None
|
||||
|
||||
@field_validator("id", mode="before")
|
||||
@classmethod
|
||||
def convert_objectid(cls, v: Any) -> Optional[str]:
|
||||
if isinstance(v, ObjectId):
|
||||
return str(v)
|
||||
return v
|
||||
|
||||
def to_mongo_doc(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"workflow_id": self.workflow_id,
|
||||
"status": self.status.value,
|
||||
"inputs": self.inputs,
|
||||
"outputs": self.outputs,
|
||||
"steps": [step.model_dump() for step in self.steps],
|
||||
"created_at": self.created_at,
|
||||
"completed_at": self.completed_at,
|
||||
}
|
||||
|
||||
@@ -200,9 +200,6 @@ class WorkflowEngine:
|
||||
|
||||
node_config = AgentNodeConfig(**node.config.get("config", node.config))
|
||||
|
||||
if node_config.sources:
|
||||
self._retrieve_node_sources(node_config)
|
||||
|
||||
if node_config.prompt_template:
|
||||
formatted_prompt = self._format_template(node_config.prompt_template)
|
||||
else:
|
||||
@@ -458,29 +455,6 @@ class WorkflowEngine:
|
||||
docs_together = "\n\n".join(docs_together_parts) if docs_together_parts else None
|
||||
return docs, docs_together
|
||||
|
||||
def _retrieve_node_sources(self, node_config: AgentNodeConfig) -> None:
|
||||
"""Retrieve documents from the node's sources for template resolution."""
|
||||
from application.retriever.retriever_creator import RetrieverCreator
|
||||
|
||||
query = self.state.get("query", "")
|
||||
if not query:
|
||||
return
|
||||
|
||||
try:
|
||||
retriever = RetrieverCreator.create_retriever(
|
||||
node_config.retriever or "classic",
|
||||
source={"active_docs": node_config.sources},
|
||||
chat_history=[],
|
||||
prompt="",
|
||||
chunks=int(node_config.chunks) if node_config.chunks else 2,
|
||||
decoded_token=self.agent.decoded_token,
|
||||
)
|
||||
docs = retriever.search(query)
|
||||
if docs:
|
||||
self.agent.retrieved_docs = docs
|
||||
except Exception:
|
||||
logger.exception("Failed to retrieve docs for workflow node")
|
||||
|
||||
def get_execution_summary(self) -> List[NodeExecutionLog]:
|
||||
return [
|
||||
NodeExecutionLog(
|
||||
|
||||
@@ -1,52 +0,0 @@
|
||||
# Alembic configuration for the DocsGPT user-data Postgres database.
|
||||
#
|
||||
# The SQLAlchemy URL is deliberately NOT set here — env.py reads it from
|
||||
# ``application.core.settings.settings.POSTGRES_URI`` so the same config
|
||||
# source serves the running app and migrations. To run from the project
|
||||
# root::
|
||||
#
|
||||
# alembic -c application/alembic.ini upgrade head
|
||||
|
||||
[alembic]
|
||||
script_location = %(here)s/alembic
|
||||
prepend_sys_path = ..
|
||||
version_path_separator = os
|
||||
|
||||
# sqlalchemy.url is intentionally left blank — env.py supplies it.
|
||||
sqlalchemy.url =
|
||||
|
||||
[post_write_hooks]
|
||||
|
||||
[loggers]
|
||||
keys = root,sqlalchemy,alembic
|
||||
|
||||
[handlers]
|
||||
keys = console
|
||||
|
||||
[formatters]
|
||||
keys = generic
|
||||
|
||||
[logger_root]
|
||||
level = WARNING
|
||||
handlers = console
|
||||
qualname =
|
||||
|
||||
[logger_sqlalchemy]
|
||||
level = WARNING
|
||||
handlers =
|
||||
qualname = sqlalchemy.engine
|
||||
|
||||
[logger_alembic]
|
||||
level = INFO
|
||||
handlers =
|
||||
qualname = alembic
|
||||
|
||||
[handler_console]
|
||||
class = StreamHandler
|
||||
args = (sys.stderr,)
|
||||
level = NOTSET
|
||||
formatter = generic
|
||||
|
||||
[formatter_generic]
|
||||
format = %(levelname)-5.5s [%(name)s] %(message)s
|
||||
datefmt = %H:%M:%S
|
||||
@@ -1,82 +0,0 @@
|
||||
"""Alembic environment for the DocsGPT user-data Postgres database.
|
||||
|
||||
The URL is pulled from ``application.core.settings`` rather than
|
||||
``alembic.ini`` so that a single ``POSTGRES_URI`` env var drives both the
|
||||
running app and ``alembic`` CLI invocations.
|
||||
"""
|
||||
|
||||
import sys
|
||||
from logging.config import fileConfig
|
||||
from pathlib import Path
|
||||
|
||||
# Make the project root importable regardless of cwd. env.py lives at
|
||||
# <repo>/application/alembic/env.py, so parents[2] is the repo root.
|
||||
_PROJECT_ROOT = Path(__file__).resolve().parents[2]
|
||||
if str(_PROJECT_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(_PROJECT_ROOT))
|
||||
|
||||
from alembic import context # noqa: E402
|
||||
from sqlalchemy import engine_from_config, pool # noqa: E402
|
||||
|
||||
from application.core.settings import settings # noqa: E402
|
||||
from application.storage.db.models import metadata as target_metadata # noqa: E402
|
||||
|
||||
config = context.config
|
||||
|
||||
# Populate the runtime URL from settings.
|
||||
if settings.POSTGRES_URI:
|
||||
config.set_main_option("sqlalchemy.url", settings.POSTGRES_URI)
|
||||
|
||||
if config.config_file_name is not None:
|
||||
fileConfig(config.config_file_name)
|
||||
|
||||
|
||||
def run_migrations_offline() -> None:
|
||||
"""Run migrations in 'offline' mode (emits SQL without a live DB)."""
|
||||
url = config.get_main_option("sqlalchemy.url")
|
||||
if not url:
|
||||
raise RuntimeError(
|
||||
"POSTGRES_URI is not configured. Set it in your .env to a "
|
||||
"psycopg3 URI such as "
|
||||
"'postgresql+psycopg://user:pass@host:5432/docsgpt'."
|
||||
)
|
||||
context.configure(
|
||||
url=url,
|
||||
target_metadata=target_metadata,
|
||||
literal_binds=True,
|
||||
dialect_opts={"paramstyle": "named"},
|
||||
compare_type=True,
|
||||
)
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
def run_migrations_online() -> None:
|
||||
"""Run migrations in 'online' mode against a live connection."""
|
||||
if not config.get_main_option("sqlalchemy.url"):
|
||||
raise RuntimeError(
|
||||
"POSTGRES_URI is not configured. Set it in your .env to a "
|
||||
"psycopg3 URI such as "
|
||||
"'postgresql+psycopg://user:pass@host:5432/docsgpt'."
|
||||
)
|
||||
connectable = engine_from_config(
|
||||
config.get_section(config.config_ini_section, {}),
|
||||
prefix="sqlalchemy.",
|
||||
poolclass=pool.NullPool,
|
||||
future=True,
|
||||
)
|
||||
|
||||
with connectable.connect() as connection:
|
||||
context.configure(
|
||||
connection=connection,
|
||||
target_metadata=target_metadata,
|
||||
compare_type=True,
|
||||
)
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
if context.is_offline_mode():
|
||||
run_migrations_offline()
|
||||
else:
|
||||
run_migrations_online()
|
||||
@@ -1,26 +0,0 @@
|
||||
"""${message}
|
||||
|
||||
Revision ID: ${up_revision}
|
||||
Revises: ${down_revision | comma,n}
|
||||
Create Date: ${create_date}
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
${imports if imports else ""}
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = ${repr(up_revision)}
|
||||
down_revision: Union[str, None] = ${repr(down_revision)}
|
||||
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
|
||||
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
${upgrades if upgrades else "pass"}
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
${downgrades if downgrades else "pass"}
|
||||
@@ -1,927 +0,0 @@
|
||||
"""0001 initial schema — consolidated Phase-1..3 baseline.
|
||||
|
||||
Revision ID: 0001_initial
|
||||
Revises:
|
||||
Create Date: 2026-04-13
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
|
||||
|
||||
revision: str = "0001_initial"
|
||||
down_revision: Union[str, None] = None
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ------------------------------------------------------------------
|
||||
# Extensions
|
||||
# ------------------------------------------------------------------
|
||||
op.execute('CREATE EXTENSION IF NOT EXISTS "pgcrypto";')
|
||||
op.execute('CREATE EXTENSION IF NOT EXISTS "citext";')
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Trigger functions
|
||||
# ------------------------------------------------------------------
|
||||
op.execute(
|
||||
"""
|
||||
CREATE FUNCTION set_updated_at() RETURNS trigger
|
||||
LANGUAGE plpgsql AS $$
|
||||
BEGIN
|
||||
NEW.updated_at = now();
|
||||
RETURN NEW;
|
||||
END;
|
||||
$$;
|
||||
"""
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
CREATE FUNCTION ensure_user_exists() RETURNS trigger
|
||||
LANGUAGE plpgsql AS $$
|
||||
BEGIN
|
||||
IF NEW.user_id IS NOT NULL THEN
|
||||
INSERT INTO users (user_id) VALUES (NEW.user_id)
|
||||
ON CONFLICT (user_id) DO NOTHING;
|
||||
END IF;
|
||||
RETURN NEW;
|
||||
END;
|
||||
$$;
|
||||
"""
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
CREATE FUNCTION cleanup_message_attachment_refs() RETURNS trigger
|
||||
LANGUAGE plpgsql AS $$
|
||||
BEGIN
|
||||
UPDATE conversation_messages
|
||||
SET attachments = array_remove(attachments, OLD.id)
|
||||
WHERE OLD.id = ANY(attachments);
|
||||
RETURN OLD;
|
||||
END;
|
||||
$$;
|
||||
"""
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
CREATE FUNCTION cleanup_agent_extra_source_refs() RETURNS trigger
|
||||
LANGUAGE plpgsql AS $$
|
||||
BEGIN
|
||||
UPDATE agents
|
||||
SET extra_source_ids = array_remove(extra_source_ids, OLD.id)
|
||||
WHERE OLD.id = ANY(extra_source_ids);
|
||||
RETURN OLD;
|
||||
END;
|
||||
$$;
|
||||
"""
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
CREATE FUNCTION cleanup_user_agent_prefs() RETURNS trigger
|
||||
LANGUAGE plpgsql AS $$
|
||||
DECLARE
|
||||
agent_id_text text := OLD.id::text;
|
||||
BEGIN
|
||||
UPDATE users
|
||||
SET agent_preferences = jsonb_set(
|
||||
jsonb_set(
|
||||
agent_preferences,
|
||||
'{pinned}',
|
||||
COALESCE((
|
||||
SELECT jsonb_agg(e)
|
||||
FROM jsonb_array_elements(
|
||||
COALESCE(agent_preferences->'pinned', '[]'::jsonb)
|
||||
) e
|
||||
WHERE (e #>> '{}') <> agent_id_text
|
||||
), '[]'::jsonb)
|
||||
),
|
||||
'{shared_with_me}',
|
||||
COALESCE((
|
||||
SELECT jsonb_agg(e)
|
||||
FROM jsonb_array_elements(
|
||||
COALESCE(agent_preferences->'shared_with_me', '[]'::jsonb)
|
||||
) e
|
||||
WHERE (e #>> '{}') <> agent_id_text
|
||||
), '[]'::jsonb)
|
||||
)
|
||||
WHERE agent_preferences->'pinned' @> to_jsonb(agent_id_text)
|
||||
OR agent_preferences->'shared_with_me' @> to_jsonb(agent_id_text);
|
||||
RETURN OLD;
|
||||
END;
|
||||
$$;
|
||||
"""
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
CREATE FUNCTION conversation_messages_fill_user_id() RETURNS trigger
|
||||
LANGUAGE plpgsql AS $$
|
||||
BEGIN
|
||||
IF NEW.user_id IS NULL THEN
|
||||
SELECT user_id INTO NEW.user_id
|
||||
FROM conversations
|
||||
WHERE id = NEW.conversation_id;
|
||||
END IF;
|
||||
RETURN NEW;
|
||||
END;
|
||||
$$;
|
||||
"""
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Tables
|
||||
# ------------------------------------------------------------------
|
||||
op.execute(
|
||||
"""
|
||||
CREATE TABLE users (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
user_id TEXT NOT NULL UNIQUE,
|
||||
agent_preferences JSONB NOT NULL
|
||||
DEFAULT '{"pinned": [], "shared_with_me": []}'::jsonb,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
CREATE TABLE prompts (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
user_id TEXT NOT NULL,
|
||||
name TEXT NOT NULL,
|
||||
content TEXT NOT NULL,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
legacy_mongo_id TEXT
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
CREATE TABLE user_tools (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
user_id TEXT NOT NULL,
|
||||
name TEXT NOT NULL,
|
||||
custom_name TEXT,
|
||||
display_name TEXT,
|
||||
description TEXT,
|
||||
config JSONB NOT NULL DEFAULT '{}'::jsonb,
|
||||
config_requirements JSONB NOT NULL DEFAULT '{}'::jsonb,
|
||||
actions JSONB NOT NULL DEFAULT '[]'::jsonb,
|
||||
status BOOLEAN NOT NULL DEFAULT true,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
legacy_mongo_id TEXT
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
CREATE TABLE token_usage (
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
user_id TEXT,
|
||||
api_key TEXT,
|
||||
agent_id UUID,
|
||||
prompt_tokens INTEGER NOT NULL DEFAULT 0,
|
||||
generated_tokens INTEGER NOT NULL DEFAULT 0,
|
||||
timestamp TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
mongo_id TEXT
|
||||
);
|
||||
"""
|
||||
)
|
||||
op.execute(
|
||||
"ALTER TABLE token_usage ADD CONSTRAINT token_usage_attribution_chk "
|
||||
"CHECK (user_id IS NOT NULL OR api_key IS NOT NULL) NOT VALID;"
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
CREATE TABLE user_logs (
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
user_id TEXT,
|
||||
endpoint TEXT,
|
||||
timestamp TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
data JSONB,
|
||||
mongo_id TEXT
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
CREATE TABLE stack_logs (
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
activity_id TEXT NOT NULL,
|
||||
endpoint TEXT,
|
||||
level TEXT,
|
||||
user_id TEXT,
|
||||
api_key TEXT,
|
||||
query TEXT,
|
||||
stacks JSONB NOT NULL DEFAULT '[]'::jsonb,
|
||||
timestamp TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
mongo_id TEXT
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
CREATE TABLE agent_folders (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
user_id TEXT NOT NULL,
|
||||
name TEXT NOT NULL,
|
||||
description TEXT,
|
||||
parent_id UUID REFERENCES agent_folders(id) ON DELETE SET NULL,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
legacy_mongo_id TEXT
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
CREATE TABLE sources (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
user_id TEXT NOT NULL,
|
||||
name TEXT NOT NULL,
|
||||
language TEXT,
|
||||
date TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
model TEXT,
|
||||
type TEXT,
|
||||
metadata JSONB NOT NULL DEFAULT '{}'::jsonb,
|
||||
retriever TEXT,
|
||||
sync_frequency TEXT,
|
||||
tokens TEXT,
|
||||
file_path TEXT,
|
||||
remote_data JSONB,
|
||||
directory_structure JSONB,
|
||||
file_name_map JSONB,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
legacy_mongo_id TEXT
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
CREATE TABLE agents (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
user_id TEXT NOT NULL,
|
||||
name TEXT NOT NULL,
|
||||
description TEXT,
|
||||
agent_type TEXT,
|
||||
status TEXT NOT NULL,
|
||||
key CITEXT UNIQUE,
|
||||
image TEXT,
|
||||
source_id UUID REFERENCES sources(id) ON DELETE SET NULL,
|
||||
extra_source_ids UUID[] NOT NULL DEFAULT '{}',
|
||||
chunks INTEGER,
|
||||
retriever TEXT,
|
||||
prompt_id UUID REFERENCES prompts(id) ON DELETE SET NULL,
|
||||
tools JSONB NOT NULL DEFAULT '[]'::jsonb,
|
||||
json_schema JSONB,
|
||||
models JSONB,
|
||||
default_model_id TEXT,
|
||||
folder_id UUID REFERENCES agent_folders(id) ON DELETE SET NULL,
|
||||
workflow_id UUID,
|
||||
limited_token_mode BOOLEAN NOT NULL DEFAULT false,
|
||||
token_limit INTEGER,
|
||||
limited_request_mode BOOLEAN NOT NULL DEFAULT false,
|
||||
request_limit INTEGER,
|
||||
allow_system_prompt_override BOOLEAN NOT NULL DEFAULT false,
|
||||
shared BOOLEAN NOT NULL DEFAULT false,
|
||||
shared_token CITEXT UNIQUE,
|
||||
shared_metadata JSONB,
|
||||
incoming_webhook_token CITEXT UNIQUE,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
last_used_at TIMESTAMPTZ,
|
||||
legacy_mongo_id TEXT
|
||||
);
|
||||
"""
|
||||
)
|
||||
op.execute(
|
||||
"ALTER TABLE token_usage ADD CONSTRAINT token_usage_agent_fk "
|
||||
"FOREIGN KEY (agent_id) REFERENCES agents(id) ON DELETE SET NULL;"
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
CREATE TABLE attachments (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
user_id TEXT NOT NULL,
|
||||
filename TEXT NOT NULL,
|
||||
upload_path TEXT NOT NULL,
|
||||
mime_type TEXT,
|
||||
size BIGINT,
|
||||
content TEXT,
|
||||
token_count INTEGER,
|
||||
openai_file_id TEXT,
|
||||
google_file_uri TEXT,
|
||||
metadata JSONB,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
legacy_mongo_id TEXT
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
CREATE TABLE memories (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
user_id TEXT NOT NULL,
|
||||
tool_id UUID REFERENCES user_tools(id) ON DELETE CASCADE,
|
||||
path TEXT NOT NULL,
|
||||
content TEXT NOT NULL,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
CREATE TABLE todos (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
user_id TEXT NOT NULL,
|
||||
tool_id UUID REFERENCES user_tools(id) ON DELETE CASCADE,
|
||||
todo_id INTEGER,
|
||||
title TEXT NOT NULL,
|
||||
completed BOOLEAN NOT NULL DEFAULT false,
|
||||
metadata JSONB NOT NULL DEFAULT '{}'::jsonb,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
legacy_mongo_id TEXT
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
CREATE TABLE notes (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
user_id TEXT NOT NULL,
|
||||
tool_id UUID REFERENCES user_tools(id) ON DELETE CASCADE,
|
||||
title TEXT NOT NULL,
|
||||
content TEXT NOT NULL,
|
||||
metadata JSONB NOT NULL DEFAULT '{}'::jsonb,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
legacy_mongo_id TEXT
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
CREATE TABLE connector_sessions (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
user_id TEXT NOT NULL,
|
||||
provider TEXT NOT NULL,
|
||||
server_url TEXT,
|
||||
session_token TEXT UNIQUE,
|
||||
user_email TEXT,
|
||||
status TEXT,
|
||||
token_info JSONB,
|
||||
session_data JSONB NOT NULL DEFAULT '{}'::jsonb,
|
||||
expires_at TIMESTAMPTZ,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
legacy_mongo_id TEXT
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
CREATE TABLE conversations (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
user_id TEXT NOT NULL,
|
||||
agent_id UUID REFERENCES agents(id) ON DELETE SET NULL,
|
||||
name TEXT,
|
||||
api_key TEXT,
|
||||
is_shared_usage BOOLEAN NOT NULL DEFAULT false,
|
||||
shared_token TEXT,
|
||||
date TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
shared_with TEXT[] NOT NULL DEFAULT '{}'::text[],
|
||||
compression_metadata JSONB,
|
||||
legacy_mongo_id TEXT,
|
||||
CONSTRAINT conversations_api_key_nonempty_chk
|
||||
CHECK (api_key IS NULL OR api_key <> '')
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
CREATE TABLE conversation_messages (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
conversation_id UUID NOT NULL REFERENCES conversations(id) ON DELETE CASCADE,
|
||||
position INTEGER NOT NULL,
|
||||
prompt TEXT,
|
||||
response TEXT,
|
||||
thought TEXT,
|
||||
sources JSONB NOT NULL DEFAULT '[]'::jsonb,
|
||||
tool_calls JSONB NOT NULL DEFAULT '[]'::jsonb,
|
||||
attachments UUID[] NOT NULL DEFAULT '{}'::uuid[],
|
||||
model_id TEXT,
|
||||
message_metadata JSONB NOT NULL DEFAULT '{}'::jsonb,
|
||||
feedback JSONB,
|
||||
timestamp TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
user_id TEXT NOT NULL,
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
CREATE TABLE shared_conversations (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
conversation_id UUID NOT NULL REFERENCES conversations(id) ON DELETE CASCADE,
|
||||
user_id TEXT NOT NULL,
|
||||
is_promptable BOOLEAN NOT NULL DEFAULT false,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
uuid UUID NOT NULL,
|
||||
first_n_queries INTEGER NOT NULL DEFAULT 0,
|
||||
api_key TEXT,
|
||||
prompt_id UUID REFERENCES prompts(id) ON DELETE SET NULL,
|
||||
chunks INTEGER,
|
||||
CONSTRAINT shared_conversations_api_key_nonempty_chk
|
||||
CHECK (api_key IS NULL OR api_key <> '')
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
CREATE TABLE pending_tool_state (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
conversation_id UUID NOT NULL REFERENCES conversations(id) ON DELETE CASCADE,
|
||||
user_id TEXT NOT NULL,
|
||||
messages JSONB NOT NULL,
|
||||
pending_tool_calls JSONB NOT NULL,
|
||||
tools_dict JSONB NOT NULL,
|
||||
tool_schemas JSONB NOT NULL,
|
||||
agent_config JSONB NOT NULL,
|
||||
client_tools JSONB,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
expires_at TIMESTAMPTZ NOT NULL
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
CREATE TABLE workflows (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
user_id TEXT NOT NULL,
|
||||
name TEXT NOT NULL,
|
||||
description TEXT,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
current_graph_version INTEGER NOT NULL DEFAULT 1,
|
||||
legacy_mongo_id TEXT
|
||||
);
|
||||
"""
|
||||
)
|
||||
# Backfill the agents.workflow_id FK now that workflows exists.
|
||||
# The column was created without a FK (forward reference to a table
|
||||
# that hadn't been declared yet); add the constraint here so workflow
|
||||
# deletion still cascades through to agent unset.
|
||||
op.execute(
|
||||
"ALTER TABLE agents ADD CONSTRAINT agents_workflow_fk "
|
||||
"FOREIGN KEY (workflow_id) REFERENCES workflows(id) ON DELETE SET NULL;"
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
CREATE TABLE workflow_nodes (
|
||||
id UUID DEFAULT gen_random_uuid() NOT NULL,
|
||||
workflow_id UUID NOT NULL REFERENCES workflows(id) ON DELETE CASCADE,
|
||||
graph_version INTEGER NOT NULL,
|
||||
node_type TEXT NOT NULL,
|
||||
config JSONB NOT NULL DEFAULT '{}'::jsonb,
|
||||
node_id TEXT NOT NULL,
|
||||
title TEXT,
|
||||
description TEXT,
|
||||
position JSONB NOT NULL DEFAULT '{"x": 0, "y": 0}'::jsonb,
|
||||
legacy_mongo_id TEXT,
|
||||
PRIMARY KEY (id),
|
||||
CONSTRAINT workflow_nodes_id_wf_ver_key
|
||||
UNIQUE (id, workflow_id, graph_version)
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
CREATE TABLE workflow_edges (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
workflow_id UUID NOT NULL REFERENCES workflows(id) ON DELETE CASCADE,
|
||||
graph_version INTEGER NOT NULL,
|
||||
from_node_id UUID NOT NULL,
|
||||
to_node_id UUID NOT NULL,
|
||||
config JSONB NOT NULL DEFAULT '{}'::jsonb,
|
||||
edge_id TEXT NOT NULL,
|
||||
source_handle TEXT,
|
||||
target_handle TEXT,
|
||||
CONSTRAINT workflow_edges_from_node_fk
|
||||
FOREIGN KEY (from_node_id, workflow_id, graph_version)
|
||||
REFERENCES workflow_nodes(id, workflow_id, graph_version) ON DELETE CASCADE,
|
||||
CONSTRAINT workflow_edges_to_node_fk
|
||||
FOREIGN KEY (to_node_id, workflow_id, graph_version)
|
||||
REFERENCES workflow_nodes(id, workflow_id, graph_version) ON DELETE CASCADE
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
CREATE TABLE workflow_runs (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
workflow_id UUID NOT NULL REFERENCES workflows(id) ON DELETE CASCADE,
|
||||
user_id TEXT NOT NULL,
|
||||
status TEXT NOT NULL,
|
||||
started_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
ended_at TIMESTAMPTZ,
|
||||
result JSONB,
|
||||
inputs JSONB,
|
||||
steps JSONB NOT NULL DEFAULT '[]'::jsonb,
|
||||
legacy_mongo_id TEXT,
|
||||
CONSTRAINT workflow_runs_status_chk
|
||||
CHECK (status IN ('pending', 'running', 'completed', 'failed'))
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Indexes
|
||||
# ------------------------------------------------------------------
|
||||
op.execute("CREATE INDEX agent_folders_user_idx ON agent_folders (user_id);")
|
||||
|
||||
op.execute("CREATE INDEX agents_user_idx ON agents (user_id);")
|
||||
op.execute("CREATE INDEX agents_shared_idx ON agents (shared) WHERE shared = true;")
|
||||
op.execute("CREATE INDEX agents_status_idx ON agents (status);")
|
||||
op.execute("CREATE INDEX agents_source_id_idx ON agents (source_id);")
|
||||
op.execute("CREATE INDEX agents_prompt_id_idx ON agents (prompt_id);")
|
||||
op.execute("CREATE INDEX agents_folder_id_idx ON agents (folder_id);")
|
||||
op.execute(
|
||||
"CREATE UNIQUE INDEX agents_legacy_mongo_id_uidx "
|
||||
"ON agents (legacy_mongo_id) WHERE legacy_mongo_id IS NOT NULL;"
|
||||
)
|
||||
|
||||
op.execute("CREATE INDEX attachments_user_idx ON attachments (user_id);")
|
||||
op.execute(
|
||||
"CREATE UNIQUE INDEX attachments_legacy_mongo_id_uidx "
|
||||
"ON attachments (legacy_mongo_id) WHERE legacy_mongo_id IS NOT NULL;"
|
||||
)
|
||||
|
||||
op.execute(
|
||||
# MCP and OAuth connectors share the ``provider`` slot, so the
|
||||
# dedup key is ``(user_id, server_url, provider)``: MCP rows
|
||||
# differentiate by server_url (one per MCP server), OAuth rows
|
||||
# have server_url = NULL and differentiate by provider alone.
|
||||
# COALESCE lets NULL server_url participate in the constraint.
|
||||
"CREATE UNIQUE INDEX connector_sessions_user_endpoint_uidx "
|
||||
"ON connector_sessions (user_id, COALESCE(server_url, ''), provider);"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE INDEX connector_sessions_expiry_idx "
|
||||
"ON connector_sessions (expires_at) WHERE expires_at IS NOT NULL;"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE INDEX connector_sessions_server_url_idx "
|
||||
"ON connector_sessions (server_url) WHERE server_url IS NOT NULL;"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE UNIQUE INDEX connector_sessions_legacy_mongo_id_uidx "
|
||||
"ON connector_sessions (legacy_mongo_id) WHERE legacy_mongo_id IS NOT NULL;"
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"CREATE UNIQUE INDEX conversation_messages_conv_pos_uidx "
|
||||
"ON conversation_messages (conversation_id, position);"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE INDEX conversation_messages_user_ts_idx "
|
||||
"ON conversation_messages (user_id, timestamp DESC);"
|
||||
)
|
||||
|
||||
op.execute("CREATE INDEX conversations_user_date_idx ON conversations (user_id, date DESC);")
|
||||
op.execute("CREATE INDEX conversations_agent_idx ON conversations (agent_id);")
|
||||
op.execute(
|
||||
"CREATE UNIQUE INDEX conversations_shared_token_uidx "
|
||||
"ON conversations (shared_token) WHERE shared_token IS NOT NULL;"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE INDEX conversations_api_key_date_idx "
|
||||
"ON conversations (api_key, date DESC) WHERE api_key IS NOT NULL;"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE UNIQUE INDEX conversations_legacy_mongo_id_uidx "
|
||||
"ON conversations (legacy_mongo_id) WHERE legacy_mongo_id IS NOT NULL;"
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"CREATE UNIQUE INDEX memories_user_tool_path_uidx "
|
||||
"ON memories (user_id, tool_id, path);"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE UNIQUE INDEX memories_user_path_null_tool_uidx "
|
||||
"ON memories (user_id, path) WHERE tool_id IS NULL;"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE INDEX memories_path_prefix_idx "
|
||||
"ON memories (user_id, tool_id, path text_pattern_ops);"
|
||||
)
|
||||
op.execute("CREATE INDEX memories_tool_id_idx ON memories (tool_id);")
|
||||
|
||||
op.execute("CREATE UNIQUE INDEX notes_user_tool_uidx ON notes (user_id, tool_id);")
|
||||
op.execute("CREATE INDEX notes_tool_id_idx ON notes (tool_id);")
|
||||
op.execute(
|
||||
"CREATE UNIQUE INDEX notes_legacy_mongo_id_uidx "
|
||||
"ON notes (legacy_mongo_id) WHERE legacy_mongo_id IS NOT NULL;"
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"CREATE UNIQUE INDEX pending_tool_state_conv_user_uidx "
|
||||
"ON pending_tool_state (conversation_id, user_id);"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE INDEX pending_tool_state_expires_idx ON pending_tool_state (expires_at);"
|
||||
)
|
||||
|
||||
op.execute("CREATE INDEX prompts_user_id_idx ON prompts (user_id);")
|
||||
op.execute(
|
||||
"CREATE UNIQUE INDEX prompts_legacy_mongo_id_uidx "
|
||||
"ON prompts (legacy_mongo_id) WHERE legacy_mongo_id IS NOT NULL;"
|
||||
)
|
||||
|
||||
op.execute("CREATE INDEX shared_conversations_user_idx ON shared_conversations (user_id);")
|
||||
op.execute("CREATE INDEX shared_conversations_conv_idx ON shared_conversations (conversation_id);")
|
||||
op.execute(
|
||||
"CREATE INDEX shared_conversations_prompt_id_idx ON shared_conversations (prompt_id);"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE UNIQUE INDEX shared_conversations_uuid_uidx ON shared_conversations (uuid);"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE UNIQUE INDEX shared_conversations_dedup_uidx "
|
||||
"ON shared_conversations (conversation_id, user_id, is_promptable, first_n_queries, COALESCE(api_key, ''));"
|
||||
)
|
||||
|
||||
op.execute("CREATE INDEX sources_user_idx ON sources (user_id);")
|
||||
op.execute(
|
||||
"CREATE UNIQUE INDEX sources_legacy_mongo_id_uidx "
|
||||
"ON sources (legacy_mongo_id) WHERE legacy_mongo_id IS NOT NULL;"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE UNIQUE INDEX user_tools_legacy_mongo_id_uidx "
|
||||
"ON user_tools (legacy_mongo_id) WHERE legacy_mongo_id IS NOT NULL;"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE UNIQUE INDEX agent_folders_legacy_mongo_id_uidx "
|
||||
"ON agent_folders (legacy_mongo_id) WHERE legacy_mongo_id IS NOT NULL;"
|
||||
)
|
||||
op.execute("CREATE INDEX agent_folders_parent_idx ON agent_folders (parent_id);")
|
||||
op.execute("CREATE INDEX agents_workflow_idx ON agents (workflow_id);")
|
||||
|
||||
op.execute('CREATE INDEX stack_logs_timestamp_idx ON stack_logs ("timestamp" DESC);')
|
||||
op.execute('CREATE INDEX stack_logs_user_ts_idx ON stack_logs (user_id, "timestamp" DESC);')
|
||||
op.execute('CREATE INDEX stack_logs_level_ts_idx ON stack_logs (level, "timestamp" DESC);')
|
||||
op.execute("CREATE INDEX stack_logs_activity_idx ON stack_logs (activity_id);")
|
||||
op.execute(
|
||||
"CREATE UNIQUE INDEX stack_logs_mongo_id_uidx "
|
||||
"ON stack_logs (mongo_id) WHERE mongo_id IS NOT NULL;"
|
||||
)
|
||||
|
||||
op.execute("CREATE INDEX todos_user_tool_idx ON todos (user_id, tool_id);")
|
||||
op.execute("CREATE INDEX todos_tool_id_idx ON todos (tool_id);")
|
||||
op.execute(
|
||||
"CREATE UNIQUE INDEX todos_legacy_mongo_id_uidx "
|
||||
"ON todos (legacy_mongo_id) WHERE legacy_mongo_id IS NOT NULL;"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE UNIQUE INDEX todos_tool_todo_id_uidx "
|
||||
"ON todos (tool_id, todo_id) WHERE todo_id IS NOT NULL;"
|
||||
)
|
||||
|
||||
op.execute('CREATE INDEX token_usage_user_ts_idx ON token_usage (user_id, "timestamp" DESC);')
|
||||
op.execute('CREATE INDEX token_usage_key_ts_idx ON token_usage (api_key, "timestamp" DESC);')
|
||||
op.execute('CREATE INDEX token_usage_agent_ts_idx ON token_usage (agent_id, "timestamp" DESC);')
|
||||
op.execute(
|
||||
"CREATE UNIQUE INDEX token_usage_mongo_id_uidx "
|
||||
"ON token_usage (mongo_id) WHERE mongo_id IS NOT NULL;"
|
||||
)
|
||||
|
||||
op.execute('CREATE INDEX user_logs_user_ts_idx ON user_logs (user_id, "timestamp" DESC);')
|
||||
op.execute(
|
||||
"CREATE UNIQUE INDEX user_logs_mongo_id_uidx "
|
||||
"ON user_logs (mongo_id) WHERE mongo_id IS NOT NULL;"
|
||||
)
|
||||
|
||||
op.execute("CREATE INDEX user_tools_user_id_idx ON user_tools (user_id);")
|
||||
|
||||
op.execute("CREATE INDEX workflow_edges_from_node_idx ON workflow_edges (from_node_id);")
|
||||
op.execute("CREATE INDEX workflow_edges_to_node_idx ON workflow_edges (to_node_id);")
|
||||
op.execute(
|
||||
"CREATE UNIQUE INDEX workflow_edges_wf_ver_eid_uidx "
|
||||
"ON workflow_edges (workflow_id, graph_version, edge_id);"
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"CREATE UNIQUE INDEX workflow_nodes_wf_ver_nid_uidx "
|
||||
"ON workflow_nodes (workflow_id, graph_version, node_id);"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE UNIQUE INDEX workflow_nodes_legacy_mongo_id_uidx "
|
||||
"ON workflow_nodes (legacy_mongo_id) WHERE legacy_mongo_id IS NOT NULL;"
|
||||
)
|
||||
|
||||
op.execute("CREATE INDEX workflow_runs_workflow_idx ON workflow_runs (workflow_id);")
|
||||
op.execute("CREATE INDEX workflow_runs_user_idx ON workflow_runs (user_id);")
|
||||
op.execute(
|
||||
"CREATE INDEX workflow_runs_status_started_idx "
|
||||
"ON workflow_runs (status, started_at DESC);"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE UNIQUE INDEX workflow_runs_legacy_mongo_id_uidx "
|
||||
"ON workflow_runs (legacy_mongo_id) WHERE legacy_mongo_id IS NOT NULL;"
|
||||
)
|
||||
|
||||
op.execute("CREATE INDEX workflows_user_idx ON workflows (user_id);")
|
||||
op.execute(
|
||||
"CREATE UNIQUE INDEX workflows_legacy_mongo_id_uidx "
|
||||
"ON workflows (legacy_mongo_id) WHERE legacy_mongo_id IS NOT NULL;"
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# user_id foreign keys (deferrable so backfills can stage rows)
|
||||
# ------------------------------------------------------------------
|
||||
user_fk_tables = (
|
||||
"agent_folders",
|
||||
"agents",
|
||||
"attachments",
|
||||
"connector_sessions",
|
||||
"conversation_messages",
|
||||
"conversations",
|
||||
"memories",
|
||||
"notes",
|
||||
"pending_tool_state",
|
||||
"prompts",
|
||||
"shared_conversations",
|
||||
"sources",
|
||||
"stack_logs",
|
||||
"todos",
|
||||
"token_usage",
|
||||
"user_logs",
|
||||
"user_tools",
|
||||
"workflow_runs",
|
||||
"workflows",
|
||||
)
|
||||
for table in user_fk_tables:
|
||||
op.execute(
|
||||
f"ALTER TABLE {table} "
|
||||
f"ADD CONSTRAINT {table}_user_id_fk "
|
||||
f"FOREIGN KEY (user_id) REFERENCES users(user_id) "
|
||||
f"ON DELETE RESTRICT DEFERRABLE INITIALLY IMMEDIATE;"
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Triggers
|
||||
# ------------------------------------------------------------------
|
||||
updated_at_tables = (
|
||||
"agent_folders",
|
||||
"agents",
|
||||
"conversation_messages",
|
||||
"conversations",
|
||||
"memories",
|
||||
"notes",
|
||||
"prompts",
|
||||
"sources",
|
||||
"todos",
|
||||
"user_tools",
|
||||
"users",
|
||||
"workflows",
|
||||
)
|
||||
for table in updated_at_tables:
|
||||
op.execute(
|
||||
f"CREATE TRIGGER {table}_set_updated_at "
|
||||
f"BEFORE UPDATE ON {table} "
|
||||
f"FOR EACH ROW WHEN (OLD.* IS DISTINCT FROM NEW.*) "
|
||||
f"EXECUTE FUNCTION set_updated_at();"
|
||||
)
|
||||
|
||||
ensure_user_tables = (
|
||||
"agent_folders",
|
||||
"agents",
|
||||
"attachments",
|
||||
"connector_sessions",
|
||||
"conversation_messages",
|
||||
"conversations",
|
||||
"memories",
|
||||
"notes",
|
||||
"pending_tool_state",
|
||||
"prompts",
|
||||
"shared_conversations",
|
||||
"sources",
|
||||
"stack_logs",
|
||||
"todos",
|
||||
"token_usage",
|
||||
"user_logs",
|
||||
"user_tools",
|
||||
"workflow_runs",
|
||||
"workflows",
|
||||
)
|
||||
for table in ensure_user_tables:
|
||||
op.execute(
|
||||
f"CREATE TRIGGER {table}_ensure_user "
|
||||
f"BEFORE INSERT OR UPDATE OF user_id ON {table} "
|
||||
f"FOR EACH ROW EXECUTE FUNCTION ensure_user_exists();"
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"CREATE TRIGGER conversation_messages_fill_user "
|
||||
"BEFORE INSERT ON conversation_messages "
|
||||
"FOR EACH ROW EXECUTE FUNCTION conversation_messages_fill_user_id();"
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"CREATE TRIGGER attachments_cleanup_message_refs "
|
||||
"AFTER DELETE ON attachments "
|
||||
"FOR EACH ROW EXECUTE FUNCTION cleanup_message_attachment_refs();"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE TRIGGER agents_cleanup_user_prefs "
|
||||
"AFTER DELETE ON agents "
|
||||
"FOR EACH ROW EXECUTE FUNCTION cleanup_user_agent_prefs();"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE TRIGGER sources_cleanup_agent_extra_refs "
|
||||
"AFTER DELETE ON sources "
|
||||
"FOR EACH ROW EXECUTE FUNCTION cleanup_agent_extra_source_refs();"
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Seed sentinel __system__ user (system/template sources attribute here)
|
||||
# ------------------------------------------------------------------
|
||||
op.execute(
|
||||
"INSERT INTO users (user_id) VALUES ('__system__') "
|
||||
"ON CONFLICT (user_id) DO NOTHING;"
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Nuclear downgrade: drop everything this migration created. The
|
||||
# ordering drops FK-bearing children before parents; CASCADE would
|
||||
# also work but explicit ordering is easier to reason about in code
|
||||
# review.
|
||||
tables_in_drop_order = (
|
||||
"workflow_edges",
|
||||
"workflow_runs",
|
||||
"workflow_nodes",
|
||||
"workflows",
|
||||
"pending_tool_state",
|
||||
"shared_conversations",
|
||||
"conversation_messages",
|
||||
"conversations",
|
||||
"connector_sessions",
|
||||
"notes",
|
||||
"todos",
|
||||
"memories",
|
||||
"attachments",
|
||||
"agents",
|
||||
"sources",
|
||||
"agent_folders",
|
||||
"stack_logs",
|
||||
"user_logs",
|
||||
"token_usage",
|
||||
"user_tools",
|
||||
"prompts",
|
||||
"users",
|
||||
)
|
||||
for table in tables_in_drop_order:
|
||||
op.execute(f"DROP TABLE IF EXISTS {table} CASCADE;")
|
||||
|
||||
for fn in (
|
||||
"conversation_messages_fill_user_id",
|
||||
"cleanup_user_agent_prefs",
|
||||
"cleanup_agent_extra_source_refs",
|
||||
"cleanup_message_attachment_refs",
|
||||
"ensure_user_exists",
|
||||
"set_updated_at",
|
||||
):
|
||||
op.execute(f"DROP FUNCTION IF EXISTS {fn}();")
|
||||
@@ -1,37 +0,0 @@
|
||||
"""0002 app_metadata — singleton key/value table for instance-wide state.
|
||||
|
||||
Used by the startup version-check client to persist the anonymous
|
||||
instance UUID and a one-shot "notice shown" flag. Both values are tiny
|
||||
plain-text strings; this is a deliberate generic-config table rather
|
||||
than dedicated columns so future one-off settings (telemetry opt-in
|
||||
timestamps, feature-flag overrides, etc.) don't each need their own
|
||||
migration.
|
||||
|
||||
Revision ID: 0002_app_metadata
|
||||
Revises: 0001_initial
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
|
||||
|
||||
revision: str = "0002_app_metadata"
|
||||
down_revision: Union[str, None] = "0001_initial"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.execute(
|
||||
"""
|
||||
CREATE TABLE app_metadata (
|
||||
key TEXT PRIMARY KEY,
|
||||
value TEXT NOT NULL
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute("DROP TABLE IF EXISTS app_metadata;")
|
||||
@@ -74,76 +74,57 @@ class AnswerResource(Resource, BaseAnswerResource):
|
||||
decoded_token = getattr(request, "decoded_token", None)
|
||||
processor = StreamProcessor(data, decoded_token)
|
||||
try:
|
||||
# ---- Continuation mode ----
|
||||
if data.get("tool_actions"):
|
||||
(
|
||||
agent,
|
||||
messages,
|
||||
tools_dict,
|
||||
pending_tool_calls,
|
||||
tool_actions,
|
||||
) = processor.resume_from_tool_actions(
|
||||
data["tool_actions"], data["conversation_id"]
|
||||
)
|
||||
if not processor.decoded_token:
|
||||
return make_response({"error": "Unauthorized"}, 401)
|
||||
if error := self.check_usage(processor.agent_config):
|
||||
return error
|
||||
stream = self.complete_stream(
|
||||
question="",
|
||||
agent=agent,
|
||||
conversation_id=processor.conversation_id,
|
||||
user_api_key=processor.agent_config.get("user_api_key"),
|
||||
decoded_token=processor.decoded_token,
|
||||
agent_id=processor.agent_id,
|
||||
model_id=processor.model_id,
|
||||
_continuation={
|
||||
"messages": messages,
|
||||
"tools_dict": tools_dict,
|
||||
"pending_tool_calls": pending_tool_calls,
|
||||
"tool_actions": tool_actions,
|
||||
},
|
||||
)
|
||||
else:
|
||||
# ---- Normal mode ----
|
||||
agent = processor.build_agent(data.get("question", ""))
|
||||
if not processor.decoded_token:
|
||||
return make_response({"error": "Unauthorized"}, 401)
|
||||
agent = processor.build_agent(data.get("question", ""))
|
||||
if not processor.decoded_token:
|
||||
return make_response({"error": "Unauthorized"}, 401)
|
||||
|
||||
if error := self.check_usage(processor.agent_config):
|
||||
return error
|
||||
|
||||
stream = self.complete_stream(
|
||||
question=data["question"],
|
||||
agent=agent,
|
||||
conversation_id=processor.conversation_id,
|
||||
user_api_key=processor.agent_config.get("user_api_key"),
|
||||
decoded_token=processor.decoded_token,
|
||||
isNoneDoc=data.get("isNoneDoc"),
|
||||
index=None,
|
||||
should_save_conversation=data.get("save_conversation", True),
|
||||
agent_id=processor.agent_id,
|
||||
is_shared_usage=processor.is_shared_usage,
|
||||
shared_token=processor.shared_token,
|
||||
model_id=processor.model_id,
|
||||
)
|
||||
if error := self.check_usage(processor.agent_config):
|
||||
return error
|
||||
|
||||
stream = self.complete_stream(
|
||||
question=data["question"],
|
||||
agent=agent,
|
||||
conversation_id=processor.conversation_id,
|
||||
user_api_key=processor.agent_config.get("user_api_key"),
|
||||
decoded_token=processor.decoded_token,
|
||||
isNoneDoc=data.get("isNoneDoc"),
|
||||
index=None,
|
||||
should_save_conversation=data.get("save_conversation", True),
|
||||
agent_id=processor.agent_id,
|
||||
is_shared_usage=processor.is_shared_usage,
|
||||
shared_token=processor.shared_token,
|
||||
model_id=processor.model_id,
|
||||
)
|
||||
stream_result = self.process_response_stream(stream)
|
||||
|
||||
if stream_result["error"]:
|
||||
return make_response({"error": stream_result["error"]}, 400)
|
||||
if len(stream_result) == 7:
|
||||
(
|
||||
conversation_id,
|
||||
response,
|
||||
sources,
|
||||
tool_calls,
|
||||
thought,
|
||||
error,
|
||||
structured_info,
|
||||
) = stream_result
|
||||
else:
|
||||
conversation_id, response, sources, tool_calls, thought, error = (
|
||||
stream_result
|
||||
)
|
||||
structured_info = None
|
||||
|
||||
if error:
|
||||
return make_response({"error": error}, 400)
|
||||
result = {
|
||||
"conversation_id": stream_result["conversation_id"],
|
||||
"answer": stream_result["answer"],
|
||||
"sources": stream_result["sources"],
|
||||
"tool_calls": stream_result["tool_calls"],
|
||||
"thought": stream_result["thought"],
|
||||
"conversation_id": conversation_id,
|
||||
"answer": response,
|
||||
"sources": sources,
|
||||
"tool_calls": tool_calls,
|
||||
"thought": thought,
|
||||
}
|
||||
|
||||
extra_info = stream_result.get("extra")
|
||||
if extra_info:
|
||||
result.update(extra_info)
|
||||
if structured_info:
|
||||
result.update(structured_info)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"/api/answer - error: {str(e)} - traceback: {traceback.format_exc()}",
|
||||
|
||||
@@ -6,7 +6,6 @@ from typing import Any, Dict, Generator, List, Optional
|
||||
from flask import jsonify, make_response, Response
|
||||
from flask_restx import Namespace
|
||||
|
||||
from application.api.answer.services.continuation_service import ContinuationService
|
||||
from application.api.answer.services.conversation_service import ConversationService
|
||||
from application.core.model_utils import (
|
||||
get_api_key_for_provider,
|
||||
@@ -14,13 +13,10 @@ from application.core.model_utils import (
|
||||
get_provider_from_model_id,
|
||||
)
|
||||
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.core.settings import settings
|
||||
from application.error import sanitize_api_error
|
||||
from application.llm.llm_creator import LLMCreator
|
||||
from application.storage.db.repositories.agents import AgentsRepository
|
||||
from application.storage.db.repositories.token_usage import TokenUsageRepository
|
||||
from application.storage.db.repositories.user_logs import UserLogsRepository
|
||||
from application.storage.db.session import db_readonly, db_session
|
||||
from application.utils import check_required_fields
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -33,22 +29,17 @@ class BaseAnswerResource:
|
||||
"""Shared base class for answer endpoints"""
|
||||
|
||||
def __init__(self):
|
||||
mongo = MongoDB.get_client()
|
||||
db = mongo[settings.MONGO_DB_NAME]
|
||||
self.db = db
|
||||
self.user_logs_collection = db["user_logs"]
|
||||
self.default_model_id = get_default_model_id()
|
||||
self.conversation_service = ConversationService()
|
||||
|
||||
def validate_request(
|
||||
self, data: Dict[str, Any], require_conversation_id: bool = False
|
||||
) -> Optional[Response]:
|
||||
"""Common request validation.
|
||||
|
||||
Continuation requests (``tool_actions`` present) require
|
||||
``conversation_id`` but not ``question``.
|
||||
"""
|
||||
if data.get("tool_actions"):
|
||||
# Continuation mode — question is not required
|
||||
if missing := check_required_fields(data, ["conversation_id"]):
|
||||
return missing
|
||||
return None
|
||||
"""Common request validation"""
|
||||
required_fields = ["question"]
|
||||
if require_conversation_id:
|
||||
required_fields.append("conversation_id")
|
||||
@@ -90,8 +81,8 @@ class BaseAnswerResource:
|
||||
api_key = agent_config.get("user_api_key")
|
||||
if not api_key:
|
||||
return None
|
||||
with db_readonly() as conn:
|
||||
agent = AgentsRepository(conn).find_by_key(api_key)
|
||||
agents_collection = self.db["agents"]
|
||||
agent = agents_collection.find_one({"key": api_key})
|
||||
|
||||
if not agent:
|
||||
return make_response(
|
||||
@@ -112,32 +103,41 @@ class BaseAnswerResource:
|
||||
)
|
||||
|
||||
token_limit = int(
|
||||
agent.get("token_limit") or settings.DEFAULT_AGENT_LIMITS["token_limit"]
|
||||
agent.get("token_limit", settings.DEFAULT_AGENT_LIMITS["token_limit"])
|
||||
)
|
||||
request_limit = int(
|
||||
agent.get("request_limit") or settings.DEFAULT_AGENT_LIMITS["request_limit"]
|
||||
agent.get("request_limit", settings.DEFAULT_AGENT_LIMITS["request_limit"])
|
||||
)
|
||||
|
||||
end_date = datetime.datetime.now(datetime.timezone.utc)
|
||||
token_usage_collection = self.db["token_usage"]
|
||||
|
||||
end_date = datetime.datetime.now()
|
||||
start_date = end_date - datetime.timedelta(hours=24)
|
||||
|
||||
if limited_token_mode or limited_request_mode:
|
||||
with db_readonly() as conn:
|
||||
token_repo = TokenUsageRepository(conn)
|
||||
if limited_token_mode:
|
||||
daily_token_usage = token_repo.sum_tokens_in_range(
|
||||
start=start_date, end=end_date, api_key=api_key,
|
||||
)
|
||||
else:
|
||||
daily_token_usage = 0
|
||||
if limited_request_mode:
|
||||
daily_request_usage = token_repo.count_in_range(
|
||||
start=start_date, end=end_date, api_key=api_key,
|
||||
)
|
||||
else:
|
||||
daily_request_usage = 0
|
||||
match_query = {
|
||||
"timestamp": {"$gte": start_date, "$lte": end_date},
|
||||
"api_key": api_key,
|
||||
}
|
||||
|
||||
if limited_token_mode:
|
||||
token_pipeline = [
|
||||
{"$match": match_query},
|
||||
{
|
||||
"$group": {
|
||||
"_id": None,
|
||||
"total_tokens": {
|
||||
"$sum": {"$add": ["$prompt_tokens", "$generated_tokens"]}
|
||||
},
|
||||
}
|
||||
},
|
||||
]
|
||||
token_result = list(token_usage_collection.aggregate(token_pipeline))
|
||||
daily_token_usage = token_result[0]["total_tokens"] if token_result else 0
|
||||
else:
|
||||
daily_token_usage = 0
|
||||
if limited_request_mode:
|
||||
daily_request_usage = token_usage_collection.count_documents(match_query)
|
||||
else:
|
||||
daily_request_usage = 0
|
||||
if not limited_token_mode and not limited_request_mode:
|
||||
return None
|
||||
@@ -177,7 +177,6 @@ class BaseAnswerResource:
|
||||
is_shared_usage: bool = False,
|
||||
shared_token: Optional[str] = None,
|
||||
model_id: Optional[str] = None,
|
||||
_continuation: Optional[Dict] = None,
|
||||
) -> Generator[str, None, None]:
|
||||
"""
|
||||
Generator function that streams the complete conversation response.
|
||||
@@ -208,19 +207,8 @@ class BaseAnswerResource:
|
||||
schema_info = None
|
||||
structured_chunks = []
|
||||
query_metadata = {}
|
||||
paused = False
|
||||
|
||||
if _continuation:
|
||||
gen_iter = agent.gen_continuation(
|
||||
messages=_continuation["messages"],
|
||||
tools_dict=_continuation["tools_dict"],
|
||||
pending_tool_calls=_continuation["pending_tool_calls"],
|
||||
tool_actions=_continuation["tool_actions"],
|
||||
)
|
||||
else:
|
||||
gen_iter = agent.gen(query=question)
|
||||
|
||||
for line in gen_iter:
|
||||
for line in agent.gen(query=question):
|
||||
if "metadata" in line:
|
||||
query_metadata.update(line["metadata"])
|
||||
elif "answer" in line:
|
||||
@@ -256,21 +244,15 @@ class BaseAnswerResource:
|
||||
data = json.dumps({"type": "thought", "thought": line["thought"]})
|
||||
yield f"data: {data}\n\n"
|
||||
elif "type" in line:
|
||||
if line.get("type") == "tool_calls_pending":
|
||||
# Save continuation state and end the stream
|
||||
paused = True
|
||||
data = json.dumps(line)
|
||||
yield f"data: {data}\n\n"
|
||||
elif line.get("type") == "error":
|
||||
if line.get("type") == "error":
|
||||
sanitized_error = {
|
||||
"type": "error",
|
||||
"error": sanitize_api_error(line.get("error", "An error occurred"))
|
||||
}
|
||||
data = json.dumps(sanitized_error)
|
||||
yield f"data: {data}\n\n"
|
||||
else:
|
||||
data = json.dumps(line)
|
||||
yield f"data: {data}\n\n"
|
||||
yield f"data: {data}\n\n"
|
||||
if is_structured and structured_chunks:
|
||||
structured_data = {
|
||||
"type": "structured_answer",
|
||||
@@ -280,93 +262,6 @@ class BaseAnswerResource:
|
||||
}
|
||||
data = json.dumps(structured_data)
|
||||
yield f"data: {data}\n\n"
|
||||
|
||||
# ---- Paused: save continuation state and end stream early ----
|
||||
if paused:
|
||||
continuation = getattr(agent, "_pending_continuation", None)
|
||||
if continuation:
|
||||
# Ensure we have a conversation_id — create a partial
|
||||
# conversation if this is the first turn.
|
||||
if not conversation_id and should_save_conversation:
|
||||
try:
|
||||
provider = (
|
||||
get_provider_from_model_id(model_id)
|
||||
if model_id
|
||||
else settings.LLM_PROVIDER
|
||||
)
|
||||
sys_api_key = get_api_key_for_provider(
|
||||
provider or settings.LLM_PROVIDER
|
||||
)
|
||||
llm = LLMCreator.create_llm(
|
||||
provider or settings.LLM_PROVIDER,
|
||||
api_key=sys_api_key,
|
||||
user_api_key=user_api_key,
|
||||
decoded_token=decoded_token,
|
||||
model_id=model_id,
|
||||
agent_id=agent_id,
|
||||
)
|
||||
conversation_id = (
|
||||
self.conversation_service.save_conversation(
|
||||
None,
|
||||
question,
|
||||
response_full,
|
||||
thought,
|
||||
source_log_docs,
|
||||
tool_calls,
|
||||
llm,
|
||||
model_id or self.default_model_id,
|
||||
decoded_token,
|
||||
api_key=user_api_key,
|
||||
agent_id=agent_id,
|
||||
is_shared_usage=is_shared_usage,
|
||||
shared_token=shared_token,
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to create conversation for continuation: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
if conversation_id:
|
||||
try:
|
||||
cont_service = ContinuationService()
|
||||
cont_service.save_state(
|
||||
conversation_id=str(conversation_id),
|
||||
user=decoded_token.get("sub", "local"),
|
||||
messages=continuation["messages"],
|
||||
pending_tool_calls=continuation["pending_tool_calls"],
|
||||
tools_dict=continuation["tools_dict"],
|
||||
tool_schemas=getattr(agent, "tools", []),
|
||||
agent_config={
|
||||
"model_id": model_id or self.default_model_id,
|
||||
"llm_name": getattr(agent, "llm_name", settings.LLM_PROVIDER),
|
||||
"api_key": getattr(agent, "api_key", None),
|
||||
"user_api_key": user_api_key,
|
||||
"agent_id": agent_id,
|
||||
"agent_type": agent.__class__.__name__,
|
||||
"prompt": getattr(agent, "prompt", ""),
|
||||
"json_schema": getattr(agent, "json_schema", None),
|
||||
"retriever_config": getattr(agent, "retriever_config", None),
|
||||
},
|
||||
client_tools=getattr(
|
||||
agent.tool_executor, "client_tools", None
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to save continuation state: {str(e)}",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
id_data = {"type": "id", "id": str(conversation_id)}
|
||||
data = json.dumps(id_data)
|
||||
yield f"data: {data}\n\n"
|
||||
|
||||
data = json.dumps({"type": "end"})
|
||||
yield f"data: {data}\n\n"
|
||||
return
|
||||
|
||||
if isNoneDoc:
|
||||
for doc in source_log_docs:
|
||||
doc["source"] = "None"
|
||||
@@ -457,18 +352,7 @@ class BaseAnswerResource:
|
||||
for key, value in log_data.items():
|
||||
if isinstance(value, str) and len(value) > 10000:
|
||||
log_data[key] = value[:10000]
|
||||
try:
|
||||
with db_session() as conn:
|
||||
UserLogsRepository(conn).insert(
|
||||
user_id=log_data.get("user"),
|
||||
endpoint="stream_answer",
|
||||
data=log_data,
|
||||
)
|
||||
except Exception as log_err:
|
||||
logger.error(
|
||||
f"Failed to persist stream_answer user log: {log_err}",
|
||||
exc_info=True,
|
||||
)
|
||||
self.user_logs_collection.insert_one(log_data)
|
||||
|
||||
data = json.dumps({"type": "end"})
|
||||
yield f"data: {data}\n\n"
|
||||
@@ -541,13 +425,8 @@ class BaseAnswerResource:
|
||||
yield f"data: {data}\n\n"
|
||||
return
|
||||
|
||||
def process_response_stream(self, stream) -> Dict[str, Any]:
|
||||
"""Process the stream response for non-streaming endpoint.
|
||||
|
||||
Returns:
|
||||
Dict with keys: conversation_id, answer, sources, tool_calls,
|
||||
thought, error, and optional extra.
|
||||
"""
|
||||
def process_response_stream(self, stream):
|
||||
"""Process the stream response for non-streaming endpoint"""
|
||||
conversation_id = ""
|
||||
response_full = ""
|
||||
source_log_docs = []
|
||||
@@ -556,7 +435,6 @@ class BaseAnswerResource:
|
||||
stream_ended = False
|
||||
is_structured = False
|
||||
schema_info = None
|
||||
pending_tool_calls = None
|
||||
|
||||
for line in stream:
|
||||
try:
|
||||
@@ -575,22 +453,11 @@ class BaseAnswerResource:
|
||||
source_log_docs = event["source"]
|
||||
elif event["type"] == "tool_calls":
|
||||
tool_calls = event["tool_calls"]
|
||||
elif event["type"] == "tool_calls_pending":
|
||||
pending_tool_calls = event.get("data", {}).get(
|
||||
"pending_tool_calls", []
|
||||
)
|
||||
elif event["type"] == "thought":
|
||||
thought = event["thought"]
|
||||
elif event["type"] == "error":
|
||||
logger.error(f"Error from stream: {event['error']}")
|
||||
return {
|
||||
"conversation_id": None,
|
||||
"answer": None,
|
||||
"sources": None,
|
||||
"tool_calls": None,
|
||||
"thought": None,
|
||||
"error": event["error"],
|
||||
}
|
||||
return None, None, None, None, event["error"], None
|
||||
elif event["type"] == "end":
|
||||
stream_ended = True
|
||||
except (json.JSONDecodeError, KeyError) as e:
|
||||
@@ -598,30 +465,18 @@ class BaseAnswerResource:
|
||||
continue
|
||||
if not stream_ended:
|
||||
logger.error("Stream ended unexpectedly without an 'end' event.")
|
||||
return {
|
||||
"conversation_id": None,
|
||||
"answer": None,
|
||||
"sources": None,
|
||||
"tool_calls": None,
|
||||
"thought": None,
|
||||
"error": "Stream ended unexpectedly",
|
||||
}
|
||||
|
||||
result: Dict[str, Any] = {
|
||||
"conversation_id": conversation_id,
|
||||
"answer": response_full,
|
||||
"sources": source_log_docs,
|
||||
"tool_calls": tool_calls,
|
||||
"thought": thought,
|
||||
"error": None,
|
||||
}
|
||||
|
||||
if pending_tool_calls is not None:
|
||||
result["extra"] = {"pending_tool_calls": pending_tool_calls}
|
||||
return None, None, None, None, "Stream ended unexpectedly", None
|
||||
result = (
|
||||
conversation_id,
|
||||
response_full,
|
||||
source_log_docs,
|
||||
tool_calls,
|
||||
thought,
|
||||
None,
|
||||
)
|
||||
|
||||
if is_structured:
|
||||
result["extra"] = {"structured": True, "schema": schema_info}
|
||||
|
||||
result = result + ({"structured": True, "schema": schema_info},)
|
||||
return result
|
||||
|
||||
def error_stream_generate(self, err_response):
|
||||
|
||||
@@ -1,21 +1,28 @@
|
||||
import logging
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from flask import make_response, request
|
||||
from flask_restx import fields, Resource
|
||||
|
||||
from bson.dbref import DBRef
|
||||
|
||||
from application.api.answer.routes.base import answer_ns
|
||||
from application.services.search_service import (
|
||||
InvalidAPIKey,
|
||||
SearchFailed,
|
||||
search,
|
||||
)
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.core.settings import settings
|
||||
from application.vectorstore.vector_creator import VectorCreator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@answer_ns.route("/api/search")
|
||||
class SearchResource(Resource):
|
||||
"""Fast search endpoint for retrieving relevant documents."""
|
||||
"""Fast search endpoint for retrieving relevant documents"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
mongo = MongoDB.get_client()
|
||||
self.db = mongo[settings.MONGO_DB_NAME]
|
||||
self.agents_collection = self.db["agents"]
|
||||
|
||||
search_model = answer_ns.model(
|
||||
"SearchModel",
|
||||
@@ -32,10 +39,116 @@ class SearchResource(Resource):
|
||||
},
|
||||
)
|
||||
|
||||
def _get_sources_from_api_key(self, api_key: str) -> List[str]:
|
||||
"""Get source IDs connected to the API key/agent.
|
||||
|
||||
"""
|
||||
agent_data = self.agents_collection.find_one({"key": api_key})
|
||||
if not agent_data:
|
||||
return []
|
||||
|
||||
source_ids = []
|
||||
|
||||
# Handle multiple sources (only if non-empty)
|
||||
sources = agent_data.get("sources", [])
|
||||
if sources and isinstance(sources, list) and len(sources) > 0:
|
||||
for source_ref in sources:
|
||||
# Skip "default" - it's a placeholder, not an actual vectorstore
|
||||
if source_ref == "default":
|
||||
continue
|
||||
elif isinstance(source_ref, DBRef):
|
||||
source_doc = self.db.dereference(source_ref)
|
||||
if source_doc:
|
||||
source_ids.append(str(source_doc["_id"]))
|
||||
|
||||
# Handle single source (legacy) - check if sources was empty or didn't yield results
|
||||
if not source_ids:
|
||||
source = agent_data.get("source")
|
||||
if isinstance(source, DBRef):
|
||||
source_doc = self.db.dereference(source)
|
||||
if source_doc:
|
||||
source_ids.append(str(source_doc["_id"]))
|
||||
# Skip "default" - it's a placeholder, not an actual vectorstore
|
||||
elif source and source != "default":
|
||||
source_ids.append(source)
|
||||
|
||||
return source_ids
|
||||
|
||||
def _search_vectorstores(
|
||||
self, query: str, source_ids: List[str], chunks: int
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Search across vectorstores and return results"""
|
||||
if not source_ids:
|
||||
return []
|
||||
|
||||
results = []
|
||||
chunks_per_source = max(1, chunks // len(source_ids))
|
||||
seen_texts = set()
|
||||
|
||||
for source_id in source_ids:
|
||||
if not source_id or not source_id.strip():
|
||||
continue
|
||||
|
||||
try:
|
||||
docsearch = VectorCreator.create_vectorstore(
|
||||
settings.VECTOR_STORE, source_id, settings.EMBEDDINGS_KEY
|
||||
)
|
||||
docs = docsearch.search(query, k=chunks_per_source * 2)
|
||||
|
||||
for doc in docs:
|
||||
if len(results) >= chunks:
|
||||
break
|
||||
|
||||
if hasattr(doc, "page_content") and hasattr(doc, "metadata"):
|
||||
page_content = doc.page_content
|
||||
metadata = doc.metadata
|
||||
else:
|
||||
page_content = doc.get("text", doc.get("page_content", ""))
|
||||
metadata = doc.get("metadata", {})
|
||||
|
||||
# Skip duplicates
|
||||
text_hash = hash(page_content[:200])
|
||||
if text_hash in seen_texts:
|
||||
continue
|
||||
seen_texts.add(text_hash)
|
||||
|
||||
title = metadata.get(
|
||||
"title", metadata.get("post_title", "")
|
||||
)
|
||||
if not isinstance(title, str):
|
||||
title = str(title) if title else ""
|
||||
|
||||
# Clean up title
|
||||
if title:
|
||||
title = title.split("/")[-1]
|
||||
else:
|
||||
# Use filename or first part of content as title
|
||||
title = metadata.get("filename", page_content[:50] + "...")
|
||||
|
||||
source = metadata.get("source", source_id)
|
||||
|
||||
results.append({
|
||||
"text": page_content,
|
||||
"title": title,
|
||||
"source": source,
|
||||
})
|
||||
|
||||
if len(results) >= chunks:
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error searching vectorstore {source_id}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
continue
|
||||
|
||||
return results[:chunks]
|
||||
|
||||
@answer_ns.expect(search_model)
|
||||
@answer_ns.doc(description="Search for relevant documents based on query")
|
||||
def post(self):
|
||||
data = request.get_json() or {}
|
||||
data = request.get_json()
|
||||
|
||||
question = data.get("question")
|
||||
api_key = data.get("api_key")
|
||||
@@ -43,13 +156,31 @@ class SearchResource(Resource):
|
||||
|
||||
if not question:
|
||||
return make_response({"error": "question is required"}, 400)
|
||||
|
||||
if not api_key:
|
||||
return make_response({"error": "api_key is required"}, 400)
|
||||
|
||||
try:
|
||||
return make_response(search(api_key, question, chunks), 200)
|
||||
except InvalidAPIKey:
|
||||
# Validate API key
|
||||
agent = self.agents_collection.find_one({"key": api_key})
|
||||
if not agent:
|
||||
return make_response({"error": "Invalid API key"}, 401)
|
||||
except SearchFailed:
|
||||
logger.exception("/api/search failed")
|
||||
|
||||
try:
|
||||
# Get sources connected to this API key
|
||||
source_ids = self._get_sources_from_api_key(api_key)
|
||||
|
||||
if not source_ids:
|
||||
return make_response([], 200)
|
||||
|
||||
# Perform search
|
||||
results = self._search_vectorstores(question, source_ids, chunks)
|
||||
|
||||
return make_response(results, 200)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"/api/search - error: {str(e)}",
|
||||
extra={"error": str(e)},
|
||||
exc_info=True,
|
||||
)
|
||||
return make_response({"error": "Search failed"}, 500)
|
||||
|
||||
@@ -79,47 +79,7 @@ class StreamResource(Resource, BaseAnswerResource):
|
||||
return error
|
||||
decoded_token = getattr(request, "decoded_token", None)
|
||||
processor = StreamProcessor(data, decoded_token)
|
||||
|
||||
try:
|
||||
# ---- Continuation mode ----
|
||||
if data.get("tool_actions"):
|
||||
(
|
||||
agent,
|
||||
messages,
|
||||
tools_dict,
|
||||
pending_tool_calls,
|
||||
tool_actions,
|
||||
) = processor.resume_from_tool_actions(
|
||||
data["tool_actions"], data["conversation_id"]
|
||||
)
|
||||
if not processor.decoded_token:
|
||||
return Response(
|
||||
self.error_stream_generate("Unauthorized"),
|
||||
status=401,
|
||||
mimetype="text/event-stream",
|
||||
)
|
||||
if error := self.check_usage(processor.agent_config):
|
||||
return error
|
||||
return Response(
|
||||
self.complete_stream(
|
||||
question="",
|
||||
agent=agent,
|
||||
conversation_id=processor.conversation_id,
|
||||
user_api_key=processor.agent_config.get("user_api_key"),
|
||||
decoded_token=processor.decoded_token,
|
||||
agent_id=processor.agent_id,
|
||||
model_id=processor.model_id,
|
||||
_continuation={
|
||||
"messages": messages,
|
||||
"tools_dict": tools_dict,
|
||||
"pending_tool_calls": pending_tool_calls,
|
||||
"tool_actions": tool_actions,
|
||||
},
|
||||
),
|
||||
mimetype="text/event-stream",
|
||||
)
|
||||
|
||||
# ---- Normal mode ----
|
||||
agent = processor.build_agent(data["question"])
|
||||
if not processor.decoded_token:
|
||||
return Response(
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
"""Message reconstruction utilities for compression."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Dict, List, Optional
|
||||
@@ -50,35 +49,28 @@ class MessageBuilder:
|
||||
if include_tool_calls and "tool_calls" in query:
|
||||
for tool_call in query["tool_calls"]:
|
||||
call_id = tool_call.get("call_id") or str(uuid.uuid4())
|
||||
args = tool_call.get("arguments")
|
||||
args_str = (
|
||||
json.dumps(args)
|
||||
if isinstance(args, dict)
|
||||
else (args or "{}")
|
||||
|
||||
function_call_dict = {
|
||||
"function_call": {
|
||||
"name": tool_call.get("action_name"),
|
||||
"args": tool_call.get("arguments"),
|
||||
"call_id": call_id,
|
||||
}
|
||||
}
|
||||
function_response_dict = {
|
||||
"function_response": {
|
||||
"name": tool_call.get("action_name"),
|
||||
"response": {"result": tool_call.get("result")},
|
||||
"call_id": call_id,
|
||||
}
|
||||
}
|
||||
|
||||
messages.append(
|
||||
{"role": "assistant", "content": [function_call_dict]}
|
||||
)
|
||||
messages.append({
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [{
|
||||
"id": call_id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool_call.get("action_name", ""),
|
||||
"arguments": args_str,
|
||||
},
|
||||
}],
|
||||
})
|
||||
result = tool_call.get("result")
|
||||
result_str = (
|
||||
json.dumps(result)
|
||||
if not isinstance(result, str)
|
||||
else (result or "")
|
||||
messages.append(
|
||||
{"role": "tool", "content": [function_response_dict]}
|
||||
)
|
||||
messages.append({
|
||||
"role": "tool",
|
||||
"tool_call_id": call_id,
|
||||
"content": result_str,
|
||||
})
|
||||
|
||||
# If no recent queries (everything was compressed), add a continuation user message
|
||||
if len(recent_queries) == 0 and compressed_summary:
|
||||
@@ -188,35 +180,28 @@ class MessageBuilder:
|
||||
if include_tool_calls and "tool_calls" in query:
|
||||
for tool_call in query["tool_calls"]:
|
||||
call_id = tool_call.get("call_id") or str(uuid.uuid4())
|
||||
args = tool_call.get("arguments")
|
||||
args_str = (
|
||||
json.dumps(args)
|
||||
if isinstance(args, dict)
|
||||
else (args or "{}")
|
||||
|
||||
function_call_dict = {
|
||||
"function_call": {
|
||||
"name": tool_call.get("action_name"),
|
||||
"args": tool_call.get("arguments"),
|
||||
"call_id": call_id,
|
||||
}
|
||||
}
|
||||
function_response_dict = {
|
||||
"function_response": {
|
||||
"name": tool_call.get("action_name"),
|
||||
"response": {"result": tool_call.get("result")},
|
||||
"call_id": call_id,
|
||||
}
|
||||
}
|
||||
|
||||
rebuilt_messages.append(
|
||||
{"role": "assistant", "content": [function_call_dict]}
|
||||
)
|
||||
rebuilt_messages.append({
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [{
|
||||
"id": call_id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool_call.get("action_name", ""),
|
||||
"arguments": args_str,
|
||||
},
|
||||
}],
|
||||
})
|
||||
result = tool_call.get("result")
|
||||
result_str = (
|
||||
json.dumps(result)
|
||||
if not isinstance(result, str)
|
||||
else (result or "")
|
||||
rebuilt_messages.append(
|
||||
{"role": "tool", "content": [function_response_dict]}
|
||||
)
|
||||
rebuilt_messages.append({
|
||||
"role": "tool",
|
||||
"tool_call_id": call_id,
|
||||
"content": result_str,
|
||||
})
|
||||
|
||||
# If no recent queries (everything was compressed), add a continuation user message
|
||||
if len(recent_queries) == 0 and compressed_summary:
|
||||
|
||||
@@ -1,157 +0,0 @@
|
||||
"""Service for saving and restoring tool-call continuation state.
|
||||
|
||||
When a stream pauses (tool needs approval or client-side execution),
|
||||
the full execution state is persisted to Postgres so the client can
|
||||
resume later by sending tool_actions.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from application.storage.db.base_repository import looks_like_uuid
|
||||
from application.storage.db.repositories.conversations import ConversationsRepository
|
||||
from application.storage.db.repositories.pending_tool_state import (
|
||||
PendingToolStateRepository,
|
||||
)
|
||||
from application.storage.db.session import db_readonly, db_session
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# TTL for pending states — auto-cleaned after this period
|
||||
PENDING_STATE_TTL_SECONDS = 30 * 60 # 30 minutes
|
||||
|
||||
|
||||
def _make_serializable(obj: Any) -> Any:
|
||||
"""Recursively coerce non-JSON values into JSON-safe forms.
|
||||
|
||||
Handles ``uuid.UUID`` (from PG columns), ``bytes``, and recurses into
|
||||
dicts/lists. Post-Mongo-cutover the ObjectId branch is gone — none of
|
||||
our writers produce them anymore.
|
||||
"""
|
||||
if isinstance(obj, UUID):
|
||||
return str(obj)
|
||||
if isinstance(obj, dict):
|
||||
return {str(k): _make_serializable(v) for k, v in obj.items()}
|
||||
if isinstance(obj, list):
|
||||
return [_make_serializable(v) for v in obj]
|
||||
if isinstance(obj, bytes):
|
||||
return obj.decode("utf-8", errors="replace")
|
||||
return obj
|
||||
|
||||
|
||||
class ContinuationService:
|
||||
"""Manages pending tool-call state in Postgres."""
|
||||
|
||||
def __init__(self):
|
||||
# No-op constructor retained for call-site compatibility. State
|
||||
# lives in Postgres now; each operation opens its own short-lived
|
||||
# session rather than holding a connection on the service.
|
||||
pass
|
||||
|
||||
def save_state(
|
||||
self,
|
||||
conversation_id: str,
|
||||
user: str,
|
||||
messages: List[Dict],
|
||||
pending_tool_calls: List[Dict],
|
||||
tools_dict: Dict,
|
||||
tool_schemas: List[Dict],
|
||||
agent_config: Dict,
|
||||
client_tools: Optional[List[Dict]] = None,
|
||||
) -> str:
|
||||
"""Save execution state for later continuation.
|
||||
|
||||
``conversation_id`` may be a Postgres UUID or the legacy Mongo
|
||||
``ObjectId`` string — the latter is resolved via
|
||||
``conversations.legacy_mongo_id`` to find the matching row.
|
||||
|
||||
Args:
|
||||
conversation_id: The conversation this state belongs to.
|
||||
user: Owner user ID.
|
||||
messages: Full messages array at the pause point.
|
||||
pending_tool_calls: Tool calls awaiting client action.
|
||||
tools_dict: Serializable tools configuration dict.
|
||||
tool_schemas: LLM-formatted tool schemas (agent.tools).
|
||||
agent_config: Config needed to recreate the agent on resume.
|
||||
client_tools: Client-provided tool schemas for client-side execution.
|
||||
|
||||
Returns:
|
||||
The string ID (conversation_id as provided) of the saved state.
|
||||
"""
|
||||
with db_session() as conn:
|
||||
conv = ConversationsRepository(conn).get_by_legacy_id(conversation_id)
|
||||
if conv is not None:
|
||||
pg_conv_id = conv["id"]
|
||||
elif looks_like_uuid(conversation_id):
|
||||
pg_conv_id = conversation_id
|
||||
else:
|
||||
# Unresolvable legacy ObjectId — downstream ``CAST AS uuid``
|
||||
# would raise and poison the save. Surface the mismatch so
|
||||
# the caller can decide (the stream loop in routes/base.py
|
||||
# already wraps this in try/except).
|
||||
raise ValueError(
|
||||
f"Cannot save continuation state: conversation_id "
|
||||
f"{conversation_id!r} is neither a PG UUID nor a "
|
||||
f"backfilled legacy Mongo id."
|
||||
)
|
||||
PendingToolStateRepository(conn).save_state(
|
||||
pg_conv_id,
|
||||
user,
|
||||
messages=_make_serializable(messages),
|
||||
pending_tool_calls=_make_serializable(pending_tool_calls),
|
||||
tools_dict=_make_serializable(tools_dict),
|
||||
tool_schemas=_make_serializable(tool_schemas),
|
||||
agent_config=_make_serializable(agent_config),
|
||||
client_tools=_make_serializable(client_tools) if client_tools else None,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Saved continuation state for conversation {conversation_id} "
|
||||
f"with {len(pending_tool_calls)} pending tool call(s)"
|
||||
)
|
||||
return conversation_id
|
||||
|
||||
def load_state(
|
||||
self, conversation_id: str, user: str
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Load pending continuation state.
|
||||
|
||||
Returns:
|
||||
The state dict, or None if no pending state exists.
|
||||
"""
|
||||
with db_readonly() as conn:
|
||||
conv = ConversationsRepository(conn).get_by_legacy_id(conversation_id)
|
||||
if conv is not None:
|
||||
pg_conv_id = conv["id"]
|
||||
elif looks_like_uuid(conversation_id):
|
||||
pg_conv_id = conversation_id
|
||||
else:
|
||||
# Unresolvable legacy ObjectId → no state can exist for it.
|
||||
return None
|
||||
doc = PendingToolStateRepository(conn).load_state(pg_conv_id, user)
|
||||
if not doc:
|
||||
return None
|
||||
return doc
|
||||
|
||||
def delete_state(self, conversation_id: str, user: str) -> bool:
|
||||
"""Delete pending state after successful resumption.
|
||||
|
||||
Returns:
|
||||
True if a row was deleted.
|
||||
"""
|
||||
with db_session() as conn:
|
||||
conv = ConversationsRepository(conn).get_by_legacy_id(conversation_id)
|
||||
if conv is not None:
|
||||
pg_conv_id = conv["id"]
|
||||
elif looks_like_uuid(conversation_id):
|
||||
pg_conv_id = conversation_id
|
||||
else:
|
||||
# Unresolvable legacy ObjectId → nothing to delete.
|
||||
return False
|
||||
deleted = PendingToolStateRepository(conn).delete_state(pg_conv_id, user)
|
||||
if deleted:
|
||||
logger.info(
|
||||
f"Deleted continuation state for conversation {conversation_id}"
|
||||
)
|
||||
return deleted
|
||||
@@ -1,51 +1,44 @@
|
||||
"""Conversation persistence service backed by Postgres.
|
||||
|
||||
Handles create / append / update / compression for conversations during
|
||||
the answer-streaming path. Connections are opened per-operation rather
|
||||
than held for the duration of a stream.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from sqlalchemy import text as sql_text
|
||||
from application.core.mongo_db import MongoDB
|
||||
|
||||
from application.core.settings import settings
|
||||
from application.storage.db.base_repository import looks_like_uuid
|
||||
from application.storage.db.repositories.agents import AgentsRepository
|
||||
from application.storage.db.repositories.conversations import ConversationsRepository
|
||||
from application.storage.db.session import db_readonly, db_session
|
||||
from bson import ObjectId
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ConversationService:
|
||||
def __init__(self):
|
||||
mongo = MongoDB.get_client()
|
||||
db = mongo[settings.MONGO_DB_NAME]
|
||||
self.conversations_collection = db["conversations"]
|
||||
self.agents_collection = db["agents"]
|
||||
|
||||
def get_conversation(
|
||||
self, conversation_id: str, user_id: str
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Retrieve a conversation with owner-or-shared access control.
|
||||
|
||||
Returns a dict in the legacy Mongo shape — ``queries`` is a list
|
||||
of message dicts (prompt/response/...) — for compatibility with
|
||||
the streaming pipeline that consumes this shape.
|
||||
"""
|
||||
"""Retrieve a conversation with proper access control"""
|
||||
if not conversation_id or not user_id:
|
||||
return None
|
||||
try:
|
||||
with db_readonly() as conn:
|
||||
repo = ConversationsRepository(conn)
|
||||
conv = repo.get_any(conversation_id, user_id)
|
||||
if conv is None:
|
||||
logger.warning(
|
||||
f"Conversation not found or unauthorized - ID: {conversation_id}, User: {user_id}"
|
||||
)
|
||||
return None
|
||||
messages = repo.get_messages(str(conv["id"]))
|
||||
conv["queries"] = messages
|
||||
conv["_id"] = str(conv["id"])
|
||||
return conv
|
||||
conversation = self.conversations_collection.find_one(
|
||||
{
|
||||
"_id": ObjectId(conversation_id),
|
||||
"$or": [{"user": user_id}, {"shared_with": user_id}],
|
||||
}
|
||||
)
|
||||
|
||||
if not conversation:
|
||||
logger.warning(
|
||||
f"Conversation not found or unauthorized - ID: {conversation_id}, User: {user_id}"
|
||||
)
|
||||
return None
|
||||
conversation["_id"] = str(conversation["_id"])
|
||||
return conversation
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching conversation: {str(e)}", exc_info=True)
|
||||
return None
|
||||
@@ -69,11 +62,7 @@ class ConversationService:
|
||||
attachment_ids: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> str:
|
||||
"""Save or update a conversation in Postgres.
|
||||
|
||||
Returns the string conversation id (PG UUID as string, or the
|
||||
caller-provided id if it was already a UUID).
|
||||
"""
|
||||
"""Save or update a conversation in the database"""
|
||||
if decoded_token is None:
|
||||
raise ValueError("Invalid or missing authentication token")
|
||||
user_id = decoded_token.get("sub")
|
||||
@@ -81,47 +70,78 @@ class ConversationService:
|
||||
raise ValueError("User ID not found in token")
|
||||
current_time = datetime.now(timezone.utc)
|
||||
|
||||
# Trim huge inline source text to a reasonable max before persist.
|
||||
# clean up in sources array such that we save max 1k characters for text part
|
||||
for source in sources:
|
||||
if "text" in source and isinstance(source["text"], str):
|
||||
source["text"] = source["text"][:1000]
|
||||
|
||||
message_payload = {
|
||||
"prompt": question,
|
||||
"response": response,
|
||||
"thought": thought,
|
||||
"sources": sources,
|
||||
"tool_calls": tool_calls,
|
||||
"attachments": attachment_ids,
|
||||
"model_id": model_id,
|
||||
"timestamp": current_time,
|
||||
}
|
||||
if metadata:
|
||||
message_payload["metadata"] = metadata
|
||||
|
||||
if conversation_id is not None and index is not None:
|
||||
with db_session() as conn:
|
||||
repo = ConversationsRepository(conn)
|
||||
conv = repo.get_any(conversation_id, user_id)
|
||||
if conv is None:
|
||||
raise ValueError("Conversation not found or unauthorized")
|
||||
conv_pg_id = str(conv["id"])
|
||||
repo.update_message_at(conv_pg_id, index, message_payload)
|
||||
repo.truncate_after(conv_pg_id, index)
|
||||
# Update existing conversation with new query
|
||||
|
||||
result = self.conversations_collection.update_one(
|
||||
{
|
||||
"_id": ObjectId(conversation_id),
|
||||
"user": user_id,
|
||||
f"queries.{index}": {"$exists": True},
|
||||
},
|
||||
{
|
||||
"$set": {
|
||||
f"queries.{index}.prompt": question,
|
||||
f"queries.{index}.response": response,
|
||||
f"queries.{index}.thought": thought,
|
||||
f"queries.{index}.sources": sources,
|
||||
f"queries.{index}.tool_calls": tool_calls,
|
||||
f"queries.{index}.timestamp": current_time,
|
||||
f"queries.{index}.attachments": attachment_ids,
|
||||
f"queries.{index}.model_id": model_id,
|
||||
**(
|
||||
{f"queries.{index}.metadata": metadata}
|
||||
if metadata
|
||||
else {}
|
||||
),
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
if result.matched_count == 0:
|
||||
raise ValueError("Conversation not found or unauthorized")
|
||||
self.conversations_collection.update_one(
|
||||
{
|
||||
"_id": ObjectId(conversation_id),
|
||||
"user": user_id,
|
||||
f"queries.{index}": {"$exists": True},
|
||||
},
|
||||
{"$push": {"queries": {"$each": [], "$slice": index + 1}}},
|
||||
)
|
||||
return conversation_id
|
||||
elif conversation_id:
|
||||
with db_session() as conn:
|
||||
repo = ConversationsRepository(conn)
|
||||
conv = repo.get_any(conversation_id, user_id)
|
||||
if conv is None:
|
||||
raise ValueError("Conversation not found or unauthorized")
|
||||
conv_pg_id = str(conv["id"])
|
||||
# append_message expects 'metadata' key either way; normalise.
|
||||
append_payload = dict(message_payload)
|
||||
append_payload.setdefault("metadata", metadata or {})
|
||||
repo.append_message(conv_pg_id, append_payload)
|
||||
# Append new message to existing conversation
|
||||
|
||||
result = self.conversations_collection.update_one(
|
||||
{"_id": ObjectId(conversation_id), "user": user_id},
|
||||
{
|
||||
"$push": {
|
||||
"queries": {
|
||||
"prompt": question,
|
||||
"response": response,
|
||||
"thought": thought,
|
||||
"sources": sources,
|
||||
"tool_calls": tool_calls,
|
||||
"timestamp": current_time,
|
||||
"attachments": attachment_ids,
|
||||
"model_id": model_id,
|
||||
**({"metadata": metadata} if metadata else {}),
|
||||
}
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
if result.matched_count == 0:
|
||||
raise ValueError("Conversation not found or unauthorized")
|
||||
return conversation_id
|
||||
else:
|
||||
# Create new conversation
|
||||
|
||||
messages_summary = [
|
||||
{
|
||||
"role": "system",
|
||||
@@ -143,64 +163,70 @@ class ConversationService:
|
||||
if not completion or not completion.strip():
|
||||
completion = question[:50] if question else "New Conversation"
|
||||
|
||||
resolved_api_key: Optional[str] = None
|
||||
resolved_agent_id: Optional[str] = None
|
||||
if api_key:
|
||||
with db_readonly() as conn:
|
||||
agent = AgentsRepository(conn).find_by_key(api_key)
|
||||
if agent:
|
||||
resolved_api_key = agent.get("key")
|
||||
if agent_id:
|
||||
resolved_agent_id = agent_id
|
||||
query_doc = {
|
||||
"prompt": question,
|
||||
"response": response,
|
||||
"thought": thought,
|
||||
"sources": sources,
|
||||
"tool_calls": tool_calls,
|
||||
"timestamp": current_time,
|
||||
"attachments": attachment_ids,
|
||||
"model_id": model_id,
|
||||
}
|
||||
if metadata:
|
||||
query_doc["metadata"] = metadata
|
||||
|
||||
with db_session() as conn:
|
||||
repo = ConversationsRepository(conn)
|
||||
conv = repo.create(
|
||||
user_id,
|
||||
completion,
|
||||
agent_id=resolved_agent_id,
|
||||
api_key=resolved_api_key,
|
||||
is_shared_usage=bool(resolved_agent_id and is_shared_usage),
|
||||
shared_token=(
|
||||
shared_token
|
||||
if (resolved_agent_id and is_shared_usage)
|
||||
else None
|
||||
),
|
||||
)
|
||||
conv_pg_id = str(conv["id"])
|
||||
append_payload = dict(message_payload)
|
||||
append_payload.setdefault("metadata", metadata or {})
|
||||
repo.append_message(conv_pg_id, append_payload)
|
||||
return conv_pg_id
|
||||
conversation_data = {
|
||||
"user": user_id,
|
||||
"date": current_time,
|
||||
"name": completion,
|
||||
"queries": [query_doc],
|
||||
}
|
||||
|
||||
if api_key:
|
||||
if agent_id:
|
||||
conversation_data["agent_id"] = agent_id
|
||||
if is_shared_usage:
|
||||
conversation_data["is_shared_usage"] = is_shared_usage
|
||||
conversation_data["shared_token"] = shared_token
|
||||
agent = self.agents_collection.find_one({"key": api_key})
|
||||
if agent:
|
||||
conversation_data["api_key"] = agent["key"]
|
||||
result = self.conversations_collection.insert_one(conversation_data)
|
||||
return str(result.inserted_id)
|
||||
|
||||
def update_compression_metadata(
|
||||
self, conversation_id: str, compression_metadata: Dict[str, Any]
|
||||
) -> None:
|
||||
"""Persist compression flags and append a compression point.
|
||||
"""
|
||||
Update conversation with compression metadata.
|
||||
|
||||
Mirrors the Mongo-era ``$set`` + ``$push $slice`` on
|
||||
``compression_metadata`` but goes through the PG repo API.
|
||||
Uses $push with $slice to keep only the most recent compression points,
|
||||
preventing unbounded array growth. Since each compression incorporates
|
||||
previous compressions, older points become redundant.
|
||||
|
||||
Args:
|
||||
conversation_id: Conversation ID
|
||||
compression_metadata: Compression point data
|
||||
"""
|
||||
try:
|
||||
with db_session() as conn:
|
||||
repo = ConversationsRepository(conn)
|
||||
# conversation_id here comes from the streaming pipeline
|
||||
# which has already resolved it; accept either UUID or
|
||||
# legacy id for safety.
|
||||
conv = repo.get_by_legacy_id(conversation_id)
|
||||
conv_pg_id = (
|
||||
str(conv["id"]) if conv is not None else conversation_id
|
||||
)
|
||||
repo.set_compression_flags(
|
||||
conv_pg_id,
|
||||
is_compressed=True,
|
||||
last_compression_at=compression_metadata.get("timestamp"),
|
||||
)
|
||||
repo.append_compression_point(
|
||||
conv_pg_id,
|
||||
compression_metadata,
|
||||
max_points=settings.COMPRESSION_MAX_HISTORY_POINTS,
|
||||
)
|
||||
self.conversations_collection.update_one(
|
||||
{"_id": ObjectId(conversation_id)},
|
||||
{
|
||||
"$set": {
|
||||
"compression_metadata.is_compressed": True,
|
||||
"compression_metadata.last_compression_at": compression_metadata.get(
|
||||
"timestamp"
|
||||
),
|
||||
},
|
||||
"$push": {
|
||||
"compression_metadata.compression_points": {
|
||||
"$each": [compression_metadata],
|
||||
"$slice": -settings.COMPRESSION_MAX_HISTORY_POINTS,
|
||||
}
|
||||
},
|
||||
},
|
||||
)
|
||||
logger.info(
|
||||
f"Updated compression metadata for conversation {conversation_id}"
|
||||
)
|
||||
@@ -213,34 +239,34 @@ class ConversationService:
|
||||
def append_compression_message(
|
||||
self, conversation_id: str, compression_metadata: Dict[str, Any]
|
||||
) -> None:
|
||||
"""Append a synthetic compression summary message to the conversation."""
|
||||
"""
|
||||
Append a synthetic compression summary entry into the conversation history.
|
||||
This makes the summary visible in the DB alongside normal queries.
|
||||
"""
|
||||
try:
|
||||
summary = compression_metadata.get("compressed_summary", "")
|
||||
if not summary:
|
||||
return
|
||||
timestamp = compression_metadata.get(
|
||||
"timestamp", datetime.now(timezone.utc)
|
||||
)
|
||||
timestamp = compression_metadata.get("timestamp", datetime.now(timezone.utc))
|
||||
|
||||
with db_session() as conn:
|
||||
repo = ConversationsRepository(conn)
|
||||
conv = repo.get_by_legacy_id(conversation_id)
|
||||
conv_pg_id = (
|
||||
str(conv["id"]) if conv is not None else conversation_id
|
||||
)
|
||||
repo.append_message(conv_pg_id, {
|
||||
"prompt": "[Context Compression Summary]",
|
||||
"response": summary,
|
||||
"thought": "",
|
||||
"sources": [],
|
||||
"tool_calls": [],
|
||||
"attachments": [],
|
||||
"model_id": compression_metadata.get("model_used"),
|
||||
"timestamp": timestamp,
|
||||
})
|
||||
logger.info(
|
||||
f"Appended compression summary to conversation {conversation_id}"
|
||||
self.conversations_collection.update_one(
|
||||
{"_id": ObjectId(conversation_id)},
|
||||
{
|
||||
"$push": {
|
||||
"queries": {
|
||||
"prompt": "[Context Compression Summary]",
|
||||
"response": summary,
|
||||
"thought": "",
|
||||
"sources": [],
|
||||
"tool_calls": [],
|
||||
"timestamp": timestamp,
|
||||
"attachments": [],
|
||||
"model_id": compression_metadata.get("model_used"),
|
||||
}
|
||||
}
|
||||
},
|
||||
)
|
||||
logger.info(f"Appended compression summary to conversation {conversation_id}")
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error appending compression summary: {str(e)}", exc_info=True
|
||||
@@ -249,30 +275,20 @@ class ConversationService:
|
||||
def get_compression_metadata(
|
||||
self, conversation_id: str
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Fetch the stored compression metadata JSONB blob for a conversation."""
|
||||
"""
|
||||
Get compression metadata for a conversation.
|
||||
|
||||
Args:
|
||||
conversation_id: Conversation ID
|
||||
|
||||
Returns:
|
||||
Compression metadata dict or None
|
||||
"""
|
||||
try:
|
||||
with db_readonly() as conn:
|
||||
repo = ConversationsRepository(conn)
|
||||
conv = repo.get_by_legacy_id(conversation_id)
|
||||
if conv is None:
|
||||
# Fallback to UUID lookup without user scoping — the
|
||||
# caller already holds an authenticated conversation
|
||||
# id from the streaming path. Gate on id shape so a
|
||||
# non-UUID (legacy ObjectId that wasn't backfilled)
|
||||
# doesn't reach CAST — the cast raises and spams the
|
||||
# logs with a stack trace on every call.
|
||||
if not looks_like_uuid(conversation_id):
|
||||
return None
|
||||
result = conn.execute(
|
||||
sql_text(
|
||||
"SELECT compression_metadata FROM conversations "
|
||||
"WHERE id = CAST(:id AS uuid)"
|
||||
),
|
||||
{"id": conversation_id},
|
||||
)
|
||||
row = result.fetchone()
|
||||
return row[0] if row is not None else None
|
||||
return conv.get("compression_metadata") if conv else None
|
||||
conversation = self.conversations_collection.find_one(
|
||||
{"_id": ObjectId(conversation_id)}, {"compression_metadata": 1}
|
||||
)
|
||||
return conversation.get("compression_metadata") if conversation else None
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error getting compression metadata: {str(e)}", exc_info=True
|
||||
|
||||
@@ -5,6 +5,10 @@ import os
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional, Set
|
||||
|
||||
from bson.dbref import DBRef
|
||||
|
||||
from bson.objectid import ObjectId
|
||||
|
||||
from application.agents.agent_creator import AgentCreator
|
||||
from application.api.answer.services.compression import CompressionOrchestrator
|
||||
from application.api.answer.services.compression.token_counter import TokenCounter
|
||||
@@ -16,16 +20,8 @@ from application.core.model_utils import (
|
||||
get_provider_from_model_id,
|
||||
validate_model_id,
|
||||
)
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.core.settings import settings
|
||||
from sqlalchemy import text as sql_text
|
||||
|
||||
from application.storage.db.base_repository import looks_like_uuid, row_to_dict
|
||||
from application.storage.db.repositories.agents import AgentsRepository
|
||||
from application.storage.db.repositories.attachments import AttachmentsRepository
|
||||
from application.storage.db.repositories.prompts import PromptsRepository
|
||||
from application.storage.db.repositories.sources import SourcesRepository
|
||||
from application.storage.db.repositories.user_tools import UserToolsRepository
|
||||
from application.storage.db.session import db_readonly, db_session
|
||||
from application.retriever.retriever_creator import RetrieverCreator
|
||||
from application.utils import (
|
||||
calculate_doc_token_budget,
|
||||
@@ -36,41 +32,28 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_prompt(prompt_id: str, prompts_collection=None) -> str:
|
||||
"""Get a prompt by preset name or Postgres ID (UUID or legacy ObjectId).
|
||||
|
||||
The ``prompts_collection`` parameter is retained for backwards
|
||||
compatibility with call sites that still pass it positionally; it is
|
||||
ignored post-cutover.
|
||||
"""
|
||||
del prompts_collection # unused — retained for call-site compatibility
|
||||
# Callers may pass a ``uuid.UUID`` (from a PG ``prompt_id`` column) or a
|
||||
# plain string ("default"/"creative"/legacy ObjectId). Normalise to str
|
||||
# so both the preset lookup and the UUID-vs-legacy branching work.
|
||||
# ``None`` / empty means "use the default prompt" — agents that never
|
||||
# set a custom prompt land here (PG ``agents.prompt_id`` is NULL).
|
||||
if prompt_id is None or prompt_id == "":
|
||||
prompt_id = "default"
|
||||
elif not isinstance(prompt_id, str):
|
||||
prompt_id = str(prompt_id)
|
||||
Get a prompt by preset name or MongoDB ID
|
||||
"""
|
||||
current_dir = Path(__file__).resolve().parents[3]
|
||||
prompts_dir = current_dir / "prompts"
|
||||
|
||||
# Maps for classic agent types
|
||||
CLASSIC_PRESETS = {
|
||||
"default": "chat_combine_default.txt",
|
||||
"creative": "chat_combine_creative.txt",
|
||||
"strict": "chat_combine_strict.txt",
|
||||
"reduce": "chat_reduce_prompt.txt",
|
||||
}
|
||||
|
||||
# Agentic counterparts — same styles, but with search tool instructions
|
||||
AGENTIC_PRESETS = {
|
||||
"default": "agentic/default.txt",
|
||||
"creative": "agentic/creative.txt",
|
||||
"strict": "agentic/strict.txt",
|
||||
}
|
||||
|
||||
preset_mapping = {
|
||||
**CLASSIC_PRESETS,
|
||||
**{f"agentic_{k}": v for k, v in AGENTIC_PRESETS.items()},
|
||||
}
|
||||
preset_mapping = {**CLASSIC_PRESETS, **{f"agentic_{k}": v for k, v in AGENTIC_PRESETS.items()}}
|
||||
|
||||
if prompt_id in preset_mapping:
|
||||
file_path = os.path.join(prompts_dir, preset_mapping[prompt_id])
|
||||
@@ -80,18 +63,14 @@ def get_prompt(prompt_id: str, prompts_collection=None) -> str:
|
||||
except FileNotFoundError:
|
||||
raise FileNotFoundError(f"Prompt file not found: {file_path}")
|
||||
try:
|
||||
with db_readonly() as conn:
|
||||
repo = PromptsRepository(conn)
|
||||
prompt_doc = None
|
||||
if looks_like_uuid(prompt_id):
|
||||
prompt_doc = repo.get_for_rendering(prompt_id)
|
||||
if prompt_doc is None:
|
||||
prompt_doc = repo.get_by_legacy_id(prompt_id)
|
||||
if prompts_collection is None:
|
||||
mongo = MongoDB.get_client()
|
||||
db = mongo[settings.MONGO_DB_NAME]
|
||||
prompts_collection = db["prompts"]
|
||||
prompt_doc = prompts_collection.find_one({"_id": ObjectId(prompt_id)})
|
||||
if not prompt_doc:
|
||||
raise ValueError(f"Prompt with ID {prompt_id} not found")
|
||||
return prompt_doc["content"]
|
||||
except ValueError:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise ValueError(f"Invalid prompt ID: {prompt_id}") from e
|
||||
|
||||
@@ -100,9 +79,12 @@ class StreamProcessor:
|
||||
def __init__(
|
||||
self, request_data: Dict[str, Any], decoded_token: Optional[Dict[str, Any]]
|
||||
):
|
||||
# Legacy attribute retained as None for any external callers that
|
||||
# introspect the processor; all DB access uses per-op connections.
|
||||
self.prompts_collection = None
|
||||
mongo = MongoDB.get_client()
|
||||
self.db = mongo[settings.MONGO_DB_NAME]
|
||||
self.agents_collection = self.db["agents"]
|
||||
self.attachments_collection = self.db["attachments"]
|
||||
self.prompts_collection = self.db["prompts"]
|
||||
|
||||
self.data = request_data
|
||||
self.decoded_token = decoded_token
|
||||
self.initial_user_id = (
|
||||
@@ -130,7 +112,6 @@ class StreamProcessor:
|
||||
self._required_tool_actions: Optional[Dict[str, Set[Optional[str]]]] = None
|
||||
self.compressed_summary: Optional[str] = None
|
||||
self.compressed_summary_tokens: int = 0
|
||||
self._agent_data: Optional[Dict[str, Any]] = None
|
||||
|
||||
def initialize(self):
|
||||
"""Initialize all required components for processing"""
|
||||
@@ -262,21 +243,17 @@ class StreamProcessor:
|
||||
if not attachment_ids:
|
||||
return []
|
||||
attachments = []
|
||||
try:
|
||||
with db_readonly() as conn:
|
||||
repo = AttachmentsRepository(conn)
|
||||
for attachment_id in attachment_ids:
|
||||
try:
|
||||
attachment_doc = repo.get_any(str(attachment_id), user_id)
|
||||
if attachment_doc:
|
||||
attachments.append(attachment_doc)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error retrieving attachment {attachment_id}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error opening attachments connection: {e}", exc_info=True)
|
||||
for attachment_id in attachment_ids:
|
||||
try:
|
||||
attachment_doc = self.attachments_collection.find_one(
|
||||
{"_id": ObjectId(attachment_id), "user": user_id}
|
||||
)
|
||||
if attachment_doc:
|
||||
attachments.append(attachment_doc)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error retrieving attachment {attachment_id}: {e}", exc_info=True
|
||||
)
|
||||
return attachments
|
||||
|
||||
def _validate_and_set_model(self):
|
||||
@@ -307,127 +284,97 @@ class StreamProcessor:
|
||||
self.model_id = get_default_model_id()
|
||||
|
||||
def _get_agent_key(self, agent_id: Optional[str], user_id: Optional[str]) -> tuple:
|
||||
"""Get API key for agent with access control."""
|
||||
"""Get API key for agent with access control"""
|
||||
if not agent_id:
|
||||
return None, False, None
|
||||
try:
|
||||
with db_readonly() as conn:
|
||||
# Lookup without user scoping — access control is done
|
||||
# against ``user_id`` / ``shared_with`` / ``shared`` flags
|
||||
# right below, matching the legacy Mongo semantics.
|
||||
repo = AgentsRepository(conn)
|
||||
agent = None
|
||||
if looks_like_uuid(str(agent_id)):
|
||||
result = conn.execute(
|
||||
sql_text(
|
||||
"SELECT * FROM agents WHERE id = CAST(:id AS uuid)"
|
||||
),
|
||||
{"id": str(agent_id)},
|
||||
)
|
||||
row = result.fetchone()
|
||||
if row is not None:
|
||||
agent = row_to_dict(row)
|
||||
if agent is None:
|
||||
agent = repo.get_by_legacy_id(str(agent_id))
|
||||
agent = self.agents_collection.find_one({"_id": ObjectId(agent_id)})
|
||||
if agent is None:
|
||||
raise Exception("Agent not found")
|
||||
agent_owner = agent.get("user_id")
|
||||
is_owner = agent_owner == user_id
|
||||
is_shared_with_user = bool(agent.get("shared", False))
|
||||
is_owner = agent.get("user") == user_id
|
||||
is_shared_with_user = agent.get(
|
||||
"shared_publicly", False
|
||||
) or user_id in agent.get("shared_with", [])
|
||||
|
||||
if not (is_owner or is_shared_with_user):
|
||||
raise Exception("Unauthorized access to the agent")
|
||||
if is_owner:
|
||||
now = datetime.datetime.now(datetime.timezone.utc)
|
||||
try:
|
||||
with db_session() as conn:
|
||||
AgentsRepository(conn).update(
|
||||
str(agent["id"]), agent_owner,
|
||||
{"last_used_at": now},
|
||||
)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to update last_used_at for agent",
|
||||
exc_info=True,
|
||||
)
|
||||
return (
|
||||
str(agent["key"]) if agent.get("key") else None,
|
||||
not is_owner,
|
||||
agent.get("shared_token"),
|
||||
)
|
||||
self.agents_collection.update_one(
|
||||
{"_id": ObjectId(agent_id)},
|
||||
{
|
||||
"$set": {
|
||||
"lastUsedAt": datetime.datetime.now(datetime.timezone.utc)
|
||||
}
|
||||
},
|
||||
)
|
||||
return str(agent["key"]), not is_owner, agent.get("shared_token")
|
||||
except Exception as e:
|
||||
logger.error(f"Error in get_agent_key: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
def _get_data_from_api_key(self, api_key: str) -> Dict[str, Any]:
|
||||
with db_readonly() as conn:
|
||||
agent = AgentsRepository(conn).find_by_key(api_key)
|
||||
if not agent:
|
||||
raise Exception("Invalid API Key, please generate a new key", 401)
|
||||
sources_repo = SourcesRepository(conn)
|
||||
# The repo dict uses "user_id" — the streaming path expects
|
||||
# a "user" key (legacy Mongo shape) for identity propagation.
|
||||
data: Dict[str, Any] = dict(agent)
|
||||
data["user"] = agent.get("user_id")
|
||||
|
||||
# Resolve the primary source row (if any) for retriever/chunks.
|
||||
source_id = agent.get("source_id")
|
||||
if source_id:
|
||||
source_doc = sources_repo.get(str(source_id), agent.get("user_id"))
|
||||
if source_doc:
|
||||
data["source"] = str(source_doc["id"])
|
||||
data["retriever"] = source_doc.get(
|
||||
"retriever", data.get("retriever")
|
||||
)
|
||||
data["chunks"] = source_doc.get("chunks", data.get("chunks"))
|
||||
else:
|
||||
data["source"] = None
|
||||
data = self.agents_collection.find_one({"key": api_key})
|
||||
if not data:
|
||||
raise Exception("Invalid API Key, please generate a new key", 401)
|
||||
source = data.get("source")
|
||||
if isinstance(source, DBRef):
|
||||
source_doc = self.db.dereference(source)
|
||||
if source_doc:
|
||||
data["source"] = str(source_doc["_id"])
|
||||
data["retriever"] = source_doc.get("retriever", data.get("retriever"))
|
||||
data["chunks"] = source_doc.get("chunks", data.get("chunks"))
|
||||
else:
|
||||
data["source"] = None
|
||||
elif source == "default":
|
||||
data["source"] = "default"
|
||||
else:
|
||||
data["source"] = None
|
||||
|
||||
sources = data.get("sources", [])
|
||||
if sources and isinstance(sources, list):
|
||||
sources_list = []
|
||||
extra = agent.get("extra_source_ids") or []
|
||||
if extra:
|
||||
for sid in extra:
|
||||
source_doc = sources_repo.get(str(sid), agent.get("user_id"))
|
||||
for i, source_ref in enumerate(sources):
|
||||
if source_ref == "default":
|
||||
processed_source = {
|
||||
"id": "default",
|
||||
"retriever": "classic",
|
||||
"chunks": data.get("chunks", "2"),
|
||||
}
|
||||
sources_list.append(processed_source)
|
||||
elif isinstance(source_ref, DBRef):
|
||||
source_doc = self.db.dereference(source_ref)
|
||||
if source_doc:
|
||||
sources_list.append(
|
||||
{
|
||||
"id": str(source_doc["id"]),
|
||||
"retriever": source_doc.get("retriever", "classic"),
|
||||
"chunks": source_doc.get(
|
||||
"chunks", data.get("chunks", "2")
|
||||
),
|
||||
}
|
||||
)
|
||||
data["sources"] = sources_list
|
||||
processed_source = {
|
||||
"id": str(source_doc["_id"]),
|
||||
"retriever": source_doc.get("retriever", "classic"),
|
||||
"chunks": source_doc.get("chunks", data.get("chunks", "2")),
|
||||
}
|
||||
sources_list.append(processed_source)
|
||||
data["sources"] = sources_list
|
||||
else:
|
||||
data["sources"] = []
|
||||
|
||||
data["default_model_id"] = data.get("default_model_id", "")
|
||||
|
||||
return data
|
||||
|
||||
def _configure_source(self):
|
||||
"""Configure the source based on agent data.
|
||||
"""Configure the source based on agent data"""
|
||||
api_key = self.data.get("api_key") or self.agent_key
|
||||
|
||||
The literal string ``"default"`` is a placeholder meaning "no
|
||||
ingested source" and is normalized to an empty source so that no
|
||||
retrieval is attempted.
|
||||
"""
|
||||
if self._agent_data:
|
||||
agent_data = self._agent_data
|
||||
if api_key:
|
||||
agent_data = self._get_data_from_api_key(api_key)
|
||||
|
||||
if agent_data.get("sources") and len(agent_data["sources"]) > 0:
|
||||
source_ids = [
|
||||
source["id"]
|
||||
for source in agent_data["sources"]
|
||||
if source.get("id") and source["id"] != "default"
|
||||
source["id"] for source in agent_data["sources"] if source.get("id")
|
||||
]
|
||||
if source_ids:
|
||||
self.source = {"active_docs": source_ids}
|
||||
else:
|
||||
self.source = {}
|
||||
self.all_sources = [
|
||||
s for s in agent_data["sources"] if s.get("id") != "default"
|
||||
]
|
||||
elif agent_data.get("source") and agent_data["source"] != "default":
|
||||
self.all_sources = agent_data["sources"]
|
||||
elif agent_data.get("source"):
|
||||
self.source = {"active_docs": agent_data["source"]}
|
||||
self.all_sources = [
|
||||
{
|
||||
@@ -440,24 +387,11 @@ class StreamProcessor:
|
||||
self.all_sources = []
|
||||
return
|
||||
if "active_docs" in self.data:
|
||||
active_docs = self.data["active_docs"]
|
||||
if active_docs and active_docs != "default":
|
||||
self.source = {"active_docs": active_docs}
|
||||
else:
|
||||
self.source = {}
|
||||
self.source = {"active_docs": self.data["active_docs"]}
|
||||
return
|
||||
self.source = {}
|
||||
self.all_sources = []
|
||||
|
||||
def _has_active_docs(self) -> bool:
|
||||
"""Return True if a real document source is configured for retrieval."""
|
||||
active_docs = self.source.get("active_docs") if self.source else None
|
||||
if not active_docs:
|
||||
return False
|
||||
if active_docs == "default":
|
||||
return False
|
||||
return True
|
||||
|
||||
def _resolve_agent_id(self) -> Optional[str]:
|
||||
"""Resolve agent_id from request, then fall back to conversation context."""
|
||||
request_agent_id = self.data.get("agent_id")
|
||||
@@ -499,45 +433,48 @@ class StreamProcessor:
|
||||
effective_key = self.data.get("api_key") or self.agent_key
|
||||
|
||||
if effective_key:
|
||||
self._agent_data = self._get_data_from_api_key(effective_key)
|
||||
if self._agent_data.get("_id"):
|
||||
self.agent_id = str(self._agent_data.get("_id"))
|
||||
data_key = self._get_data_from_api_key(effective_key)
|
||||
if data_key.get("_id"):
|
||||
self.agent_id = str(data_key.get("_id"))
|
||||
|
||||
self.agent_config.update(
|
||||
{
|
||||
"prompt_id": self._agent_data.get("prompt_id", "default"),
|
||||
"agent_type": self._agent_data.get("agent_type", settings.AGENT_NAME),
|
||||
"prompt_id": data_key.get("prompt_id", "default"),
|
||||
"agent_type": data_key.get("agent_type", settings.AGENT_NAME),
|
||||
"user_api_key": effective_key,
|
||||
"json_schema": self._agent_data.get("json_schema"),
|
||||
"default_model_id": self._agent_data.get("default_model_id", ""),
|
||||
"models": self._agent_data.get("models", []),
|
||||
"allow_system_prompt_override": self._agent_data.get(
|
||||
"allow_system_prompt_override", False
|
||||
),
|
||||
"json_schema": data_key.get("json_schema"),
|
||||
"default_model_id": data_key.get("default_model_id", ""),
|
||||
"models": data_key.get("models", []),
|
||||
}
|
||||
)
|
||||
|
||||
# Set identity context
|
||||
if self.data.get("api_key"):
|
||||
# External API key: use the key owner's identity
|
||||
self.initial_user_id = self._agent_data.get("user")
|
||||
self.decoded_token = {"sub": self._agent_data.get("user")}
|
||||
self.initial_user_id = data_key.get("user")
|
||||
self.decoded_token = {"sub": data_key.get("user")}
|
||||
elif self.is_shared_usage:
|
||||
# Shared agent: keep the caller's identity
|
||||
pass
|
||||
else:
|
||||
# Owner using their own agent
|
||||
self.decoded_token = {"sub": self._agent_data.get("user")}
|
||||
self.decoded_token = {"sub": data_key.get("user")}
|
||||
|
||||
# PG row exposes the workflow as ``workflow_id`` (UUID column);
|
||||
# legacy Mongo shape used the key ``workflow``. Accept either so
|
||||
# API-key-invoked workflow agents bind correctly downstream.
|
||||
wf_ref = self._agent_data.get("workflow") or self._agent_data.get(
|
||||
"workflow_id"
|
||||
)
|
||||
if wf_ref:
|
||||
self.agent_config["workflow"] = str(wf_ref)
|
||||
self.agent_config["workflow_owner"] = self._agent_data.get("user")
|
||||
if data_key.get("source"):
|
||||
self.source = {"active_docs": data_key["source"]}
|
||||
if data_key.get("workflow"):
|
||||
self.agent_config["workflow"] = data_key["workflow"]
|
||||
self.agent_config["workflow_owner"] = data_key.get("user")
|
||||
if data_key.get("retriever"):
|
||||
self.retriever_config["retriever_name"] = data_key["retriever"]
|
||||
if data_key.get("chunks") is not None:
|
||||
try:
|
||||
self.retriever_config["chunks"] = int(data_key["chunks"])
|
||||
except (ValueError, TypeError):
|
||||
logger.warning(
|
||||
f"Invalid chunks value: {data_key['chunks']}, using default value 2"
|
||||
)
|
||||
self.retriever_config["chunks"] = 2
|
||||
else:
|
||||
# No API key — default/workflow configuration
|
||||
agent_type = settings.AGENT_NAME
|
||||
@@ -560,45 +497,14 @@ class StreamProcessor:
|
||||
)
|
||||
|
||||
def _configure_retriever(self):
|
||||
"""Assemble retriever config with precedence: request > agent > default."""
|
||||
doc_token_limit = calculate_doc_token_budget(model_id=self.model_id)
|
||||
|
||||
# Start with defaults
|
||||
retriever_name = "classic"
|
||||
chunks = 2
|
||||
|
||||
# Layer agent-level config (if present)
|
||||
if self._agent_data:
|
||||
if self._agent_data.get("retriever"):
|
||||
retriever_name = self._agent_data["retriever"]
|
||||
if self._agent_data.get("chunks") is not None:
|
||||
try:
|
||||
chunks = int(self._agent_data["chunks"])
|
||||
except (ValueError, TypeError):
|
||||
logger.warning(
|
||||
f"Invalid agent chunks value: {self._agent_data['chunks']}, "
|
||||
"using default value 2"
|
||||
)
|
||||
|
||||
# Explicit request values win over agent config
|
||||
if "retriever" in self.data:
|
||||
retriever_name = self.data["retriever"]
|
||||
if "chunks" in self.data:
|
||||
try:
|
||||
chunks = int(self.data["chunks"])
|
||||
except (ValueError, TypeError):
|
||||
logger.warning(
|
||||
f"Invalid request chunks value: {self.data['chunks']}, "
|
||||
"using default value 2"
|
||||
)
|
||||
|
||||
self.retriever_config = {
|
||||
"retriever_name": retriever_name,
|
||||
"chunks": chunks,
|
||||
"retriever_name": self.data.get("retriever", "classic"),
|
||||
"chunks": int(self.data.get("chunks", 2)),
|
||||
"doc_token_limit": doc_token_limit,
|
||||
}
|
||||
|
||||
# isNoneDoc without an API key forces no retrieval
|
||||
api_key = self.data.get("api_key") or self.agent_key
|
||||
if not api_key and "isNoneDoc" in self.data and self.data["isNoneDoc"]:
|
||||
self.retriever_config["chunks"] = 0
|
||||
@@ -622,9 +528,6 @@ class StreamProcessor:
|
||||
if self.data.get("isNoneDoc", False) and not self.agent_id:
|
||||
logger.info("Pre-fetch skipped: isNoneDoc=True")
|
||||
return None, None
|
||||
if not self._has_active_docs():
|
||||
logger.info("Pre-fetch skipped: no active docs configured")
|
||||
return None, None
|
||||
try:
|
||||
retriever = self.create_retriever()
|
||||
logger.info(
|
||||
@@ -671,9 +574,12 @@ class StreamProcessor:
|
||||
filtering_enabled = required_tool_actions is not None
|
||||
|
||||
try:
|
||||
user_tools_collection = self.db["user_tools"]
|
||||
user_id = self.initial_user_id or "local"
|
||||
with db_readonly() as conn:
|
||||
user_tools = UserToolsRepository(conn).list_active_for_user(user_id)
|
||||
|
||||
user_tools = list(
|
||||
user_tools_collection.find({"user": user_id, "status": True})
|
||||
)
|
||||
|
||||
if not user_tools:
|
||||
return None
|
||||
@@ -865,121 +771,6 @@ class StreamProcessor:
|
||||
logger.warning(f"Failed to fetch memory tool data: {str(e)}")
|
||||
return None
|
||||
|
||||
def resume_from_tool_actions(
|
||||
self,
|
||||
tool_actions: list,
|
||||
conversation_id: str,
|
||||
):
|
||||
"""Resume a paused agent from saved continuation state.
|
||||
|
||||
Loads the pending state from MongoDB, recreates the agent with
|
||||
the saved configuration, and returns an agent ready to call
|
||||
``gen_continuation()``.
|
||||
|
||||
Args:
|
||||
tool_actions: Client-provided actions (approvals / results).
|
||||
conversation_id: The conversation being resumed.
|
||||
|
||||
Returns:
|
||||
Tuple of (agent, messages, tools_dict, pending_tool_calls, tool_actions).
|
||||
"""
|
||||
from application.api.answer.services.continuation_service import (
|
||||
ContinuationService,
|
||||
)
|
||||
from application.agents.agent_creator import AgentCreator
|
||||
from application.agents.tool_executor import ToolExecutor
|
||||
from application.llm.handlers.handler_creator import LLMHandlerCreator
|
||||
from application.llm.llm_creator import LLMCreator
|
||||
|
||||
cont_service = ContinuationService()
|
||||
state = cont_service.load_state(conversation_id, self.initial_user_id)
|
||||
if not state:
|
||||
raise ValueError("No pending tool state found for this conversation")
|
||||
|
||||
messages = state["messages"]
|
||||
pending_tool_calls = state["pending_tool_calls"]
|
||||
tools_dict = state["tools_dict"]
|
||||
tool_schemas = state.get("tool_schemas", [])
|
||||
agent_config = state["agent_config"]
|
||||
|
||||
model_id = agent_config.get("model_id")
|
||||
llm_name = agent_config.get("llm_name", settings.LLM_PROVIDER)
|
||||
api_key = agent_config.get("api_key")
|
||||
user_api_key = agent_config.get("user_api_key")
|
||||
agent_id = agent_config.get("agent_id")
|
||||
prompt = agent_config.get("prompt", "")
|
||||
json_schema = agent_config.get("json_schema")
|
||||
retriever_config = agent_config.get("retriever_config")
|
||||
|
||||
# Recreate dependencies
|
||||
system_api_key = api_key or get_api_key_for_provider(llm_name)
|
||||
llm = LLMCreator.create_llm(
|
||||
llm_name,
|
||||
api_key=system_api_key,
|
||||
user_api_key=user_api_key,
|
||||
decoded_token=self.decoded_token,
|
||||
model_id=model_id,
|
||||
agent_id=agent_id,
|
||||
)
|
||||
llm_handler = LLMHandlerCreator.create_handler(llm_name or "default")
|
||||
tool_executor = ToolExecutor(
|
||||
user_api_key=user_api_key,
|
||||
user=self.initial_user_id,
|
||||
decoded_token=self.decoded_token,
|
||||
)
|
||||
tool_executor.conversation_id = conversation_id
|
||||
# Restore client tools so they stay available for subsequent LLM calls
|
||||
saved_client_tools = state.get("client_tools")
|
||||
if saved_client_tools:
|
||||
tool_executor.client_tools = saved_client_tools
|
||||
# Re-merge into tools_dict (they may have been stripped during serialization)
|
||||
tool_executor.merge_client_tools(tools_dict, saved_client_tools)
|
||||
|
||||
agent_type = agent_config.get("agent_type", "ClassicAgent")
|
||||
# Map class names back to agent creator keys
|
||||
type_map = {
|
||||
"ClassicAgent": "classic",
|
||||
"AgenticAgent": "agentic",
|
||||
"ResearchAgent": "research",
|
||||
"WorkflowAgent": "workflow",
|
||||
}
|
||||
agent_key = type_map.get(agent_type, "classic")
|
||||
|
||||
agent_kwargs = {
|
||||
"endpoint": "stream",
|
||||
"llm_name": llm_name,
|
||||
"model_id": model_id,
|
||||
"api_key": system_api_key,
|
||||
"agent_id": agent_id,
|
||||
"user_api_key": user_api_key,
|
||||
"prompt": prompt,
|
||||
"chat_history": [],
|
||||
"decoded_token": self.decoded_token,
|
||||
"json_schema": json_schema,
|
||||
"llm": llm,
|
||||
"llm_handler": llm_handler,
|
||||
"tool_executor": tool_executor,
|
||||
}
|
||||
|
||||
if agent_key in ("agentic", "research") and retriever_config:
|
||||
agent_kwargs["retriever_config"] = retriever_config
|
||||
|
||||
agent = AgentCreator.create_agent(agent_key, **agent_kwargs)
|
||||
agent.conversation_id = conversation_id
|
||||
agent.initial_user_id = self.initial_user_id
|
||||
agent.tools = tool_schemas
|
||||
|
||||
# Store config for the route layer
|
||||
self.model_id = model_id
|
||||
self.agent_id = agent_id
|
||||
self.agent_config["user_api_key"] = user_api_key
|
||||
self.conversation_id = conversation_id
|
||||
|
||||
# Delete state so it can't be replayed
|
||||
cont_service.delete_state(conversation_id, self.initial_user_id)
|
||||
|
||||
return agent, messages, tools_dict, pending_tool_calls, tool_actions
|
||||
|
||||
def create_agent(
|
||||
self,
|
||||
docs_together: Optional[str] = None,
|
||||
@@ -1004,23 +795,15 @@ class StreamProcessor:
|
||||
raw_prompt = get_prompt(prompt_id, self.prompts_collection)
|
||||
self._prompt_content = raw_prompt
|
||||
|
||||
# Allow API callers to override the system prompt when the agent
|
||||
# has opted in via allow_system_prompt_override.
|
||||
if (
|
||||
self.agent_config.get("allow_system_prompt_override", False)
|
||||
and self.data.get("system_prompt_override")
|
||||
):
|
||||
rendered_prompt = self.data["system_prompt_override"]
|
||||
else:
|
||||
rendered_prompt = self.prompt_renderer.render_prompt(
|
||||
prompt_content=raw_prompt,
|
||||
user_id=self.initial_user_id,
|
||||
request_id=self.data.get("request_id"),
|
||||
passthrough_data=self.data.get("passthrough"),
|
||||
docs=docs,
|
||||
docs_together=docs_together,
|
||||
tools_data=tools_data,
|
||||
)
|
||||
rendered_prompt = self.prompt_renderer.render_prompt(
|
||||
prompt_content=raw_prompt,
|
||||
user_id=self.initial_user_id,
|
||||
request_id=self.data.get("request_id"),
|
||||
passthrough_data=self.data.get("passthrough"),
|
||||
docs=docs,
|
||||
docs_together=docs_together,
|
||||
tools_data=tools_data,
|
||||
)
|
||||
|
||||
provider = (
|
||||
get_provider_from_model_id(self.model_id)
|
||||
@@ -1034,10 +817,8 @@ class StreamProcessor:
|
||||
from application.llm.handlers.handler_creator import LLMHandlerCreator
|
||||
from application.agents.tool_executor import ToolExecutor
|
||||
|
||||
# Compute backup models: agent's configured models minus the active one.
|
||||
# PG agents may carry an explicit ``models: NULL`` (not absent), so
|
||||
# ``.get("models", [])`` isn't enough — coerce None → [].
|
||||
agent_models = self.agent_config.get("models") or []
|
||||
# Compute backup models: agent's configured models minus the active one
|
||||
agent_models = self.agent_config.get("models", [])
|
||||
backup_models = [m for m in agent_models if m != self.model_id]
|
||||
|
||||
llm = LLMCreator.create_llm(
|
||||
@@ -1060,10 +841,6 @@ class StreamProcessor:
|
||||
decoded_token=self.decoded_token,
|
||||
)
|
||||
tool_executor.conversation_id = self.conversation_id
|
||||
# Pass client-side tools so they get merged in get_tools()
|
||||
client_tools = self.data.get("client_tools")
|
||||
if client_tools:
|
||||
tool_executor.client_tools = client_tools
|
||||
|
||||
# Base agent kwargs
|
||||
agent_kwargs = {
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
import base64
|
||||
import datetime
|
||||
import html
|
||||
import json
|
||||
import uuid
|
||||
from urllib.parse import urlencode
|
||||
|
||||
|
||||
from bson.objectid import ObjectId
|
||||
from flask import (
|
||||
Blueprint,
|
||||
current_app,
|
||||
@@ -15,18 +17,22 @@ from flask import (
|
||||
from flask_restx import fields, Namespace, Resource
|
||||
|
||||
|
||||
from application.api import api
|
||||
from application.api.user.tasks import (
|
||||
ingest_connector_task,
|
||||
)
|
||||
from application.parser.connectors.connector_creator import ConnectorCreator
|
||||
from application.storage.db.repositories.connector_sessions import (
|
||||
ConnectorSessionsRepository,
|
||||
)
|
||||
from application.storage.db.repositories.sources import SourcesRepository
|
||||
from application.storage.db.session import db_readonly, db_session
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.core.settings import settings
|
||||
from application.api import api
|
||||
|
||||
|
||||
from application.parser.connectors.connector_creator import ConnectorCreator
|
||||
|
||||
|
||||
mongo = MongoDB.get_client()
|
||||
db = mongo[settings.MONGO_DB_NAME]
|
||||
sources_collection = db["sources"]
|
||||
sessions_collection = db["connector_sessions"]
|
||||
|
||||
connector = Blueprint("connector", __name__)
|
||||
connectors_ns = Namespace("connectors", description="Connector operations", path="/")
|
||||
api.add_namespace(connectors_ns)
|
||||
@@ -62,14 +68,16 @@ class ConnectorAuth(Resource):
|
||||
return make_response(jsonify({"success": False, "error": "Unauthorized"}), 401)
|
||||
user_id = decoded_token.get('sub')
|
||||
|
||||
with db_session() as conn:
|
||||
session_row = ConnectorSessionsRepository(conn).upsert(
|
||||
user_id, provider, status="pending",
|
||||
)
|
||||
session_pg_id = str(session_row["id"])
|
||||
now = datetime.datetime.now(datetime.timezone.utc)
|
||||
result = sessions_collection.insert_one({
|
||||
"provider": provider,
|
||||
"user": user_id,
|
||||
"status": "pending",
|
||||
"created_at": now
|
||||
})
|
||||
state_dict = {
|
||||
"provider": provider,
|
||||
"object_id": session_pg_id,
|
||||
"object_id": str(result.inserted_id)
|
||||
}
|
||||
state = base64.urlsafe_b64encode(json.dumps(state_dict).encode()).decode()
|
||||
|
||||
@@ -152,25 +160,17 @@ class ConnectorsCallback(Resource):
|
||||
|
||||
sanitized_token_info = auth.sanitize_token_info(token_info)
|
||||
|
||||
# ``object_id`` in the OAuth state is the PG session row
|
||||
# UUID (new flow) or a legacy Mongo ObjectId (pre-cutover
|
||||
# issued state). Try UUID update first; fall back to
|
||||
# legacy id path.
|
||||
patch = {
|
||||
"session_token": session_token,
|
||||
"token_info": sanitized_token_info,
|
||||
"user_email": user_email,
|
||||
"status": "authorized",
|
||||
}
|
||||
with db_session() as conn:
|
||||
repo = ConnectorSessionsRepository(conn)
|
||||
if state_object_id:
|
||||
value = str(state_object_id)
|
||||
updated = False
|
||||
if len(value) == 36 and "-" in value:
|
||||
updated = repo.update(value, patch)
|
||||
if not updated:
|
||||
repo.update_by_legacy_id(value, patch)
|
||||
sessions_collection.find_one_and_update(
|
||||
{"_id": ObjectId(state_object_id), "provider": provider},
|
||||
{
|
||||
"$set": {
|
||||
"session_token": session_token,
|
||||
"token_info": sanitized_token_info,
|
||||
"user_email": user_email,
|
||||
"status": "authorized"
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
# Redirect to success page with session token and user email
|
||||
return redirect(build_callback_redirect({
|
||||
@@ -222,11 +222,8 @@ class ConnectorFiles(Resource):
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False, "error": "Unauthorized"}), 401)
|
||||
user = decoded_token.get('sub')
|
||||
with db_readonly() as conn:
|
||||
session = ConnectorSessionsRepository(conn).get_by_session_token(
|
||||
session_token,
|
||||
)
|
||||
if not session or session.get("user_id") != user:
|
||||
session = sessions_collection.find_one({"session_token": session_token, "user": user})
|
||||
if not session:
|
||||
return make_response(jsonify({"success": False, "error": "Invalid or unauthorized session"}), 401)
|
||||
|
||||
loader = ConnectorCreator.create_connector(provider, session_token)
|
||||
@@ -291,11 +288,8 @@ class ConnectorValidateSession(Resource):
|
||||
return make_response(jsonify({"success": False, "error": "Unauthorized"}), 401)
|
||||
user = decoded_token.get('sub')
|
||||
|
||||
with db_readonly() as conn:
|
||||
session = ConnectorSessionsRepository(conn).get_by_session_token(
|
||||
session_token,
|
||||
)
|
||||
if not session or session.get("user_id") != user or not session.get("token_info"):
|
||||
session = sessions_collection.find_one({"session_token": session_token, "user": user})
|
||||
if not session or "token_info" not in session:
|
||||
return make_response(jsonify({"success": False, "error": "Invalid or expired session"}), 401)
|
||||
|
||||
token_info = session["token_info"]
|
||||
@@ -306,11 +300,10 @@ class ConnectorValidateSession(Resource):
|
||||
try:
|
||||
refreshed_token_info = auth.refresh_access_token(token_info.get('refresh_token'))
|
||||
sanitized_token_info = auth.sanitize_token_info(refreshed_token_info)
|
||||
with db_session() as conn:
|
||||
repo = ConnectorSessionsRepository(conn)
|
||||
row = repo.get_by_session_token(session_token)
|
||||
if row:
|
||||
repo.update(str(row["id"]), {"token_info": sanitized_token_info})
|
||||
sessions_collection.update_one(
|
||||
{"session_token": session_token},
|
||||
{"$set": {"token_info": sanitized_token_info}}
|
||||
)
|
||||
token_info = sanitized_token_info
|
||||
is_expired = False
|
||||
except Exception as refresh_error:
|
||||
@@ -354,11 +347,8 @@ class ConnectorDisconnect(Resource):
|
||||
|
||||
|
||||
if session_token:
|
||||
with db_session() as conn:
|
||||
ConnectorSessionsRepository(conn).delete_by_session_token(
|
||||
session_token,
|
||||
)
|
||||
|
||||
sessions_collection.delete_one({"session_token": session_token})
|
||||
|
||||
return make_response(jsonify({"success": True}), 200)
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error disconnecting connector session: {e}", exc_info=True)
|
||||
@@ -395,28 +385,32 @@ class ConnectorSync(Resource):
|
||||
}),
|
||||
400
|
||||
)
|
||||
user_id = decoded_token.get('sub')
|
||||
with db_readonly() as conn:
|
||||
source = SourcesRepository(conn).get_any(source_id, user_id)
|
||||
source = sources_collection.find_one({"_id": ObjectId(source_id)})
|
||||
if not source:
|
||||
return make_response(
|
||||
jsonify({
|
||||
"success": False,
|
||||
"error": "Source not found"
|
||||
}),
|
||||
}),
|
||||
404
|
||||
)
|
||||
|
||||
# ``get_any`` already scopes by ``user_id``; an extra guard
|
||||
# here would be dead code.
|
||||
if source.get('user') != decoded_token.get('sub'):
|
||||
return make_response(
|
||||
jsonify({
|
||||
"success": False,
|
||||
"error": "Unauthorized access to source"
|
||||
}),
|
||||
403
|
||||
)
|
||||
|
||||
remote_data = source.get('remote_data') or {}
|
||||
if isinstance(remote_data, str):
|
||||
try:
|
||||
remote_data = json.loads(remote_data)
|
||||
except json.JSONDecodeError:
|
||||
current_app.logger.error(f"Invalid remote_data format for source {source_id}")
|
||||
remote_data = {}
|
||||
remote_data = {}
|
||||
try:
|
||||
if source.get('remote_data'):
|
||||
remote_data = json.loads(source.get('remote_data'))
|
||||
except json.JSONDecodeError:
|
||||
current_app.logger.error(f"Invalid remote_data format for source {source_id}")
|
||||
remote_data = {}
|
||||
|
||||
source_type = remote_data.get('provider')
|
||||
if not source_type:
|
||||
@@ -444,7 +438,7 @@ class ConnectorSync(Resource):
|
||||
recursive=recursive,
|
||||
retriever=source.get('retriever', 'classic'),
|
||||
operation_mode="sync",
|
||||
doc_id=str(source.get('id') or source_id),
|
||||
doc_id=source_id,
|
||||
sync_frequency=source.get('sync_frequency', 'never')
|
||||
)
|
||||
|
||||
|
||||
@@ -3,16 +3,18 @@ import datetime
|
||||
import json
|
||||
from flask import Blueprint, request, send_from_directory, jsonify
|
||||
from werkzeug.utils import secure_filename
|
||||
from bson.objectid import ObjectId
|
||||
import logging
|
||||
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.core.settings import settings
|
||||
from application.storage.db.base_repository import looks_like_uuid
|
||||
from application.storage.db.repositories.sources import SourcesRepository
|
||||
from application.storage.db.session import db_session
|
||||
from application.storage.storage_creator import StorageCreator
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
mongo = MongoDB.get_client()
|
||||
db = mongo[settings.MONGO_DB_NAME]
|
||||
conversations_collection = db["conversations"]
|
||||
sources_collection = db["sources"]
|
||||
|
||||
current_dir = os.path.dirname(
|
||||
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
@@ -24,20 +26,12 @@ internal = Blueprint("internal", __name__)
|
||||
|
||||
@internal.before_request
|
||||
def verify_internal_key():
|
||||
"""Verify INTERNAL_KEY for all internal endpoint requests.
|
||||
|
||||
Deny by default: if INTERNAL_KEY is not configured, reject all requests.
|
||||
"""
|
||||
if not settings.INTERNAL_KEY:
|
||||
logger.warning(
|
||||
f"Internal API request rejected from {request.remote_addr}: "
|
||||
"INTERNAL_KEY is not configured"
|
||||
)
|
||||
return jsonify({"error": "Unauthorized", "message": "Internal API is not configured"}), 401
|
||||
internal_key = request.headers.get("X-Internal-Key")
|
||||
if not internal_key or internal_key != settings.INTERNAL_KEY:
|
||||
logger.warning(f"Unauthorized internal API access attempt from {request.remote_addr}")
|
||||
return jsonify({"error": "Unauthorized", "message": "Invalid or missing internal key"}), 401
|
||||
"""Verify INTERNAL_KEY for all internal endpoint requests."""
|
||||
if settings.INTERNAL_KEY:
|
||||
internal_key = request.headers.get("X-Internal-Key")
|
||||
if not internal_key or internal_key != settings.INTERNAL_KEY:
|
||||
logger.warning(f"Unauthorized internal API access attempt from {request.remote_addr}")
|
||||
return jsonify({"error": "Unauthorized", "message": "Invalid or missing internal key"}), 401
|
||||
|
||||
|
||||
@internal.route("/api/download", methods=["get"])
|
||||
@@ -54,21 +48,21 @@ def upload_index_files():
|
||||
"""Upload two files(index.faiss, index.pkl) to the user's folder."""
|
||||
if "user" not in request.form:
|
||||
return {"status": "no user"}
|
||||
user = request.form["user"]
|
||||
user = request.form["user"]
|
||||
if "name" not in request.form:
|
||||
return {"status": "no name"}
|
||||
job_name = request.form["name"]
|
||||
tokens = request.form["tokens"]
|
||||
retriever = request.form["retriever"]
|
||||
source_id = request.form["id"]
|
||||
id = request.form["id"]
|
||||
type = request.form["type"]
|
||||
remote_data = request.form["remote_data"] if "remote_data" in request.form else None
|
||||
sync_frequency = request.form["sync_frequency"] if "sync_frequency" in request.form else None
|
||||
|
||||
|
||||
file_path = request.form.get("file_path")
|
||||
directory_structure = request.form.get("directory_structure")
|
||||
file_name_map = request.form.get("file_name_map")
|
||||
|
||||
|
||||
if directory_structure:
|
||||
try:
|
||||
directory_structure = json.loads(directory_structure)
|
||||
@@ -87,8 +81,8 @@ def upload_index_files():
|
||||
file_name_map = None
|
||||
|
||||
storage = StorageCreator.get_storage()
|
||||
index_base_path = f"indexes/{source_id}"
|
||||
|
||||
index_base_path = f"indexes/{id}"
|
||||
|
||||
if settings.VECTOR_STORE == "faiss":
|
||||
if "file_faiss" not in request.files:
|
||||
logger.error("No file_faiss part")
|
||||
@@ -109,48 +103,46 @@ def upload_index_files():
|
||||
storage.save_file(file_faiss, faiss_storage_path)
|
||||
storage.save_file(file_pkl, pkl_storage_path)
|
||||
|
||||
now = datetime.datetime.now(datetime.timezone.utc)
|
||||
update_fields = {
|
||||
"name": job_name,
|
||||
"type": type,
|
||||
"language": job_name,
|
||||
"date": now,
|
||||
"model": settings.EMBEDDINGS_NAME,
|
||||
"tokens": tokens,
|
||||
"retriever": retriever,
|
||||
"remote_data": remote_data,
|
||||
"sync_frequency": sync_frequency,
|
||||
"file_path": file_path,
|
||||
"directory_structure": directory_structure,
|
||||
}
|
||||
if file_name_map is not None:
|
||||
update_fields["file_name_map"] = file_name_map
|
||||
|
||||
with db_session() as conn:
|
||||
repo = SourcesRepository(conn)
|
||||
existing = None
|
||||
if looks_like_uuid(source_id):
|
||||
existing = repo.get(source_id, user)
|
||||
if existing is None:
|
||||
existing = repo.get_by_legacy_id(source_id, user)
|
||||
if existing is not None:
|
||||
repo.update(str(existing["id"]), user, update_fields)
|
||||
else:
|
||||
repo.create(
|
||||
job_name,
|
||||
source_id=source_id if looks_like_uuid(source_id) else None,
|
||||
user_id=user,
|
||||
type=type,
|
||||
tokens=tokens,
|
||||
retriever=retriever,
|
||||
remote_data=remote_data,
|
||||
sync_frequency=sync_frequency,
|
||||
file_path=file_path,
|
||||
directory_structure=directory_structure,
|
||||
file_name_map=file_name_map,
|
||||
language=job_name,
|
||||
model=settings.EMBEDDINGS_NAME,
|
||||
date=now,
|
||||
legacy_mongo_id=None if looks_like_uuid(source_id) else str(source_id),
|
||||
)
|
||||
existing_entry = sources_collection.find_one({"_id": ObjectId(id)})
|
||||
if existing_entry:
|
||||
update_fields = {
|
||||
"user": user,
|
||||
"name": job_name,
|
||||
"language": job_name,
|
||||
"date": datetime.datetime.now(),
|
||||
"model": settings.EMBEDDINGS_NAME,
|
||||
"type": type,
|
||||
"tokens": tokens,
|
||||
"retriever": retriever,
|
||||
"remote_data": remote_data,
|
||||
"sync_frequency": sync_frequency,
|
||||
"file_path": file_path,
|
||||
"directory_structure": directory_structure,
|
||||
}
|
||||
if file_name_map is not None:
|
||||
update_fields["file_name_map"] = file_name_map
|
||||
sources_collection.update_one(
|
||||
{"_id": ObjectId(id)},
|
||||
{"$set": update_fields},
|
||||
)
|
||||
else:
|
||||
insert_doc = {
|
||||
"_id": ObjectId(id),
|
||||
"user": user,
|
||||
"name": job_name,
|
||||
"language": job_name,
|
||||
"date": datetime.datetime.now(),
|
||||
"model": settings.EMBEDDINGS_NAME,
|
||||
"type": type,
|
||||
"tokens": tokens,
|
||||
"retriever": retriever,
|
||||
"remote_data": remote_data,
|
||||
"sync_frequency": sync_frequency,
|
||||
"file_path": file_path,
|
||||
"directory_structure": directory_structure,
|
||||
}
|
||||
if file_name_map is not None:
|
||||
insert_doc["file_name_map"] = file_name_map
|
||||
sources_collection.insert_one(insert_doc)
|
||||
return {"status": "ok"}
|
||||
|
||||
@@ -3,50 +3,27 @@ Agent folders management routes.
|
||||
Provides virtual folder organization for agents (Google Drive-like structure).
|
||||
"""
|
||||
|
||||
import datetime
|
||||
from bson.objectid import ObjectId
|
||||
from flask import current_app, jsonify, make_response, request
|
||||
from flask_restx import Namespace, Resource, fields
|
||||
from sqlalchemy import text as _sql_text
|
||||
|
||||
from application.api import api
|
||||
from application.storage.db.base_repository import looks_like_uuid
|
||||
from application.storage.db.repositories.agent_folders import AgentFoldersRepository
|
||||
from application.storage.db.repositories.agents import AgentsRepository
|
||||
from application.storage.db.session import db_readonly, db_session
|
||||
|
||||
from application.api.user.base import (
|
||||
agent_folders_collection,
|
||||
agents_collection,
|
||||
)
|
||||
|
||||
agents_folders_ns = Namespace(
|
||||
"agents_folders", description="Agent folder management", path="/api/agents/folders"
|
||||
)
|
||||
|
||||
|
||||
def _resolve_folder_id(repo: AgentFoldersRepository, folder_id: str, user: str):
|
||||
"""Resolve a folder id that may be either a UUID or legacy Mongo ObjectId."""
|
||||
if not folder_id:
|
||||
return None
|
||||
if looks_like_uuid(folder_id):
|
||||
row = repo.get(folder_id, user)
|
||||
if row is not None:
|
||||
return row
|
||||
return repo.get_by_legacy_id(folder_id, user)
|
||||
|
||||
|
||||
def _folder_error_response(message: str, err: Exception):
|
||||
current_app.logger.error(f"{message}: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False, "message": message}), 400)
|
||||
|
||||
|
||||
def _serialize_folder(f: dict) -> dict:
|
||||
created_at = f.get("created_at")
|
||||
updated_at = f.get("updated_at")
|
||||
return {
|
||||
"id": str(f["id"]),
|
||||
"name": f.get("name"),
|
||||
"parent_id": str(f["parent_id"]) if f.get("parent_id") else None,
|
||||
"created_at": created_at.isoformat() if hasattr(created_at, "isoformat") else created_at,
|
||||
"updated_at": updated_at.isoformat() if hasattr(updated_at, "isoformat") else updated_at,
|
||||
}
|
||||
|
||||
|
||||
@agents_folders_ns.route("/")
|
||||
class AgentFolders(Resource):
|
||||
@api.doc(description="Get all folders for the user")
|
||||
@@ -56,9 +33,17 @@ class AgentFolders(Resource):
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
try:
|
||||
with db_readonly() as conn:
|
||||
folders = AgentFoldersRepository(conn).list_for_user(user)
|
||||
result = [_serialize_folder(f) for f in folders]
|
||||
folders = list(agent_folders_collection.find({"user": user}))
|
||||
result = [
|
||||
{
|
||||
"id": str(f["_id"]),
|
||||
"name": f["name"],
|
||||
"parent_id": f.get("parent_id"),
|
||||
"created_at": f.get("created_at", "").isoformat() if f.get("created_at") else None,
|
||||
"updated_at": f.get("updated_at", "").isoformat() if f.get("updated_at") else None,
|
||||
}
|
||||
for f in folders
|
||||
]
|
||||
return make_response(jsonify({"folders": result}), 200)
|
||||
except Exception as err:
|
||||
return _folder_error_response("Failed to fetch folders", err)
|
||||
@@ -82,34 +67,24 @@ class AgentFolders(Resource):
|
||||
if not data or not data.get("name"):
|
||||
return make_response(jsonify({"success": False, "message": "Folder name is required"}), 400)
|
||||
|
||||
parent_id_input = data.get("parent_id")
|
||||
description = data.get("description")
|
||||
parent_id = data.get("parent_id")
|
||||
if parent_id:
|
||||
parent = agent_folders_collection.find_one({"_id": ObjectId(parent_id), "user": user})
|
||||
if not parent:
|
||||
return make_response(jsonify({"success": False, "message": "Parent folder not found"}), 404)
|
||||
|
||||
try:
|
||||
with db_session() as conn:
|
||||
repo = AgentFoldersRepository(conn)
|
||||
pg_parent_id = None
|
||||
if parent_id_input:
|
||||
parent = _resolve_folder_id(repo, parent_id_input, user)
|
||||
if not parent:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Parent folder not found"}),
|
||||
404,
|
||||
)
|
||||
pg_parent_id = str(parent["id"])
|
||||
folder = repo.create(
|
||||
user, data["name"],
|
||||
description=description,
|
||||
parent_id=pg_parent_id,
|
||||
)
|
||||
now = datetime.datetime.now(datetime.timezone.utc)
|
||||
folder = {
|
||||
"user": user,
|
||||
"name": data["name"],
|
||||
"parent_id": parent_id,
|
||||
"created_at": now,
|
||||
"updated_at": now,
|
||||
}
|
||||
result = agent_folders_collection.insert_one(folder)
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"id": str(folder["id"]),
|
||||
"name": folder["name"],
|
||||
"parent_id": pg_parent_id,
|
||||
}
|
||||
),
|
||||
jsonify({"id": str(result.inserted_id), "name": data["name"], "parent_id": parent_id}),
|
||||
201,
|
||||
)
|
||||
except Exception as err:
|
||||
@@ -125,51 +100,26 @@ class AgentFolder(Resource):
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
try:
|
||||
with db_readonly() as conn:
|
||||
folders_repo = AgentFoldersRepository(conn)
|
||||
folder = _resolve_folder_id(folders_repo, folder_id, user)
|
||||
if not folder:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Folder not found"}),
|
||||
404,
|
||||
)
|
||||
pg_folder_id = str(folder["id"])
|
||||
|
||||
agents_rows = conn.execute(
|
||||
_sql_text(
|
||||
"SELECT id, name, description FROM agents "
|
||||
"WHERE user_id = :user_id AND folder_id = CAST(:fid AS uuid) "
|
||||
"ORDER BY created_at DESC"
|
||||
),
|
||||
{"user_id": user, "fid": pg_folder_id},
|
||||
).fetchall()
|
||||
agents_list = [
|
||||
{
|
||||
"id": str(row._mapping["id"]),
|
||||
"name": row._mapping["name"],
|
||||
"description": row._mapping.get("description", "") or "",
|
||||
}
|
||||
for row in agents_rows
|
||||
]
|
||||
|
||||
subfolders = folders_repo.list_children(pg_folder_id, user)
|
||||
subfolders_list = [
|
||||
{"id": str(sf["id"]), "name": sf["name"]}
|
||||
for sf in subfolders
|
||||
]
|
||||
folder = agent_folders_collection.find_one({"_id": ObjectId(folder_id), "user": user})
|
||||
if not folder:
|
||||
return make_response(jsonify({"success": False, "message": "Folder not found"}), 404)
|
||||
|
||||
agents = list(agents_collection.find({"user": user, "folder_id": folder_id}))
|
||||
agents_list = [
|
||||
{"id": str(a["_id"]), "name": a["name"], "description": a.get("description", "")}
|
||||
for a in agents
|
||||
]
|
||||
subfolders = list(agent_folders_collection.find({"user": user, "parent_id": folder_id}))
|
||||
subfolders_list = [{"id": str(sf["_id"]), "name": sf["name"]} for sf in subfolders]
|
||||
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"id": pg_folder_id,
|
||||
"name": folder["name"],
|
||||
"parent_id": (
|
||||
str(folder["parent_id"]) if folder.get("parent_id") else None
|
||||
),
|
||||
"agents": agents_list,
|
||||
"subfolders": subfolders_list,
|
||||
}
|
||||
),
|
||||
jsonify({
|
||||
"id": str(folder["_id"]),
|
||||
"name": folder["name"],
|
||||
"parent_id": folder.get("parent_id"),
|
||||
"agents": agents_list,
|
||||
"subfolders": subfolders_list,
|
||||
}),
|
||||
200,
|
||||
)
|
||||
except Exception as err:
|
||||
@@ -186,57 +136,19 @@ class AgentFolder(Resource):
|
||||
return make_response(jsonify({"success": False, "message": "No data provided"}), 400)
|
||||
|
||||
try:
|
||||
with db_session() as conn:
|
||||
repo = AgentFoldersRepository(conn)
|
||||
folder = _resolve_folder_id(repo, folder_id, user)
|
||||
if not folder:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Folder not found"}),
|
||||
404,
|
||||
)
|
||||
pg_folder_id = str(folder["id"])
|
||||
|
||||
update_fields: dict = {}
|
||||
if "name" in data:
|
||||
update_fields["name"] = data["name"]
|
||||
if "description" in data:
|
||||
update_fields["description"] = data["description"]
|
||||
if "parent_id" in data:
|
||||
parent_input = data.get("parent_id")
|
||||
if parent_input:
|
||||
if parent_input == folder_id or parent_input == pg_folder_id:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": "Cannot set folder as its own parent",
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
parent = _resolve_folder_id(repo, parent_input, user)
|
||||
if not parent:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Parent folder not found"}),
|
||||
404,
|
||||
)
|
||||
if str(parent["id"]) == pg_folder_id:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": "Cannot set folder as its own parent",
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
update_fields["parent_id"] = str(parent["id"])
|
||||
else:
|
||||
update_fields["parent_id"] = None
|
||||
|
||||
if update_fields:
|
||||
repo.update(pg_folder_id, user, update_fields)
|
||||
update_fields = {"updated_at": datetime.datetime.now(datetime.timezone.utc)}
|
||||
if "name" in data:
|
||||
update_fields["name"] = data["name"]
|
||||
if "parent_id" in data:
|
||||
if data["parent_id"] == folder_id:
|
||||
return make_response(jsonify({"success": False, "message": "Cannot set folder as its own parent"}), 400)
|
||||
update_fields["parent_id"] = data["parent_id"]
|
||||
|
||||
result = agent_folders_collection.update_one(
|
||||
{"_id": ObjectId(folder_id), "user": user}, {"$set": update_fields}
|
||||
)
|
||||
if result.matched_count == 0:
|
||||
return make_response(jsonify({"success": False, "message": "Folder not found"}), 404)
|
||||
return make_response(jsonify({"success": True}), 200)
|
||||
except Exception as err:
|
||||
return _folder_error_response("Failed to update folder", err)
|
||||
@@ -248,24 +160,15 @@ class AgentFolder(Resource):
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
try:
|
||||
with db_session() as conn:
|
||||
repo = AgentFoldersRepository(conn)
|
||||
folder = _resolve_folder_id(repo, folder_id, user)
|
||||
if not folder:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Folder not found"}),
|
||||
404,
|
||||
)
|
||||
pg_folder_id = str(folder["id"])
|
||||
# Clear folder assignments from agents; self-FK
|
||||
# ``ON DELETE SET NULL`` handles child folders.
|
||||
AgentsRepository(conn).clear_folder_for_all(pg_folder_id, user)
|
||||
deleted = repo.delete(pg_folder_id, user)
|
||||
if not deleted:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Folder not found"}),
|
||||
404,
|
||||
)
|
||||
agents_collection.update_many(
|
||||
{"user": user, "folder_id": folder_id}, {"$unset": {"folder_id": ""}}
|
||||
)
|
||||
agent_folders_collection.update_many(
|
||||
{"user": user, "parent_id": folder_id}, {"$unset": {"parent_id": ""}}
|
||||
)
|
||||
result = agent_folders_collection.delete_one({"_id": ObjectId(folder_id), "user": user})
|
||||
if result.deleted_count == 0:
|
||||
return make_response(jsonify({"success": False, "message": "Folder not found"}), 404)
|
||||
return make_response(jsonify({"success": True}), 200)
|
||||
except Exception as err:
|
||||
return _folder_error_response("Failed to delete folder", err)
|
||||
@@ -292,29 +195,26 @@ class MoveAgentToFolder(Resource):
|
||||
if not data or not data.get("agent_id"):
|
||||
return make_response(jsonify({"success": False, "message": "Agent ID is required"}), 400)
|
||||
|
||||
agent_id_input = data["agent_id"]
|
||||
folder_id_input = data.get("folder_id")
|
||||
agent_id = data["agent_id"]
|
||||
folder_id = data.get("folder_id")
|
||||
|
||||
try:
|
||||
with db_session() as conn:
|
||||
agents_repo = AgentsRepository(conn)
|
||||
agent = agents_repo.get_any(agent_id_input, user)
|
||||
if not agent:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Agent not found"}),
|
||||
404,
|
||||
)
|
||||
pg_folder_id = None
|
||||
if folder_id_input:
|
||||
folders_repo = AgentFoldersRepository(conn)
|
||||
folder = _resolve_folder_id(folders_repo, folder_id_input, user)
|
||||
if not folder:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Folder not found"}),
|
||||
404,
|
||||
)
|
||||
pg_folder_id = str(folder["id"])
|
||||
agents_repo.set_folder(str(agent["id"]), user, pg_folder_id)
|
||||
agent = agents_collection.find_one({"_id": ObjectId(agent_id), "user": user})
|
||||
if not agent:
|
||||
return make_response(jsonify({"success": False, "message": "Agent not found"}), 404)
|
||||
|
||||
if folder_id:
|
||||
folder = agent_folders_collection.find_one({"_id": ObjectId(folder_id), "user": user})
|
||||
if not folder:
|
||||
return make_response(jsonify({"success": False, "message": "Folder not found"}), 404)
|
||||
agents_collection.update_one(
|
||||
{"_id": ObjectId(agent_id)}, {"$set": {"folder_id": folder_id}}
|
||||
)
|
||||
else:
|
||||
agents_collection.update_one(
|
||||
{"_id": ObjectId(agent_id)}, {"$unset": {"folder_id": ""}}
|
||||
)
|
||||
|
||||
return make_response(jsonify({"success": True}), 200)
|
||||
except Exception as err:
|
||||
return _folder_error_response("Failed to move agent", err)
|
||||
@@ -342,25 +242,25 @@ class BulkMoveAgents(Resource):
|
||||
return make_response(jsonify({"success": False, "message": "Agent IDs are required"}), 400)
|
||||
|
||||
agent_ids = data["agent_ids"]
|
||||
folder_id_input = data.get("folder_id")
|
||||
folder_id = data.get("folder_id")
|
||||
|
||||
try:
|
||||
with db_session() as conn:
|
||||
agents_repo = AgentsRepository(conn)
|
||||
pg_folder_id = None
|
||||
if folder_id_input:
|
||||
folders_repo = AgentFoldersRepository(conn)
|
||||
folder = _resolve_folder_id(folders_repo, folder_id_input, user)
|
||||
if not folder:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Folder not found"}),
|
||||
404,
|
||||
)
|
||||
pg_folder_id = str(folder["id"])
|
||||
for agent_id_input in agent_ids:
|
||||
agent = agents_repo.get_any(agent_id_input, user)
|
||||
if agent is not None:
|
||||
agents_repo.set_folder(str(agent["id"]), user, pg_folder_id)
|
||||
if folder_id:
|
||||
folder = agent_folders_collection.find_one({"_id": ObjectId(folder_id), "user": user})
|
||||
if not folder:
|
||||
return make_response(jsonify({"success": False, "message": "Folder not found"}), 404)
|
||||
|
||||
object_ids = [ObjectId(aid) for aid in agent_ids]
|
||||
if folder_id:
|
||||
agents_collection.update_many(
|
||||
{"_id": {"$in": object_ids}, "user": user},
|
||||
{"$set": {"folder_id": folder_id}},
|
||||
)
|
||||
else:
|
||||
agents_collection.update_many(
|
||||
{"_id": {"$in": object_ids}, "user": user},
|
||||
{"$unset": {"folder_id": ""}},
|
||||
)
|
||||
return make_response(jsonify({"success": True}), 200)
|
||||
except Exception as err:
|
||||
return _folder_error_response("Failed to move agents", err)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -3,17 +3,21 @@
|
||||
import datetime
|
||||
import secrets
|
||||
|
||||
from bson import DBRef
|
||||
from bson.objectid import ObjectId
|
||||
from flask import current_app, jsonify, make_response, request
|
||||
from flask_restx import fields, Namespace, Resource
|
||||
from sqlalchemy import text as _sql_text
|
||||
|
||||
from application.api import api
|
||||
from application.core.settings import settings
|
||||
from application.api.user.base import resolve_tool_details
|
||||
from application.storage.db.base_repository import looks_like_uuid
|
||||
from application.storage.db.repositories.agents import AgentsRepository
|
||||
from application.storage.db.repositories.users import UsersRepository
|
||||
from application.storage.db.session import db_readonly, db_session
|
||||
from application.api.user.base import (
|
||||
agents_collection,
|
||||
db,
|
||||
ensure_user_doc,
|
||||
resolve_tool_details,
|
||||
user_tools_collection,
|
||||
users_collection,
|
||||
)
|
||||
from application.utils import generate_image_url
|
||||
|
||||
agents_sharing_ns = Namespace(
|
||||
@@ -21,38 +25,6 @@ agents_sharing_ns = Namespace(
|
||||
)
|
||||
|
||||
|
||||
def _serialize_agent_basic(agent: dict) -> dict:
|
||||
"""Shape a PG agent row into the API response dict."""
|
||||
source_id = agent.get("source_id")
|
||||
return {
|
||||
"id": str(agent["id"]),
|
||||
"user": agent.get("user_id", ""),
|
||||
"name": agent.get("name", ""),
|
||||
"image": (
|
||||
generate_image_url(agent["image"]) if agent.get("image") else ""
|
||||
),
|
||||
"description": agent.get("description", ""),
|
||||
"source": str(source_id) if source_id else "",
|
||||
"chunks": str(agent["chunks"]) if agent.get("chunks") is not None else "0",
|
||||
"retriever": agent.get("retriever", "classic") or "classic",
|
||||
"prompt_id": str(agent["prompt_id"]) if agent.get("prompt_id") else "default",
|
||||
"tools": agent.get("tools", []) or [],
|
||||
"tool_details": resolve_tool_details(agent.get("tools", []) or []),
|
||||
"agent_type": agent.get("agent_type", "") or "",
|
||||
"status": agent.get("status", "") or "",
|
||||
"json_schema": agent.get("json_schema"),
|
||||
"limited_token_mode": agent.get("limited_token_mode", False),
|
||||
"token_limit": agent.get("token_limit") or settings.DEFAULT_AGENT_LIMITS["token_limit"],
|
||||
"limited_request_mode": agent.get("limited_request_mode", False),
|
||||
"request_limit": agent.get("request_limit") or settings.DEFAULT_AGENT_LIMITS["request_limit"],
|
||||
"created_at": agent.get("created_at", ""),
|
||||
"updated_at": agent.get("updated_at", ""),
|
||||
"shared": bool(agent.get("shared", False)),
|
||||
"shared_token": agent.get("shared_token", "") or "",
|
||||
"shared_metadata": agent.get("shared_metadata", {}) or {},
|
||||
}
|
||||
|
||||
|
||||
@agents_sharing_ns.route("/shared_agent")
|
||||
class SharedAgent(Resource):
|
||||
@api.doc(
|
||||
@@ -69,33 +41,70 @@ class SharedAgent(Resource):
|
||||
jsonify({"success": False, "message": "Token or ID is required"}), 400
|
||||
)
|
||||
try:
|
||||
with db_readonly() as conn:
|
||||
shared_agent = AgentsRepository(conn).find_by_shared_token(
|
||||
shared_token,
|
||||
)
|
||||
query = {
|
||||
"shared_publicly": True,
|
||||
"shared_token": shared_token,
|
||||
}
|
||||
shared_agent = agents_collection.find_one(query)
|
||||
if not shared_agent:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Shared agent not found"}),
|
||||
404,
|
||||
)
|
||||
agent_id = str(shared_agent["id"])
|
||||
data = _serialize_agent_basic(shared_agent)
|
||||
agent_id = str(shared_agent["_id"])
|
||||
data = {
|
||||
"id": agent_id,
|
||||
"user": shared_agent.get("user", ""),
|
||||
"name": shared_agent.get("name", ""),
|
||||
"image": (
|
||||
generate_image_url(shared_agent["image"])
|
||||
if shared_agent.get("image")
|
||||
else ""
|
||||
),
|
||||
"description": shared_agent.get("description", ""),
|
||||
"source": (
|
||||
str(source_doc["_id"])
|
||||
if isinstance(shared_agent.get("source"), DBRef)
|
||||
and (source_doc := db.dereference(shared_agent.get("source")))
|
||||
else ""
|
||||
),
|
||||
"chunks": shared_agent.get("chunks", "0"),
|
||||
"retriever": shared_agent.get("retriever", "classic"),
|
||||
"prompt_id": shared_agent.get("prompt_id", "default"),
|
||||
"tools": shared_agent.get("tools", []),
|
||||
"tool_details": resolve_tool_details(shared_agent.get("tools", [])),
|
||||
"agent_type": shared_agent.get("agent_type", ""),
|
||||
"status": shared_agent.get("status", ""),
|
||||
"json_schema": shared_agent.get("json_schema"),
|
||||
"limited_token_mode": shared_agent.get("limited_token_mode", False),
|
||||
"token_limit": shared_agent.get("token_limit", settings.DEFAULT_AGENT_LIMITS["token_limit"]),
|
||||
"limited_request_mode": shared_agent.get("limited_request_mode", False),
|
||||
"request_limit": shared_agent.get("request_limit", settings.DEFAULT_AGENT_LIMITS["request_limit"]),
|
||||
"created_at": shared_agent.get("createdAt", ""),
|
||||
"updated_at": shared_agent.get("updatedAt", ""),
|
||||
"shared": shared_agent.get("shared_publicly", False),
|
||||
"shared_token": shared_agent.get("shared_token", ""),
|
||||
"shared_metadata": shared_agent.get("shared_metadata", {}),
|
||||
}
|
||||
|
||||
if data["tools"]:
|
||||
enriched_tools = []
|
||||
for detail in data["tool_details"]:
|
||||
enriched_tools.append(detail.get("name", ""))
|
||||
for tool in data["tools"]:
|
||||
tool_data = user_tools_collection.find_one({"_id": ObjectId(tool)})
|
||||
if tool_data:
|
||||
enriched_tools.append(tool_data.get("name", ""))
|
||||
data["tools"] = enriched_tools
|
||||
decoded_token = getattr(request, "decoded_token", None)
|
||||
if decoded_token:
|
||||
user_id = decoded_token.get("sub")
|
||||
owner_id = shared_agent.get("user_id")
|
||||
owner_id = shared_agent.get("user")
|
||||
|
||||
if user_id != owner_id:
|
||||
with db_session() as conn:
|
||||
users_repo = UsersRepository(conn)
|
||||
users_repo.upsert(user_id)
|
||||
users_repo.add_shared(user_id, agent_id)
|
||||
ensure_user_doc(user_id)
|
||||
users_collection.update_one(
|
||||
{"user_id": user_id},
|
||||
{"$addToSet": {"agent_preferences.shared_with_me": agent_id}},
|
||||
)
|
||||
return make_response(jsonify(data), 200)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error retrieving shared agent: {err}")
|
||||
@@ -112,73 +121,52 @@ class SharedAgents(Resource):
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user_id = decoded_token.get("sub")
|
||||
|
||||
with db_session() as conn:
|
||||
users_repo = UsersRepository(conn)
|
||||
user_doc = users_repo.upsert(user_id)
|
||||
shared_with_ids = (
|
||||
user_doc.get("agent_preferences", {}).get("shared_with_me", [])
|
||||
if isinstance(user_doc.get("agent_preferences"), dict)
|
||||
else []
|
||||
user_doc = ensure_user_doc(user_id)
|
||||
shared_with_ids = user_doc.get("agent_preferences", {}).get(
|
||||
"shared_with_me", []
|
||||
)
|
||||
shared_object_ids = [ObjectId(id) for id in shared_with_ids]
|
||||
|
||||
shared_agents_cursor = agents_collection.find(
|
||||
{"_id": {"$in": shared_object_ids}, "shared_publicly": True}
|
||||
)
|
||||
shared_agents = list(shared_agents_cursor)
|
||||
|
||||
found_ids_set = {str(agent["_id"]) for agent in shared_agents}
|
||||
stale_ids = [id for id in shared_with_ids if id not in found_ids_set]
|
||||
if stale_ids:
|
||||
users_collection.update_one(
|
||||
{"user_id": user_id},
|
||||
{"$pullAll": {"agent_preferences.shared_with_me": stale_ids}},
|
||||
)
|
||||
# Keep only UUID-shaped ids; ObjectId leftovers are stripped below.
|
||||
uuid_ids = [sid for sid in shared_with_ids if looks_like_uuid(sid)]
|
||||
non_uuid_ids = [sid for sid in shared_with_ids if not looks_like_uuid(sid)]
|
||||
pinned_ids = set(user_doc.get("agent_preferences", {}).get("pinned", []))
|
||||
|
||||
if uuid_ids:
|
||||
result = conn.execute(
|
||||
_sql_text(
|
||||
"SELECT * FROM agents "
|
||||
"WHERE id = ANY(CAST(:ids AS uuid[])) "
|
||||
"AND shared = true"
|
||||
),
|
||||
{"ids": uuid_ids},
|
||||
)
|
||||
shared_agents = [dict(row._mapping) for row in result.fetchall()]
|
||||
else:
|
||||
shared_agents = []
|
||||
|
||||
found_ids_set = {str(agent["id"]) for agent in shared_agents}
|
||||
stale_ids = [sid for sid in uuid_ids if sid not in found_ids_set]
|
||||
stale_ids.extend(non_uuid_ids)
|
||||
if stale_ids:
|
||||
users_repo.remove_shared_bulk(user_id, stale_ids)
|
||||
|
||||
pinned_ids = set(
|
||||
user_doc.get("agent_preferences", {}).get("pinned", [])
|
||||
if isinstance(user_doc.get("agent_preferences"), dict)
|
||||
else []
|
||||
)
|
||||
|
||||
list_shared_agents = []
|
||||
for agent in shared_agents:
|
||||
agent_id_str = str(agent["id"])
|
||||
list_shared_agents.append(
|
||||
{
|
||||
"id": agent_id_str,
|
||||
"name": agent.get("name", ""),
|
||||
"description": agent.get("description", ""),
|
||||
"image": (
|
||||
generate_image_url(agent["image"]) if agent.get("image") else ""
|
||||
),
|
||||
"tools": agent.get("tools", []) or [],
|
||||
"tool_details": resolve_tool_details(
|
||||
agent.get("tools", []) or []
|
||||
),
|
||||
"agent_type": agent.get("agent_type", "") or "",
|
||||
"status": agent.get("status", "") or "",
|
||||
"json_schema": agent.get("json_schema"),
|
||||
"limited_token_mode": agent.get("limited_token_mode", False),
|
||||
"token_limit": agent.get("token_limit") or settings.DEFAULT_AGENT_LIMITS["token_limit"],
|
||||
"limited_request_mode": agent.get("limited_request_mode", False),
|
||||
"request_limit": agent.get("request_limit") or settings.DEFAULT_AGENT_LIMITS["request_limit"],
|
||||
"created_at": agent.get("created_at", ""),
|
||||
"updated_at": agent.get("updated_at", ""),
|
||||
"pinned": agent_id_str in pinned_ids,
|
||||
"shared": bool(agent.get("shared", False)),
|
||||
"shared_token": agent.get("shared_token", "") or "",
|
||||
"shared_metadata": agent.get("shared_metadata", {}) or {},
|
||||
}
|
||||
)
|
||||
list_shared_agents = [
|
||||
{
|
||||
"id": str(agent["_id"]),
|
||||
"name": agent.get("name", ""),
|
||||
"description": agent.get("description", ""),
|
||||
"image": (
|
||||
generate_image_url(agent["image"]) if agent.get("image") else ""
|
||||
),
|
||||
"tools": agent.get("tools", []),
|
||||
"tool_details": resolve_tool_details(agent.get("tools", [])),
|
||||
"agent_type": agent.get("agent_type", ""),
|
||||
"status": agent.get("status", ""),
|
||||
"json_schema": agent.get("json_schema"),
|
||||
"limited_token_mode": agent.get("limited_token_mode", False),
|
||||
"token_limit": agent.get("token_limit", settings.DEFAULT_AGENT_LIMITS["token_limit"]),
|
||||
"limited_request_mode": agent.get("limited_request_mode", False),
|
||||
"request_limit": agent.get("request_limit", settings.DEFAULT_AGENT_LIMITS["request_limit"]),
|
||||
"created_at": agent.get("createdAt", ""),
|
||||
"updated_at": agent.get("updatedAt", ""),
|
||||
"pinned": str(agent["_id"]) in pinned_ids,
|
||||
"shared": agent.get("shared_publicly", False),
|
||||
"shared_token": agent.get("shared_token", ""),
|
||||
"shared_metadata": agent.get("shared_metadata", {}),
|
||||
}
|
||||
for agent in shared_agents
|
||||
]
|
||||
|
||||
return make_response(jsonify(list_shared_agents), 200)
|
||||
except Exception as err:
|
||||
@@ -232,43 +220,44 @@ class ShareAgent(Resource):
|
||||
),
|
||||
400,
|
||||
)
|
||||
shared_token = None
|
||||
try:
|
||||
with db_session() as conn:
|
||||
repo = AgentsRepository(conn)
|
||||
agent = repo.get_any(agent_id, user)
|
||||
if not agent:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Agent not found"}), 404
|
||||
)
|
||||
if shared:
|
||||
shared_metadata = {
|
||||
"shared_by": username,
|
||||
"shared_at": datetime.datetime.now(
|
||||
datetime.timezone.utc
|
||||
).isoformat(),
|
||||
}
|
||||
shared_token = secrets.token_urlsafe(32)
|
||||
repo.update(
|
||||
str(agent["id"]), user,
|
||||
{
|
||||
"shared": True,
|
||||
"shared_token": shared_token,
|
||||
try:
|
||||
agent_oid = ObjectId(agent_id)
|
||||
except Exception:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Invalid agent ID"}), 400
|
||||
)
|
||||
agent = agents_collection.find_one({"_id": agent_oid, "user": user})
|
||||
if not agent:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Agent not found"}), 404
|
||||
)
|
||||
if shared:
|
||||
shared_metadata = {
|
||||
"shared_by": username,
|
||||
"shared_at": datetime.datetime.now(datetime.timezone.utc),
|
||||
}
|
||||
shared_token = secrets.token_urlsafe(32)
|
||||
agents_collection.update_one(
|
||||
{"_id": agent_oid, "user": user},
|
||||
{
|
||||
"$set": {
|
||||
"shared_publicly": shared,
|
||||
"shared_metadata": shared_metadata,
|
||||
},
|
||||
)
|
||||
else:
|
||||
repo.update(
|
||||
str(agent["id"]), user,
|
||||
{
|
||||
"shared": False,
|
||||
"shared_token": None,
|
||||
"shared_metadata": None,
|
||||
},
|
||||
)
|
||||
"shared_token": shared_token,
|
||||
}
|
||||
},
|
||||
)
|
||||
else:
|
||||
agents_collection.update_one(
|
||||
{"_id": agent_oid, "user": user},
|
||||
{"$set": {"shared_publicly": shared, "shared_token": None}},
|
||||
{"$unset": {"shared_metadata": ""}},
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error sharing/unsharing agent: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False, "error": "Failed to update agent sharing status"}), 400)
|
||||
shared_token = shared_token if shared else None
|
||||
return make_response(
|
||||
jsonify({"success": True, "shared_token": shared_token}), 200
|
||||
)
|
||||
|
||||
@@ -2,15 +2,14 @@
|
||||
|
||||
import secrets
|
||||
|
||||
from bson.objectid import ObjectId
|
||||
from flask import current_app, jsonify, make_response, request
|
||||
from flask_restx import Namespace, Resource
|
||||
|
||||
from application.api import api
|
||||
from application.api.user.base import require_agent
|
||||
from application.api.user.base import agents_collection, require_agent
|
||||
from application.api.user.tasks import process_agent_webhook
|
||||
from application.core.settings import settings
|
||||
from application.storage.db.repositories.agents import AgentsRepository
|
||||
from application.storage.db.session import db_readonly, db_session
|
||||
|
||||
|
||||
agents_webhooks_ns = Namespace(
|
||||
@@ -35,8 +34,9 @@ class AgentWebhook(Resource):
|
||||
jsonify({"success": False, "message": "ID is required"}), 400
|
||||
)
|
||||
try:
|
||||
with db_readonly() as conn:
|
||||
agent = AgentsRepository(conn).get_any(agent_id, user)
|
||||
agent = agents_collection.find_one(
|
||||
{"_id": ObjectId(agent_id), "user": user}
|
||||
)
|
||||
if not agent:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Agent not found"}), 404
|
||||
@@ -44,11 +44,10 @@ class AgentWebhook(Resource):
|
||||
webhook_token = agent.get("incoming_webhook_token")
|
||||
if not webhook_token:
|
||||
webhook_token = secrets.token_urlsafe(32)
|
||||
with db_session() as conn:
|
||||
AgentsRepository(conn).update(
|
||||
str(agent["id"]), user,
|
||||
{"incoming_webhook_token": webhook_token},
|
||||
)
|
||||
agents_collection.update_one(
|
||||
{"_id": ObjectId(agent_id), "user": user},
|
||||
{"$set": {"incoming_webhook_token": webhook_token}},
|
||||
)
|
||||
base_url = settings.API_URL.rstrip("/")
|
||||
full_webhook_url = f"{base_url}/api/webhooks/agents/{webhook_token}"
|
||||
except Exception as err:
|
||||
|
||||
@@ -2,84 +2,26 @@
|
||||
|
||||
import datetime
|
||||
|
||||
from bson.objectid import ObjectId
|
||||
from flask import current_app, jsonify, make_response, request
|
||||
from flask_restx import fields, Namespace, Resource
|
||||
from sqlalchemy import text as _sql_text
|
||||
|
||||
from application.api import api
|
||||
from application.api.user.base import (
|
||||
agents_collection,
|
||||
conversations_collection,
|
||||
generate_date_range,
|
||||
generate_hourly_range,
|
||||
generate_minute_range,
|
||||
token_usage_collection,
|
||||
user_logs_collection,
|
||||
)
|
||||
from application.storage.db.repositories.agents import AgentsRepository
|
||||
from application.storage.db.repositories.token_usage import TokenUsageRepository
|
||||
from application.storage.db.repositories.user_logs import UserLogsRepository
|
||||
from application.storage.db.session import db_readonly
|
||||
|
||||
|
||||
analytics_ns = Namespace(
|
||||
"analytics", description="Analytics and reporting operations", path="/api"
|
||||
)
|
||||
|
||||
|
||||
_FILTER_BUCKETS = {
|
||||
"last_hour": ("minute", "%Y-%m-%d %H:%M:00", "YYYY-MM-DD HH24:MI:00"),
|
||||
"last_24_hour": ("hour", "%Y-%m-%d %H:00", "YYYY-MM-DD HH24:00"),
|
||||
"last_7_days": ("day", "%Y-%m-%d", "YYYY-MM-DD"),
|
||||
"last_15_days": ("day", "%Y-%m-%d", "YYYY-MM-DD"),
|
||||
"last_30_days": ("day", "%Y-%m-%d", "YYYY-MM-DD"),
|
||||
}
|
||||
|
||||
|
||||
def _range_for_filter(filter_option: str):
|
||||
"""Return ``(start_date, end_date, bucket_unit, pg_fmt)`` for the filter.
|
||||
|
||||
Returns ``None`` on invalid filter.
|
||||
"""
|
||||
if filter_option not in _FILTER_BUCKETS:
|
||||
return None
|
||||
end_date = datetime.datetime.now(datetime.timezone.utc)
|
||||
bucket_unit, _py_fmt, pg_fmt = _FILTER_BUCKETS[filter_option]
|
||||
|
||||
if filter_option == "last_hour":
|
||||
start_date = end_date - datetime.timedelta(hours=1)
|
||||
elif filter_option == "last_24_hour":
|
||||
start_date = end_date - datetime.timedelta(hours=24)
|
||||
else:
|
||||
days = {
|
||||
"last_7_days": 6,
|
||||
"last_15_days": 14,
|
||||
"last_30_days": 29,
|
||||
}[filter_option]
|
||||
start_date = end_date - datetime.timedelta(days=days)
|
||||
start_date = start_date.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
end_date = end_date.replace(
|
||||
hour=23, minute=59, second=59, microsecond=999999
|
||||
)
|
||||
return start_date, end_date, bucket_unit, pg_fmt
|
||||
|
||||
|
||||
def _intervals_for_filter(filter_option, start_date, end_date):
|
||||
if filter_option == "last_hour":
|
||||
return generate_minute_range(start_date, end_date)
|
||||
if filter_option == "last_24_hour":
|
||||
return generate_hourly_range(start_date, end_date)
|
||||
return generate_date_range(start_date, end_date)
|
||||
|
||||
|
||||
def _resolve_api_key(conn, api_key_id, user_id):
|
||||
"""Look up the ``agents.key`` value for a given agent id.
|
||||
|
||||
Scoped by ``user_id`` so an authenticated caller can't probe another
|
||||
user's agents. Accepts either UUID or legacy Mongo ObjectId shape.
|
||||
"""
|
||||
if not api_key_id:
|
||||
return None
|
||||
agent = AgentsRepository(conn).get_any(api_key_id, user_id)
|
||||
return (agent or {}).get("key") if agent else None
|
||||
|
||||
|
||||
@analytics_ns.route("/get_message_analytics")
|
||||
class GetMessageAnalytics(Resource):
|
||||
get_message_analytics_model = api.model(
|
||||
@@ -90,7 +32,13 @@ class GetMessageAnalytics(Resource):
|
||||
required=False,
|
||||
description="Filter option for analytics",
|
||||
default="last_30_days",
|
||||
enum=list(_FILTER_BUCKETS.keys()),
|
||||
enum=[
|
||||
"last_hour",
|
||||
"last_24_hour",
|
||||
"last_7_days",
|
||||
"last_15_days",
|
||||
"last_30_days",
|
||||
],
|
||||
),
|
||||
},
|
||||
)
|
||||
@@ -102,54 +50,88 @@ class GetMessageAnalytics(Resource):
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
data = request.get_json() or {}
|
||||
data = request.get_json()
|
||||
api_key_id = data.get("api_key_id")
|
||||
filter_option = data.get("filter_option", "last_30_days")
|
||||
|
||||
window = _range_for_filter(filter_option)
|
||||
if window is None:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Invalid option"}), 400
|
||||
)
|
||||
start_date, end_date, _bucket_unit, pg_fmt = window
|
||||
|
||||
try:
|
||||
with db_readonly() as conn:
|
||||
api_key = _resolve_api_key(conn, api_key_id, user)
|
||||
|
||||
# Count messages per bucket, filtered by the conversation's
|
||||
# owner (user_id) and optionally the agent api_key. The
|
||||
# ``user_id`` filter is always applied post-cutover to
|
||||
# prevent cross-tenant leakage on admin dashboards.
|
||||
clauses = [
|
||||
"c.user_id = :user_id",
|
||||
"m.timestamp >= :start",
|
||||
"m.timestamp <= :end",
|
||||
api_key = (
|
||||
agents_collection.find_one({"_id": ObjectId(api_key_id), "user": user})[
|
||||
"key"
|
||||
]
|
||||
params: dict = {
|
||||
"user_id": user,
|
||||
"start": start_date,
|
||||
"end": end_date,
|
||||
"fmt": pg_fmt,
|
||||
}
|
||||
if api_key:
|
||||
clauses.append("c.api_key = :api_key")
|
||||
params["api_key"] = api_key
|
||||
where = " AND ".join(clauses)
|
||||
sql = (
|
||||
"SELECT to_char(m.timestamp AT TIME ZONE 'UTC', :fmt) AS bucket, "
|
||||
"COUNT(*) AS count "
|
||||
"FROM conversation_messages m "
|
||||
"JOIN conversations c ON c.id = m.conversation_id "
|
||||
f"WHERE {where} "
|
||||
"GROUP BY bucket ORDER BY bucket ASC"
|
||||
)
|
||||
rows = conn.execute(_sql_text(sql), params).fetchall()
|
||||
if api_key_id
|
||||
else None
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error getting API key: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
end_date = datetime.datetime.now(datetime.timezone.utc)
|
||||
|
||||
intervals = _intervals_for_filter(filter_option, start_date, end_date)
|
||||
if filter_option == "last_hour":
|
||||
start_date = end_date - datetime.timedelta(hours=1)
|
||||
group_format = "%Y-%m-%d %H:%M:00"
|
||||
elif filter_option == "last_24_hour":
|
||||
start_date = end_date - datetime.timedelta(hours=24)
|
||||
group_format = "%Y-%m-%d %H:00"
|
||||
else:
|
||||
if filter_option in ["last_7_days", "last_15_days", "last_30_days"]:
|
||||
filter_days = (
|
||||
6
|
||||
if filter_option == "last_7_days"
|
||||
else 14 if filter_option == "last_15_days" else 29
|
||||
)
|
||||
else:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Invalid option"}), 400
|
||||
)
|
||||
start_date = end_date - datetime.timedelta(days=filter_days)
|
||||
start_date = start_date.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
end_date = end_date.replace(
|
||||
hour=23, minute=59, second=59, microsecond=999999
|
||||
)
|
||||
group_format = "%Y-%m-%d"
|
||||
try:
|
||||
match_stage = {
|
||||
"$match": {
|
||||
"user": user,
|
||||
}
|
||||
}
|
||||
if api_key:
|
||||
match_stage["$match"]["api_key"] = api_key
|
||||
pipeline = [
|
||||
match_stage,
|
||||
{"$unwind": "$queries"},
|
||||
{
|
||||
"$match": {
|
||||
"queries.timestamp": {"$gte": start_date, "$lte": end_date}
|
||||
}
|
||||
},
|
||||
{
|
||||
"$group": {
|
||||
"_id": {
|
||||
"$dateToString": {
|
||||
"format": group_format,
|
||||
"date": "$queries.timestamp",
|
||||
}
|
||||
},
|
||||
"count": {"$sum": 1},
|
||||
}
|
||||
},
|
||||
{"$sort": {"_id": 1}},
|
||||
]
|
||||
|
||||
message_data = conversations_collection.aggregate(pipeline)
|
||||
|
||||
if filter_option == "last_hour":
|
||||
intervals = generate_minute_range(start_date, end_date)
|
||||
elif filter_option == "last_24_hour":
|
||||
intervals = generate_hourly_range(start_date, end_date)
|
||||
else:
|
||||
intervals = generate_date_range(start_date, end_date)
|
||||
daily_messages = {interval: 0 for interval in intervals}
|
||||
for row in rows:
|
||||
daily_messages[row._mapping["bucket"]] = int(row._mapping["count"])
|
||||
|
||||
for entry in message_data:
|
||||
daily_messages[entry["_id"]] = entry["count"]
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error getting message analytics: {err}", exc_info=True
|
||||
@@ -170,7 +152,13 @@ class GetTokenAnalytics(Resource):
|
||||
required=False,
|
||||
description="Filter option for analytics",
|
||||
default="last_30_days",
|
||||
enum=list(_FILTER_BUCKETS.keys()),
|
||||
enum=[
|
||||
"last_hour",
|
||||
"last_24_hour",
|
||||
"last_7_days",
|
||||
"last_15_days",
|
||||
"last_30_days",
|
||||
],
|
||||
),
|
||||
},
|
||||
)
|
||||
@@ -182,36 +170,123 @@ class GetTokenAnalytics(Resource):
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
data = request.get_json() or {}
|
||||
data = request.get_json()
|
||||
api_key_id = data.get("api_key_id")
|
||||
filter_option = data.get("filter_option", "last_30_days")
|
||||
|
||||
window = _range_for_filter(filter_option)
|
||||
if window is None:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Invalid option"}), 400
|
||||
)
|
||||
start_date, end_date, bucket_unit, _pg_fmt = window
|
||||
|
||||
try:
|
||||
with db_readonly() as conn:
|
||||
api_key = _resolve_api_key(conn, api_key_id, user)
|
||||
# ``bucketed_totals`` applies user_id / api_key filters
|
||||
# directly — no need to reshape a Mongo pipeline.
|
||||
rows = TokenUsageRepository(conn).bucketed_totals(
|
||||
bucket_unit=bucket_unit,
|
||||
user_id=user,
|
||||
api_key=api_key,
|
||||
timestamp_gte=start_date,
|
||||
timestamp_lt=end_date,
|
||||
)
|
||||
api_key = (
|
||||
agents_collection.find_one({"_id": ObjectId(api_key_id), "user": user})[
|
||||
"key"
|
||||
]
|
||||
if api_key_id
|
||||
else None
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error getting API key: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
end_date = datetime.datetime.now(datetime.timezone.utc)
|
||||
|
||||
intervals = _intervals_for_filter(filter_option, start_date, end_date)
|
||||
daily_token_usage = {interval: 0 for interval in intervals}
|
||||
for entry in rows:
|
||||
daily_token_usage[entry["bucket"]] = int(
|
||||
entry["prompt_tokens"] + entry["generated_tokens"]
|
||||
if filter_option == "last_hour":
|
||||
start_date = end_date - datetime.timedelta(hours=1)
|
||||
group_format = "%Y-%m-%d %H:%M:00"
|
||||
group_stage = {
|
||||
"$group": {
|
||||
"_id": {
|
||||
"minute": {
|
||||
"$dateToString": {
|
||||
"format": group_format,
|
||||
"date": "$timestamp",
|
||||
}
|
||||
}
|
||||
},
|
||||
"total_tokens": {
|
||||
"$sum": {"$add": ["$prompt_tokens", "$generated_tokens"]}
|
||||
},
|
||||
}
|
||||
}
|
||||
elif filter_option == "last_24_hour":
|
||||
start_date = end_date - datetime.timedelta(hours=24)
|
||||
group_format = "%Y-%m-%d %H:00"
|
||||
group_stage = {
|
||||
"$group": {
|
||||
"_id": {
|
||||
"hour": {
|
||||
"$dateToString": {
|
||||
"format": group_format,
|
||||
"date": "$timestamp",
|
||||
}
|
||||
}
|
||||
},
|
||||
"total_tokens": {
|
||||
"$sum": {"$add": ["$prompt_tokens", "$generated_tokens"]}
|
||||
},
|
||||
}
|
||||
}
|
||||
else:
|
||||
if filter_option in ["last_7_days", "last_15_days", "last_30_days"]:
|
||||
filter_days = (
|
||||
6
|
||||
if filter_option == "last_7_days"
|
||||
else (14 if filter_option == "last_15_days" else 29)
|
||||
)
|
||||
else:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Invalid option"}), 400
|
||||
)
|
||||
start_date = end_date - datetime.timedelta(days=filter_days)
|
||||
start_date = start_date.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
end_date = end_date.replace(
|
||||
hour=23, minute=59, second=59, microsecond=999999
|
||||
)
|
||||
group_format = "%Y-%m-%d"
|
||||
group_stage = {
|
||||
"$group": {
|
||||
"_id": {
|
||||
"day": {
|
||||
"$dateToString": {
|
||||
"format": group_format,
|
||||
"date": "$timestamp",
|
||||
}
|
||||
}
|
||||
},
|
||||
"total_tokens": {
|
||||
"$sum": {"$add": ["$prompt_tokens", "$generated_tokens"]}
|
||||
},
|
||||
}
|
||||
}
|
||||
try:
|
||||
match_stage = {
|
||||
"$match": {
|
||||
"user_id": user,
|
||||
"timestamp": {"$gte": start_date, "$lte": end_date},
|
||||
}
|
||||
}
|
||||
if api_key:
|
||||
match_stage["$match"]["api_key"] = api_key
|
||||
token_usage_data = token_usage_collection.aggregate(
|
||||
[
|
||||
match_stage,
|
||||
group_stage,
|
||||
{"$sort": {"_id": 1}},
|
||||
]
|
||||
)
|
||||
|
||||
if filter_option == "last_hour":
|
||||
intervals = generate_minute_range(start_date, end_date)
|
||||
elif filter_option == "last_24_hour":
|
||||
intervals = generate_hourly_range(start_date, end_date)
|
||||
else:
|
||||
intervals = generate_date_range(start_date, end_date)
|
||||
daily_token_usage = {interval: 0 for interval in intervals}
|
||||
|
||||
for entry in token_usage_data:
|
||||
if filter_option == "last_hour":
|
||||
daily_token_usage[entry["_id"]["minute"]] = entry["total_tokens"]
|
||||
elif filter_option == "last_24_hour":
|
||||
daily_token_usage[entry["_id"]["hour"]] = entry["total_tokens"]
|
||||
else:
|
||||
daily_token_usage[entry["_id"]["day"]] = entry["total_tokens"]
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error getting token analytics: {err}", exc_info=True
|
||||
@@ -232,7 +307,13 @@ class GetFeedbackAnalytics(Resource):
|
||||
required=False,
|
||||
description="Filter option for analytics",
|
||||
default="last_30_days",
|
||||
enum=list(_FILTER_BUCKETS.keys()),
|
||||
enum=[
|
||||
"last_hour",
|
||||
"last_24_hour",
|
||||
"last_7_days",
|
||||
"last_15_days",
|
||||
"last_30_days",
|
||||
],
|
||||
),
|
||||
},
|
||||
)
|
||||
@@ -244,64 +325,128 @@ class GetFeedbackAnalytics(Resource):
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
data = request.get_json() or {}
|
||||
data = request.get_json()
|
||||
api_key_id = data.get("api_key_id")
|
||||
filter_option = data.get("filter_option", "last_30_days")
|
||||
|
||||
window = _range_for_filter(filter_option)
|
||||
if window is None:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Invalid option"}), 400
|
||||
)
|
||||
start_date, end_date, _bucket_unit, pg_fmt = window
|
||||
|
||||
try:
|
||||
with db_readonly() as conn:
|
||||
api_key = _resolve_api_key(conn, api_key_id, user)
|
||||
|
||||
# Feedback lives inside the ``conversation_messages.feedback``
|
||||
# JSONB as ``{"text": "like"|"dislike", "timestamp": "..."}``.
|
||||
# There is no scalar ``feedback_timestamp`` column — extract
|
||||
# the timestamp from the JSONB and cast it to timestamptz for
|
||||
# the range filter + bucket grouping.
|
||||
clauses = [
|
||||
"c.user_id = :user_id",
|
||||
"m.feedback IS NOT NULL",
|
||||
"(m.feedback->>'timestamp')::timestamptz >= :start",
|
||||
"(m.feedback->>'timestamp')::timestamptz <= :end",
|
||||
api_key = (
|
||||
agents_collection.find_one({"_id": ObjectId(api_key_id), "user": user})[
|
||||
"key"
|
||||
]
|
||||
params: dict = {
|
||||
"user_id": user,
|
||||
"start": start_date,
|
||||
"end": end_date,
|
||||
"fmt": pg_fmt,
|
||||
}
|
||||
if api_key:
|
||||
clauses.append("c.api_key = :api_key")
|
||||
params["api_key"] = api_key
|
||||
where = " AND ".join(clauses)
|
||||
sql = (
|
||||
"SELECT to_char("
|
||||
"(m.feedback->>'timestamp')::timestamptz AT TIME ZONE 'UTC', :fmt"
|
||||
") AS bucket, "
|
||||
"SUM(CASE WHEN m.feedback->>'text' = 'like' THEN 1 ELSE 0 END) AS positive, "
|
||||
"SUM(CASE WHEN m.feedback->>'text' = 'dislike' THEN 1 ELSE 0 END) AS negative "
|
||||
"FROM conversation_messages m "
|
||||
"JOIN conversations c ON c.id = m.conversation_id "
|
||||
f"WHERE {where} "
|
||||
"GROUP BY bucket ORDER BY bucket ASC"
|
||||
)
|
||||
rows = conn.execute(_sql_text(sql), params).fetchall()
|
||||
if api_key_id
|
||||
else None
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error getting API key: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
end_date = datetime.datetime.now(datetime.timezone.utc)
|
||||
|
||||
intervals = _intervals_for_filter(filter_option, start_date, end_date)
|
||||
if filter_option == "last_hour":
|
||||
start_date = end_date - datetime.timedelta(hours=1)
|
||||
group_format = "%Y-%m-%d %H:%M:00"
|
||||
date_field = {
|
||||
"$dateToString": {
|
||||
"format": group_format,
|
||||
"date": "$queries.feedback_timestamp",
|
||||
}
|
||||
}
|
||||
elif filter_option == "last_24_hour":
|
||||
start_date = end_date - datetime.timedelta(hours=24)
|
||||
group_format = "%Y-%m-%d %H:00"
|
||||
date_field = {
|
||||
"$dateToString": {
|
||||
"format": group_format,
|
||||
"date": "$queries.feedback_timestamp",
|
||||
}
|
||||
}
|
||||
else:
|
||||
if filter_option in ["last_7_days", "last_15_days", "last_30_days"]:
|
||||
filter_days = (
|
||||
6
|
||||
if filter_option == "last_7_days"
|
||||
else (14 if filter_option == "last_15_days" else 29)
|
||||
)
|
||||
else:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Invalid option"}), 400
|
||||
)
|
||||
start_date = end_date - datetime.timedelta(days=filter_days)
|
||||
start_date = start_date.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
end_date = end_date.replace(
|
||||
hour=23, minute=59, second=59, microsecond=999999
|
||||
)
|
||||
group_format = "%Y-%m-%d"
|
||||
date_field = {
|
||||
"$dateToString": {
|
||||
"format": group_format,
|
||||
"date": "$queries.feedback_timestamp",
|
||||
}
|
||||
}
|
||||
try:
|
||||
match_stage = {
|
||||
"$match": {
|
||||
"queries.feedback_timestamp": {
|
||||
"$gte": start_date,
|
||||
"$lte": end_date,
|
||||
},
|
||||
"queries.feedback": {"$exists": True},
|
||||
}
|
||||
}
|
||||
if api_key:
|
||||
match_stage["$match"]["api_key"] = api_key
|
||||
pipeline = [
|
||||
match_stage,
|
||||
{"$unwind": "$queries"},
|
||||
{"$match": {"queries.feedback": {"$exists": True}}},
|
||||
{
|
||||
"$group": {
|
||||
"_id": {"time": date_field, "feedback": "$queries.feedback"},
|
||||
"count": {"$sum": 1},
|
||||
}
|
||||
},
|
||||
{
|
||||
"$group": {
|
||||
"_id": "$_id.time",
|
||||
"positive": {
|
||||
"$sum": {
|
||||
"$cond": [
|
||||
{"$eq": ["$_id.feedback", "LIKE"]},
|
||||
"$count",
|
||||
0,
|
||||
]
|
||||
}
|
||||
},
|
||||
"negative": {
|
||||
"$sum": {
|
||||
"$cond": [
|
||||
{"$eq": ["$_id.feedback", "DISLIKE"]},
|
||||
"$count",
|
||||
0,
|
||||
]
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
{"$sort": {"_id": 1}},
|
||||
]
|
||||
|
||||
feedback_data = conversations_collection.aggregate(pipeline)
|
||||
|
||||
if filter_option == "last_hour":
|
||||
intervals = generate_minute_range(start_date, end_date)
|
||||
elif filter_option == "last_24_hour":
|
||||
intervals = generate_hourly_range(start_date, end_date)
|
||||
else:
|
||||
intervals = generate_date_range(start_date, end_date)
|
||||
daily_feedback = {
|
||||
interval: {"positive": 0, "negative": 0} for interval in intervals
|
||||
}
|
||||
for row in rows:
|
||||
bucket = row._mapping["bucket"]
|
||||
daily_feedback[bucket] = {
|
||||
"positive": int(row._mapping["positive"] or 0),
|
||||
"negative": int(row._mapping["negative"] or 0),
|
||||
|
||||
for entry in feedback_data:
|
||||
daily_feedback[entry["_id"]] = {
|
||||
"positive": entry["positive"],
|
||||
"negative": entry["negative"],
|
||||
}
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
@@ -339,89 +484,47 @@ class GetUserLogs(Resource):
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
data = request.get_json() or {}
|
||||
data = request.get_json()
|
||||
page = int(data.get("page", 1))
|
||||
api_key_id = data.get("api_key_id")
|
||||
page_size = int(data.get("page_size", 10))
|
||||
skip = (page - 1) * page_size
|
||||
|
||||
try:
|
||||
with db_readonly() as conn:
|
||||
api_key = _resolve_api_key(conn, api_key_id, user)
|
||||
logs_repo = UserLogsRepository(conn)
|
||||
if api_key:
|
||||
# ``find_by_api_key`` filters on ``data->>'api_key'``
|
||||
# — the PG shape of the legacy top-level ``api_key``
|
||||
# filter. Paginate client-side using offset/limit.
|
||||
all_rows = logs_repo.find_by_api_key(api_key)
|
||||
offset = (page - 1) * page_size
|
||||
window = all_rows[offset: offset + page_size + 1]
|
||||
items = window
|
||||
else:
|
||||
items, has_more_flag = logs_repo.list_paginated(
|
||||
user_id=user,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
# list_paginated already trims to page_size and
|
||||
# returns has_more separately.
|
||||
results = [
|
||||
{
|
||||
"id": str(item.get("id") or item.get("_id")),
|
||||
"action": (item.get("data") or {}).get("action"),
|
||||
"level": (item.get("data") or {}).get("level"),
|
||||
"user": item.get("user_id"),
|
||||
"question": (item.get("data") or {}).get("question"),
|
||||
"sources": (item.get("data") or {}).get("sources"),
|
||||
"retriever_params": (item.get("data") or {}).get(
|
||||
"retriever_params"
|
||||
),
|
||||
"timestamp": (
|
||||
item["timestamp"].isoformat()
|
||||
if hasattr(item.get("timestamp"), "isoformat")
|
||||
else item.get("timestamp")
|
||||
),
|
||||
}
|
||||
for item in items
|
||||
]
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": True,
|
||||
"logs": results,
|
||||
"page": page,
|
||||
"page_size": page_size,
|
||||
"has_more": has_more_flag,
|
||||
}
|
||||
),
|
||||
200,
|
||||
)
|
||||
|
||||
has_more = len(items) > page_size
|
||||
items = items[:page_size]
|
||||
results = [
|
||||
{
|
||||
"id": str(item.get("id") or item.get("_id")),
|
||||
"action": (item.get("data") or {}).get("action"),
|
||||
"level": (item.get("data") or {}).get("level"),
|
||||
"user": item.get("user_id"),
|
||||
"question": (item.get("data") or {}).get("question"),
|
||||
"sources": (item.get("data") or {}).get("sources"),
|
||||
"retriever_params": (item.get("data") or {}).get(
|
||||
"retriever_params"
|
||||
),
|
||||
"timestamp": (
|
||||
item["timestamp"].isoformat()
|
||||
if hasattr(item.get("timestamp"), "isoformat")
|
||||
else item.get("timestamp")
|
||||
),
|
||||
}
|
||||
for item in items
|
||||
]
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error getting user logs: {err}", exc_info=True
|
||||
api_key = (
|
||||
agents_collection.find_one({"_id": ObjectId(api_key_id)})["key"]
|
||||
if api_key_id
|
||||
else None
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error getting API key: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
query = {"user": user}
|
||||
if api_key:
|
||||
query = {"api_key": api_key}
|
||||
items_cursor = (
|
||||
user_logs_collection.find(query)
|
||||
.sort("timestamp", -1)
|
||||
.skip(skip)
|
||||
.limit(page_size + 1)
|
||||
)
|
||||
items = list(items_cursor)
|
||||
|
||||
results = [
|
||||
{
|
||||
"id": str(item.get("_id")),
|
||||
"action": item.get("action"),
|
||||
"level": item.get("level"),
|
||||
"user": item.get("user"),
|
||||
"question": item.get("question"),
|
||||
"sources": item.get("sources"),
|
||||
"retriever_params": item.get("retriever_params"),
|
||||
"timestamp": item.get("timestamp"),
|
||||
}
|
||||
for item in items[:page_size]
|
||||
]
|
||||
|
||||
has_more = len(items) > page_size
|
||||
|
||||
return make_response(
|
||||
jsonify(
|
||||
|
||||
@@ -4,16 +4,13 @@ import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import uuid
|
||||
|
||||
from bson.objectid import ObjectId
|
||||
from flask import current_app, jsonify, make_response, request
|
||||
from flask_restx import fields, Namespace, Resource
|
||||
|
||||
from application.api import api
|
||||
from application.cache import get_redis_instance
|
||||
from application.core.settings import settings
|
||||
from application.storage.db.repositories.agents import AgentsRepository
|
||||
from application.storage.db.session import db_readonly
|
||||
from application.stt.constants import (
|
||||
SUPPORTED_AUDIO_EXTENSIONS,
|
||||
SUPPORTED_AUDIO_MIME_TYPES,
|
||||
@@ -51,13 +48,14 @@ def _resolve_authenticated_user():
|
||||
return safe_filename(decoded_token.get("sub"))
|
||||
|
||||
if api_key:
|
||||
with db_readonly() as conn:
|
||||
agent = AgentsRepository(conn).find_by_key(api_key)
|
||||
from application.api.user.base import agents_collection
|
||||
|
||||
agent = agents_collection.find_one({"key": api_key})
|
||||
if not agent:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Invalid API key"}), 401
|
||||
)
|
||||
return safe_filename(agent.get("user_id"))
|
||||
return safe_filename(agent.get("user"))
|
||||
|
||||
return None
|
||||
|
||||
@@ -159,7 +157,7 @@ class StoreAttachment(Resource):
|
||||
|
||||
for idx, file in enumerate(files):
|
||||
try:
|
||||
attachment_id = uuid.uuid4()
|
||||
attachment_id = ObjectId()
|
||||
original_filename = safe_filename(os.path.basename(file.filename))
|
||||
_enforce_uploaded_audio_size_limit(file, original_filename)
|
||||
relative_path = f"{settings.UPLOAD_FOLDER}/{user}/attachments/{str(attachment_id)}/{original_filename}"
|
||||
@@ -614,10 +612,6 @@ class LiveSpeechToTextFinish(Resource):
|
||||
class ServeImage(Resource):
|
||||
@api.doc(description="Serve an image from storage")
|
||||
def get(self, image_path):
|
||||
if ".." in image_path or image_path.startswith("/") or "\x00" in image_path:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Invalid image path"}), 400
|
||||
)
|
||||
try:
|
||||
from application.api.user.base import storage
|
||||
|
||||
@@ -635,10 +629,6 @@ class ServeImage(Resource):
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Image not found"}), 404
|
||||
)
|
||||
except ValueError:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Invalid image path"}), 400
|
||||
)
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error serving image: {e}")
|
||||
return make_response(
|
||||
|
||||
@@ -8,15 +8,13 @@ import uuid
|
||||
from functools import wraps
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from bson.objectid import ObjectId
|
||||
from flask import current_app, jsonify, make_response, Response
|
||||
from pymongo import ReturnDocument
|
||||
from werkzeug.utils import secure_filename
|
||||
|
||||
from sqlalchemy import text as _sql_text
|
||||
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.core.settings import settings
|
||||
from application.storage.db.base_repository import looks_like_uuid, row_to_dict
|
||||
from application.storage.db.repositories.users import UsersRepository
|
||||
from application.storage.db.session import db_readonly, db_session
|
||||
from application.storage.storage_creator import StorageCreator
|
||||
from application.vectorstore.vector_creator import VectorCreator
|
||||
|
||||
@@ -24,6 +22,56 @@ from application.vectorstore.vector_creator import VectorCreator
|
||||
storage = StorageCreator.get_storage()
|
||||
|
||||
|
||||
mongo = MongoDB.get_client()
|
||||
db = mongo[settings.MONGO_DB_NAME]
|
||||
|
||||
|
||||
conversations_collection = db["conversations"]
|
||||
sources_collection = db["sources"]
|
||||
prompts_collection = db["prompts"]
|
||||
feedback_collection = db["feedback"]
|
||||
agents_collection = db["agents"]
|
||||
agent_folders_collection = db["agent_folders"]
|
||||
token_usage_collection = db["token_usage"]
|
||||
shared_conversations_collections = db["shared_conversations"]
|
||||
users_collection = db["users"]
|
||||
user_logs_collection = db["user_logs"]
|
||||
user_tools_collection = db["user_tools"]
|
||||
attachments_collection = db["attachments"]
|
||||
workflow_runs_collection = db["workflow_runs"]
|
||||
workflows_collection = db["workflows"]
|
||||
workflow_nodes_collection = db["workflow_nodes"]
|
||||
workflow_edges_collection = db["workflow_edges"]
|
||||
|
||||
|
||||
try:
|
||||
agents_collection.create_index(
|
||||
[("shared", 1)],
|
||||
name="shared_index",
|
||||
background=True,
|
||||
)
|
||||
users_collection.create_index("user_id", unique=True)
|
||||
workflows_collection.create_index(
|
||||
[("user", 1)], name="workflow_user_index", background=True
|
||||
)
|
||||
workflow_nodes_collection.create_index(
|
||||
[("workflow_id", 1)], name="node_workflow_index", background=True
|
||||
)
|
||||
workflow_nodes_collection.create_index(
|
||||
[("workflow_id", 1), ("graph_version", 1)],
|
||||
name="node_workflow_graph_version_index",
|
||||
background=True,
|
||||
)
|
||||
workflow_edges_collection.create_index(
|
||||
[("workflow_id", 1)], name="edge_workflow_index", background=True
|
||||
)
|
||||
workflow_edges_collection.create_index(
|
||||
[("workflow_id", 1), ("graph_version", 1)],
|
||||
name="edge_workflow_graph_version_index",
|
||||
background=True,
|
||||
)
|
||||
except Exception as e:
|
||||
print("Error creating indexes:", e)
|
||||
current_dir = os.path.dirname(
|
||||
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
)
|
||||
@@ -55,95 +103,66 @@ def generate_date_range(start_date, end_date):
|
||||
|
||||
def ensure_user_doc(user_id):
|
||||
"""
|
||||
Ensure a Postgres ``users`` row exists for ``user_id``.
|
||||
|
||||
Returns the row as a dict with the shape legacy callers expect — in
|
||||
particular ``user_id`` and ``agent_preferences`` (with ``pinned`` and
|
||||
``shared_with_me`` list keys always present).
|
||||
Ensure user document exists with proper agent preferences structure.
|
||||
|
||||
Args:
|
||||
user_id: The user ID to ensure
|
||||
|
||||
Returns:
|
||||
The user document as a dict.
|
||||
The user document
|
||||
"""
|
||||
with db_session() as conn:
|
||||
user_doc = UsersRepository(conn).upsert(user_id)
|
||||
default_prefs = {
|
||||
"pinned": [],
|
||||
"shared_with_me": [],
|
||||
}
|
||||
|
||||
prefs = user_doc.get("agent_preferences") or {}
|
||||
if not isinstance(prefs, dict):
|
||||
prefs = {}
|
||||
prefs.setdefault("pinned", [])
|
||||
prefs.setdefault("shared_with_me", [])
|
||||
user_doc["agent_preferences"] = prefs
|
||||
user_doc = users_collection.find_one_and_update(
|
||||
{"user_id": user_id},
|
||||
{"$setOnInsert": {"agent_preferences": default_prefs}},
|
||||
upsert=True,
|
||||
return_document=ReturnDocument.AFTER,
|
||||
)
|
||||
|
||||
prefs = user_doc.get("agent_preferences", {})
|
||||
updates = {}
|
||||
if "pinned" not in prefs:
|
||||
updates["agent_preferences.pinned"] = []
|
||||
if "shared_with_me" not in prefs:
|
||||
updates["agent_preferences.shared_with_me"] = []
|
||||
if updates:
|
||||
users_collection.update_one({"user_id": user_id}, {"$set": updates})
|
||||
user_doc = users_collection.find_one({"user_id": user_id})
|
||||
return user_doc
|
||||
|
||||
|
||||
def resolve_tool_details(tool_ids):
|
||||
"""
|
||||
Resolve tool IDs to their display details.
|
||||
|
||||
Accepts either Postgres UUIDs or legacy Mongo ObjectId strings (mixed
|
||||
lists are supported — each id is looked up via ``get_any``, which
|
||||
resolves to whichever column matches). Unknown ids are silently
|
||||
skipped.
|
||||
Resolve tool IDs to their details.
|
||||
|
||||
Args:
|
||||
tool_ids: List of tool IDs (UUIDs or legacy Mongo ObjectId strings).
|
||||
tool_ids: List of tool IDs
|
||||
|
||||
Returns:
|
||||
List of tool details with ``id``, ``name``, and ``display_name``.
|
||||
List of tool details with id, name, and display_name
|
||||
"""
|
||||
if not tool_ids:
|
||||
return []
|
||||
|
||||
uuid_ids: list[str] = []
|
||||
legacy_ids: list[str] = []
|
||||
valid_ids = []
|
||||
for tid in tool_ids:
|
||||
if not tid:
|
||||
try:
|
||||
valid_ids.append(ObjectId(tid))
|
||||
except Exception:
|
||||
continue
|
||||
tid_str = str(tid)
|
||||
if looks_like_uuid(tid_str):
|
||||
uuid_ids.append(tid_str)
|
||||
else:
|
||||
legacy_ids.append(tid_str)
|
||||
|
||||
if not uuid_ids and not legacy_ids:
|
||||
return []
|
||||
|
||||
rows: list[dict] = []
|
||||
with db_readonly() as conn:
|
||||
if uuid_ids:
|
||||
result = conn.execute(
|
||||
_sql_text(
|
||||
"SELECT * FROM user_tools "
|
||||
"WHERE id = ANY(CAST(:ids AS uuid[]))"
|
||||
),
|
||||
{"ids": uuid_ids},
|
||||
)
|
||||
rows.extend(row_to_dict(r) for r in result.fetchall())
|
||||
if legacy_ids:
|
||||
result = conn.execute(
|
||||
_sql_text(
|
||||
"SELECT * FROM user_tools "
|
||||
"WHERE legacy_mongo_id = ANY(:ids)"
|
||||
),
|
||||
{"ids": legacy_ids},
|
||||
)
|
||||
rows.extend(row_to_dict(r) for r in result.fetchall())
|
||||
|
||||
tools = user_tools_collection.find(
|
||||
{"_id": {"$in": valid_ids}}
|
||||
) if valid_ids else []
|
||||
return [
|
||||
{
|
||||
"id": str(tool.get("id") or tool.get("legacy_mongo_id") or ""),
|
||||
"name": tool.get("name", "") or "",
|
||||
"display_name": (
|
||||
tool.get("custom_name")
|
||||
or tool.get("display_name")
|
||||
or tool.get("name", "")
|
||||
or ""
|
||||
),
|
||||
"id": str(tool["_id"]),
|
||||
"name": tool.get("name", ""),
|
||||
"display_name": tool.get("customName")
|
||||
or tool.get("displayName")
|
||||
or tool.get("name", ""),
|
||||
}
|
||||
for tool in rows
|
||||
for tool in tools
|
||||
]
|
||||
|
||||
|
||||
@@ -213,15 +232,14 @@ def require_agent(func):
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
from application.storage.db.repositories.agents import AgentsRepository
|
||||
|
||||
webhook_token = kwargs.get("webhook_token")
|
||||
if not webhook_token:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Webhook token missing"}), 400
|
||||
)
|
||||
with db_readonly() as conn:
|
||||
agent = AgentsRepository(conn).find_by_webhook_token(webhook_token)
|
||||
agent = agents_collection.find_one(
|
||||
{"incoming_webhook_token": webhook_token}, {"_id": 1}
|
||||
)
|
||||
if not agent:
|
||||
current_app.logger.warning(
|
||||
f"Webhook attempt with invalid token: {webhook_token}"
|
||||
@@ -230,7 +248,7 @@ def require_agent(func):
|
||||
jsonify({"success": False, "message": "Agent not found"}), 404
|
||||
)
|
||||
kwargs["agent"] = agent
|
||||
kwargs["agent_id_str"] = str(agent["id"])
|
||||
kwargs["agent_id_str"] = str(agent["_id"])
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
@@ -2,13 +2,12 @@
|
||||
|
||||
import datetime
|
||||
|
||||
from bson.objectid import ObjectId
|
||||
from flask import current_app, jsonify, make_response, request
|
||||
from flask_restx import fields, Namespace, Resource
|
||||
|
||||
from application.api import api
|
||||
from application.storage.db.repositories.attachments import AttachmentsRepository
|
||||
from application.storage.db.repositories.conversations import ConversationsRepository
|
||||
from application.storage.db.session import db_readonly, db_session
|
||||
from application.api.user.base import attachments_collection, conversations_collection
|
||||
from application.utils import check_required_fields
|
||||
|
||||
conversations_ns = Namespace(
|
||||
@@ -31,13 +30,10 @@ class DeleteConversation(Resource):
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "ID is required"}), 400
|
||||
)
|
||||
user_id = decoded_token["sub"]
|
||||
try:
|
||||
with db_session() as conn:
|
||||
repo = ConversationsRepository(conn)
|
||||
conv = repo.get_any(conversation_id, user_id)
|
||||
if conv is not None:
|
||||
repo.delete(str(conv["id"]), user_id)
|
||||
conversations_collection.delete_one(
|
||||
{"_id": ObjectId(conversation_id), "user": decoded_token["sub"]}
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error deleting conversation: {err}", exc_info=True
|
||||
@@ -57,8 +53,7 @@ class DeleteAllConversations(Resource):
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user_id = decoded_token.get("sub")
|
||||
try:
|
||||
with db_session() as conn:
|
||||
ConversationsRepository(conn).delete_all_for_user(user_id)
|
||||
conversations_collection.delete_many({"user": user_id})
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error deleting all conversations: {err}", exc_info=True
|
||||
@@ -76,21 +71,26 @@ class GetConversations(Resource):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user_id = decoded_token.get("sub")
|
||||
try:
|
||||
with db_readonly() as conn:
|
||||
conversations = ConversationsRepository(conn).list_for_user(
|
||||
user_id, limit=30
|
||||
conversations = (
|
||||
conversations_collection.find(
|
||||
{
|
||||
"$or": [
|
||||
{"api_key": {"$exists": False}},
|
||||
{"agent_id": {"$exists": True}},
|
||||
],
|
||||
"user": decoded_token.get("sub"),
|
||||
}
|
||||
)
|
||||
.sort("date", -1)
|
||||
.limit(30)
|
||||
)
|
||||
|
||||
list_conversations = [
|
||||
{
|
||||
"id": str(conversation["id"]),
|
||||
"id": str(conversation["_id"]),
|
||||
"name": conversation["name"],
|
||||
"agent_id": (
|
||||
str(conversation["agent_id"])
|
||||
if conversation.get("agent_id")
|
||||
else None
|
||||
),
|
||||
"agent_id": conversation.get("agent_id", None),
|
||||
"is_shared_usage": conversation.get("is_shared_usage", False),
|
||||
"shared_token": conversation.get("shared_token", None),
|
||||
}
|
||||
@@ -119,67 +119,38 @@ class GetSingleConversation(Resource):
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "ID is required"}), 400
|
||||
)
|
||||
user_id = decoded_token.get("sub")
|
||||
try:
|
||||
with db_readonly() as conn:
|
||||
repo = ConversationsRepository(conn)
|
||||
conversation = repo.get_any(conversation_id, user_id)
|
||||
if not conversation:
|
||||
return make_response(jsonify({"status": "not found"}), 404)
|
||||
conv_pg_id = str(conversation["id"])
|
||||
messages = repo.get_messages(conv_pg_id)
|
||||
conversation = conversations_collection.find_one(
|
||||
{"_id": ObjectId(conversation_id), "user": decoded_token.get("sub")}
|
||||
)
|
||||
if not conversation:
|
||||
return make_response(jsonify({"status": "not found"}), 404)
|
||||
# Process queries to include attachment names
|
||||
|
||||
# Resolve attachment details (id, fileName) for each message.
|
||||
attachments_repo = AttachmentsRepository(conn)
|
||||
queries = []
|
||||
for msg in messages:
|
||||
query = {
|
||||
"prompt": msg.get("prompt"),
|
||||
"response": msg.get("response"),
|
||||
"thought": msg.get("thought"),
|
||||
"sources": msg.get("sources") or [],
|
||||
"tool_calls": msg.get("tool_calls") or [],
|
||||
"timestamp": msg.get("timestamp"),
|
||||
"model_id": msg.get("model_id"),
|
||||
}
|
||||
if msg.get("metadata"):
|
||||
query["metadata"] = msg["metadata"]
|
||||
# Feedback on conversation_messages is a JSONB blob with
|
||||
# shape {"text": <str>, "timestamp": <iso>}. The legacy
|
||||
# frontend consumed a flat scalar feedback string, so
|
||||
# unwrap the ``text`` field for compat.
|
||||
feedback = msg.get("feedback")
|
||||
if feedback is not None:
|
||||
if isinstance(feedback, dict):
|
||||
query["feedback"] = feedback.get("text")
|
||||
if feedback.get("timestamp"):
|
||||
query["feedback_timestamp"] = feedback["timestamp"]
|
||||
else:
|
||||
query["feedback"] = feedback
|
||||
attachments = msg.get("attachments") or []
|
||||
if attachments:
|
||||
attachment_details = []
|
||||
for attachment_id in attachments:
|
||||
try:
|
||||
att = attachments_repo.get_any(
|
||||
str(attachment_id), user_id
|
||||
queries = conversation["queries"]
|
||||
for query in queries:
|
||||
if "attachments" in query and query["attachments"]:
|
||||
attachment_details = []
|
||||
for attachment_id in query["attachments"]:
|
||||
try:
|
||||
attachment = attachments_collection.find_one(
|
||||
{"_id": ObjectId(attachment_id)}
|
||||
)
|
||||
if attachment:
|
||||
attachment_details.append(
|
||||
{
|
||||
"id": str(attachment["_id"]),
|
||||
"fileName": attachment.get(
|
||||
"filename", "Unknown file"
|
||||
),
|
||||
}
|
||||
)
|
||||
if att:
|
||||
attachment_details.append(
|
||||
{
|
||||
"id": str(att["id"]),
|
||||
"fileName": att.get(
|
||||
"filename", "Unknown file"
|
||||
),
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
current_app.logger.error(
|
||||
f"Error retrieving attachment {attachment_id}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
query["attachments"] = attachment_details
|
||||
queries.append(query)
|
||||
except Exception as e:
|
||||
current_app.logger.error(
|
||||
f"Error retrieving attachment {attachment_id}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
query["attachments"] = attachment_details
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error retrieving conversation: {err}", exc_info=True
|
||||
@@ -187,9 +158,7 @@ class GetSingleConversation(Resource):
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
data = {
|
||||
"queries": queries,
|
||||
"agent_id": (
|
||||
str(conversation["agent_id"]) if conversation.get("agent_id") else None
|
||||
),
|
||||
"agent_id": conversation.get("agent_id"),
|
||||
"is_shared_usage": conversation.get("is_shared_usage", False),
|
||||
"shared_token": conversation.get("shared_token", None),
|
||||
}
|
||||
@@ -221,13 +190,11 @@ class UpdateConversationName(Resource):
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
user_id = decoded_token.get("sub")
|
||||
try:
|
||||
with db_session() as conn:
|
||||
repo = ConversationsRepository(conn)
|
||||
conv = repo.get_any(data["id"], user_id)
|
||||
if conv is not None:
|
||||
repo.rename(str(conv["id"]), user_id, data["name"])
|
||||
conversations_collection.update_one(
|
||||
{"_id": ObjectId(data["id"]), "user": decoded_token.get("sub")},
|
||||
{"$set": {"name": data["name"]}},
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error updating conversation name: {err}", exc_info=True
|
||||
@@ -270,33 +237,43 @@ class SubmitFeedback(Resource):
|
||||
missing_fields = check_required_fields(data, required_fields)
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
user_id = decoded_token.get("sub")
|
||||
feedback_value = data["feedback"]
|
||||
question_index = int(data["question_index"])
|
||||
# Normalize string feedback to lowercase so analytics queries
|
||||
# (which match 'like'/'dislike') count rows correctly. Tolerate
|
||||
# legacy uppercase clients on ingest. Non-string values pass through.
|
||||
if isinstance(feedback_value, str):
|
||||
feedback_value = feedback_value.lower()
|
||||
feedback_payload = (
|
||||
None
|
||||
if feedback_value is None
|
||||
else {
|
||||
"text": feedback_value,
|
||||
"timestamp": datetime.datetime.now(
|
||||
datetime.timezone.utc
|
||||
).isoformat(),
|
||||
}
|
||||
)
|
||||
try:
|
||||
with db_session() as conn:
|
||||
repo = ConversationsRepository(conn)
|
||||
conv = repo.get_any(data["conversation_id"], user_id)
|
||||
if conv is None:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Not found"}), 404
|
||||
)
|
||||
repo.set_feedback(str(conv["id"]), question_index, feedback_payload)
|
||||
if data["feedback"] is None:
|
||||
# Remove feedback and feedback_timestamp if feedback is null
|
||||
|
||||
conversations_collection.update_one(
|
||||
{
|
||||
"_id": ObjectId(data["conversation_id"]),
|
||||
"user": decoded_token.get("sub"),
|
||||
f"queries.{data['question_index']}": {"$exists": True},
|
||||
},
|
||||
{
|
||||
"$unset": {
|
||||
f"queries.{data['question_index']}.feedback": "",
|
||||
f"queries.{data['question_index']}.feedback_timestamp": "",
|
||||
}
|
||||
},
|
||||
)
|
||||
else:
|
||||
# Set feedback and feedback_timestamp if feedback has a value
|
||||
|
||||
conversations_collection.update_one(
|
||||
{
|
||||
"_id": ObjectId(data["conversation_id"]),
|
||||
"user": decoded_token.get("sub"),
|
||||
f"queries.{data['question_index']}": {"$exists": True},
|
||||
},
|
||||
{
|
||||
"$set": {
|
||||
f"queries.{data['question_index']}.feedback": data[
|
||||
"feedback"
|
||||
],
|
||||
f"queries.{data['question_index']}.feedback_timestamp": datetime.datetime.now(
|
||||
datetime.timezone.utc
|
||||
),
|
||||
}
|
||||
},
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error submitting feedback: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
|
||||
@@ -2,13 +2,12 @@
|
||||
|
||||
import os
|
||||
|
||||
from bson.objectid import ObjectId
|
||||
from flask import current_app, jsonify, make_response, request
|
||||
from flask_restx import fields, Namespace, Resource
|
||||
|
||||
from application.api import api
|
||||
from application.api.user.base import current_dir
|
||||
from application.storage.db.repositories.prompts import PromptsRepository
|
||||
from application.storage.db.session import db_readonly, db_session
|
||||
from application.api.user.base import current_dir, prompts_collection
|
||||
from application.utils import check_required_fields
|
||||
|
||||
prompts_ns = Namespace(
|
||||
@@ -41,9 +40,15 @@ class CreatePrompt(Resource):
|
||||
return missing_fields
|
||||
user = decoded_token.get("sub")
|
||||
try:
|
||||
with db_session() as conn:
|
||||
prompt = PromptsRepository(conn).create(user, data["name"], data["content"])
|
||||
new_id = str(prompt["id"])
|
||||
|
||||
resp = prompts_collection.insert_one(
|
||||
{
|
||||
"name": data["name"],
|
||||
"content": data["content"],
|
||||
"user": user,
|
||||
}
|
||||
)
|
||||
new_id = str(resp.inserted_id)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error creating prompt: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
@@ -59,17 +64,17 @@ class GetPrompts(Resource):
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
try:
|
||||
with db_readonly() as conn:
|
||||
prompts = PromptsRepository(conn).list_for_user(user)
|
||||
prompts = prompts_collection.find({"user": user})
|
||||
list_prompts = [
|
||||
{"id": "default", "name": "default", "type": "public"},
|
||||
{"id": "creative", "name": "creative", "type": "public"},
|
||||
{"id": "strict", "name": "strict", "type": "public"},
|
||||
]
|
||||
|
||||
for prompt in prompts:
|
||||
list_prompts.append(
|
||||
{
|
||||
"id": str(prompt["id"]),
|
||||
"id": str(prompt["_id"]),
|
||||
"name": prompt["name"],
|
||||
"type": "private",
|
||||
}
|
||||
@@ -114,12 +119,9 @@ class GetSinglePrompt(Resource):
|
||||
) as f:
|
||||
chat_reduce_strict = f.read()
|
||||
return make_response(jsonify({"content": chat_reduce_strict}), 200)
|
||||
with db_readonly() as conn:
|
||||
prompt = PromptsRepository(conn).get_any(prompt_id, user)
|
||||
if not prompt:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Prompt not found"}), 404
|
||||
)
|
||||
prompt = prompts_collection.find_one(
|
||||
{"_id": ObjectId(prompt_id), "user": user}
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error retrieving prompt: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
@@ -146,15 +148,7 @@ class DeletePrompt(Resource):
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
try:
|
||||
with db_session() as conn:
|
||||
repo = PromptsRepository(conn)
|
||||
prompt = repo.get_any(data["id"], user)
|
||||
if not prompt:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Prompt not found"}),
|
||||
404,
|
||||
)
|
||||
repo.delete(str(prompt["id"]), user)
|
||||
prompts_collection.delete_one({"_id": ObjectId(data["id"]), "user": user})
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error deleting prompt: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
@@ -187,15 +181,10 @@ class UpdatePrompt(Resource):
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
try:
|
||||
with db_session() as conn:
|
||||
repo = PromptsRepository(conn)
|
||||
prompt = repo.get_any(data["id"], user)
|
||||
if not prompt:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Prompt not found"}),
|
||||
404,
|
||||
)
|
||||
repo.update(str(prompt["id"]), user, data["name"], data["content"])
|
||||
prompts_collection.update_one(
|
||||
{"_id": ObjectId(data["id"]), "user": user},
|
||||
{"$set": {"name": data["name"], "content": data["content"]}},
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error updating prompt: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
|
||||
@@ -2,126 +2,26 @@
|
||||
|
||||
import uuid
|
||||
|
||||
from bson.binary import Binary, UuidRepresentation
|
||||
from bson.dbref import DBRef
|
||||
from bson.objectid import ObjectId
|
||||
from flask import current_app, jsonify, make_response, request
|
||||
from flask_restx import fields, inputs, Namespace, Resource
|
||||
from sqlalchemy import text as _sql_text
|
||||
|
||||
from application.api import api
|
||||
from application.storage.db.base_repository import looks_like_uuid
|
||||
from application.storage.db.repositories.agents import AgentsRepository
|
||||
from application.storage.db.repositories.attachments import AttachmentsRepository
|
||||
from application.storage.db.repositories.conversations import ConversationsRepository
|
||||
from application.storage.db.repositories.shared_conversations import (
|
||||
SharedConversationsRepository,
|
||||
from application.api.user.base import (
|
||||
agents_collection,
|
||||
attachments_collection,
|
||||
conversations_collection,
|
||||
shared_conversations_collections,
|
||||
)
|
||||
from application.storage.db.session import db_readonly, db_session
|
||||
from application.utils import check_required_fields
|
||||
|
||||
|
||||
sharing_ns = Namespace(
|
||||
"sharing", description="Conversation sharing operations", path="/api"
|
||||
)
|
||||
|
||||
|
||||
def _resolve_prompt_pg_id(conn, prompt_id_raw, user_id):
|
||||
"""Translate an incoming prompt id (UUID or legacy Mongo ObjectId) to a PG UUID.
|
||||
|
||||
Scoped by ``user_id`` so a caller can't link another user's prompt
|
||||
into their share record. Returns ``None`` for sentinel values
|
||||
(``"default"``) or unresolved ids.
|
||||
"""
|
||||
if not prompt_id_raw or prompt_id_raw == "default":
|
||||
return None
|
||||
value = str(prompt_id_raw)
|
||||
# Already UUID — trust it but still require ownership. A shape-gate
|
||||
# (rather than a loose ``len == 36 and '-' in value`` check) keeps
|
||||
# non-UUID input out of ``CAST(:pid AS uuid)``; the cast would raise
|
||||
# and poison the readonly transaction otherwise.
|
||||
if looks_like_uuid(value):
|
||||
row = conn.execute(
|
||||
_sql_text(
|
||||
"SELECT id FROM prompts WHERE id = CAST(:pid AS uuid) "
|
||||
"AND user_id = :uid"
|
||||
),
|
||||
{"pid": value, "uid": user_id},
|
||||
).fetchone()
|
||||
return str(row[0]) if row else None
|
||||
# Legacy Mongo ObjectId fallback.
|
||||
row = conn.execute(
|
||||
_sql_text(
|
||||
"SELECT id FROM prompts WHERE legacy_mongo_id = :pid "
|
||||
"AND user_id = :uid"
|
||||
),
|
||||
{"pid": value, "uid": user_id},
|
||||
).fetchone()
|
||||
return str(row[0]) if row else None
|
||||
|
||||
|
||||
def _resolve_source_pg_id(conn, source_raw):
|
||||
"""Translate a source id (UUID or legacy Mongo ObjectId) to a PG UUID."""
|
||||
if not source_raw:
|
||||
return None
|
||||
value = str(source_raw)
|
||||
# See ``_resolve_prompt_pg_id`` for the shape-gate rationale.
|
||||
if looks_like_uuid(value):
|
||||
row = conn.execute(
|
||||
_sql_text(
|
||||
"SELECT id FROM sources WHERE id = CAST(:sid AS uuid)"
|
||||
),
|
||||
{"sid": value},
|
||||
).fetchone()
|
||||
return str(row[0]) if row else None
|
||||
row = conn.execute(
|
||||
_sql_text("SELECT id FROM sources WHERE legacy_mongo_id = :sid"),
|
||||
{"sid": value},
|
||||
).fetchone()
|
||||
return str(row[0]) if row else None
|
||||
|
||||
|
||||
def _find_reusable_share_agent(
|
||||
conn, user_id, *, prompt_pg_id, chunks, source_pg_id, retriever,
|
||||
):
|
||||
"""Find an existing share-as-agent key row matching these parameters.
|
||||
|
||||
Mirrors the legacy Mongo ``agents_collection.find_one`` pre-existence
|
||||
check. Used to reuse an api key across repeated shares of the same
|
||||
conversation with the same prompt/chunks/source/retriever.
|
||||
"""
|
||||
clauses = ["user_id = :uid", "key IS NOT NULL"]
|
||||
params: dict = {"uid": user_id}
|
||||
if prompt_pg_id is None:
|
||||
clauses.append("prompt_id IS NULL")
|
||||
else:
|
||||
clauses.append("prompt_id = CAST(:pid AS uuid)")
|
||||
params["pid"] = prompt_pg_id
|
||||
if chunks is None:
|
||||
clauses.append("chunks IS NULL")
|
||||
else:
|
||||
clauses.append("chunks = :chunks")
|
||||
params["chunks"] = int(chunks)
|
||||
if source_pg_id is None:
|
||||
clauses.append("source_id IS NULL")
|
||||
else:
|
||||
clauses.append("source_id = CAST(:sid AS uuid)")
|
||||
params["sid"] = source_pg_id
|
||||
if retriever is None:
|
||||
clauses.append("retriever IS NULL")
|
||||
else:
|
||||
clauses.append("retriever = :retr")
|
||||
params["retr"] = retriever
|
||||
sql = (
|
||||
"SELECT * FROM agents WHERE "
|
||||
+ " AND ".join(clauses)
|
||||
+ " LIMIT 1"
|
||||
)
|
||||
row = conn.execute(_sql_text(sql), params).fetchone()
|
||||
if row is None:
|
||||
return None
|
||||
mapping = dict(row._mapping)
|
||||
mapping["id"] = str(mapping["id"]) if mapping.get("id") else None
|
||||
return mapping
|
||||
|
||||
|
||||
@sharing_ns.route("/share")
|
||||
class ShareConversation(Resource):
|
||||
share_conversation_model = api.model(
|
||||
@@ -156,94 +56,146 @@ class ShareConversation(Resource):
|
||||
conversation_id = data["conversation_id"]
|
||||
|
||||
try:
|
||||
with db_session() as conn:
|
||||
conv_repo = ConversationsRepository(conn)
|
||||
shared_repo = SharedConversationsRepository(conn)
|
||||
agents_repo = AgentsRepository(conn)
|
||||
conversation = conversations_collection.find_one(
|
||||
{"_id": ObjectId(conversation_id)}
|
||||
)
|
||||
if conversation is None:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"status": "error",
|
||||
"message": "Conversation does not exist",
|
||||
}
|
||||
),
|
||||
404,
|
||||
)
|
||||
current_n_queries = len(conversation["queries"])
|
||||
explicit_binary = Binary.from_uuid(
|
||||
uuid.uuid4(), UuidRepresentation.STANDARD
|
||||
)
|
||||
|
||||
conversation = conv_repo.get_any(conversation_id, user)
|
||||
if conversation is None:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"status": "error",
|
||||
"message": "Conversation does not exist",
|
||||
}
|
||||
),
|
||||
404,
|
||||
if is_promptable:
|
||||
prompt_id = data.get("prompt_id", "default")
|
||||
chunks = data.get("chunks", "2")
|
||||
|
||||
name = conversation["name"] + "(shared)"
|
||||
new_api_key_data = {
|
||||
"prompt_id": prompt_id,
|
||||
"chunks": chunks,
|
||||
"user": user,
|
||||
}
|
||||
|
||||
if "source" in data and ObjectId.is_valid(data["source"]):
|
||||
new_api_key_data["source"] = DBRef(
|
||||
"sources", ObjectId(data["source"])
|
||||
)
|
||||
conv_pg_id = str(conversation["id"])
|
||||
current_n_queries = conv_repo.message_count(conv_pg_id)
|
||||
|
||||
if is_promptable:
|
||||
prompt_id_raw = data.get("prompt_id", "default")
|
||||
chunks_raw = data.get("chunks", "2")
|
||||
try:
|
||||
chunks_int = int(chunks_raw) if chunks_raw not in (None, "") else None
|
||||
except (TypeError, ValueError):
|
||||
chunks_int = None
|
||||
|
||||
prompt_pg_id = _resolve_prompt_pg_id(conn, prompt_id_raw, user)
|
||||
source_pg_id = _resolve_source_pg_id(conn, data.get("source"))
|
||||
retriever = data.get("retriever")
|
||||
|
||||
reusable = _find_reusable_share_agent(
|
||||
conn, user,
|
||||
prompt_pg_id=prompt_pg_id,
|
||||
chunks=chunks_int,
|
||||
source_pg_id=source_pg_id,
|
||||
retriever=retriever,
|
||||
if "retriever" in data:
|
||||
new_api_key_data["retriever"] = data["retriever"]
|
||||
pre_existing_api_document = agents_collection.find_one(new_api_key_data)
|
||||
if pre_existing_api_document:
|
||||
api_uuid = pre_existing_api_document["key"]
|
||||
pre_existing = shared_conversations_collections.find_one(
|
||||
{
|
||||
"conversation_id": ObjectId(conversation_id),
|
||||
"isPromptable": is_promptable,
|
||||
"first_n_queries": current_n_queries,
|
||||
"user": user,
|
||||
"api_key": api_uuid,
|
||||
}
|
||||
)
|
||||
if reusable:
|
||||
api_uuid = reusable.get("key")
|
||||
else:
|
||||
api_uuid = str(uuid.uuid4())
|
||||
name = (conversation.get("name") or "") + "(shared)"
|
||||
agents_repo.create(
|
||||
user,
|
||||
name,
|
||||
"published",
|
||||
key=api_uuid,
|
||||
retriever=retriever,
|
||||
chunks=chunks_int,
|
||||
prompt_id=prompt_pg_id,
|
||||
source_id=source_pg_id,
|
||||
if pre_existing is not None:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": True,
|
||||
"identifier": str(pre_existing["uuid"].as_uuid()),
|
||||
}
|
||||
),
|
||||
200,
|
||||
)
|
||||
else:
|
||||
shared_conversations_collections.insert_one(
|
||||
{
|
||||
"uuid": explicit_binary,
|
||||
"conversation_id": ObjectId(conversation_id),
|
||||
"isPromptable": is_promptable,
|
||||
"first_n_queries": current_n_queries,
|
||||
"user": user,
|
||||
"api_key": api_uuid,
|
||||
}
|
||||
)
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": True,
|
||||
"identifier": str(explicit_binary.as_uuid()),
|
||||
}
|
||||
),
|
||||
201,
|
||||
)
|
||||
else:
|
||||
api_uuid = str(uuid.uuid4())
|
||||
new_api_key_data["key"] = api_uuid
|
||||
new_api_key_data["name"] = name
|
||||
|
||||
share = shared_repo.get_or_create(
|
||||
conv_pg_id,
|
||||
user,
|
||||
is_promptable=True,
|
||||
first_n_queries=current_n_queries,
|
||||
api_key=api_uuid,
|
||||
prompt_id=prompt_pg_id,
|
||||
chunks=chunks_int,
|
||||
if "source" in data and ObjectId.is_valid(data["source"]):
|
||||
new_api_key_data["source"] = DBRef(
|
||||
"sources", ObjectId(data["source"])
|
||||
)
|
||||
if "retriever" in data:
|
||||
new_api_key_data["retriever"] = data["retriever"]
|
||||
agents_collection.insert_one(new_api_key_data)
|
||||
shared_conversations_collections.insert_one(
|
||||
{
|
||||
"uuid": explicit_binary,
|
||||
"conversation_id": ObjectId(conversation_id),
|
||||
"isPromptable": is_promptable,
|
||||
"first_n_queries": current_n_queries,
|
||||
"user": user,
|
||||
"api_key": api_uuid,
|
||||
}
|
||||
)
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": True,
|
||||
"identifier": str(share["uuid"]),
|
||||
"identifier": str(explicit_binary.as_uuid()),
|
||||
}
|
||||
),
|
||||
201 if reusable is None else 200,
|
||||
201,
|
||||
)
|
||||
|
||||
# Non-promptable share path.
|
||||
share = shared_repo.get_or_create(
|
||||
conv_pg_id,
|
||||
user,
|
||||
is_promptable=False,
|
||||
first_n_queries=current_n_queries,
|
||||
api_key=None,
|
||||
)
|
||||
pre_existing = shared_conversations_collections.find_one(
|
||||
{
|
||||
"conversation_id": ObjectId(conversation_id),
|
||||
"isPromptable": is_promptable,
|
||||
"first_n_queries": current_n_queries,
|
||||
"user": user,
|
||||
}
|
||||
)
|
||||
if pre_existing is not None:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": True,
|
||||
"identifier": str(share["uuid"]),
|
||||
"identifier": str(pre_existing["uuid"].as_uuid()),
|
||||
}
|
||||
),
|
||||
200,
|
||||
)
|
||||
else:
|
||||
shared_conversations_collections.insert_one(
|
||||
{
|
||||
"uuid": explicit_binary,
|
||||
"conversation_id": ObjectId(conversation_id),
|
||||
"isPromptable": is_promptable,
|
||||
"first_n_queries": current_n_queries,
|
||||
"user": user,
|
||||
}
|
||||
)
|
||||
return make_response(
|
||||
jsonify(
|
||||
{"success": True, "identifier": str(explicit_binary.as_uuid())}
|
||||
),
|
||||
201,
|
||||
)
|
||||
except Exception as err:
|
||||
@@ -258,13 +210,37 @@ class GetPubliclySharedConversations(Resource):
|
||||
@api.doc(description="Get publicly shared conversations by identifier")
|
||||
def get(self, identifier: str):
|
||||
try:
|
||||
with db_readonly() as conn:
|
||||
shared_repo = SharedConversationsRepository(conn)
|
||||
conv_repo = ConversationsRepository(conn)
|
||||
attach_repo = AttachmentsRepository(conn)
|
||||
query_uuid = Binary.from_uuid(
|
||||
uuid.UUID(identifier), UuidRepresentation.STANDARD
|
||||
)
|
||||
shared = shared_conversations_collections.find_one({"uuid": query_uuid})
|
||||
conversation_queries = []
|
||||
|
||||
shared = shared_repo.find_by_uuid(identifier)
|
||||
if not shared or not shared.get("conversation_id"):
|
||||
if (
|
||||
shared
|
||||
and "conversation_id" in shared
|
||||
):
|
||||
# Handle DBRef (legacy), ObjectId, dict, and string formats for conversation_id
|
||||
conversation_id = shared["conversation_id"]
|
||||
if isinstance(conversation_id, DBRef):
|
||||
conversation_id = conversation_id.id
|
||||
elif isinstance(conversation_id, dict):
|
||||
# Handle dict representation of DBRef (e.g., {"$ref": "...", "$id": "..."})
|
||||
if "$id" in conversation_id:
|
||||
conv_id = conversation_id["$id"]
|
||||
# $id might be a dict like {"$oid": "..."} or a string
|
||||
if isinstance(conv_id, dict) and "$oid" in conv_id:
|
||||
conversation_id = ObjectId(conv_id["$oid"])
|
||||
else:
|
||||
conversation_id = ObjectId(conv_id)
|
||||
elif "_id" in conversation_id:
|
||||
conversation_id = ObjectId(conversation_id["_id"])
|
||||
elif isinstance(conversation_id, str):
|
||||
conversation_id = ObjectId(conversation_id)
|
||||
conversation = conversations_collection.find_one(
|
||||
{"_id": conversation_id}
|
||||
)
|
||||
if conversation is None:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
@@ -274,60 +250,22 @@ class GetPubliclySharedConversations(Resource):
|
||||
),
|
||||
404,
|
||||
)
|
||||
conv_pg_id = str(shared["conversation_id"])
|
||||
owner_user = shared.get("user_id")
|
||||
conversation_queries = conversation["queries"][
|
||||
: (shared["first_n_queries"])
|
||||
]
|
||||
|
||||
conversation = conv_repo.get_owned(conv_pg_id, owner_user) if owner_user else None
|
||||
if conversation is None:
|
||||
# Fall back to any-user lookup in case shared row's
|
||||
# user_id is missing — still keyed by PG UUID.
|
||||
row = conn.execute(
|
||||
_sql_text(
|
||||
"SELECT * FROM conversations WHERE id = CAST(:id AS uuid)"
|
||||
),
|
||||
{"id": conv_pg_id},
|
||||
).fetchone()
|
||||
if row is None:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"error": "might have broken url or the conversation does not exist",
|
||||
}
|
||||
),
|
||||
404,
|
||||
)
|
||||
conversation = dict(row._mapping)
|
||||
|
||||
messages = conv_repo.get_messages(conv_pg_id)
|
||||
first_n = shared.get("first_n_queries") or 0
|
||||
conversation_queries = []
|
||||
for msg in messages[:first_n]:
|
||||
query = {
|
||||
"prompt": msg.get("prompt"),
|
||||
"response": msg.get("response"),
|
||||
"thought": msg.get("thought"),
|
||||
"sources": msg.get("sources") or [],
|
||||
"tool_calls": msg.get("tool_calls") or [],
|
||||
"timestamp": (
|
||||
msg["timestamp"].isoformat()
|
||||
if hasattr(msg.get("timestamp"), "isoformat")
|
||||
else msg.get("timestamp")
|
||||
),
|
||||
"feedback": msg.get("feedback"),
|
||||
}
|
||||
attachments = msg.get("attachments") or []
|
||||
if attachments:
|
||||
for query in conversation_queries:
|
||||
if "attachments" in query and query["attachments"]:
|
||||
attachment_details = []
|
||||
for attachment_id in attachments:
|
||||
for attachment_id in query["attachments"]:
|
||||
try:
|
||||
attachment = attach_repo.get_any(
|
||||
str(attachment_id), owner_user,
|
||||
) if owner_user else None
|
||||
attachment = attachments_collection.find_one(
|
||||
{"_id": ObjectId(attachment_id)}
|
||||
)
|
||||
if attachment:
|
||||
attachment_details.append(
|
||||
{
|
||||
"id": str(attachment["id"]),
|
||||
"id": str(attachment["_id"]),
|
||||
"fileName": attachment.get(
|
||||
"filename", "Unknown file"
|
||||
),
|
||||
@@ -339,23 +277,26 @@ class GetPubliclySharedConversations(Resource):
|
||||
exc_info=True,
|
||||
)
|
||||
query["attachments"] = attachment_details
|
||||
conversation_queries.append(query)
|
||||
|
||||
created = conversation.get("created_at") or conversation.get("date")
|
||||
date_iso = (
|
||||
created.isoformat()
|
||||
if hasattr(created, "isoformat")
|
||||
else (str(created) if created is not None else None)
|
||||
else:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"error": "might have broken url or the conversation does not exist",
|
||||
}
|
||||
),
|
||||
404,
|
||||
)
|
||||
res = {
|
||||
"success": True,
|
||||
"queries": conversation_queries,
|
||||
"title": conversation.get("name"),
|
||||
"timestamp": date_iso,
|
||||
}
|
||||
if shared.get("is_promptable") and shared.get("api_key"):
|
||||
res["api_key"] = shared["api_key"]
|
||||
return make_response(jsonify(res), 200)
|
||||
date = conversation["_id"].generation_time.isoformat()
|
||||
res = {
|
||||
"success": True,
|
||||
"queries": conversation_queries,
|
||||
"title": conversation["name"],
|
||||
"timestamp": date,
|
||||
}
|
||||
if shared["isPromptable"] and "api_key" in shared:
|
||||
res["api_key"] = shared["api_key"]
|
||||
return make_response(jsonify(res), 200)
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error getting shared conversation: {err}", exc_info=True
|
||||
|
||||
@@ -1,12 +1,11 @@
|
||||
"""Source document management chunk management."""
|
||||
|
||||
from bson.objectid import ObjectId
|
||||
from flask import current_app, jsonify, make_response, request
|
||||
from flask_restx import fields, Namespace, Resource
|
||||
|
||||
from application.api import api
|
||||
from application.api.user.base import get_vector_store
|
||||
from application.storage.db.repositories.sources import SourcesRepository
|
||||
from application.storage.db.session import db_readonly
|
||||
from application.api.user.base import get_vector_store, sources_collection
|
||||
from application.utils import check_required_fields, num_tokens_from_string
|
||||
|
||||
sources_chunks_ns = Namespace(
|
||||
@@ -14,15 +13,6 @@ sources_chunks_ns = Namespace(
|
||||
)
|
||||
|
||||
|
||||
def _resolve_source(doc_id: str, user: str):
|
||||
"""Resolve a source (UUID or legacy ObjectId) for the caller.
|
||||
|
||||
Returns the row dict (with PG UUID in ``id``) or ``None`` if missing.
|
||||
"""
|
||||
with db_readonly() as conn:
|
||||
return SourcesRepository(conn).get_any(doc_id, user)
|
||||
|
||||
|
||||
@sources_chunks_ns.route("/get_chunks")
|
||||
class GetChunks(Resource):
|
||||
@api.doc(
|
||||
@@ -46,34 +36,36 @@ class GetChunks(Resource):
|
||||
path = request.args.get("path")
|
||||
search_term = request.args.get("search", "").strip().lower()
|
||||
|
||||
if not doc_id:
|
||||
return make_response(jsonify({"error": "Invalid doc_id"}), 400)
|
||||
try:
|
||||
doc = _resolve_source(doc_id, user)
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error resolving source: {e}", exc_info=True)
|
||||
if not ObjectId.is_valid(doc_id):
|
||||
return make_response(jsonify({"error": "Invalid doc_id"}), 400)
|
||||
doc = sources_collection.find_one({"_id": ObjectId(doc_id), "user": user})
|
||||
if not doc:
|
||||
return make_response(
|
||||
jsonify({"error": "Document not found or access denied"}), 404
|
||||
)
|
||||
resolved_id = str(doc["id"])
|
||||
try:
|
||||
store = get_vector_store(resolved_id)
|
||||
store = get_vector_store(doc_id)
|
||||
chunks = store.get_chunks()
|
||||
|
||||
filtered_chunks = []
|
||||
for chunk in chunks:
|
||||
metadata = chunk.get("metadata", {})
|
||||
|
||||
# Filter by path if provided
|
||||
|
||||
if path:
|
||||
chunk_source = metadata.get("source", "")
|
||||
chunk_file_path = metadata.get("file_path", "")
|
||||
# Check if the chunk matches the requested path
|
||||
# For file uploads: source ends with path (e.g., "inputs/.../file.pdf" ends with "file.pdf")
|
||||
# For crawlers: file_path ends with path (e.g., "guides/setup.md" ends with "setup.md")
|
||||
source_match = chunk_source and chunk_source.endswith(path)
|
||||
file_path_match = chunk_file_path and chunk_file_path.endswith(path)
|
||||
|
||||
if not (source_match or file_path_match):
|
||||
continue
|
||||
# Filter by search term if provided
|
||||
|
||||
if search_term:
|
||||
text_match = search_term in chunk.get("text", "").lower()
|
||||
title_match = search_term in metadata.get("title", "").lower()
|
||||
@@ -140,17 +132,15 @@ class AddChunk(Resource):
|
||||
token_count = num_tokens_from_string(text)
|
||||
metadata["token_count"] = token_count
|
||||
|
||||
try:
|
||||
doc = _resolve_source(doc_id, user)
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error resolving source: {e}", exc_info=True)
|
||||
if not ObjectId.is_valid(doc_id):
|
||||
return make_response(jsonify({"error": "Invalid doc_id"}), 400)
|
||||
doc = sources_collection.find_one({"_id": ObjectId(doc_id), "user": user})
|
||||
if not doc:
|
||||
return make_response(
|
||||
jsonify({"error": "Document not found or access denied"}), 404
|
||||
)
|
||||
try:
|
||||
store = get_vector_store(str(doc["id"]))
|
||||
store = get_vector_store(doc_id)
|
||||
chunk_id = store.add_chunk(text, metadata)
|
||||
return make_response(
|
||||
jsonify({"message": "Chunk added successfully", "chunk_id": chunk_id}),
|
||||
@@ -175,17 +165,15 @@ class DeleteChunk(Resource):
|
||||
doc_id = request.args.get("id")
|
||||
chunk_id = request.args.get("chunk_id")
|
||||
|
||||
try:
|
||||
doc = _resolve_source(doc_id, user)
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error resolving source: {e}", exc_info=True)
|
||||
if not ObjectId.is_valid(doc_id):
|
||||
return make_response(jsonify({"error": "Invalid doc_id"}), 400)
|
||||
doc = sources_collection.find_one({"_id": ObjectId(doc_id), "user": user})
|
||||
if not doc:
|
||||
return make_response(
|
||||
jsonify({"error": "Document not found or access denied"}), 404
|
||||
)
|
||||
try:
|
||||
store = get_vector_store(str(doc["id"]))
|
||||
store = get_vector_store(doc_id)
|
||||
deleted = store.delete_chunk(chunk_id)
|
||||
if deleted:
|
||||
return make_response(
|
||||
@@ -244,17 +232,15 @@ class UpdateChunk(Resource):
|
||||
if metadata is None:
|
||||
metadata = {}
|
||||
metadata["token_count"] = token_count
|
||||
try:
|
||||
doc = _resolve_source(doc_id, user)
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error resolving source: {e}", exc_info=True)
|
||||
if not ObjectId.is_valid(doc_id):
|
||||
return make_response(jsonify({"error": "Invalid doc_id"}), 400)
|
||||
doc = sources_collection.find_one({"_id": ObjectId(doc_id), "user": user})
|
||||
if not doc:
|
||||
return make_response(
|
||||
jsonify({"error": "Document not found or access denied"}), 404
|
||||
)
|
||||
try:
|
||||
store = get_vector_store(str(doc["id"]))
|
||||
store = get_vector_store(doc_id)
|
||||
|
||||
chunks = store.get_chunks()
|
||||
existing_chunk = next((c for c in chunks if c["doc_id"] == chunk_id), None)
|
||||
|
||||
@@ -3,14 +3,14 @@
|
||||
import json
|
||||
import math
|
||||
|
||||
from bson.objectid import ObjectId
|
||||
from flask import current_app, jsonify, make_response, redirect, request
|
||||
from flask_restx import fields, Namespace, Resource
|
||||
|
||||
from application.api import api
|
||||
from application.api.user.base import sources_collection
|
||||
from application.api.user.tasks import sync_source
|
||||
from application.core.settings import settings
|
||||
from application.storage.db.repositories.sources import SourcesRepository
|
||||
from application.storage.db.session import db_readonly, db_session
|
||||
from application.storage.storage_creator import StorageCreator
|
||||
from application.utils import check_required_fields
|
||||
from application.vectorstore.vector_creator import VectorCreator
|
||||
@@ -56,20 +56,11 @@ class CombinedJson(Resource):
|
||||
]
|
||||
|
||||
try:
|
||||
with db_readonly() as conn:
|
||||
indexes = SourcesRepository(conn).list_for_user(user)
|
||||
# list_for_user sorts by created_at DESC; legacy shape sorted by
|
||||
# "date" DESC. Both are monotonic on creation so the ordering is
|
||||
# equivalent for dev; re-sort defensively.
|
||||
indexes = sorted(
|
||||
indexes, key=lambda r: r.get("date") or r.get("created_at") or "",
|
||||
reverse=True,
|
||||
)
|
||||
for index in indexes:
|
||||
for index in sources_collection.find({"user": user}).sort("date", -1):
|
||||
provider = _get_provider_from_remote_data(index.get("remote_data"))
|
||||
data.append(
|
||||
{
|
||||
"id": str(index["id"]),
|
||||
"id": str(index["_id"]),
|
||||
"name": index.get("name"),
|
||||
"date": index.get("date"),
|
||||
"model": settings.EMBEDDINGS_NAME,
|
||||
@@ -79,7 +70,9 @@ class CombinedJson(Resource):
|
||||
"syncFrequency": index.get("sync_frequency", ""),
|
||||
"provider": provider,
|
||||
"is_nested": bool(index.get("directory_structure")),
|
||||
"type": index.get("type", "file"),
|
||||
"type": index.get(
|
||||
"type", "file"
|
||||
), # Add type field with default "file"
|
||||
}
|
||||
)
|
||||
except Exception as err:
|
||||
@@ -96,55 +89,61 @@ class PaginatedSources(Resource):
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
sort_field = request.args.get("sort", "date")
|
||||
sort_order = request.args.get("order", "desc")
|
||||
page = max(1, int(request.args.get("page", 1)))
|
||||
rows_per_page = max(1, int(request.args.get("rows", 10)))
|
||||
search_term = request.args.get("search", "").strip() or None
|
||||
sort_field = request.args.get("sort", "date") # Default to 'date'
|
||||
sort_order = request.args.get("order", "desc") # Default to 'desc'
|
||||
page = int(request.args.get("page", 1)) # Default to 1
|
||||
rows_per_page = int(request.args.get("rows", 10)) # Default to 10
|
||||
# add .strip() to remove leading and trailing whitespaces
|
||||
|
||||
search_term = request.args.get(
|
||||
"search", ""
|
||||
).strip() # add search for filter documents
|
||||
|
||||
# Prepare query for filtering
|
||||
|
||||
query = {"user": user}
|
||||
if search_term:
|
||||
query["name"] = {
|
||||
"$regex": search_term,
|
||||
"$options": "i", # using case-insensitive search
|
||||
}
|
||||
total_documents = sources_collection.count_documents(query)
|
||||
total_pages = max(1, math.ceil(total_documents / rows_per_page))
|
||||
page = min(
|
||||
max(1, page), total_pages
|
||||
) # add this to make sure page inbound is within the range
|
||||
sort_order = 1 if sort_order == "asc" else -1
|
||||
skip = (page - 1) * rows_per_page
|
||||
|
||||
try:
|
||||
with db_readonly() as conn:
|
||||
repo = SourcesRepository(conn)
|
||||
total_documents = repo.count_for_user(
|
||||
user, search_term=search_term,
|
||||
)
|
||||
# Prior in-Python implementation returned ``totalPages = 1``
|
||||
# for empty result sets (``max(1, ceil(0/rows))``); we
|
||||
# preserve that contract so the frontend pager stays stable.
|
||||
total_pages = max(1, math.ceil(total_documents / rows_per_page))
|
||||
effective_page = min(page, total_pages)
|
||||
offset = (effective_page - 1) * rows_per_page
|
||||
window = repo.list_for_user(
|
||||
user,
|
||||
limit=rows_per_page,
|
||||
offset=offset,
|
||||
search_term=search_term,
|
||||
sort_field=sort_field,
|
||||
sort_order=sort_order,
|
||||
)
|
||||
documents = (
|
||||
sources_collection.find(query)
|
||||
.sort(sort_field, sort_order)
|
||||
.skip(skip)
|
||||
.limit(rows_per_page)
|
||||
)
|
||||
|
||||
paginated_docs = []
|
||||
for doc in window:
|
||||
for doc in documents:
|
||||
provider = _get_provider_from_remote_data(doc.get("remote_data"))
|
||||
paginated_docs.append(
|
||||
{
|
||||
"id": str(doc["id"]),
|
||||
"name": doc.get("name", ""),
|
||||
"date": doc.get("date", ""),
|
||||
"model": settings.EMBEDDINGS_NAME,
|
||||
"location": "local",
|
||||
"tokens": doc.get("tokens", ""),
|
||||
"retriever": doc.get("retriever", "classic"),
|
||||
"syncFrequency": doc.get("sync_frequency", ""),
|
||||
"provider": provider,
|
||||
"isNested": bool(doc.get("directory_structure")),
|
||||
"type": doc.get("type", "file"),
|
||||
}
|
||||
)
|
||||
doc_data = {
|
||||
"id": str(doc["_id"]),
|
||||
"name": doc.get("name", ""),
|
||||
"date": doc.get("date", ""),
|
||||
"model": settings.EMBEDDINGS_NAME,
|
||||
"location": "local",
|
||||
"tokens": doc.get("tokens", ""),
|
||||
"retriever": doc.get("retriever", "classic"),
|
||||
"syncFrequency": doc.get("sync_frequency", ""),
|
||||
"provider": provider,
|
||||
"isNested": bool(doc.get("directory_structure")),
|
||||
"type": doc.get("type", "file"),
|
||||
}
|
||||
paginated_docs.append(doc_data)
|
||||
response = {
|
||||
"total": total_documents,
|
||||
"totalPages": total_pages,
|
||||
"currentPage": effective_page,
|
||||
"currentPage": page,
|
||||
"paginated": paginated_docs,
|
||||
}
|
||||
return make_response(jsonify(response), 200)
|
||||
@@ -155,6 +154,28 @@ class PaginatedSources(Resource):
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
|
||||
|
||||
@sources_ns.route("/delete_by_ids")
|
||||
class DeleteByIds(Resource):
|
||||
@api.doc(
|
||||
description="Deletes documents from the vector store by IDs",
|
||||
params={"path": "Comma-separated list of IDs"},
|
||||
)
|
||||
def get(self):
|
||||
ids = request.args.get("path")
|
||||
if not ids:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Missing required fields"}), 400
|
||||
)
|
||||
try:
|
||||
result = sources_collection.delete_index(ids=ids)
|
||||
if result:
|
||||
return make_response(jsonify({"success": True}), 200)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error deleting indexes: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
|
||||
|
||||
@sources_ns.route("/delete_old")
|
||||
class DeleteOldIndexes(Resource):
|
||||
@api.doc(
|
||||
@@ -165,33 +186,30 @@ class DeleteOldIndexes(Resource):
|
||||
decoded_token = request.decoded_token
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
source_id = request.args.get("source_id")
|
||||
if not source_id:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Missing required fields"}), 400
|
||||
)
|
||||
try:
|
||||
with db_readonly() as conn:
|
||||
doc = SourcesRepository(conn).get_any(source_id, user)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error looking up source: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
doc = sources_collection.find_one(
|
||||
{"_id": ObjectId(source_id), "user": decoded_token.get("sub")}
|
||||
)
|
||||
if not doc:
|
||||
return make_response(jsonify({"status": "not found"}), 404)
|
||||
storage = StorageCreator.get_storage()
|
||||
resolved_id = str(doc["id"])
|
||||
|
||||
try:
|
||||
# Delete vector index
|
||||
|
||||
if settings.VECTOR_STORE == "faiss":
|
||||
index_path = f"indexes/{resolved_id}"
|
||||
index_path = f"indexes/{str(doc['_id'])}"
|
||||
if storage.file_exists(f"{index_path}/index.faiss"):
|
||||
storage.delete_file(f"{index_path}/index.faiss")
|
||||
if storage.file_exists(f"{index_path}/index.pkl"):
|
||||
storage.delete_file(f"{index_path}/index.pkl")
|
||||
else:
|
||||
vectorstore = VectorCreator.create_vectorstore(
|
||||
settings.VECTOR_STORE, source_id=resolved_id
|
||||
settings.VECTOR_STORE, source_id=str(doc["_id"])
|
||||
)
|
||||
vectorstore.delete_index()
|
||||
if "file_path" in doc and doc["file_path"]:
|
||||
@@ -209,14 +227,7 @@ class DeleteOldIndexes(Resource):
|
||||
f"Error deleting files and indexes: {err}", exc_info=True
|
||||
)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
try:
|
||||
with db_session() as conn:
|
||||
SourcesRepository(conn).delete(resolved_id, user)
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error deleting source row: {err}", exc_info=True
|
||||
)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
sources_collection.delete_one({"_id": ObjectId(source_id)})
|
||||
return make_response(jsonify({"success": True}), 200)
|
||||
|
||||
|
||||
@@ -261,16 +272,15 @@ class ManageSync(Resource):
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Invalid frequency"}), 400
|
||||
)
|
||||
update_data = {"$set": {"sync_frequency": sync_frequency}}
|
||||
try:
|
||||
with db_session() as conn:
|
||||
repo = SourcesRepository(conn)
|
||||
doc = repo.get_any(source_id, user)
|
||||
if doc is None:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Source not found"}),
|
||||
404,
|
||||
)
|
||||
repo.update(str(doc["id"]), user, {"sync_frequency": sync_frequency})
|
||||
sources_collection.update_one(
|
||||
{
|
||||
"_id": ObjectId(source_id),
|
||||
"user": user,
|
||||
},
|
||||
update_data,
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error updating sync frequency: {err}", exc_info=True
|
||||
@@ -299,20 +309,19 @@ class SyncSource(Resource):
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
source_id = data["source_id"]
|
||||
try:
|
||||
with db_readonly() as conn:
|
||||
doc = SourcesRepository(conn).get_any(source_id, user)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error looking up source: {err}", exc_info=True)
|
||||
if not ObjectId.is_valid(source_id):
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Invalid source ID"}), 400
|
||||
)
|
||||
doc = sources_collection.find_one(
|
||||
{"_id": ObjectId(source_id), "user": user}
|
||||
)
|
||||
if not doc:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Source not found"}), 404
|
||||
)
|
||||
source_type = doc.get("type", "")
|
||||
if source_type and source_type.startswith("connector"):
|
||||
if source_type.startswith("connector"):
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
@@ -335,7 +344,7 @@ class SyncSource(Resource):
|
||||
loader=source_type,
|
||||
sync_frequency=doc.get("sync_frequency", "never"),
|
||||
retriever=doc.get("retriever", "classic"),
|
||||
doc_id=str(doc["id"]),
|
||||
doc_id=source_id,
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
@@ -361,9 +370,10 @@ class DirectoryStructure(Resource):
|
||||
|
||||
if not doc_id:
|
||||
return make_response(jsonify({"error": "Document ID is required"}), 400)
|
||||
if not ObjectId.is_valid(doc_id):
|
||||
return make_response(jsonify({"error": "Invalid document ID"}), 400)
|
||||
try:
|
||||
with db_readonly() as conn:
|
||||
doc = SourcesRepository(conn).get_any(doc_id, user)
|
||||
doc = sources_collection.find_one({"_id": ObjectId(doc_id), "user": user})
|
||||
if not doc:
|
||||
return make_response(
|
||||
jsonify({"error": "Document not found or access denied"}), 404
|
||||
@@ -377,8 +387,6 @@ class DirectoryStructure(Resource):
|
||||
if isinstance(remote_data, str) and remote_data:
|
||||
remote_data_obj = json.loads(remote_data)
|
||||
provider = remote_data_obj.get("provider")
|
||||
elif isinstance(remote_data, dict):
|
||||
provider = remote_data.get("provider")
|
||||
except Exception as e:
|
||||
current_app.logger.warning(
|
||||
f"Failed to parse remote_data for doc {doc_id}: {e}"
|
||||
@@ -398,7 +406,4 @@ class DirectoryStructure(Resource):
|
||||
current_app.logger.error(
|
||||
f"Error retrieving directory structure: {e}", exc_info=True
|
||||
)
|
||||
return make_response(
|
||||
jsonify({"success": False, "error": "Failed to retrieve directory structure"}),
|
||||
500,
|
||||
)
|
||||
return make_response(jsonify({"success": False, "error": "Failed to retrieve directory structure"}), 500)
|
||||
|
||||
@@ -5,16 +5,16 @@ import os
|
||||
import tempfile
|
||||
import zipfile
|
||||
|
||||
from bson.objectid import ObjectId
|
||||
from flask import current_app, jsonify, make_response, request
|
||||
from flask_restx import fields, Namespace, Resource
|
||||
|
||||
from application.api import api
|
||||
from application.api.user.base import sources_collection
|
||||
from application.api.user.tasks import ingest, ingest_connector_task, ingest_remote
|
||||
from application.core.settings import settings
|
||||
from application.parser.connectors.connector_creator import ConnectorCreator
|
||||
from application.parser.file.constants import SUPPORTED_SOURCE_EXTENSIONS
|
||||
from application.storage.db.repositories.sources import SourcesRepository
|
||||
from application.storage.db.session import db_readonly, db_session
|
||||
from application.storage.storage_creator import StorageCreator
|
||||
from application.stt.upload_limits import (
|
||||
AudioFileTooLargeError,
|
||||
@@ -329,8 +329,15 @@ class ManageSourceFiles(Resource):
|
||||
400,
|
||||
)
|
||||
try:
|
||||
with db_readonly() as conn:
|
||||
source = SourcesRepository(conn).get_any(source_id, user)
|
||||
ObjectId(source_id)
|
||||
except Exception:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Invalid source ID format"}), 400
|
||||
)
|
||||
try:
|
||||
source = sources_collection.find_one(
|
||||
{"_id": ObjectId(source_id), "user": user}
|
||||
)
|
||||
if not source:
|
||||
return make_response(
|
||||
jsonify(
|
||||
@@ -346,7 +353,6 @@ class ManageSourceFiles(Resource):
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Database error"}), 500
|
||||
)
|
||||
resolved_source_id = str(source["id"])
|
||||
try:
|
||||
storage = StorageCreator.get_storage()
|
||||
source_file_path = source.get("file_path", "")
|
||||
@@ -405,18 +411,15 @@ class ManageSourceFiles(Resource):
|
||||
map_updated = True
|
||||
|
||||
if map_updated:
|
||||
with db_session() as conn:
|
||||
SourcesRepository(conn).update(
|
||||
resolved_source_id, user,
|
||||
{"file_name_map": dict(file_name_map)},
|
||||
)
|
||||
sources_collection.update_one(
|
||||
{"_id": ObjectId(source_id)},
|
||||
{"$set": {"file_name_map": file_name_map}},
|
||||
)
|
||||
# Trigger re-ingestion pipeline
|
||||
|
||||
from application.api.user.tasks import reingest_source_task
|
||||
|
||||
task = reingest_source_task.delay(
|
||||
source_id=resolved_source_id, user=user
|
||||
)
|
||||
task = reingest_source_task.delay(source_id=source_id, user=user)
|
||||
|
||||
return make_response(
|
||||
jsonify(
|
||||
@@ -460,16 +463,6 @@ class ManageSourceFiles(Resource):
|
||||
removed_files = []
|
||||
map_updated = False
|
||||
for file_path in file_paths:
|
||||
if ".." in str(file_path) or str(file_path).startswith("/"):
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": "Invalid file path",
|
||||
}
|
||||
),
|
||||
400,
|
||||
)
|
||||
full_path = f"{source_file_path}/{file_path}"
|
||||
|
||||
# Remove from storage
|
||||
@@ -482,18 +475,15 @@ class ManageSourceFiles(Resource):
|
||||
map_updated = True
|
||||
|
||||
if map_updated and isinstance(file_name_map, dict):
|
||||
with db_session() as conn:
|
||||
SourcesRepository(conn).update(
|
||||
resolved_source_id, user,
|
||||
{"file_name_map": dict(file_name_map)},
|
||||
)
|
||||
sources_collection.update_one(
|
||||
{"_id": ObjectId(source_id)},
|
||||
{"$set": {"file_name_map": file_name_map}},
|
||||
)
|
||||
# Trigger re-ingestion pipeline
|
||||
|
||||
from application.api.user.tasks import reingest_source_task
|
||||
|
||||
task = reingest_source_task.delay(
|
||||
source_id=resolved_source_id, user=user
|
||||
)
|
||||
task = reingest_source_task.delay(source_id=source_id, user=user)
|
||||
|
||||
return make_response(
|
||||
jsonify(
|
||||
@@ -581,19 +571,16 @@ class ManageSourceFiles(Resource):
|
||||
if keys_to_remove:
|
||||
for key in keys_to_remove:
|
||||
file_name_map.pop(key, None)
|
||||
with db_session() as conn:
|
||||
SourcesRepository(conn).update(
|
||||
resolved_source_id, user,
|
||||
{"file_name_map": dict(file_name_map)},
|
||||
)
|
||||
sources_collection.update_one(
|
||||
{"_id": ObjectId(source_id)},
|
||||
{"$set": {"file_name_map": file_name_map}},
|
||||
)
|
||||
|
||||
# Trigger re-ingestion pipeline
|
||||
|
||||
from application.api.user.tasks import reingest_source_task
|
||||
|
||||
task = reingest_source_task.delay(
|
||||
source_id=resolved_source_id, user=user
|
||||
)
|
||||
task = reingest_source_task.delay(source_id=source_id, user=user)
|
||||
|
||||
return make_response(
|
||||
jsonify(
|
||||
|
||||
@@ -134,17 +134,6 @@ def setup_periodic_tasks(sender, **kwargs):
|
||||
timedelta(days=30),
|
||||
schedule_syncs.s("monthly"),
|
||||
)
|
||||
# Replaces Mongo's TTL index on pending_tool_state.expires_at.
|
||||
sender.add_periodic_task(
|
||||
timedelta(seconds=60),
|
||||
cleanup_pending_tool_state.s(),
|
||||
name="cleanup-pending-tool-state",
|
||||
)
|
||||
sender.add_periodic_task(
|
||||
timedelta(hours=7),
|
||||
version_check_task.s(),
|
||||
name="version-check",
|
||||
)
|
||||
|
||||
|
||||
@celery.task(bind=True)
|
||||
@@ -157,40 +146,3 @@ def mcp_oauth_task(self, config, user):
|
||||
def mcp_oauth_status_task(self, task_id):
|
||||
resp = mcp_oauth_status(self, task_id)
|
||||
return resp
|
||||
|
||||
|
||||
@celery.task(bind=True)
|
||||
def cleanup_pending_tool_state(self):
|
||||
"""Delete pending_tool_state rows past their TTL.
|
||||
|
||||
Replaces Mongo's ``expireAfterSeconds=0`` TTL index — Postgres has
|
||||
no native TTL, so this task runs every 60 seconds to keep
|
||||
``pending_tool_state`` bounded. No-ops if ``POSTGRES_URI`` isn't
|
||||
configured (keeps the task runnable in Mongo-only environments).
|
||||
"""
|
||||
from application.core.settings import settings
|
||||
if not settings.POSTGRES_URI:
|
||||
return {"deleted": 0, "skipped": "POSTGRES_URI not set"}
|
||||
|
||||
from application.storage.db.engine import get_engine
|
||||
from application.storage.db.repositories.pending_tool_state import (
|
||||
PendingToolStateRepository,
|
||||
)
|
||||
|
||||
engine = get_engine()
|
||||
with engine.begin() as conn:
|
||||
deleted = PendingToolStateRepository(conn).cleanup_expired()
|
||||
return {"deleted": deleted}
|
||||
|
||||
|
||||
@celery.task(bind=True)
|
||||
def version_check_task(self):
|
||||
"""Periodic anonymous version check.
|
||||
|
||||
Complements the ``worker_ready`` boot trigger so long-running
|
||||
deployments (>6h cache TTL) still refresh advisories. ``run_check``
|
||||
is fail-silent and coordinates across replicas via Redis lock +
|
||||
cache (see ``application.updates.version_check``).
|
||||
"""
|
||||
from application.updates.version_check import run_check
|
||||
run_check()
|
||||
|
||||
@@ -3,24 +3,26 @@
|
||||
import json
|
||||
from urllib.parse import urlencode, urlparse
|
||||
|
||||
from bson.objectid import ObjectId
|
||||
from flask import current_app, jsonify, make_response, redirect, request
|
||||
from flask_restx import Namespace, Resource, fields
|
||||
|
||||
from application.agents.tools.mcp_tool import MCPOAuthManager, MCPTool
|
||||
from application.api import api
|
||||
from application.api.user.base import user_tools_collection
|
||||
from application.api.user.tools.routes import transform_actions
|
||||
from application.cache import get_redis_instance
|
||||
from application.core.url_validation import SSRFError, validate_url
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.core.settings import settings
|
||||
from application.security.encryption import decrypt_credentials, encrypt_credentials
|
||||
from application.storage.db.repositories.connector_sessions import (
|
||||
ConnectorSessionsRepository,
|
||||
)
|
||||
from application.storage.db.repositories.user_tools import UserToolsRepository
|
||||
from application.storage.db.session import db_readonly, db_session
|
||||
from application.utils import check_required_fields
|
||||
|
||||
tools_mcp_ns = Namespace("tools", description="Tool management operations", path="/api")
|
||||
|
||||
_mongo = MongoDB.get_client()
|
||||
_db = _mongo[settings.MONGO_DB_NAME]
|
||||
_connector_sessions = _db["connector_sessions"]
|
||||
|
||||
_ALLOWED_TRANSPORTS = {"auto", "sse", "http"}
|
||||
|
||||
|
||||
@@ -61,21 +63,6 @@ def _extract_auth_credentials(config):
|
||||
return auth_credentials
|
||||
|
||||
|
||||
def _validate_mcp_server_url(config: dict) -> None:
|
||||
"""Validate the server_url in an MCP config to prevent SSRF.
|
||||
|
||||
Raises:
|
||||
ValueError: If the URL is missing or points to a blocked address.
|
||||
"""
|
||||
server_url = (config.get("server_url") or "").strip()
|
||||
if not server_url:
|
||||
raise ValueError("server_url is required")
|
||||
try:
|
||||
validate_url(server_url)
|
||||
except SSRFError as exc:
|
||||
raise ValueError(f"Invalid server URL: {exc}") from exc
|
||||
|
||||
|
||||
@tools_mcp_ns.route("/mcp_server/test")
|
||||
class TestMCPServerConfig(Resource):
|
||||
@api.expect(
|
||||
@@ -110,8 +97,6 @@ class TestMCPServerConfig(Resource):
|
||||
400,
|
||||
)
|
||||
|
||||
_validate_mcp_server_url(config)
|
||||
|
||||
auth_credentials = _extract_auth_credentials(config)
|
||||
test_config = config.copy()
|
||||
test_config["auth_credentials"] = auth_credentials
|
||||
@@ -120,41 +105,15 @@ class TestMCPServerConfig(Resource):
|
||||
result = mcp_tool.test_connection()
|
||||
|
||||
if result.get("requires_oauth"):
|
||||
safe_result = {
|
||||
k: v
|
||||
for k, v in result.items()
|
||||
if k in ("success", "requires_oauth", "auth_url")
|
||||
}
|
||||
return make_response(jsonify(safe_result), 200)
|
||||
return make_response(jsonify(result), 200)
|
||||
|
||||
if not result.get("success"):
|
||||
if not result.get("success") and "message" in result:
|
||||
current_app.logger.error(
|
||||
f"MCP connection test failed: {result.get('message')}"
|
||||
)
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"message": "Connection test failed",
|
||||
"tools_count": 0,
|
||||
}
|
||||
),
|
||||
200,
|
||||
)
|
||||
result["message"] = "Connection test failed"
|
||||
|
||||
safe_result = {
|
||||
"success": True,
|
||||
"message": result.get("message", "Connection successful"),
|
||||
"tools_count": result.get("tools_count", 0),
|
||||
"tools": result.get("tools", []),
|
||||
}
|
||||
return make_response(jsonify(safe_result), 200)
|
||||
except ValueError as e:
|
||||
current_app.logger.warning(f"Invalid MCP server test request: {e}")
|
||||
return make_response(
|
||||
jsonify({"success": False, "error": "Invalid MCP server configuration"}),
|
||||
400,
|
||||
)
|
||||
return make_response(jsonify(result), 200)
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error testing MCP server: {e}", exc_info=True)
|
||||
return make_response(
|
||||
@@ -206,8 +165,6 @@ class MCPServerSave(Resource):
|
||||
400,
|
||||
)
|
||||
|
||||
_validate_mcp_server_url(config)
|
||||
|
||||
auth_credentials = _extract_auth_credentials(config)
|
||||
auth_type = config.get("auth_type", "none")
|
||||
mcp_config = config.copy()
|
||||
@@ -249,18 +206,15 @@ class MCPServerSave(Resource):
|
||||
storage_config = config.copy()
|
||||
|
||||
tool_id = data.get("id")
|
||||
existing_doc = None
|
||||
existing_encrypted = None
|
||||
if tool_id:
|
||||
with db_readonly() as conn:
|
||||
repo = UserToolsRepository(conn)
|
||||
existing_doc = repo.get_any(tool_id, user)
|
||||
if existing_doc and existing_doc.get("name") == "mcp_tool":
|
||||
existing_encrypted = (existing_doc.get("config") or {}).get(
|
||||
existing_doc = user_tools_collection.find_one(
|
||||
{"_id": ObjectId(tool_id), "user": user, "name": "mcp_tool"}
|
||||
)
|
||||
if existing_doc:
|
||||
existing_encrypted = existing_doc.get("config", {}).get(
|
||||
"encrypted_credentials"
|
||||
)
|
||||
else:
|
||||
existing_doc = None
|
||||
|
||||
if auth_credentials:
|
||||
if existing_encrypted:
|
||||
@@ -283,95 +237,48 @@ class MCPServerSave(Resource):
|
||||
]:
|
||||
storage_config.pop(field, None)
|
||||
transformed_actions = transform_actions(actions_metadata)
|
||||
tool_data = {
|
||||
"name": "mcp_tool",
|
||||
"displayName": data["displayName"],
|
||||
"customName": data["displayName"],
|
||||
"description": f"MCP Server: {storage_config.get('server_url', 'Unknown')}",
|
||||
"config": storage_config,
|
||||
"actions": transformed_actions,
|
||||
"status": data.get("status", True),
|
||||
"user": user,
|
||||
}
|
||||
|
||||
display_name = data["displayName"]
|
||||
description = f"MCP Server: {storage_config.get('server_url', 'Unknown')}"
|
||||
status_bool = bool(data.get("status", True))
|
||||
|
||||
with db_session() as conn:
|
||||
repo = UserToolsRepository(conn)
|
||||
if existing_doc:
|
||||
repo.update(
|
||||
str(existing_doc["id"]), user,
|
||||
{
|
||||
"display_name": display_name,
|
||||
"custom_name": display_name,
|
||||
"description": description,
|
||||
"config": storage_config,
|
||||
"actions": transformed_actions,
|
||||
"status": status_bool,
|
||||
},
|
||||
)
|
||||
saved_id = str(existing_doc["id"])
|
||||
response_data = {
|
||||
"success": True,
|
||||
"id": saved_id,
|
||||
"message": f"MCP server updated successfully! Discovered {len(transformed_actions)} tools.",
|
||||
"tools_count": len(transformed_actions),
|
||||
}
|
||||
else:
|
||||
# Fall back to find_by_user_and_name — the original
|
||||
# dual-write path also ran an existence check before
|
||||
# deciding between insert and update.
|
||||
existing_by_name = repo.find_by_user_and_name(user, "mcp_tool")
|
||||
if tool_id is None and existing_by_name and (
|
||||
(existing_by_name.get("config") or {}).get("server_url")
|
||||
== storage_config.get("server_url")
|
||||
):
|
||||
repo.update(
|
||||
str(existing_by_name["id"]), user,
|
||||
if tool_id:
|
||||
result = user_tools_collection.update_one(
|
||||
{"_id": ObjectId(tool_id), "user": user, "name": "mcp_tool"},
|
||||
{"$set": {k: v for k, v in tool_data.items() if k != "user"}},
|
||||
)
|
||||
if result.matched_count == 0:
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"display_name": display_name,
|
||||
"custom_name": display_name,
|
||||
"description": description,
|
||||
"config": storage_config,
|
||||
"actions": transformed_actions,
|
||||
"status": status_bool,
|
||||
},
|
||||
)
|
||||
saved_id = str(existing_by_name["id"])
|
||||
response_data = {
|
||||
"success": True,
|
||||
"id": saved_id,
|
||||
"message": f"MCP server updated successfully! Discovered {len(transformed_actions)} tools.",
|
||||
"tools_count": len(transformed_actions),
|
||||
}
|
||||
else:
|
||||
created = repo.create(
|
||||
user, "mcp_tool",
|
||||
config=storage_config,
|
||||
custom_name=display_name,
|
||||
display_name=display_name,
|
||||
description=description,
|
||||
config_requirements={},
|
||||
actions=transformed_actions,
|
||||
status=status_bool,
|
||||
)
|
||||
saved_id = str(created["id"])
|
||||
response_data = {
|
||||
"success": True,
|
||||
"id": saved_id,
|
||||
"message": f"MCP server created successfully! Discovered {len(transformed_actions)} tools.",
|
||||
"tools_count": len(transformed_actions),
|
||||
}
|
||||
if tool_id and existing_doc is None:
|
||||
# Client requested update on a non-existent tool id.
|
||||
return make_response(
|
||||
jsonify(
|
||||
{
|
||||
"success": False,
|
||||
"error": "Tool not found or access denied",
|
||||
}
|
||||
),
|
||||
404,
|
||||
)
|
||||
"success": False,
|
||||
"error": "Tool not found or access denied",
|
||||
}
|
||||
),
|
||||
404,
|
||||
)
|
||||
response_data = {
|
||||
"success": True,
|
||||
"id": tool_id,
|
||||
"message": f"MCP server updated successfully! Discovered {len(transformed_actions)} tools.",
|
||||
"tools_count": len(transformed_actions),
|
||||
}
|
||||
else:
|
||||
result = user_tools_collection.insert_one(tool_data)
|
||||
tool_id = str(result.inserted_id)
|
||||
response_data = {
|
||||
"success": True,
|
||||
"id": tool_id,
|
||||
"message": f"MCP server created successfully! Discovered {len(transformed_actions)} tools.",
|
||||
"tools_count": len(transformed_actions),
|
||||
}
|
||||
return make_response(jsonify(response_data), 200)
|
||||
except ValueError as e:
|
||||
current_app.logger.warning(f"Invalid MCP server save request: {e}")
|
||||
return make_response(
|
||||
jsonify({"success": False, "error": "Invalid MCP server configuration"}),
|
||||
400,
|
||||
)
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"Error saving MCP server: {e}", exc_info=True)
|
||||
return make_response(
|
||||
@@ -500,59 +407,49 @@ class MCPAuthStatus(Resource):
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
try:
|
||||
with db_readonly() as conn:
|
||||
tools_repo = UserToolsRepository(conn)
|
||||
sessions_repo = ConnectorSessionsRepository(conn)
|
||||
all_tools = tools_repo.list_for_user(user)
|
||||
mcp_tools = [t for t in all_tools if t.get("name") == "mcp_tool"]
|
||||
if not mcp_tools:
|
||||
return make_response(
|
||||
jsonify({"success": True, "statuses": {}}), 200
|
||||
)
|
||||
mcp_tools = list(
|
||||
user_tools_collection.find(
|
||||
{"user": user, "name": "mcp_tool"},
|
||||
{"_id": 1, "config": 1},
|
||||
)
|
||||
)
|
||||
if not mcp_tools:
|
||||
return make_response(jsonify({"success": True, "statuses": {}}), 200)
|
||||
|
||||
oauth_server_urls: dict = {}
|
||||
statuses: dict = {}
|
||||
for tool in mcp_tools:
|
||||
tool_id = str(tool["id"])
|
||||
config = tool.get("config") or {}
|
||||
auth_type = config.get("auth_type", "none")
|
||||
if auth_type == "oauth":
|
||||
server_url = config.get("server_url", "")
|
||||
if server_url:
|
||||
parsed = urlparse(server_url)
|
||||
base_url = f"{parsed.scheme}://{parsed.netloc}"
|
||||
oauth_server_urls[tool_id] = base_url
|
||||
else:
|
||||
statuses[tool_id] = "needs_auth"
|
||||
oauth_server_urls = {}
|
||||
statuses = {}
|
||||
for tool in mcp_tools:
|
||||
tool_id = str(tool["_id"])
|
||||
config = tool.get("config", {})
|
||||
auth_type = config.get("auth_type", "none")
|
||||
if auth_type == "oauth":
|
||||
server_url = config.get("server_url", "")
|
||||
if server_url:
|
||||
parsed = urlparse(server_url)
|
||||
base_url = f"{parsed.scheme}://{parsed.netloc}"
|
||||
oauth_server_urls[tool_id] = base_url
|
||||
else:
|
||||
statuses[tool_id] = "configured"
|
||||
statuses[tool_id] = "needs_auth"
|
||||
else:
|
||||
statuses[tool_id] = "configured"
|
||||
|
||||
if oauth_server_urls:
|
||||
# Look up a session per distinct base URL. MCP sessions
|
||||
# are stored with ``provider = "mcp:<server_url>"``
|
||||
# and the URL in ``server_url``; reuse the repo's
|
||||
# per-URL accessor rather than an ad-hoc $in query.
|
||||
url_has_tokens: dict = {}
|
||||
for base_url in set(oauth_server_urls.values()):
|
||||
session = sessions_repo.get_by_user_and_server_url(
|
||||
user, base_url,
|
||||
)
|
||||
tokens = (
|
||||
(session or {}).get("session_data", {}) or {}
|
||||
).get("tokens", {}) or {}
|
||||
# MCP code also stashes tokens into token_info on
|
||||
# the row; consider either present as "connected".
|
||||
token_info = (session or {}).get("token_info") or {}
|
||||
url_has_tokens[base_url] = bool(
|
||||
tokens.get("access_token")
|
||||
or token_info.get("access_token")
|
||||
)
|
||||
|
||||
for tool_id, base_url in oauth_server_urls.items():
|
||||
if url_has_tokens.get(base_url):
|
||||
statuses[tool_id] = "connected"
|
||||
else:
|
||||
statuses[tool_id] = "needs_auth"
|
||||
if oauth_server_urls:
|
||||
unique_urls = list(set(oauth_server_urls.values()))
|
||||
sessions = list(
|
||||
_connector_sessions.find(
|
||||
{"user_id": user, "server_url": {"$in": unique_urls}},
|
||||
{"server_url": 1, "tokens": 1},
|
||||
)
|
||||
)
|
||||
url_has_tokens = {
|
||||
doc["server_url"]: bool(doc.get("tokens", {}).get("access_token"))
|
||||
for doc in sessions
|
||||
}
|
||||
for tool_id, base_url in oauth_server_urls.items():
|
||||
if url_has_tokens.get(base_url):
|
||||
statuses[tool_id] = "connected"
|
||||
else:
|
||||
statuses[tool_id] = "needs_auth"
|
||||
|
||||
return make_response(jsonify({"success": True, "statuses": statuses}), 200)
|
||||
except Exception as e:
|
||||
|
||||
@@ -1,59 +1,20 @@
|
||||
"""Tool management routes."""
|
||||
|
||||
from bson.objectid import ObjectId
|
||||
from flask import current_app, jsonify, make_response, request
|
||||
from flask_restx import fields, Namespace, Resource
|
||||
|
||||
from application.agents.tools.spec_parser import parse_spec
|
||||
from application.agents.tools.tool_manager import ToolManager
|
||||
from application.api import api
|
||||
from application.core.url_validation import SSRFError, validate_url
|
||||
from application.api.user.base import user_tools_collection
|
||||
from application.security.encryption import decrypt_credentials, encrypt_credentials
|
||||
from application.storage.db.repositories.notes import NotesRepository
|
||||
from application.storage.db.repositories.todos import TodosRepository
|
||||
from application.storage.db.repositories.user_tools import UserToolsRepository
|
||||
from application.storage.db.session import db_readonly, db_session
|
||||
from application.utils import check_required_fields, validate_function_name
|
||||
|
||||
tool_config = {}
|
||||
tool_manager = ToolManager(config=tool_config)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Shape translation helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
# The frontend speaks camelCase (``displayName`` / ``customName`` /
|
||||
# ``configRequirements``). The PG ``user_tools`` table stores snake_case
|
||||
# (``display_name`` / ``custom_name`` / ``config_requirements``). Keep the
|
||||
# translation localized to this module so repositories stay pure.
|
||||
|
||||
_CAMEL_TO_SNAKE = {
|
||||
"displayName": "display_name",
|
||||
"customName": "custom_name",
|
||||
"configRequirements": "config_requirements",
|
||||
}
|
||||
_SNAKE_TO_CAMEL = {v: k for k, v in _CAMEL_TO_SNAKE.items()}
|
||||
|
||||
|
||||
def _row_to_api(row: dict) -> dict:
|
||||
"""Rename DB-native snake_case keys to the camelCase shape the frontend expects."""
|
||||
out = dict(row)
|
||||
for snake, camel in _SNAKE_TO_CAMEL.items():
|
||||
if snake in out:
|
||||
out[camel] = out.pop(snake)
|
||||
# ``user_id`` is exposed as ``user`` in the legacy API shape.
|
||||
if "user_id" in out:
|
||||
out["user"] = out.pop("user_id")
|
||||
return out
|
||||
|
||||
|
||||
def _api_to_update_fields(data: dict) -> dict:
|
||||
"""Rename incoming camelCase update keys to the repo's snake_case columns."""
|
||||
fields_out: dict = {}
|
||||
for key, value in data.items():
|
||||
fields_out[_CAMEL_TO_SNAKE.get(key, key)] = value
|
||||
return fields_out
|
||||
|
||||
|
||||
def _encrypt_secret_fields(config, config_requirements, user_id):
|
||||
secret_keys = [
|
||||
key for key, spec in config_requirements.items()
|
||||
@@ -169,8 +130,6 @@ tools_ns = Namespace("tools", description="Tool management operations", path="/a
|
||||
class AvailableTools(Resource):
|
||||
@api.doc(description="Get available tools for a user")
|
||||
def get(self):
|
||||
if not request.decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
try:
|
||||
tools_metadata = []
|
||||
for tool_name, tool_instance in tool_manager.tools.items():
|
||||
@@ -206,11 +165,12 @@ class GetTools(Resource):
|
||||
if not decoded_token:
|
||||
return make_response(jsonify({"success": False}), 401)
|
||||
user = decoded_token.get("sub")
|
||||
with db_readonly() as conn:
|
||||
rows = UserToolsRepository(conn).list_for_user(user)
|
||||
tools = user_tools_collection.find({"user": user})
|
||||
user_tools = []
|
||||
for row in rows:
|
||||
tool_copy = _row_to_api(row)
|
||||
for tool in tools:
|
||||
tool_copy = {**tool}
|
||||
tool_copy["id"] = str(tool["_id"])
|
||||
tool_copy.pop("_id", None)
|
||||
|
||||
config_req = tool_copy.get("configRequirements", {})
|
||||
if not config_req:
|
||||
@@ -276,16 +236,6 @@ class CreateTool(Resource):
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
try:
|
||||
if data["name"] == "mcp_tool":
|
||||
server_url = (data.get("config", {}).get("server_url") or "").strip()
|
||||
if server_url:
|
||||
try:
|
||||
validate_url(server_url)
|
||||
except SSRFError:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Invalid server URL"}),
|
||||
400,
|
||||
)
|
||||
tool_instance = tool_manager.tools.get(data["name"])
|
||||
if not tool_instance:
|
||||
return make_response(
|
||||
@@ -318,19 +268,19 @@ class CreateTool(Resource):
|
||||
storage_config = _encrypt_secret_fields(
|
||||
data["config"], config_requirements, user
|
||||
)
|
||||
with db_session() as conn:
|
||||
created = UserToolsRepository(conn).create(
|
||||
user,
|
||||
data["name"],
|
||||
config=storage_config,
|
||||
custom_name=data.get("customName", ""),
|
||||
display_name=data["displayName"],
|
||||
description=data["description"],
|
||||
config_requirements=config_requirements,
|
||||
actions=transformed_actions,
|
||||
status=bool(data.get("status", True)),
|
||||
)
|
||||
new_id = str(created["id"])
|
||||
new_tool = {
|
||||
"user": user,
|
||||
"name": data["name"],
|
||||
"displayName": data["displayName"],
|
||||
"description": data["description"],
|
||||
"customName": data.get("customName", ""),
|
||||
"actions": transformed_actions,
|
||||
"config": storage_config,
|
||||
"configRequirements": config_requirements,
|
||||
"status": data["status"],
|
||||
}
|
||||
resp = user_tools_collection.insert_one(new_tool)
|
||||
new_id = str(resp.inserted_id)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error creating tool: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
@@ -368,10 +318,17 @@ class UpdateTool(Resource):
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
try:
|
||||
update_data: dict = {}
|
||||
for key in ("name", "displayName", "customName", "description", "actions"):
|
||||
if key in data:
|
||||
update_data[key] = data[key]
|
||||
update_data = {}
|
||||
if "name" in data:
|
||||
update_data["name"] = data["name"]
|
||||
if "displayName" in data:
|
||||
update_data["displayName"] = data["displayName"]
|
||||
if "customName" in data:
|
||||
update_data["customName"] = data["customName"]
|
||||
if "description" in data:
|
||||
update_data["description"] = data["description"]
|
||||
if "actions" in data:
|
||||
update_data["actions"] = data["actions"]
|
||||
if "config" in data:
|
||||
if "actions" in data["config"]:
|
||||
for action_name in list(data["config"]["actions"].keys()):
|
||||
@@ -386,61 +343,46 @@ class UpdateTool(Resource):
|
||||
),
|
||||
400,
|
||||
)
|
||||
with db_session() as conn:
|
||||
repo = UserToolsRepository(conn)
|
||||
tool_doc = repo.get_any(data["id"], user)
|
||||
if not tool_doc:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Tool not found"}),
|
||||
404,
|
||||
)
|
||||
tool_name = tool_doc.get("name", data.get("name"))
|
||||
tool_instance = tool_manager.tools.get(tool_name)
|
||||
config_requirements = (
|
||||
tool_instance.get_config_requirements()
|
||||
if tool_instance
|
||||
else {}
|
||||
tool_doc = user_tools_collection.find_one(
|
||||
{"_id": ObjectId(data["id"]), "user": user}
|
||||
)
|
||||
if not tool_doc:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Tool not found"}),
|
||||
404,
|
||||
)
|
||||
existing_config = tool_doc.get("config", {}) or {}
|
||||
has_existing_secrets = "encrypted_credentials" in existing_config
|
||||
tool_name = tool_doc.get("name", data.get("name"))
|
||||
tool_instance = tool_manager.tools.get(tool_name)
|
||||
config_requirements = (
|
||||
tool_instance.get_config_requirements() if tool_instance else {}
|
||||
)
|
||||
existing_config = tool_doc.get("config", {})
|
||||
has_existing_secrets = "encrypted_credentials" in existing_config
|
||||
|
||||
if config_requirements:
|
||||
validation_errors = _validate_config(
|
||||
data["config"], config_requirements,
|
||||
has_existing_secrets=has_existing_secrets,
|
||||
)
|
||||
if validation_errors:
|
||||
return make_response(
|
||||
jsonify({
|
||||
"success": False,
|
||||
"message": "Validation failed",
|
||||
"errors": validation_errors,
|
||||
}),
|
||||
400,
|
||||
)
|
||||
|
||||
update_data["config"] = _merge_secrets_on_update(
|
||||
data["config"], existing_config, config_requirements, user
|
||||
if config_requirements:
|
||||
validation_errors = _validate_config(
|
||||
data["config"], config_requirements,
|
||||
has_existing_secrets=has_existing_secrets,
|
||||
)
|
||||
if "status" in data:
|
||||
update_data["status"] = bool(data["status"])
|
||||
repo.update(
|
||||
str(tool_doc["id"]), user, _api_to_update_fields(update_data),
|
||||
)
|
||||
else:
|
||||
if "status" in data:
|
||||
update_data["status"] = bool(data["status"])
|
||||
with db_session() as conn:
|
||||
repo = UserToolsRepository(conn)
|
||||
tool_doc = repo.get_any(data["id"], user)
|
||||
if not tool_doc:
|
||||
if validation_errors:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Tool not found"}),
|
||||
404,
|
||||
jsonify({
|
||||
"success": False,
|
||||
"message": "Validation failed",
|
||||
"errors": validation_errors,
|
||||
}),
|
||||
400,
|
||||
)
|
||||
repo.update(
|
||||
str(tool_doc["id"]), user, _api_to_update_fields(update_data),
|
||||
)
|
||||
|
||||
update_data["config"] = _merge_secrets_on_update(
|
||||
data["config"], existing_config, config_requirements, user
|
||||
)
|
||||
if "status" in data:
|
||||
update_data["status"] = data["status"]
|
||||
user_tools_collection.update_one(
|
||||
{"_id": ObjectId(data["id"]), "user": user},
|
||||
{"$set": update_data},
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error updating tool: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
@@ -472,50 +414,43 @@ class UpdateToolConfig(Resource):
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
try:
|
||||
with db_session() as conn:
|
||||
repo = UserToolsRepository(conn)
|
||||
tool_doc = repo.get_any(data["id"], user)
|
||||
if not tool_doc:
|
||||
return make_response(jsonify({"success": False}), 404)
|
||||
tool_doc = user_tools_collection.find_one(
|
||||
{"_id": ObjectId(data["id"]), "user": user}
|
||||
)
|
||||
if not tool_doc:
|
||||
return make_response(jsonify({"success": False}), 404)
|
||||
|
||||
tool_name = tool_doc.get("name")
|
||||
if tool_name == "mcp_tool":
|
||||
server_url = (data["config"].get("server_url") or "").strip()
|
||||
if server_url:
|
||||
try:
|
||||
validate_url(server_url)
|
||||
except SSRFError:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Invalid server URL"}),
|
||||
400,
|
||||
)
|
||||
tool_instance = tool_manager.tools.get(tool_name)
|
||||
config_requirements = (
|
||||
tool_instance.get_config_requirements() if tool_instance else {}
|
||||
tool_name = tool_doc.get("name")
|
||||
tool_instance = tool_manager.tools.get(tool_name)
|
||||
config_requirements = (
|
||||
tool_instance.get_config_requirements() if tool_instance else {}
|
||||
)
|
||||
existing_config = tool_doc.get("config", {})
|
||||
has_existing_secrets = "encrypted_credentials" in existing_config
|
||||
|
||||
if config_requirements:
|
||||
validation_errors = _validate_config(
|
||||
data["config"], config_requirements,
|
||||
has_existing_secrets=has_existing_secrets,
|
||||
)
|
||||
existing_config = tool_doc.get("config", {}) or {}
|
||||
has_existing_secrets = "encrypted_credentials" in existing_config
|
||||
|
||||
if config_requirements:
|
||||
validation_errors = _validate_config(
|
||||
data["config"], config_requirements,
|
||||
has_existing_secrets=has_existing_secrets,
|
||||
if validation_errors:
|
||||
return make_response(
|
||||
jsonify({
|
||||
"success": False,
|
||||
"message": "Validation failed",
|
||||
"errors": validation_errors,
|
||||
}),
|
||||
400,
|
||||
)
|
||||
if validation_errors:
|
||||
return make_response(
|
||||
jsonify({
|
||||
"success": False,
|
||||
"message": "Validation failed",
|
||||
"errors": validation_errors,
|
||||
}),
|
||||
400,
|
||||
)
|
||||
|
||||
final_config = _merge_secrets_on_update(
|
||||
data["config"], existing_config, config_requirements, user
|
||||
)
|
||||
final_config = _merge_secrets_on_update(
|
||||
data["config"], existing_config, config_requirements, user
|
||||
)
|
||||
|
||||
repo.update(str(tool_doc["id"]), user, {"config": final_config})
|
||||
user_tools_collection.update_one(
|
||||
{"_id": ObjectId(data["id"]), "user": user},
|
||||
{"$set": {"config": final_config}},
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error updating tool config: {err}", exc_info=True
|
||||
@@ -551,17 +486,10 @@ class UpdateToolActions(Resource):
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
try:
|
||||
with db_session() as conn:
|
||||
repo = UserToolsRepository(conn)
|
||||
tool_doc = repo.get_any(data["id"], user)
|
||||
if not tool_doc:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Tool not found"}),
|
||||
404,
|
||||
)
|
||||
repo.update(
|
||||
str(tool_doc["id"]), user, {"actions": data["actions"]},
|
||||
)
|
||||
user_tools_collection.update_one(
|
||||
{"_id": ObjectId(data["id"]), "user": user},
|
||||
{"$set": {"actions": data["actions"]}},
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error updating tool actions: {err}", exc_info=True
|
||||
@@ -595,17 +523,10 @@ class UpdateToolStatus(Resource):
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
try:
|
||||
with db_session() as conn:
|
||||
repo = UserToolsRepository(conn)
|
||||
tool_doc = repo.get_any(data["id"], user)
|
||||
if not tool_doc:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Tool not found"}),
|
||||
404,
|
||||
)
|
||||
repo.update(
|
||||
str(tool_doc["id"]), user, {"status": bool(data["status"])},
|
||||
)
|
||||
user_tools_collection.update_one(
|
||||
{"_id": ObjectId(data["id"]), "user": user},
|
||||
{"$set": {"status": data["status"]}},
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error updating tool status: {err}", exc_info=True
|
||||
@@ -634,14 +555,13 @@ class DeleteTool(Resource):
|
||||
if missing_fields:
|
||||
return missing_fields
|
||||
try:
|
||||
with db_session() as conn:
|
||||
repo = UserToolsRepository(conn)
|
||||
tool_doc = repo.get_any(data["id"], user)
|
||||
if not tool_doc:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Tool not found"}), 404
|
||||
)
|
||||
repo.delete(str(tool_doc["id"]), user)
|
||||
result = user_tools_collection.delete_one(
|
||||
{"_id": ObjectId(data["id"]), "user": user}
|
||||
)
|
||||
if result.deleted_count == 0:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Tool not found"}), 404
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(f"Error deleting tool: {err}", exc_info=True)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
@@ -710,88 +630,70 @@ class GetArtifact(Resource):
|
||||
user_id = decoded_token.get("sub")
|
||||
|
||||
try:
|
||||
with db_readonly() as conn:
|
||||
notes_repo = NotesRepository(conn)
|
||||
todos_repo = TodosRepository(conn)
|
||||
|
||||
# Artifact IDs may be PG UUIDs (post-cutover) or legacy
|
||||
# Mongo ObjectIds embedded in older conversation history.
|
||||
# Both repos' ``get_any`` handles the id-shape branching
|
||||
# internally so a non-UUID input never reaches
|
||||
# ``CAST(:id AS uuid)`` (which would poison the readonly
|
||||
# transaction and break the fallback below).
|
||||
note_doc = notes_repo.get_any(artifact_id, user_id)
|
||||
|
||||
if note_doc:
|
||||
content = note_doc.get("note", "") or note_doc.get("content", "")
|
||||
line_count = len(content.split("\n")) if content else 0
|
||||
updated = note_doc.get("updated_at")
|
||||
artifact = {
|
||||
"artifact_type": "note",
|
||||
"data": {
|
||||
"content": content,
|
||||
"line_count": line_count,
|
||||
"updated_at": (
|
||||
updated.isoformat()
|
||||
if hasattr(updated, "isoformat")
|
||||
else updated
|
||||
),
|
||||
},
|
||||
}
|
||||
return make_response(
|
||||
jsonify({"success": True, "artifact": artifact}), 200
|
||||
)
|
||||
|
||||
todo_doc = todos_repo.get_any(artifact_id, user_id)
|
||||
if todo_doc:
|
||||
tool_id = todo_doc.get("tool_id")
|
||||
all_todos = todos_repo.list_for_tool(user_id, tool_id) if tool_id else []
|
||||
items = []
|
||||
open_count = 0
|
||||
completed_count = 0
|
||||
for t in all_todos:
|
||||
# PG ``todos`` stores a ``completed BOOLEAN`` column;
|
||||
# the legacy Mongo shape used a ``status`` string.
|
||||
# Keep the response shape stable by translating here.
|
||||
status = "completed" if t.get("completed") else "open"
|
||||
if status == "open":
|
||||
open_count += 1
|
||||
else:
|
||||
completed_count += 1
|
||||
created = t.get("created_at")
|
||||
updated = t.get("updated_at")
|
||||
items.append({
|
||||
"todo_id": t.get("todo_id"),
|
||||
"title": t.get("title", ""),
|
||||
"status": status,
|
||||
"created_at": (
|
||||
created.isoformat()
|
||||
if hasattr(created, "isoformat")
|
||||
else created
|
||||
),
|
||||
"updated_at": (
|
||||
updated.isoformat()
|
||||
if hasattr(updated, "isoformat")
|
||||
else updated
|
||||
),
|
||||
})
|
||||
artifact = {
|
||||
"artifact_type": "todo_list",
|
||||
"data": {
|
||||
"items": items,
|
||||
"total_count": len(items),
|
||||
"open_count": open_count,
|
||||
"completed_count": completed_count,
|
||||
},
|
||||
}
|
||||
return make_response(
|
||||
jsonify({"success": True, "artifact": artifact}), 200
|
||||
)
|
||||
except Exception as err:
|
||||
current_app.logger.error(
|
||||
f"Error retrieving artifact: {err}", exc_info=True
|
||||
obj_id = ObjectId(artifact_id)
|
||||
except Exception:
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Invalid artifact ID"}), 400
|
||||
)
|
||||
return make_response(jsonify({"success": False}), 400)
|
||||
|
||||
from application.core.mongo_db import MongoDB
|
||||
from application.core.settings import settings
|
||||
|
||||
db = MongoDB.get_client()[settings.MONGO_DB_NAME]
|
||||
|
||||
note_doc = db["notes"].find_one({"_id": obj_id, "user_id": user_id})
|
||||
if note_doc:
|
||||
content = note_doc.get("note", "")
|
||||
line_count = len(content.split("\n")) if content else 0
|
||||
artifact = {
|
||||
"artifact_type": "note",
|
||||
"data": {
|
||||
"content": content,
|
||||
"line_count": line_count,
|
||||
"updated_at": (
|
||||
note_doc["updated_at"].isoformat()
|
||||
if note_doc.get("updated_at")
|
||||
else None
|
||||
),
|
||||
},
|
||||
}
|
||||
return make_response(jsonify({"success": True, "artifact": artifact}), 200)
|
||||
|
||||
todo_doc = db["todos"].find_one({"_id": obj_id, "user_id": user_id})
|
||||
if todo_doc:
|
||||
tool_id = todo_doc.get("tool_id")
|
||||
query = {"user_id": user_id, "tool_id": tool_id}
|
||||
all_todos = list(db["todos"].find(query))
|
||||
items = []
|
||||
open_count = 0
|
||||
completed_count = 0
|
||||
for t in all_todos:
|
||||
status = t.get("status", "open")
|
||||
if status == "open":
|
||||
open_count += 1
|
||||
elif status == "completed":
|
||||
completed_count += 1
|
||||
items.append({
|
||||
"todo_id": t.get("todo_id"),
|
||||
"title": t.get("title", ""),
|
||||
"status": status,
|
||||
"created_at": (
|
||||
t["created_at"].isoformat() if t.get("created_at") else None
|
||||
),
|
||||
"updated_at": (
|
||||
t["updated_at"].isoformat() if t.get("updated_at") else None
|
||||
),
|
||||
})
|
||||
artifact = {
|
||||
"artifact_type": "todo_list",
|
||||
"data": {
|
||||
"items": items,
|
||||
"total_count": len(items),
|
||||
"open_count": open_count,
|
||||
"completed_count": completed_count,
|
||||
},
|
||||
}
|
||||
return make_response(jsonify({"success": True, "artifact": artifact}), 200)
|
||||
|
||||
return make_response(
|
||||
jsonify({"success": False, "message": "Artifact not found"}), 404
|
||||
|
||||
@@ -1,61 +1,290 @@
|
||||
"""Centralized utilities for API routes.
|
||||
|
||||
Post-Mongo-cutover slim: the old Mongo-shaped helpers (``validate_object_id``,
|
||||
``check_resource_ownership``, ``paginated_response``, ``serialize_object_id``,
|
||||
``safe_db_operation``, ``validate_enum``, ``extract_sort_params``) have been
|
||||
removed — they carried ``bson`` / ``pymongo`` imports and had zero callers.
|
||||
"""
|
||||
"""Centralized utilities for API routes."""
|
||||
|
||||
from functools import wraps
|
||||
from typing import Callable, Optional
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
from bson.errors import InvalidId
|
||||
from bson.objectid import ObjectId
|
||||
from flask import (
|
||||
Response,
|
||||
current_app,
|
||||
has_app_context,
|
||||
jsonify,
|
||||
make_response,
|
||||
request,
|
||||
)
|
||||
from pymongo.collection import Collection
|
||||
|
||||
|
||||
def get_user_id() -> Optional[str]:
|
||||
"""Extract user ID from decoded JWT token, or None if unauthenticated."""
|
||||
"""
|
||||
Extract user ID from decoded JWT token.
|
||||
|
||||
Returns:
|
||||
User ID string or None if not authenticated
|
||||
"""
|
||||
decoded_token = getattr(request, "decoded_token", None)
|
||||
return decoded_token.get("sub") if decoded_token else None
|
||||
|
||||
|
||||
def require_auth(func: Callable) -> Callable:
|
||||
"""Decorator to require authentication. Returns 401 when absent."""
|
||||
"""
|
||||
Decorator to require authentication for route handlers.
|
||||
|
||||
Usage:
|
||||
@require_auth
|
||||
def get(self):
|
||||
user_id = get_user_id()
|
||||
...
|
||||
"""
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
user_id = get_user_id()
|
||||
if not user_id:
|
||||
return make_response(jsonify({"success": False, "error": "Unauthorized"}), 401)
|
||||
return error_response("Unauthorized", 401)
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def success_response(
|
||||
data=None, message: Optional[str] = None, status: int = 200
|
||||
data: Optional[Dict[str, Any]] = None, status: int = 200
|
||||
) -> Response:
|
||||
"""Shape a successful JSON response."""
|
||||
body = {"success": True}
|
||||
if data is not None:
|
||||
body["data"] = data
|
||||
if message is not None:
|
||||
body["message"] = message
|
||||
return make_response(jsonify(body), status)
|
||||
"""
|
||||
Create a standardized success response.
|
||||
|
||||
Args:
|
||||
data: Optional data dictionary to include in response
|
||||
status: HTTP status code (default: 200)
|
||||
|
||||
Returns:
|
||||
Flask Response object
|
||||
|
||||
Example:
|
||||
return success_response({"users": [...], "total": 10})
|
||||
"""
|
||||
response = {"success": True}
|
||||
if data:
|
||||
response.update(data)
|
||||
return make_response(jsonify(response), status)
|
||||
|
||||
|
||||
def error_response(message: str, status: int = 400, **kwargs) -> Response:
|
||||
"""Shape an error JSON response; any kwargs are merged into the body."""
|
||||
body = {"success": False, "error": message, **kwargs}
|
||||
return make_response(jsonify(body), status)
|
||||
"""
|
||||
Create a standardized error response.
|
||||
|
||||
Args:
|
||||
message: Error message string
|
||||
status: HTTP status code (default: 400)
|
||||
**kwargs: Additional fields to include in response
|
||||
|
||||
Returns:
|
||||
Flask Response object
|
||||
|
||||
Example:
|
||||
return error_response("Resource not found", 404)
|
||||
return error_response("Invalid input", 400, errors=["field1", "field2"])
|
||||
"""
|
||||
response = {"success": False, "message": message}
|
||||
response.update(kwargs)
|
||||
return make_response(jsonify(response), status)
|
||||
|
||||
|
||||
def require_fields(required: list) -> Callable:
|
||||
"""Decorator: return 400 if any listed field is missing/falsy in the JSON body."""
|
||||
def validate_object_id(
|
||||
id_string: str, resource_name: str = "Resource"
|
||||
) -> Tuple[Optional[ObjectId], Optional[Response]]:
|
||||
"""
|
||||
Validate and convert string to ObjectId.
|
||||
|
||||
Args:
|
||||
id_string: String to convert
|
||||
resource_name: Name of resource for error message
|
||||
|
||||
Returns:
|
||||
Tuple of (ObjectId or None, error_response or None)
|
||||
|
||||
Example:
|
||||
obj_id, error = validate_object_id(workflow_id, "Workflow")
|
||||
if error:
|
||||
return error
|
||||
"""
|
||||
try:
|
||||
return ObjectId(id_string), None
|
||||
except (InvalidId, TypeError):
|
||||
return None, error_response(f"Invalid {resource_name} ID format")
|
||||
|
||||
|
||||
def validate_pagination(
|
||||
default_limit: int = 20, max_limit: int = 100
|
||||
) -> Tuple[int, int, Optional[Response]]:
|
||||
"""
|
||||
Extract and validate pagination parameters from request.
|
||||
|
||||
Args:
|
||||
default_limit: Default items per page
|
||||
max_limit: Maximum allowed items per page
|
||||
|
||||
Returns:
|
||||
Tuple of (limit, skip, error_response or None)
|
||||
|
||||
Example:
|
||||
limit, skip, error = validate_pagination()
|
||||
if error:
|
||||
return error
|
||||
"""
|
||||
try:
|
||||
limit = min(int(request.args.get("limit", default_limit)), max_limit)
|
||||
skip = int(request.args.get("skip", 0))
|
||||
if limit < 1 or skip < 0:
|
||||
return 0, 0, error_response("Invalid pagination parameters")
|
||||
return limit, skip, None
|
||||
except ValueError:
|
||||
return 0, 0, error_response("Invalid pagination parameters")
|
||||
|
||||
|
||||
def check_resource_ownership(
|
||||
collection: Collection,
|
||||
resource_id: ObjectId,
|
||||
user_id: str,
|
||||
resource_name: str = "Resource",
|
||||
) -> Tuple[Optional[Dict], Optional[Response]]:
|
||||
"""
|
||||
Check if resource exists and belongs to user.
|
||||
|
||||
Args:
|
||||
collection: MongoDB collection
|
||||
resource_id: Resource ObjectId
|
||||
user_id: User ID string
|
||||
resource_name: Name of resource for error messages
|
||||
|
||||
Returns:
|
||||
Tuple of (resource_dict or None, error_response or None)
|
||||
|
||||
Example:
|
||||
workflow, error = check_resource_ownership(
|
||||
workflows_collection,
|
||||
workflow_id,
|
||||
user_id,
|
||||
"Workflow"
|
||||
)
|
||||
if error:
|
||||
return error
|
||||
"""
|
||||
resource = collection.find_one({"_id": resource_id, "user": user_id})
|
||||
if not resource:
|
||||
return None, error_response(f"{resource_name} not found", 404)
|
||||
return resource, None
|
||||
|
||||
|
||||
def serialize_object_id(
|
||||
obj: Dict[str, Any], id_field: str = "_id", new_field: str = "id"
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Convert ObjectId to string in a dictionary.
|
||||
|
||||
Args:
|
||||
obj: Dictionary containing ObjectId
|
||||
id_field: Field name containing ObjectId
|
||||
new_field: New field name for string ID
|
||||
|
||||
Returns:
|
||||
Modified dictionary
|
||||
|
||||
Example:
|
||||
user = serialize_object_id(user_doc)
|
||||
# user["id"] = "507f1f77bcf86cd799439011"
|
||||
"""
|
||||
if id_field in obj:
|
||||
obj[new_field] = str(obj[id_field])
|
||||
if id_field != new_field:
|
||||
obj.pop(id_field, None)
|
||||
return obj
|
||||
|
||||
|
||||
def serialize_list(items: List[Dict], serializer: Callable[[Dict], Dict]) -> List[Dict]:
|
||||
"""
|
||||
Apply serializer function to list of items.
|
||||
|
||||
Args:
|
||||
items: List of dictionaries
|
||||
serializer: Function to apply to each item
|
||||
|
||||
Returns:
|
||||
List of serialized items
|
||||
|
||||
Example:
|
||||
workflows = serialize_list(workflow_docs, serialize_workflow)
|
||||
"""
|
||||
return [serializer(item) for item in items]
|
||||
|
||||
|
||||
def paginated_response(
|
||||
collection: Collection,
|
||||
query: Dict[str, Any],
|
||||
serializer: Callable[[Dict], Dict],
|
||||
limit: int,
|
||||
skip: int,
|
||||
sort_field: str = "created_at",
|
||||
sort_order: int = -1,
|
||||
response_key: str = "items",
|
||||
) -> Response:
|
||||
"""
|
||||
Create paginated response for collection query.
|
||||
|
||||
Args:
|
||||
collection: MongoDB collection
|
||||
query: Query dictionary
|
||||
serializer: Function to serialize each item
|
||||
limit: Items per page
|
||||
skip: Number of items to skip
|
||||
sort_field: Field to sort by
|
||||
sort_order: Sort order (1=asc, -1=desc)
|
||||
response_key: Key name for items in response
|
||||
|
||||
Returns:
|
||||
Flask Response with paginated data
|
||||
|
||||
Example:
|
||||
return paginated_response(
|
||||
workflows_collection,
|
||||
{"user": user_id},
|
||||
serialize_workflow,
|
||||
limit, skip,
|
||||
response_key="workflows"
|
||||
)
|
||||
"""
|
||||
items = list(
|
||||
collection.find(query).sort(sort_field, sort_order).skip(skip).limit(limit)
|
||||
)
|
||||
total = collection.count_documents(query)
|
||||
|
||||
return success_response(
|
||||
{
|
||||
response_key: serialize_list(items, serializer),
|
||||
"total": total,
|
||||
"limit": limit,
|
||||
"skip": skip,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def require_fields(required: List[str]) -> Callable:
|
||||
"""
|
||||
Decorator to validate required fields in request JSON.
|
||||
|
||||
Args:
|
||||
required: List of required field names
|
||||
|
||||
Returns:
|
||||
Decorator function
|
||||
|
||||
Example:
|
||||
@require_fields(["name", "description"])
|
||||
def post(self):
|
||||
data = request.get_json()
|
||||
...
|
||||
"""
|
||||
|
||||
def decorator(func: Callable) -> Callable:
|
||||
@wraps(func)
|
||||
@@ -65,11 +294,94 @@ def require_fields(required: list) -> Callable:
|
||||
return error_response("Request body required")
|
||||
missing = [field for field in required if not data.get(field)]
|
||||
if missing:
|
||||
return error_response(
|
||||
f"Missing required fields: {', '.join(missing)}"
|
||||
)
|
||||
return error_response(f"Missing required fields: {', '.join(missing)}")
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def safe_db_operation(
|
||||
operation: Callable, error_message: str = "Database operation failed"
|
||||
) -> Tuple[Any, Optional[Response]]:
|
||||
"""
|
||||
Safely execute database operation with error handling.
|
||||
|
||||
Args:
|
||||
operation: Function to execute
|
||||
error_message: Error message if operation fails
|
||||
|
||||
Returns:
|
||||
Tuple of (result or None, error_response or None)
|
||||
|
||||
Example:
|
||||
result, error = safe_db_operation(
|
||||
lambda: collection.insert_one(doc),
|
||||
"Failed to create resource"
|
||||
)
|
||||
if error:
|
||||
return error
|
||||
"""
|
||||
try:
|
||||
result = operation()
|
||||
return result, None
|
||||
except Exception as err:
|
||||
if has_app_context():
|
||||
current_app.logger.error(f"{error_message}: {err}", exc_info=True)
|
||||
return None, error_response(error_message)
|
||||
|
||||
|
||||
def validate_enum(
|
||||
value: Any, allowed: List[Any], field_name: str
|
||||
) -> Optional[Response]:
|
||||
"""
|
||||
Validate that value is in allowed list.
|
||||
|
||||
Args:
|
||||
value: Value to validate
|
||||
allowed: List of allowed values
|
||||
field_name: Field name for error message
|
||||
|
||||
Returns:
|
||||
error_response if invalid, None if valid
|
||||
|
||||
Example:
|
||||
error = validate_enum(status, ["draft", "published"], "status")
|
||||
if error:
|
||||
return error
|
||||
"""
|
||||
if value not in allowed:
|
||||
allowed_str = ", ".join(f"'{v}'" for v in allowed)
|
||||
return error_response(f"Invalid {field_name}. Must be one of: {allowed_str}")
|
||||
return None
|
||||
|
||||
|
||||
def extract_sort_params(
|
||||
default_field: str = "created_at",
|
||||
default_order: str = "desc",
|
||||
allowed_fields: Optional[List[str]] = None,
|
||||
) -> Tuple[str, int]:
|
||||
"""
|
||||
Extract and validate sort parameters from request.
|
||||
|
||||
Args:
|
||||
default_field: Default sort field
|
||||
default_order: Default sort order ("asc" or "desc")
|
||||
allowed_fields: List of allowed sort fields (None = no validation)
|
||||
|
||||
Returns:
|
||||
Tuple of (sort_field, sort_order)
|
||||
|
||||
Example:
|
||||
sort_field, sort_order = extract_sort_params(
|
||||
allowed_fields=["name", "date", "status"]
|
||||
)
|
||||
"""
|
||||
sort_field = request.args.get("sort", default_field)
|
||||
sort_order_str = request.args.get("order", default_order).lower()
|
||||
|
||||
if allowed_fields and sort_field not in allowed_fields:
|
||||
sort_field = default_field
|
||||
sort_order = -1 if sort_order_str == "desc" else 1
|
||||
return sort_field, sort_order
|
||||
|
||||
@@ -1,26 +1,30 @@
|
||||
"""Workflow management routes."""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Dict, List, Optional, Set
|
||||
|
||||
from flask import current_app, request
|
||||
from flask_restx import Namespace, Resource
|
||||
|
||||
from application.storage.db.base_repository import looks_like_uuid
|
||||
from application.storage.db.repositories.workflow_edges import WorkflowEdgesRepository
|
||||
from application.storage.db.repositories.workflow_nodes import WorkflowNodesRepository
|
||||
from application.storage.db.repositories.workflows import WorkflowsRepository
|
||||
from application.storage.db.session import db_readonly, db_session
|
||||
from application.api.user.base import (
|
||||
workflow_edges_collection,
|
||||
workflow_nodes_collection,
|
||||
workflows_collection,
|
||||
)
|
||||
from application.core.json_schema_utils import (
|
||||
JsonSchemaValidationError,
|
||||
normalize_json_schema_payload,
|
||||
)
|
||||
from application.core.model_utils import get_model_capabilities
|
||||
from application.api.user.utils import (
|
||||
check_resource_ownership,
|
||||
error_response,
|
||||
get_user_id,
|
||||
require_auth,
|
||||
require_fields,
|
||||
safe_db_operation,
|
||||
success_response,
|
||||
validate_object_id,
|
||||
)
|
||||
|
||||
workflows_ns = Namespace("workflows", path="/api")
|
||||
@@ -31,112 +35,33 @@ def _workflow_error_response(message: str, err: Exception):
|
||||
return error_response(message)
|
||||
|
||||
|
||||
def _resolve_workflow(repo: WorkflowsRepository, workflow_id: str, user_id: str):
|
||||
"""Resolve a workflow by UUID or legacy Mongo id, scoped to user."""
|
||||
if not workflow_id:
|
||||
return None
|
||||
if looks_like_uuid(workflow_id):
|
||||
row = repo.get(workflow_id, user_id)
|
||||
if row is not None:
|
||||
return row
|
||||
return repo.get_by_legacy_id(workflow_id, user_id)
|
||||
|
||||
|
||||
def _write_graph(
|
||||
conn,
|
||||
pg_workflow_id: str,
|
||||
graph_version: int,
|
||||
nodes_data: List[Dict],
|
||||
edges_data: List[Dict],
|
||||
) -> List[Dict]:
|
||||
"""Bulk-create nodes + edges for one graph version. Uses ON CONFLICT upsert.
|
||||
|
||||
Edges arrive with source/target as user-provided node-id strings. We
|
||||
insert nodes first, capture their ``node_id → UUID`` map, then
|
||||
translate edges before insertion. Edges referencing missing nodes are
|
||||
dropped with a warning.
|
||||
"""
|
||||
nodes_repo = WorkflowNodesRepository(conn)
|
||||
edges_repo = WorkflowEdgesRepository(conn)
|
||||
|
||||
if nodes_data:
|
||||
created_nodes = nodes_repo.bulk_create(
|
||||
pg_workflow_id, graph_version,
|
||||
[
|
||||
{
|
||||
"node_id": n["id"],
|
||||
"node_type": n["type"],
|
||||
"title": n.get("title", ""),
|
||||
"description": n.get("description", ""),
|
||||
"position": n.get("position", {"x": 0, "y": 0}),
|
||||
"config": n.get("data", {}),
|
||||
}
|
||||
for n in nodes_data
|
||||
],
|
||||
)
|
||||
node_uuid_by_str = {n["node_id"]: n["id"] for n in created_nodes}
|
||||
else:
|
||||
created_nodes = []
|
||||
node_uuid_by_str = {}
|
||||
|
||||
if edges_data:
|
||||
translated_edges: List[Dict] = []
|
||||
for e in edges_data:
|
||||
src = e.get("source")
|
||||
tgt = e.get("target")
|
||||
from_uuid = node_uuid_by_str.get(src)
|
||||
to_uuid = node_uuid_by_str.get(tgt)
|
||||
if not from_uuid or not to_uuid:
|
||||
current_app.logger.warning(
|
||||
"Workflow graph write: dropping edge %s; node refs unresolved "
|
||||
"(source=%s, target=%s)",
|
||||
e.get("id"), src, tgt,
|
||||
)
|
||||
continue
|
||||
translated_edges.append({
|
||||
"edge_id": e["id"],
|
||||
"from_node_id": from_uuid,
|
||||
"to_node_id": to_uuid,
|
||||
"source_handle": e.get("sourceHandle"),
|
||||
"target_handle": e.get("targetHandle"),
|
||||
})
|
||||
if translated_edges:
|
||||
edges_repo.bulk_create(
|
||||
pg_workflow_id, graph_version, translated_edges,
|
||||
)
|
||||
|
||||
return created_nodes
|
||||
|
||||
|
||||
def serialize_workflow(w: Dict) -> Dict:
|
||||
"""Serialize workflow row to API response format."""
|
||||
created_at = w.get("created_at")
|
||||
updated_at = w.get("updated_at")
|
||||
"""Serialize workflow document to API response format."""
|
||||
return {
|
||||
"id": str(w["id"]),
|
||||
"id": str(w["_id"]),
|
||||
"name": w.get("name"),
|
||||
"description": w.get("description"),
|
||||
"created_at": created_at.isoformat() if hasattr(created_at, "isoformat") else created_at,
|
||||
"updated_at": updated_at.isoformat() if hasattr(updated_at, "isoformat") else updated_at,
|
||||
"created_at": w["created_at"].isoformat() if w.get("created_at") else None,
|
||||
"updated_at": w["updated_at"].isoformat() if w.get("updated_at") else None,
|
||||
}
|
||||
|
||||
|
||||
def serialize_node(n: Dict) -> Dict:
|
||||
"""Serialize workflow node row to API response format."""
|
||||
"""Serialize workflow node document to API response format."""
|
||||
return {
|
||||
"id": n["node_id"],
|
||||
"type": n["node_type"],
|
||||
"id": n["id"],
|
||||
"type": n["type"],
|
||||
"title": n.get("title"),
|
||||
"description": n.get("description"),
|
||||
"position": n.get("position"),
|
||||
"data": n.get("config", {}) or {},
|
||||
"data": n.get("config", {}),
|
||||
}
|
||||
|
||||
|
||||
def serialize_edge(e: Dict) -> Dict:
|
||||
"""Serialize workflow edge row to API response format."""
|
||||
"""Serialize workflow edge document to API response format."""
|
||||
return {
|
||||
"id": e["edge_id"],
|
||||
"id": e["id"],
|
||||
"source": e.get("source_id"),
|
||||
"target": e.get("target_id"),
|
||||
"sourceHandle": e.get("source_handle"),
|
||||
@@ -145,7 +70,7 @@ def serialize_edge(e: Dict) -> Dict:
|
||||
|
||||
|
||||
def get_workflow_graph_version(workflow: Dict) -> int:
|
||||
"""Get current graph version with fallback."""
|
||||
"""Get current graph version with legacy fallback."""
|
||||
raw_version = workflow.get("current_graph_version", 1)
|
||||
try:
|
||||
version = int(raw_version)
|
||||
@@ -154,6 +79,22 @@ def get_workflow_graph_version(workflow: Dict) -> int:
|
||||
return 1
|
||||
|
||||
|
||||
def fetch_graph_documents(collection, workflow_id: str, graph_version: int) -> List[Dict]:
|
||||
"""Fetch graph docs for active version, with fallback for legacy unversioned data."""
|
||||
docs = list(
|
||||
collection.find({"workflow_id": workflow_id, "graph_version": graph_version})
|
||||
)
|
||||
if docs:
|
||||
return docs
|
||||
if graph_version == 1:
|
||||
return list(
|
||||
collection.find(
|
||||
{"workflow_id": workflow_id, "graph_version": {"$exists": False}}
|
||||
)
|
||||
)
|
||||
return docs
|
||||
|
||||
|
||||
def validate_json_schema_payload(
|
||||
json_schema: Any,
|
||||
) -> tuple[Optional[Dict[str, Any]], Optional[str]]:
|
||||
@@ -374,6 +315,49 @@ def _can_reach_end(
|
||||
return any(_can_reach_end(t, edges, node_map, end_ids, visited) for t in outgoing if t)
|
||||
|
||||
|
||||
def create_workflow_nodes(
|
||||
workflow_id: str, nodes_data: List[Dict], graph_version: int
|
||||
) -> None:
|
||||
"""Insert workflow nodes into database."""
|
||||
if nodes_data:
|
||||
workflow_nodes_collection.insert_many(
|
||||
[
|
||||
{
|
||||
"id": n["id"],
|
||||
"workflow_id": workflow_id,
|
||||
"graph_version": graph_version,
|
||||
"type": n["type"],
|
||||
"title": n.get("title", ""),
|
||||
"description": n.get("description", ""),
|
||||
"position": n.get("position", {"x": 0, "y": 0}),
|
||||
"config": n.get("data", {}),
|
||||
}
|
||||
for n in nodes_data
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def create_workflow_edges(
|
||||
workflow_id: str, edges_data: List[Dict], graph_version: int
|
||||
) -> None:
|
||||
"""Insert workflow edges into database."""
|
||||
if edges_data:
|
||||
workflow_edges_collection.insert_many(
|
||||
[
|
||||
{
|
||||
"id": e["id"],
|
||||
"workflow_id": workflow_id,
|
||||
"graph_version": graph_version,
|
||||
"source_id": e.get("source"),
|
||||
"target_id": e.get("target"),
|
||||
"source_handle": e.get("sourceHandle"),
|
||||
"target_handle": e.get("targetHandle"),
|
||||
}
|
||||
for e in edges_data
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@workflows_ns.route("/workflows")
|
||||
class WorkflowList(Resource):
|
||||
|
||||
@@ -385,7 +369,6 @@ class WorkflowList(Resource):
|
||||
data = request.get_json()
|
||||
|
||||
name = data.get("name", "").strip()
|
||||
description = data.get("description", "")
|
||||
nodes_data = data.get("nodes", [])
|
||||
edges_data = data.get("edges", [])
|
||||
|
||||
@@ -396,16 +379,35 @@ class WorkflowList(Resource):
|
||||
)
|
||||
nodes_data = normalize_agent_node_json_schemas(nodes_data)
|
||||
|
||||
try:
|
||||
with db_session() as conn:
|
||||
repo = WorkflowsRepository(conn)
|
||||
workflow = repo.create(user_id, name, description=description)
|
||||
pg_workflow_id = str(workflow["id"])
|
||||
_write_graph(conn, pg_workflow_id, 1, nodes_data, edges_data)
|
||||
except Exception as err:
|
||||
return _workflow_error_response("Failed to create workflow", err)
|
||||
now = datetime.now(timezone.utc)
|
||||
workflow_doc = {
|
||||
"name": name,
|
||||
"description": data.get("description", ""),
|
||||
"user": user_id,
|
||||
"created_at": now,
|
||||
"updated_at": now,
|
||||
"current_graph_version": 1,
|
||||
}
|
||||
|
||||
return success_response({"id": pg_workflow_id}, 201)
|
||||
result, error = safe_db_operation(
|
||||
lambda: workflows_collection.insert_one(workflow_doc),
|
||||
"Failed to create workflow",
|
||||
)
|
||||
if error:
|
||||
return error
|
||||
|
||||
workflow_id = str(result.inserted_id)
|
||||
|
||||
try:
|
||||
create_workflow_nodes(workflow_id, nodes_data, 1)
|
||||
create_workflow_edges(workflow_id, edges_data, 1)
|
||||
except Exception as err:
|
||||
workflow_nodes_collection.delete_many({"workflow_id": workflow_id})
|
||||
workflow_edges_collection.delete_many({"workflow_id": workflow_id})
|
||||
workflows_collection.delete_one({"_id": result.inserted_id})
|
||||
return _workflow_error_response("Failed to create workflow structure", err)
|
||||
|
||||
return success_response({"id": workflow_id}, 201)
|
||||
|
||||
|
||||
@workflows_ns.route("/workflows/<string:workflow_id>")
|
||||
@@ -415,22 +417,23 @@ class WorkflowDetail(Resource):
|
||||
def get(self, workflow_id: str):
|
||||
"""Get workflow details with nodes and edges."""
|
||||
user_id = get_user_id()
|
||||
try:
|
||||
with db_readonly() as conn:
|
||||
repo = WorkflowsRepository(conn)
|
||||
workflow = _resolve_workflow(repo, workflow_id, user_id)
|
||||
if workflow is None:
|
||||
return error_response("Workflow not found", 404)
|
||||
pg_workflow_id = str(workflow["id"])
|
||||
graph_version = get_workflow_graph_version(workflow)
|
||||
nodes = WorkflowNodesRepository(conn).find_by_version(
|
||||
pg_workflow_id, graph_version,
|
||||
)
|
||||
edges = WorkflowEdgesRepository(conn).find_by_version(
|
||||
pg_workflow_id, graph_version,
|
||||
)
|
||||
except Exception as err:
|
||||
return _workflow_error_response("Failed to fetch workflow", err)
|
||||
obj_id, error = validate_object_id(workflow_id, "Workflow")
|
||||
if error:
|
||||
return error
|
||||
|
||||
workflow, error = check_resource_ownership(
|
||||
workflows_collection, obj_id, user_id, "Workflow"
|
||||
)
|
||||
if error:
|
||||
return error
|
||||
|
||||
graph_version = get_workflow_graph_version(workflow)
|
||||
nodes = fetch_graph_documents(
|
||||
workflow_nodes_collection, workflow_id, graph_version
|
||||
)
|
||||
edges = fetch_graph_documents(
|
||||
workflow_edges_collection, workflow_id, graph_version
|
||||
)
|
||||
|
||||
return success_response(
|
||||
{
|
||||
@@ -445,9 +448,18 @@ class WorkflowDetail(Resource):
|
||||
def put(self, workflow_id: str):
|
||||
"""Update workflow and replace nodes/edges."""
|
||||
user_id = get_user_id()
|
||||
obj_id, error = validate_object_id(workflow_id, "Workflow")
|
||||
if error:
|
||||
return error
|
||||
|
||||
workflow, error = check_resource_ownership(
|
||||
workflows_collection, obj_id, user_id, "Workflow"
|
||||
)
|
||||
if error:
|
||||
return error
|
||||
|
||||
data = request.get_json()
|
||||
name = data.get("name", "").strip()
|
||||
description = data.get("description", "")
|
||||
nodes_data = data.get("nodes", [])
|
||||
edges_data = data.get("edges", [])
|
||||
|
||||
@@ -458,36 +470,55 @@ class WorkflowDetail(Resource):
|
||||
)
|
||||
nodes_data = normalize_agent_node_json_schemas(nodes_data)
|
||||
|
||||
current_graph_version = get_workflow_graph_version(workflow)
|
||||
next_graph_version = current_graph_version + 1
|
||||
try:
|
||||
with db_session() as conn:
|
||||
repo = WorkflowsRepository(conn)
|
||||
workflow = _resolve_workflow(repo, workflow_id, user_id)
|
||||
if workflow is None:
|
||||
return error_response("Workflow not found", 404)
|
||||
pg_workflow_id = str(workflow["id"])
|
||||
current_graph_version = get_workflow_graph_version(workflow)
|
||||
next_graph_version = current_graph_version + 1
|
||||
|
||||
_write_graph(
|
||||
conn, pg_workflow_id, next_graph_version,
|
||||
nodes_data, edges_data,
|
||||
)
|
||||
repo.update(
|
||||
pg_workflow_id, user_id,
|
||||
{
|
||||
"name": name,
|
||||
"description": description,
|
||||
"current_graph_version": next_graph_version,
|
||||
},
|
||||
)
|
||||
WorkflowNodesRepository(conn).delete_other_versions(
|
||||
pg_workflow_id, next_graph_version,
|
||||
)
|
||||
WorkflowEdgesRepository(conn).delete_other_versions(
|
||||
pg_workflow_id, next_graph_version,
|
||||
)
|
||||
create_workflow_nodes(workflow_id, nodes_data, next_graph_version)
|
||||
create_workflow_edges(workflow_id, edges_data, next_graph_version)
|
||||
except Exception as err:
|
||||
return _workflow_error_response("Failed to update workflow", err)
|
||||
workflow_nodes_collection.delete_many(
|
||||
{"workflow_id": workflow_id, "graph_version": next_graph_version}
|
||||
)
|
||||
workflow_edges_collection.delete_many(
|
||||
{"workflow_id": workflow_id, "graph_version": next_graph_version}
|
||||
)
|
||||
return _workflow_error_response("Failed to update workflow structure", err)
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
_, error = safe_db_operation(
|
||||
lambda: workflows_collection.update_one(
|
||||
{"_id": obj_id},
|
||||
{
|
||||
"$set": {
|
||||
"name": name,
|
||||
"description": data.get("description", ""),
|
||||
"updated_at": now,
|
||||
"current_graph_version": next_graph_version,
|
||||
}
|
||||
},
|
||||
),
|
||||
"Failed to update workflow",
|
||||
)
|
||||
if error:
|
||||
workflow_nodes_collection.delete_many(
|
||||
{"workflow_id": workflow_id, "graph_version": next_graph_version}
|
||||
)
|
||||
workflow_edges_collection.delete_many(
|
||||
{"workflow_id": workflow_id, "graph_version": next_graph_version}
|
||||
)
|
||||
return error
|
||||
|
||||
try:
|
||||
workflow_nodes_collection.delete_many(
|
||||
{"workflow_id": workflow_id, "graph_version": {"$ne": next_graph_version}}
|
||||
)
|
||||
workflow_edges_collection.delete_many(
|
||||
{"workflow_id": workflow_id, "graph_version": {"$ne": next_graph_version}}
|
||||
)
|
||||
except Exception as cleanup_err:
|
||||
current_app.logger.warning(
|
||||
f"Failed to clean old workflow graph versions for {workflow_id}: {cleanup_err}"
|
||||
)
|
||||
|
||||
return success_response()
|
||||
|
||||
@@ -495,14 +526,20 @@ class WorkflowDetail(Resource):
|
||||
def delete(self, workflow_id: str):
|
||||
"""Delete workflow and its graph."""
|
||||
user_id = get_user_id()
|
||||
obj_id, error = validate_object_id(workflow_id, "Workflow")
|
||||
if error:
|
||||
return error
|
||||
|
||||
workflow, error = check_resource_ownership(
|
||||
workflows_collection, obj_id, user_id, "Workflow"
|
||||
)
|
||||
if error:
|
||||
return error
|
||||
|
||||
try:
|
||||
with db_session() as conn:
|
||||
repo = WorkflowsRepository(conn)
|
||||
workflow = _resolve_workflow(repo, workflow_id, user_id)
|
||||
if workflow is None:
|
||||
return error_response("Workflow not found", 404)
|
||||
# ON DELETE CASCADE on workflow_nodes/edges cleans children.
|
||||
repo.delete(str(workflow["id"]), user_id)
|
||||
workflow_nodes_collection.delete_many({"workflow_id": workflow_id})
|
||||
workflow_edges_collection.delete_many({"workflow_id": workflow_id})
|
||||
workflows_collection.delete_one({"_id": workflow["_id"], "user": user_id})
|
||||
except Exception as err:
|
||||
return _workflow_error_response("Failed to delete workflow", err)
|
||||
|
||||
|
||||
@@ -1,3 +0,0 @@
|
||||
from application.api.v1.routes import v1_bp
|
||||
|
||||
__all__ = ["v1_bp"]
|
||||
@@ -1,331 +0,0 @@
|
||||
"""Standard chat completions API routes.
|
||||
|
||||
Exposes ``/v1/chat/completions`` and ``/v1/models`` endpoints that
|
||||
follow the widely-adopted chat completions protocol so external tools
|
||||
(opencode, continue, etc.) can connect to DocsGPT agents.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
import traceback
|
||||
from typing import Any, Dict, Generator, Optional
|
||||
|
||||
from flask import Blueprint, jsonify, make_response, request, Response
|
||||
|
||||
from application.api.answer.routes.base import BaseAnswerResource
|
||||
from application.api.answer.services.stream_processor import StreamProcessor
|
||||
from application.api.v1.translator import (
|
||||
translate_request,
|
||||
translate_response,
|
||||
translate_stream_event,
|
||||
)
|
||||
from application.storage.db.repositories.agents import AgentsRepository
|
||||
from application.storage.db.session import db_readonly
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
v1_bp = Blueprint("v1", __name__, url_prefix="/v1")
|
||||
|
||||
|
||||
def _extract_bearer_token() -> Optional[str]:
|
||||
"""Extract API key from Authorization: Bearer header."""
|
||||
auth = request.headers.get("Authorization", "")
|
||||
if auth.startswith("Bearer "):
|
||||
return auth[7:].strip()
|
||||
return None
|
||||
|
||||
|
||||
def _lookup_agent(api_key: str) -> Optional[Dict]:
|
||||
"""Look up the agent document for this API key."""
|
||||
try:
|
||||
with db_readonly() as conn:
|
||||
return AgentsRepository(conn).find_by_key(api_key)
|
||||
except Exception:
|
||||
logger.warning("Failed to look up agent for API key", exc_info=True)
|
||||
return None
|
||||
|
||||
|
||||
def _get_model_name(agent: Optional[Dict], api_key: str) -> str:
|
||||
"""Return agent name for display as model name."""
|
||||
if agent:
|
||||
return agent.get("name", api_key)
|
||||
return api_key
|
||||
|
||||
|
||||
class _V1AnswerHelper(BaseAnswerResource):
|
||||
"""Thin wrapper to access complete_stream / process_response_stream."""
|
||||
pass
|
||||
|
||||
|
||||
@v1_bp.route("/chat/completions", methods=["POST"])
|
||||
def chat_completions():
|
||||
"""Handle POST /v1/chat/completions."""
|
||||
api_key = _extract_bearer_token()
|
||||
if not api_key:
|
||||
return make_response(
|
||||
jsonify({"error": {"message": "Missing Authorization header", "type": "auth_error"}}),
|
||||
401,
|
||||
)
|
||||
|
||||
data = request.get_json()
|
||||
if not data or not data.get("messages"):
|
||||
return make_response(
|
||||
jsonify({"error": {"message": "messages field is required", "type": "invalid_request"}}),
|
||||
400,
|
||||
)
|
||||
|
||||
is_stream = data.get("stream", False)
|
||||
agent_doc = _lookup_agent(api_key)
|
||||
model_name = _get_model_name(agent_doc, api_key)
|
||||
|
||||
try:
|
||||
internal_data = translate_request(data, api_key)
|
||||
except Exception as e:
|
||||
logger.error(f"/v1/chat/completions translate error: {e}", exc_info=True)
|
||||
return make_response(
|
||||
jsonify({"error": {"message": "Failed to process request", "type": "invalid_request"}}),
|
||||
400,
|
||||
)
|
||||
|
||||
# Link decoded_token to the agent's owner so continuation state,
|
||||
# logs, and tool execution use the correct user identity. The PG
|
||||
# ``agents`` row exposes the owner via ``user_id`` (``user`` is the
|
||||
# legacy Mongo field name kept in ``row_to_dict`` only for the
|
||||
# mapping ``id``/``_id``).
|
||||
agent_user = (
|
||||
(agent_doc.get("user_id") or agent_doc.get("user"))
|
||||
if agent_doc else None
|
||||
)
|
||||
decoded_token = {"sub": agent_user or "api_key_user"}
|
||||
|
||||
try:
|
||||
processor = StreamProcessor(internal_data, decoded_token)
|
||||
|
||||
if internal_data.get("tool_actions"):
|
||||
# Continuation mode
|
||||
conversation_id = internal_data.get("conversation_id")
|
||||
if not conversation_id:
|
||||
return make_response(
|
||||
jsonify({"error": {"message": "conversation_id required for tool continuation", "type": "invalid_request"}}),
|
||||
400,
|
||||
)
|
||||
(
|
||||
agent,
|
||||
messages,
|
||||
tools_dict,
|
||||
pending_tool_calls,
|
||||
tool_actions,
|
||||
) = processor.resume_from_tool_actions(
|
||||
internal_data["tool_actions"], conversation_id
|
||||
)
|
||||
continuation = {
|
||||
"messages": messages,
|
||||
"tools_dict": tools_dict,
|
||||
"pending_tool_calls": pending_tool_calls,
|
||||
"tool_actions": tool_actions,
|
||||
}
|
||||
question = ""
|
||||
else:
|
||||
# Normal mode
|
||||
question = internal_data.get("question", "")
|
||||
agent = processor.build_agent(question)
|
||||
continuation = None
|
||||
|
||||
if not processor.decoded_token:
|
||||
return make_response(
|
||||
jsonify({"error": {"message": "Unauthorized", "type": "auth_error"}}),
|
||||
401,
|
||||
)
|
||||
|
||||
helper = _V1AnswerHelper()
|
||||
usage_error = helper.check_usage(processor.agent_config)
|
||||
if usage_error:
|
||||
return usage_error
|
||||
|
||||
should_save_conversation = bool(internal_data.get("save_conversation", False))
|
||||
|
||||
if is_stream:
|
||||
return Response(
|
||||
_stream_response(
|
||||
helper,
|
||||
question,
|
||||
agent,
|
||||
processor,
|
||||
model_name,
|
||||
continuation,
|
||||
should_save_conversation,
|
||||
),
|
||||
mimetype="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
else:
|
||||
return _non_stream_response(
|
||||
helper,
|
||||
question,
|
||||
agent,
|
||||
processor,
|
||||
model_name,
|
||||
continuation,
|
||||
should_save_conversation,
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
logger.error(
|
||||
f"/v1/chat/completions error: {e} - {traceback.format_exc()}",
|
||||
extra={"error": str(e)},
|
||||
)
|
||||
return make_response(
|
||||
jsonify({"error": {"message": "Failed to process request", "type": "invalid_request"}}),
|
||||
400,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"/v1/chat/completions error: {e} - {traceback.format_exc()}",
|
||||
extra={"error": str(e)},
|
||||
)
|
||||
return make_response(
|
||||
jsonify({"error": {"message": "Internal server error", "type": "server_error"}}),
|
||||
500,
|
||||
)
|
||||
|
||||
|
||||
def _stream_response(
|
||||
helper: _V1AnswerHelper,
|
||||
question: str,
|
||||
agent: Any,
|
||||
processor: StreamProcessor,
|
||||
model_name: str,
|
||||
continuation: Optional[Dict],
|
||||
should_save_conversation: bool,
|
||||
) -> Generator[str, None, None]:
|
||||
"""Generate translated SSE chunks for streaming response."""
|
||||
completion_id = f"chatcmpl-{int(time.time())}"
|
||||
|
||||
internal_stream = helper.complete_stream(
|
||||
question=question,
|
||||
agent=agent,
|
||||
conversation_id=processor.conversation_id,
|
||||
user_api_key=processor.agent_config.get("user_api_key"),
|
||||
decoded_token=processor.decoded_token,
|
||||
agent_id=processor.agent_id,
|
||||
model_id=processor.model_id,
|
||||
should_save_conversation=should_save_conversation,
|
||||
_continuation=continuation,
|
||||
)
|
||||
|
||||
for line in internal_stream:
|
||||
if not line.strip():
|
||||
continue
|
||||
# Parse the internal SSE event
|
||||
event_str = line.replace("data: ", "").strip()
|
||||
try:
|
||||
event_data = json.loads(event_str)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
continue
|
||||
|
||||
# Update completion_id when we get the conversation id
|
||||
if event_data.get("type") == "id":
|
||||
conv_id = event_data.get("id", "")
|
||||
if conv_id:
|
||||
completion_id = f"chatcmpl-{conv_id}"
|
||||
|
||||
# Translate to standard format
|
||||
translated = translate_stream_event(event_data, completion_id, model_name)
|
||||
for chunk in translated:
|
||||
yield chunk
|
||||
|
||||
|
||||
def _non_stream_response(
|
||||
helper: _V1AnswerHelper,
|
||||
question: str,
|
||||
agent: Any,
|
||||
processor: StreamProcessor,
|
||||
model_name: str,
|
||||
continuation: Optional[Dict],
|
||||
should_save_conversation: bool,
|
||||
) -> Response:
|
||||
"""Collect full response and return as single JSON."""
|
||||
stream = helper.complete_stream(
|
||||
question=question,
|
||||
agent=agent,
|
||||
conversation_id=processor.conversation_id,
|
||||
user_api_key=processor.agent_config.get("user_api_key"),
|
||||
decoded_token=processor.decoded_token,
|
||||
agent_id=processor.agent_id,
|
||||
model_id=processor.model_id,
|
||||
should_save_conversation=should_save_conversation,
|
||||
_continuation=continuation,
|
||||
)
|
||||
|
||||
result = helper.process_response_stream(stream)
|
||||
|
||||
if result["error"]:
|
||||
return make_response(
|
||||
jsonify({"error": {"message": result["error"], "type": "server_error"}}),
|
||||
500,
|
||||
)
|
||||
|
||||
extra = result.get("extra")
|
||||
pending = extra.get("pending_tool_calls") if isinstance(extra, dict) else None
|
||||
|
||||
response = translate_response(
|
||||
conversation_id=result["conversation_id"],
|
||||
answer=result["answer"] or "",
|
||||
sources=result["sources"],
|
||||
tool_calls=result["tool_calls"],
|
||||
thought=result["thought"] or "",
|
||||
model_name=model_name,
|
||||
pending_tool_calls=pending,
|
||||
)
|
||||
return make_response(jsonify(response), 200)
|
||||
|
||||
|
||||
@v1_bp.route("/models", methods=["GET"])
|
||||
def list_models():
|
||||
"""Handle GET /v1/models — return agents as models."""
|
||||
api_key = _extract_bearer_token()
|
||||
if not api_key:
|
||||
return make_response(
|
||||
jsonify({"error": {"message": "Missing Authorization header", "type": "auth_error"}}),
|
||||
401,
|
||||
)
|
||||
|
||||
try:
|
||||
with db_readonly() as conn:
|
||||
agents_repo = AgentsRepository(conn)
|
||||
agent = agents_repo.find_by_key(api_key)
|
||||
if not agent:
|
||||
return make_response(
|
||||
jsonify({"error": {"message": "Invalid API key", "type": "auth_error"}}),
|
||||
401,
|
||||
)
|
||||
|
||||
created = agent.get("created_at") or agent.get("createdAt")
|
||||
created_ts = (
|
||||
int(created.timestamp()) if hasattr(created, "timestamp")
|
||||
else int(time.time())
|
||||
)
|
||||
model_id = str(agent.get("id") or agent.get("_id") or "")
|
||||
model = {
|
||||
"id": model_id,
|
||||
"object": "model",
|
||||
"created": created_ts,
|
||||
"owned_by": "docsgpt",
|
||||
"name": agent.get("name", ""),
|
||||
"description": agent.get("description", ""),
|
||||
}
|
||||
|
||||
return make_response(
|
||||
jsonify({"object": "list", "data": [model]}),
|
||||
200,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"/v1/models error: {e}", exc_info=True)
|
||||
return make_response(
|
||||
jsonify({"error": {"message": "Internal server error", "type": "server_error"}}),
|
||||
500,
|
||||
)
|
||||
@@ -1,433 +0,0 @@
|
||||
"""Translate between standard chat completions format and DocsGPT internals.
|
||||
|
||||
This module handles:
|
||||
- Request translation (chat completions -> DocsGPT internal format)
|
||||
- Response translation (DocsGPT response -> chat completions format)
|
||||
- Streaming event translation (DocsGPT SSE -> standard SSE chunks)
|
||||
"""
|
||||
|
||||
import json
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
def _get_client_tool_name(tc: Dict) -> str:
|
||||
"""Return the original tool name for client-facing responses.
|
||||
|
||||
For client-side tools the ``tool_name`` field carries the name the
|
||||
client originally registered. Fall back to ``action_name`` (which
|
||||
is now the clean LLM-visible name) or ``name``.
|
||||
"""
|
||||
return tc.get("tool_name", tc.get("action_name", tc.get("name", "")))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Request translation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def is_continuation(messages: List[Dict]) -> bool:
|
||||
"""Check if messages represent a tool-call continuation.
|
||||
|
||||
A continuation is detected when the last message(s) have ``role: "tool"``
|
||||
immediately after an assistant message with ``tool_calls``.
|
||||
"""
|
||||
if not messages:
|
||||
return False
|
||||
# Walk backwards: if we see tool messages before hitting a non-tool, non-assistant message
|
||||
# and there's an assistant message with tool_calls, it's a continuation.
|
||||
i = len(messages) - 1
|
||||
while i >= 0 and messages[i].get("role") == "tool":
|
||||
i -= 1
|
||||
if i < 0:
|
||||
return False
|
||||
return (
|
||||
messages[i].get("role") == "assistant"
|
||||
and bool(messages[i].get("tool_calls"))
|
||||
)
|
||||
|
||||
|
||||
def extract_tool_results(messages: List[Dict]) -> List[Dict]:
|
||||
"""Extract tool results from trailing tool messages for continuation.
|
||||
|
||||
Returns a list of ``tool_actions`` dicts with ``call_id`` and ``result``.
|
||||
"""
|
||||
results = []
|
||||
for msg in reversed(messages):
|
||||
if msg.get("role") != "tool":
|
||||
break
|
||||
call_id = msg.get("tool_call_id", "")
|
||||
content = msg.get("content", "")
|
||||
if isinstance(content, str):
|
||||
try:
|
||||
content = json.loads(content)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
results.append({"call_id": call_id, "result": content})
|
||||
results.reverse()
|
||||
return results
|
||||
|
||||
|
||||
def extract_conversation_id(messages: List[Dict]) -> Optional[str]:
|
||||
"""Try to extract conversation_id from the assistant message before tool results.
|
||||
|
||||
The conversation_id may be stored in a custom field on the assistant message
|
||||
from a previous response cycle.
|
||||
"""
|
||||
for msg in reversed(messages):
|
||||
if msg.get("role") == "assistant":
|
||||
# Check docsgpt extension
|
||||
return msg.get("docsgpt", {}).get("conversation_id")
|
||||
return None
|
||||
|
||||
|
||||
def extract_system_prompt(messages: List[Dict]) -> Optional[str]:
|
||||
"""Extract the first system message content from the messages array.
|
||||
|
||||
Returns None if no system message is present.
|
||||
"""
|
||||
for msg in messages:
|
||||
if msg.get("role") == "system":
|
||||
return msg.get("content", "")
|
||||
return None
|
||||
|
||||
|
||||
def convert_history(messages: List[Dict]) -> List[Dict]:
|
||||
"""Convert chat completions messages array to DocsGPT history format.
|
||||
|
||||
DocsGPT history is a list of ``{prompt, response}`` dicts.
|
||||
Excludes the last user message (that becomes the ``question``).
|
||||
"""
|
||||
history = []
|
||||
i = 0
|
||||
while i < len(messages):
|
||||
msg = messages[i]
|
||||
if msg.get("role") == "system":
|
||||
i += 1
|
||||
continue
|
||||
if msg.get("role") == "user":
|
||||
# Look ahead for assistant response
|
||||
if i + 1 < len(messages) and messages[i + 1].get("role") == "assistant":
|
||||
content = messages[i + 1].get("content") or ""
|
||||
history.append({
|
||||
"prompt": msg.get("content", ""),
|
||||
"response": content,
|
||||
})
|
||||
i += 2
|
||||
continue
|
||||
# Last user message without response — skip (it's the question)
|
||||
i += 1
|
||||
continue
|
||||
i += 1
|
||||
return history
|
||||
|
||||
|
||||
def translate_request(
|
||||
data: Dict[str, Any], api_key: str
|
||||
) -> Dict[str, Any]:
|
||||
"""Translate a chat completions request to DocsGPT internal format.
|
||||
|
||||
Args:
|
||||
data: The incoming request body.
|
||||
api_key: Agent API key from the Authorization header.
|
||||
|
||||
Returns:
|
||||
Dict suitable for passing to ``StreamProcessor``.
|
||||
"""
|
||||
messages = data.get("messages", [])
|
||||
|
||||
# Check for continuation (tool results after assistant tool_calls)
|
||||
if is_continuation(messages):
|
||||
tool_actions = extract_tool_results(messages)
|
||||
conversation_id = extract_conversation_id(messages)
|
||||
if not conversation_id:
|
||||
conversation_id = data.get("conversation_id")
|
||||
result = {
|
||||
"conversation_id": conversation_id,
|
||||
"tool_actions": tool_actions,
|
||||
"api_key": api_key,
|
||||
}
|
||||
# Carry tools forward for next iteration
|
||||
if data.get("tools"):
|
||||
result["client_tools"] = data["tools"]
|
||||
return result
|
||||
|
||||
# Normal request — extract question from last user message
|
||||
question = ""
|
||||
for msg in reversed(messages):
|
||||
if msg.get("role") == "user":
|
||||
question = msg.get("content", "")
|
||||
break
|
||||
|
||||
history = convert_history(messages)
|
||||
system_prompt_override = extract_system_prompt(messages)
|
||||
|
||||
docsgpt = data.get("docsgpt", {})
|
||||
|
||||
result = {
|
||||
"question": question,
|
||||
"api_key": api_key,
|
||||
"history": json.dumps(history),
|
||||
# Conversations are NOT persisted by default on the v1 endpoint.
|
||||
# Callers opt in via ``docsgpt.save_conversation: true``.
|
||||
"save_conversation": bool(docsgpt.get("save_conversation", False)),
|
||||
}
|
||||
|
||||
if system_prompt_override is not None:
|
||||
result["system_prompt_override"] = system_prompt_override
|
||||
|
||||
# Client tools
|
||||
if data.get("tools"):
|
||||
result["client_tools"] = data["tools"]
|
||||
|
||||
# DocsGPT extensions
|
||||
if docsgpt.get("attachments"):
|
||||
result["attachments"] = docsgpt["attachments"]
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Response translation (non-streaming)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def translate_response(
|
||||
conversation_id: str,
|
||||
answer: str,
|
||||
sources: Optional[List[Dict]],
|
||||
tool_calls: Optional[List[Dict]],
|
||||
thought: str,
|
||||
model_name: str,
|
||||
pending_tool_calls: Optional[List[Dict]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Translate DocsGPT response to chat completions format.
|
||||
|
||||
Args:
|
||||
conversation_id: The DocsGPT conversation ID.
|
||||
answer: The assistant's text response.
|
||||
sources: RAG retrieval sources.
|
||||
tool_calls: Completed tool call results.
|
||||
thought: Reasoning/thinking tokens.
|
||||
model_name: Model/agent identifier.
|
||||
pending_tool_calls: Pending client-side tool calls (if paused).
|
||||
|
||||
Returns:
|
||||
Dict in the standard chat completions response format.
|
||||
"""
|
||||
created = int(time.time())
|
||||
completion_id = f"chatcmpl-{conversation_id}" if conversation_id else f"chatcmpl-{created}"
|
||||
|
||||
# Build message
|
||||
message: Dict[str, Any] = {"role": "assistant"}
|
||||
|
||||
if pending_tool_calls:
|
||||
# Tool calls pending — return them for client execution
|
||||
message["content"] = None
|
||||
message["tool_calls"] = [
|
||||
{
|
||||
"id": tc.get("call_id", ""),
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": _get_client_tool_name(tc),
|
||||
"arguments": (
|
||||
json.dumps(tc["arguments"])
|
||||
if isinstance(tc.get("arguments"), dict)
|
||||
else tc.get("arguments", "{}")
|
||||
),
|
||||
},
|
||||
}
|
||||
for tc in pending_tool_calls
|
||||
]
|
||||
finish_reason = "tool_calls"
|
||||
else:
|
||||
message["content"] = answer
|
||||
if thought:
|
||||
message["reasoning_content"] = thought
|
||||
finish_reason = "stop"
|
||||
|
||||
result: Dict[str, Any] = {
|
||||
"id": completion_id,
|
||||
"object": "chat.completion",
|
||||
"created": created,
|
||||
"model": model_name,
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": message,
|
||||
"finish_reason": finish_reason,
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"prompt_tokens": 0,
|
||||
"completion_tokens": 0,
|
||||
"total_tokens": 0,
|
||||
},
|
||||
}
|
||||
|
||||
# DocsGPT extensions
|
||||
docsgpt: Dict[str, Any] = {}
|
||||
if conversation_id:
|
||||
docsgpt["conversation_id"] = conversation_id
|
||||
if sources:
|
||||
docsgpt["sources"] = sources
|
||||
if tool_calls:
|
||||
docsgpt["tool_calls"] = tool_calls
|
||||
if docsgpt:
|
||||
result["docsgpt"] = docsgpt
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Streaming event translation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_chunk(
|
||||
completion_id: str,
|
||||
model_name: str,
|
||||
delta: Dict[str, Any],
|
||||
finish_reason: Optional[str] = None,
|
||||
) -> str:
|
||||
"""Build a single SSE chunk in the standard streaming format."""
|
||||
chunk = {
|
||||
"id": completion_id,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": int(time.time()),
|
||||
"model": model_name,
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"delta": delta,
|
||||
"finish_reason": finish_reason,
|
||||
}
|
||||
],
|
||||
}
|
||||
return f"data: {json.dumps(chunk)}\n\n"
|
||||
|
||||
|
||||
def _make_docsgpt_chunk(data: Dict[str, Any]) -> str:
|
||||
"""Build a DocsGPT extension SSE chunk."""
|
||||
return f"data: {json.dumps({'docsgpt': data})}\n\n"
|
||||
|
||||
|
||||
def translate_stream_event(
|
||||
event_data: Dict[str, Any],
|
||||
completion_id: str,
|
||||
model_name: str,
|
||||
) -> List[str]:
|
||||
"""Translate a DocsGPT SSE event dict to standard streaming chunks.
|
||||
|
||||
May return 0, 1, or 2 chunks per input event. For example, a completed
|
||||
tool call produces both a docsgpt extension chunk and nothing on the
|
||||
standard side (since server-side tool calls aren't surfaced in standard
|
||||
format).
|
||||
|
||||
Args:
|
||||
event_data: Parsed DocsGPT event dict.
|
||||
completion_id: The completion ID for this response.
|
||||
model_name: Model/agent identifier.
|
||||
|
||||
Returns:
|
||||
List of SSE-formatted strings to send to the client.
|
||||
"""
|
||||
event_type = event_data.get("type")
|
||||
chunks: List[str] = []
|
||||
|
||||
if event_type == "answer":
|
||||
chunks.append(
|
||||
_make_chunk(completion_id, model_name, {"content": event_data.get("answer", "")})
|
||||
)
|
||||
|
||||
elif event_type == "thought":
|
||||
chunks.append(
|
||||
_make_chunk(
|
||||
completion_id, model_name,
|
||||
{"reasoning_content": event_data.get("thought", "")},
|
||||
)
|
||||
)
|
||||
|
||||
elif event_type == "source":
|
||||
chunks.append(
|
||||
_make_docsgpt_chunk({
|
||||
"type": "source",
|
||||
"sources": event_data.get("source", []),
|
||||
})
|
||||
)
|
||||
|
||||
elif event_type == "tool_call":
|
||||
tc_data = event_data.get("data", {})
|
||||
status = tc_data.get("status")
|
||||
|
||||
if status == "requires_client_execution":
|
||||
# Standard: stream as tool_calls delta
|
||||
args = tc_data.get("arguments", {})
|
||||
args_str = json.dumps(args) if isinstance(args, dict) else str(args)
|
||||
chunks.append(
|
||||
_make_chunk(completion_id, model_name, {
|
||||
"tool_calls": [{
|
||||
"index": 0,
|
||||
"id": tc_data.get("call_id", ""),
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": _get_client_tool_name(tc_data),
|
||||
"arguments": args_str,
|
||||
},
|
||||
}],
|
||||
})
|
||||
)
|
||||
elif status == "awaiting_approval":
|
||||
# Extension: approval needed
|
||||
chunks.append(_make_docsgpt_chunk({"type": "tool_call", "data": tc_data}))
|
||||
elif status in ("completed", "pending", "error", "denied", "skipped"):
|
||||
# Extension: tool call progress
|
||||
chunks.append(_make_docsgpt_chunk({"type": "tool_call", "data": tc_data}))
|
||||
|
||||
elif event_type == "tool_calls_pending":
|
||||
# Standard: finish_reason = tool_calls
|
||||
chunks.append(
|
||||
_make_chunk(completion_id, model_name, {}, finish_reason="tool_calls")
|
||||
)
|
||||
# Also emit as docsgpt extension
|
||||
chunks.append(
|
||||
_make_docsgpt_chunk({
|
||||
"type": "tool_calls_pending",
|
||||
"pending_tool_calls": event_data.get("data", {}).get("pending_tool_calls", []),
|
||||
})
|
||||
)
|
||||
|
||||
elif event_type == "end":
|
||||
chunks.append(
|
||||
_make_chunk(completion_id, model_name, {}, finish_reason="stop")
|
||||
)
|
||||
chunks.append("data: [DONE]\n\n")
|
||||
|
||||
elif event_type == "id":
|
||||
chunks.append(
|
||||
_make_docsgpt_chunk({
|
||||
"type": "id",
|
||||
"conversation_id": event_data.get("id", ""),
|
||||
})
|
||||
)
|
||||
|
||||
elif event_type == "error":
|
||||
# Emit as standard error (non-standard but widely supported)
|
||||
error_data = {
|
||||
"error": {
|
||||
"message": event_data.get("error", "An error occurred"),
|
||||
"type": "server_error",
|
||||
}
|
||||
}
|
||||
chunks.append(f"data: {json.dumps(error_data)}\n\n")
|
||||
|
||||
elif event_type == "structured_answer":
|
||||
chunks.append(
|
||||
_make_chunk(
|
||||
completion_id, model_name,
|
||||
{"content": event_data.get("answer", "")},
|
||||
)
|
||||
)
|
||||
|
||||
# Skip: tool_calls (redundant), research_plan, research_progress
|
||||
|
||||
return chunks
|
||||
@@ -1,10 +1,9 @@
|
||||
import logging
|
||||
import os
|
||||
import platform
|
||||
import uuid
|
||||
|
||||
import dotenv
|
||||
from flask import Flask, Response, jsonify, redirect, request
|
||||
from flask import Flask, jsonify, redirect, request
|
||||
from jose import jwt
|
||||
|
||||
from application.auth import handle_auth
|
||||
@@ -18,10 +17,8 @@ from application.api.answer import answer # noqa: E402
|
||||
from application.api.internal.routes import internal # noqa: E402
|
||||
from application.api.user.routes import user # noqa: E402
|
||||
from application.api.connector.routes import connector # noqa: E402
|
||||
from application.api.v1 import v1_bp # noqa: E402
|
||||
from application.celery_init import celery # noqa: E402
|
||||
from application.core.settings import settings # noqa: E402
|
||||
from application.storage.db.bootstrap import ensure_database_ready # noqa: E402
|
||||
from application.stt.upload_limits import ( # noqa: E402
|
||||
build_stt_file_size_limit_message,
|
||||
should_reject_stt_request,
|
||||
@@ -34,23 +31,11 @@ if platform.system() == "Windows":
|
||||
pathlib.PosixPath = pathlib.WindowsPath
|
||||
dotenv.load_dotenv()
|
||||
|
||||
# Self-bootstrap the user-data Postgres DB. Runs before any blueprint or
|
||||
# repository touches the engine, so the first request can't race the
|
||||
# schema being created. Gated by AUTO_CREATE_DB / AUTO_MIGRATE settings
|
||||
# (default ON for dev; disable in prod if schema is managed out-of-band).
|
||||
ensure_database_ready(
|
||||
settings.POSTGRES_URI,
|
||||
create_db=settings.AUTO_CREATE_DB,
|
||||
migrate=settings.AUTO_MIGRATE,
|
||||
logger=logging.getLogger("application.app"),
|
||||
)
|
||||
|
||||
app = Flask(__name__)
|
||||
app.register_blueprint(user)
|
||||
app.register_blueprint(answer)
|
||||
app.register_blueprint(internal)
|
||||
app.register_blueprint(connector)
|
||||
app.register_blueprint(v1_bp)
|
||||
app.config.update(
|
||||
UPLOAD_FOLDER="inputs",
|
||||
CELERY_BROKER_URL=settings.CELERY_BROKER_URL,
|
||||
@@ -133,12 +118,6 @@ def enforce_stt_request_size_limits():
|
||||
def authenticate_request():
|
||||
if request.method == "OPTIONS":
|
||||
return "", 200
|
||||
# OpenAI-compatible routes authenticate via opaque agent API keys in the
|
||||
# Authorization header, which the JWT decoder below would reject. Defer
|
||||
# auth to the route handlers (see application/api/v1/routes.py).
|
||||
if request.path.startswith("/v1/"):
|
||||
request.decoded_token = None
|
||||
return None
|
||||
decoded_token = handle_auth(request)
|
||||
if not decoded_token:
|
||||
request.decoded_token = None
|
||||
@@ -149,11 +128,12 @@ def authenticate_request():
|
||||
|
||||
|
||||
@app.after_request
|
||||
def after_request(response: Response) -> Response:
|
||||
"""Add CORS headers for the pure Flask development entrypoint."""
|
||||
response.headers["Access-Control-Allow-Origin"] = "*"
|
||||
response.headers["Access-Control-Allow-Headers"] = "Content-Type, Authorization"
|
||||
response.headers["Access-Control-Allow-Methods"] = "GET, POST, PUT, DELETE, OPTIONS"
|
||||
def after_request(response):
|
||||
response.headers.add("Access-Control-Allow-Origin", "*")
|
||||
response.headers.add("Access-Control-Allow-Headers", "Content-Type, Authorization")
|
||||
response.headers.add(
|
||||
"Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS"
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
|
||||
@@ -1,33 +0,0 @@
|
||||
"""ASGI entrypoint: Flask (WSGI) + FastMCP on the same process."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from a2wsgi import WSGIMiddleware
|
||||
from starlette.applications import Starlette
|
||||
from starlette.middleware import Middleware
|
||||
from starlette.middleware.cors import CORSMiddleware
|
||||
from starlette.routing import Mount
|
||||
|
||||
from application.app import app as flask_app
|
||||
from application.mcp_server import mcp
|
||||
|
||||
_WSGI_THREADPOOL = 32
|
||||
|
||||
mcp_app = mcp.http_app(path="/")
|
||||
|
||||
asgi_app = Starlette(
|
||||
routes=[
|
||||
Mount("/mcp", app=mcp_app),
|
||||
Mount("/", app=WSGIMiddleware(flask_app, workers=_WSGI_THREADPOOL)),
|
||||
],
|
||||
middleware=[
|
||||
Middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"],
|
||||
allow_headers=["Content-Type", "Authorization", "Mcp-Session-Id"],
|
||||
expose_headers=["Mcp-Session-Id"],
|
||||
),
|
||||
],
|
||||
lifespan=mcp_app.lifespan,
|
||||
)
|
||||
@@ -1,8 +1,6 @@
|
||||
import threading
|
||||
|
||||
from celery import Celery
|
||||
from application.core.settings import settings
|
||||
from celery.signals import setup_logging, worker_process_init, worker_ready
|
||||
from celery.signals import setup_logging
|
||||
|
||||
|
||||
def make_celery(app_name=__name__):
|
||||
@@ -22,44 +20,5 @@ def config_loggers(*args, **kwargs):
|
||||
setup_logging()
|
||||
|
||||
|
||||
@worker_process_init.connect
|
||||
def _dispose_db_engine_on_fork(*args, **kwargs):
|
||||
"""Dispose the SQLAlchemy engine pool in each forked Celery worker.
|
||||
|
||||
SQLAlchemy connection pools are not fork-safe: file descriptors shared
|
||||
between the parent and a forked worker will corrupt the pool. Disposing
|
||||
on ``worker_process_init`` gives every worker its own fresh pool on
|
||||
first use.
|
||||
|
||||
Imported lazily so Celery workers that don't touch Postgres (or where
|
||||
``POSTGRES_URI`` is unset) don't fail at startup.
|
||||
"""
|
||||
try:
|
||||
from application.storage.db.engine import dispose_engine
|
||||
except Exception:
|
||||
return
|
||||
dispose_engine()
|
||||
|
||||
|
||||
@worker_ready.connect
|
||||
def _run_version_check(*args, **kwargs):
|
||||
"""Kick off the anonymous version check on worker startup.
|
||||
|
||||
Runs in a daemon thread so a slow endpoint or bad DNS never holds
|
||||
up the worker becoming ready for tasks. The check itself is
|
||||
fail-silent (see ``application.updates.version_check.run_check``);
|
||||
this handler's only job is to launch it and get out of the way.
|
||||
|
||||
Import is lazy so the symbol resolution never fires at module
|
||||
import time — consistent with the ``_dispose_db_engine_on_fork``
|
||||
pattern above.
|
||||
"""
|
||||
try:
|
||||
from application.updates.version_check import run_check
|
||||
except Exception:
|
||||
return
|
||||
threading.Thread(target=run_check, name="version-check", daemon=True).start()
|
||||
|
||||
|
||||
celery = make_celery()
|
||||
celery.config_from_object("application.celeryconfig")
|
||||
|
||||
@@ -9,8 +9,3 @@ accept_content = ['json']
|
||||
|
||||
# Autodiscover tasks
|
||||
imports = ('application.api.user.tasks',)
|
||||
|
||||
beat_scheduler = "redbeat.RedBeatScheduler"
|
||||
redbeat_redis_url = broker_url
|
||||
redbeat_key_prefix = "redbeat:docsgpt:"
|
||||
redbeat_lock_timeout = 90
|
||||
|
||||
@@ -1,89 +0,0 @@
|
||||
"""Normalize user-supplied Postgres URIs for different drivers.
|
||||
|
||||
DocsGPT has two Postgres connection strings pointing at potentially
|
||||
different databases:
|
||||
|
||||
* ``POSTGRES_URI`` feeds SQLAlchemy, which needs the
|
||||
``postgresql+psycopg://`` dialect prefix to pick the psycopg v3 driver.
|
||||
* ``PGVECTOR_CONNECTION_STRING`` feeds ``psycopg.connect()`` directly
|
||||
(via libpq) in ``application/vectorstore/pgvector.py``. libpq only
|
||||
understands ``postgres://`` and ``postgresql://`` — the SQLAlchemy
|
||||
dialect prefix is an invalid URI from its point of view.
|
||||
|
||||
The two fields therefore need opposite normalization so operators don't
|
||||
have to know which driver a given field feeds. Each normalizer also
|
||||
silently upgrades the legacy ``postgresql+psycopg2://`` prefix since
|
||||
psycopg2 is no longer in the project.
|
||||
|
||||
This module is deliberately separate from ``application/core/settings.py``
|
||||
so the Settings class stays focused on field declarations, and the
|
||||
URI-rewriting logic can be unit-tested without triggering ``.env``
|
||||
file loading from importing Settings.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
def _rewrite_uri_prefixes(v, rewrites):
|
||||
"""Shared URI prefix rewriter used by both normalizers below.
|
||||
|
||||
Strips whitespace, returns ``None`` for empty / ``"none"`` values,
|
||||
applies the first matching rewrite, and passes unrecognised input
|
||||
through so downstream consumers (SQLAlchemy, libpq) can produce
|
||||
their own error messages rather than us silently eating a
|
||||
misconfiguration.
|
||||
"""
|
||||
if v is None:
|
||||
return None
|
||||
if not isinstance(v, str):
|
||||
return v
|
||||
v = v.strip()
|
||||
if not v or v.lower() == "none":
|
||||
return None
|
||||
for prefix, target in rewrites:
|
||||
if v.startswith(prefix):
|
||||
return target + v[len(prefix):]
|
||||
return v
|
||||
|
||||
|
||||
# POSTGRES_URI feeds SQLAlchemy, which needs a ``postgresql+psycopg://``
|
||||
# dialect prefix to select the psycopg v3 driver. Normalize the
|
||||
# operator-friendly forms TOWARD that dialect.
|
||||
_POSTGRES_URI_REWRITES = (
|
||||
("postgresql+psycopg2://", "postgresql+psycopg://"),
|
||||
("postgresql://", "postgresql+psycopg://"),
|
||||
("postgres://", "postgresql+psycopg://"),
|
||||
)
|
||||
|
||||
|
||||
# PGVECTOR_CONNECTION_STRING feeds ``psycopg.connect()`` directly in
|
||||
# application/vectorstore/pgvector.py — NOT SQLAlchemy. libpq only
|
||||
# understands ``postgres://`` and ``postgresql://``; the SQLAlchemy
|
||||
# dialect prefix is an invalid URI from libpq's point of view. Strip it
|
||||
# if the operator accidentally copied their POSTGRES_URI value here.
|
||||
_PGVECTOR_CONNECTION_STRING_REWRITES = (
|
||||
("postgresql+psycopg2://", "postgresql://"),
|
||||
("postgresql+psycopg://", "postgresql://"),
|
||||
)
|
||||
|
||||
|
||||
def normalize_postgres_uri(v):
|
||||
"""Normalize a user-supplied POSTGRES_URI to the SQLAlchemy psycopg3 form.
|
||||
|
||||
Accepts the forms operators naturally write (``postgres://``,
|
||||
``postgresql://``) and rewrites them to ``postgresql+psycopg://``.
|
||||
Unknown schemes pass through unchanged so SQLAlchemy can produce its
|
||||
own dialect-not-found error.
|
||||
"""
|
||||
return _rewrite_uri_prefixes(v, _POSTGRES_URI_REWRITES)
|
||||
|
||||
|
||||
def normalize_pgvector_connection_string(v):
|
||||
"""Normalize a user-supplied PGVECTOR_CONNECTION_STRING for libpq.
|
||||
|
||||
Strips the SQLAlchemy dialect prefix if the operator accidentally
|
||||
copied their POSTGRES_URI value here — libpq can't parse it.
|
||||
User-friendly forms (``postgres://``, ``postgresql://``) pass
|
||||
through unchanged since libpq accepts them natively.
|
||||
"""
|
||||
return _rewrite_uri_prefixes(v, _PGVECTOR_CONNECTION_STRING_REWRITES)
|
||||
@@ -1,45 +1,11 @@
|
||||
import logging
|
||||
import os
|
||||
from logging.config import dictConfig
|
||||
|
||||
|
||||
def _otlp_logs_enabled() -> bool:
|
||||
"""Return True when the user has opted in to OTLP log export.
|
||||
|
||||
Gated by the standard OTEL env vars so no project-specific knob is needed:
|
||||
set ``OTEL_LOGS_EXPORTER=otlp`` (and leave ``OTEL_SDK_DISABLED`` unset or
|
||||
false) to flip it on. When false, ``setup_logging`` keeps its original
|
||||
console-only behavior.
|
||||
"""
|
||||
exporter = os.getenv("OTEL_LOGS_EXPORTER", "").strip().lower()
|
||||
disabled = os.getenv("OTEL_SDK_DISABLED", "false").strip().lower() == "true"
|
||||
return exporter == "otlp" and not disabled
|
||||
|
||||
|
||||
def setup_logging() -> None:
|
||||
"""Configure the root logger with a stdout console handler.
|
||||
|
||||
When OTLP log export is enabled, ``opentelemetry-instrument`` attaches a
|
||||
``LoggingHandler`` to the root logger before this function runs. The
|
||||
``dictConfig`` call below replaces ``root.handlers`` with the console
|
||||
handler, which would silently drop the OTEL handler. To make OTLP log
|
||||
export work without forcing every contributor to opt in, snapshot the
|
||||
OTEL handlers up front and re-attach them after ``dictConfig``.
|
||||
"""
|
||||
preserved_handlers: list[logging.Handler] = []
|
||||
if _otlp_logs_enabled():
|
||||
preserved_handlers = [
|
||||
h
|
||||
for h in logging.getLogger().handlers
|
||||
if h.__class__.__module__.startswith("opentelemetry")
|
||||
]
|
||||
|
||||
def setup_logging():
|
||||
dictConfig({
|
||||
"version": 1,
|
||||
"disable_existing_loggers": False,
|
||||
"formatters": {
|
||||
"default": {
|
||||
"format": "[%(asctime)s] %(levelname)s in %(module)s: %(message)s",
|
||||
'version': 1,
|
||||
'formatters': {
|
||||
'default': {
|
||||
'format': '[%(asctime)s] %(levelname)s in %(module)s: %(message)s',
|
||||
}
|
||||
},
|
||||
"handlers": {
|
||||
@@ -49,14 +15,8 @@ def setup_logging() -> None:
|
||||
"formatter": "default",
|
||||
}
|
||||
},
|
||||
"root": {
|
||||
"level": "INFO",
|
||||
"handlers": ["console"],
|
||||
'root': {
|
||||
'level': 'INFO',
|
||||
'handlers': ['console'],
|
||||
},
|
||||
})
|
||||
|
||||
if preserved_handlers:
|
||||
root = logging.getLogger()
|
||||
for handler in preserved_handlers:
|
||||
if handler not in root.handlers:
|
||||
root.addHandler(handler)
|
||||
})
|
||||
224
application/core/model_configs.py
Normal file
224
application/core/model_configs.py
Normal file
@@ -0,0 +1,224 @@
|
||||
"""
|
||||
Model configurations for all supported LLM providers.
|
||||
"""
|
||||
|
||||
from application.core.model_settings import (
|
||||
AvailableModel,
|
||||
ModelCapabilities,
|
||||
ModelProvider,
|
||||
)
|
||||
|
||||
# Base image attachment types supported by most vision-capable LLMs
|
||||
IMAGE_ATTACHMENTS = [
|
||||
"image/png",
|
||||
"image/jpeg",
|
||||
"image/jpg",
|
||||
"image/webp",
|
||||
"image/gif",
|
||||
]
|
||||
|
||||
# PDF excluded: most OpenAI-compatible endpoints don't support native PDF uploads.
|
||||
# When excluded, PDFs are synthetically processed by converting pages to images.
|
||||
OPENAI_ATTACHMENTS = IMAGE_ATTACHMENTS
|
||||
|
||||
GOOGLE_ATTACHMENTS = ["application/pdf"] + IMAGE_ATTACHMENTS
|
||||
|
||||
ANTHROPIC_ATTACHMENTS = IMAGE_ATTACHMENTS
|
||||
|
||||
OPENROUTER_ATTACHMENTS = IMAGE_ATTACHMENTS
|
||||
|
||||
|
||||
OPENAI_MODELS = [
|
||||
AvailableModel(
|
||||
id="gpt-5.1",
|
||||
provider=ModelProvider.OPENAI,
|
||||
display_name="GPT-5.1",
|
||||
description="Flagship model with enhanced reasoning, coding, and agentic capabilities",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=True,
|
||||
supports_structured_output=True,
|
||||
supported_attachment_types=OPENAI_ATTACHMENTS,
|
||||
context_window=200000,
|
||||
),
|
||||
),
|
||||
AvailableModel(
|
||||
id="gpt-5-mini",
|
||||
provider=ModelProvider.OPENAI,
|
||||
display_name="GPT-5 Mini",
|
||||
description="Faster, cost-effective variant of GPT-5.1",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=True,
|
||||
supports_structured_output=True,
|
||||
supported_attachment_types=OPENAI_ATTACHMENTS,
|
||||
context_window=200000,
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
ANTHROPIC_MODELS = [
|
||||
AvailableModel(
|
||||
id="claude-3-5-sonnet-20241022",
|
||||
provider=ModelProvider.ANTHROPIC,
|
||||
display_name="Claude 3.5 Sonnet (Latest)",
|
||||
description="Latest Claude 3.5 Sonnet with enhanced capabilities",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=True,
|
||||
supported_attachment_types=ANTHROPIC_ATTACHMENTS,
|
||||
context_window=200000,
|
||||
),
|
||||
),
|
||||
AvailableModel(
|
||||
id="claude-3-5-sonnet",
|
||||
provider=ModelProvider.ANTHROPIC,
|
||||
display_name="Claude 3.5 Sonnet",
|
||||
description="Balanced performance and capability",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=True,
|
||||
supported_attachment_types=ANTHROPIC_ATTACHMENTS,
|
||||
context_window=200000,
|
||||
),
|
||||
),
|
||||
AvailableModel(
|
||||
id="claude-3-opus",
|
||||
provider=ModelProvider.ANTHROPIC,
|
||||
display_name="Claude 3 Opus",
|
||||
description="Most capable Claude model",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=True,
|
||||
supported_attachment_types=ANTHROPIC_ATTACHMENTS,
|
||||
context_window=200000,
|
||||
),
|
||||
),
|
||||
AvailableModel(
|
||||
id="claude-3-haiku",
|
||||
provider=ModelProvider.ANTHROPIC,
|
||||
display_name="Claude 3 Haiku",
|
||||
description="Fastest Claude model",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=True,
|
||||
supported_attachment_types=ANTHROPIC_ATTACHMENTS,
|
||||
context_window=200000,
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
GOOGLE_MODELS = [
|
||||
AvailableModel(
|
||||
id="gemini-flash-latest",
|
||||
provider=ModelProvider.GOOGLE,
|
||||
display_name="Gemini Flash (Latest)",
|
||||
description="Latest experimental Gemini model",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=True,
|
||||
supports_structured_output=True,
|
||||
supported_attachment_types=GOOGLE_ATTACHMENTS,
|
||||
context_window=int(1e6),
|
||||
),
|
||||
),
|
||||
AvailableModel(
|
||||
id="gemini-flash-lite-latest",
|
||||
provider=ModelProvider.GOOGLE,
|
||||
display_name="Gemini Flash Lite (Latest)",
|
||||
description="Fast with huge context window",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=True,
|
||||
supports_structured_output=True,
|
||||
supported_attachment_types=GOOGLE_ATTACHMENTS,
|
||||
context_window=int(1e6),
|
||||
),
|
||||
),
|
||||
AvailableModel(
|
||||
id="gemini-3-pro-preview",
|
||||
provider=ModelProvider.GOOGLE,
|
||||
display_name="Gemini 3 Pro",
|
||||
description="Most capable Gemini model",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=True,
|
||||
supports_structured_output=True,
|
||||
supported_attachment_types=GOOGLE_ATTACHMENTS,
|
||||
context_window=2000000,
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
GROQ_MODELS = [
|
||||
AvailableModel(
|
||||
id="llama-3.3-70b-versatile",
|
||||
provider=ModelProvider.GROQ,
|
||||
display_name="Llama 3.3 70B",
|
||||
description="Latest Llama model with high-speed inference",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=True,
|
||||
context_window=128000,
|
||||
),
|
||||
),
|
||||
AvailableModel(
|
||||
id="openai/gpt-oss-120b",
|
||||
provider=ModelProvider.GROQ,
|
||||
display_name="GPT-OSS 120B",
|
||||
description="Open-source GPT model optimized for speed",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=True,
|
||||
context_window=128000,
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
OPENROUTER_MODELS = [
|
||||
AvailableModel(
|
||||
id="qwen/qwen3-coder:free",
|
||||
provider=ModelProvider.OPENROUTER,
|
||||
display_name="Qwen 3 Coder",
|
||||
description="Latest Qwen model with high-speed inference",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=True,
|
||||
context_window=128000,
|
||||
supported_attachment_types=OPENROUTER_ATTACHMENTS
|
||||
),
|
||||
),
|
||||
AvailableModel(
|
||||
id="google/gemma-3-27b-it:free",
|
||||
provider=ModelProvider.OPENROUTER,
|
||||
display_name="Gemma 3 27B",
|
||||
description="Latest Gemma model with high-speed inference",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=True,
|
||||
context_window=128000,
|
||||
supported_attachment_types=OPENROUTER_ATTACHMENTS
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
AZURE_OPENAI_MODELS = [
|
||||
AvailableModel(
|
||||
id="azure-gpt-4",
|
||||
provider=ModelProvider.AZURE_OPENAI,
|
||||
display_name="Azure OpenAI GPT-4",
|
||||
description="Azure-hosted GPT model",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=True,
|
||||
supports_structured_output=True,
|
||||
supported_attachment_types=OPENAI_ATTACHMENTS,
|
||||
context_window=8192,
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def create_custom_openai_model(model_name: str, base_url: str) -> AvailableModel:
|
||||
"""Create a custom OpenAI-compatible model (e.g., LM Studio, Ollama)."""
|
||||
return AvailableModel(
|
||||
id=model_name,
|
||||
provider=ModelProvider.OPENAI,
|
||||
display_name=model_name,
|
||||
description=f"Custom OpenAI-compatible model at {base_url}",
|
||||
base_url=base_url,
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=True,
|
||||
supported_attachment_types=OPENAI_ATTACHMENTS,
|
||||
),
|
||||
)
|
||||
@@ -1,164 +0,0 @@
|
||||
"""Layered model registry.
|
||||
|
||||
Loads model catalogs from YAML files (built-in + operator-supplied),
|
||||
groups them by provider name, then for each registered provider plugin
|
||||
calls ``get_models`` to produce the final per-provider model list.
|
||||
|
||||
The ``user_id`` parameter on lookup methods is reserved for the future
|
||||
end-user BYOM (per-user model records in Postgres). It is currently
|
||||
ignored — defaulted to ``None`` everywhere — so call sites can be
|
||||
threaded through without a wide refactor when BYOM lands.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from application.core.model_settings import AvailableModel
|
||||
from application.core.model_yaml import (
|
||||
BUILTIN_MODELS_DIR,
|
||||
ProviderCatalog,
|
||||
load_model_yamls,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ModelRegistry:
|
||||
"""Singleton registry of available models."""
|
||||
|
||||
_instance: Optional["ModelRegistry"] = None
|
||||
_initialized: bool = False
|
||||
|
||||
def __new__(cls):
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
if not ModelRegistry._initialized:
|
||||
self.models: Dict[str, AvailableModel] = {}
|
||||
self.default_model_id: Optional[str] = None
|
||||
self._load_models()
|
||||
ModelRegistry._initialized = True
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls) -> "ModelRegistry":
|
||||
return cls()
|
||||
|
||||
@classmethod
|
||||
def reset(cls) -> None:
|
||||
"""Clear the singleton. Intended for test fixtures."""
|
||||
cls._instance = None
|
||||
cls._initialized = False
|
||||
|
||||
def _load_models(self) -> None:
|
||||
from pathlib import Path
|
||||
|
||||
from application.core.settings import settings
|
||||
from application.llm.providers import ALL_PROVIDERS
|
||||
|
||||
directories = [BUILTIN_MODELS_DIR]
|
||||
operator_dir = getattr(settings, "MODELS_CONFIG_DIR", None)
|
||||
if operator_dir:
|
||||
op_path = Path(operator_dir)
|
||||
if not op_path.exists():
|
||||
logger.warning(
|
||||
"MODELS_CONFIG_DIR=%s does not exist; no operator "
|
||||
"model YAMLs will be loaded.",
|
||||
operator_dir,
|
||||
)
|
||||
elif not op_path.is_dir():
|
||||
logger.warning(
|
||||
"MODELS_CONFIG_DIR=%s is not a directory; no operator "
|
||||
"model YAMLs will be loaded.",
|
||||
operator_dir,
|
||||
)
|
||||
else:
|
||||
directories.append(op_path)
|
||||
|
||||
catalogs = load_model_yamls(directories)
|
||||
|
||||
# Validate every catalog targets a known plugin before doing any
|
||||
# registry work, so an unknown provider name in YAML aborts boot
|
||||
# with a clear error.
|
||||
plugin_names = {p.name for p in ALL_PROVIDERS}
|
||||
for c in catalogs:
|
||||
if c.provider not in plugin_names:
|
||||
raise ValueError(
|
||||
f"{c.source_path}: YAML declares unknown provider "
|
||||
f"{c.provider!r}; no Provider plugin is registered "
|
||||
f"under that name. Known: {sorted(plugin_names)}"
|
||||
)
|
||||
|
||||
catalogs_by_provider: Dict[str, List[ProviderCatalog]] = defaultdict(list)
|
||||
for c in catalogs:
|
||||
catalogs_by_provider[c.provider].append(c)
|
||||
|
||||
self.models.clear()
|
||||
for provider in ALL_PROVIDERS:
|
||||
if not provider.is_enabled(settings):
|
||||
continue
|
||||
for model in provider.get_models(
|
||||
settings, catalogs_by_provider.get(provider.name, [])
|
||||
):
|
||||
self.models[model.id] = model
|
||||
|
||||
self.default_model_id = self._resolve_default(settings)
|
||||
|
||||
logger.info(
|
||||
"ModelRegistry loaded %d models, default: %s",
|
||||
len(self.models),
|
||||
self.default_model_id,
|
||||
)
|
||||
|
||||
def _resolve_default(self, settings) -> Optional[str]:
|
||||
if settings.LLM_NAME:
|
||||
for name in self._parse_model_names(settings.LLM_NAME):
|
||||
if name in self.models:
|
||||
return name
|
||||
if settings.LLM_NAME in self.models:
|
||||
return settings.LLM_NAME
|
||||
|
||||
if settings.LLM_PROVIDER and settings.API_KEY:
|
||||
for model_id, model in self.models.items():
|
||||
if model.provider.value == settings.LLM_PROVIDER:
|
||||
return model_id
|
||||
|
||||
if self.models:
|
||||
return next(iter(self.models.keys()))
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _parse_model_names(llm_name: str) -> List[str]:
|
||||
if not llm_name:
|
||||
return []
|
||||
return [name.strip() for name in llm_name.split(",") if name.strip()]
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Lookup API. ``user_id`` is reserved for the future BYOM and
|
||||
# is ignored today — but threading it through every call site now
|
||||
# means BYOM doesn't require a wide refactor when we build it.
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def get_model(
|
||||
self, model_id: str, user_id: Optional[str] = None
|
||||
) -> Optional[AvailableModel]:
|
||||
return self.models.get(model_id)
|
||||
|
||||
def get_all_models(
|
||||
self, user_id: Optional[str] = None
|
||||
) -> List[AvailableModel]:
|
||||
return list(self.models.values())
|
||||
|
||||
def get_enabled_models(
|
||||
self, user_id: Optional[str] = None
|
||||
) -> List[AvailableModel]:
|
||||
return [m for m in self.models.values() if m.enabled]
|
||||
|
||||
def model_exists(
|
||||
self, model_id: str, user_id: Optional[str] = None
|
||||
) -> bool:
|
||||
return model_id in self.models
|
||||
@@ -5,16 +5,9 @@ from typing import Dict, List, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Re-exported here so existing call sites (and tests) that do
|
||||
# ``from application.core.model_settings import ModelRegistry`` keep
|
||||
# working. The implementation lives in ``application/core/model_registry.py``.
|
||||
# Imported lazily inside ``__getattr__`` to avoid an import cycle with
|
||||
# ``model_yaml`` → ``model_settings`` (this file).
|
||||
|
||||
|
||||
class ModelProvider(str, Enum):
|
||||
OPENAI = "openai"
|
||||
OPENAI_COMPATIBLE = "openai_compatible"
|
||||
OPENROUTER = "openrouter"
|
||||
AZURE_OPENAI = "azure_openai"
|
||||
ANTHROPIC = "anthropic"
|
||||
@@ -48,20 +41,11 @@ class AvailableModel:
|
||||
capabilities: ModelCapabilities = field(default_factory=ModelCapabilities)
|
||||
enabled: bool = True
|
||||
base_url: Optional[str] = None
|
||||
# User-facing label distinct from the dispatch ``provider``. Used by
|
||||
# openai_compatible YAMLs so a Mistral model shows "mistral" in the
|
||||
# API response while still routing through the OpenAI wire format.
|
||||
display_provider: Optional[str] = None
|
||||
# Per-record API key. Operator YAMLs leave this None; populated for
|
||||
# openai_compatible models (resolved from the YAML's ``api_key_env``)
|
||||
# and reserved for the future end-user BYOM phase. Never serialized
|
||||
# into to_dict().
|
||||
api_key: Optional[str] = field(default=None, repr=False, compare=False)
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
result = {
|
||||
"id": self.id,
|
||||
"provider": self.display_provider or self.provider.value,
|
||||
"provider": self.provider.value,
|
||||
"display_name": self.display_name,
|
||||
"description": self.description,
|
||||
"supported_attachment_types": self.capabilities.supported_attachment_types,
|
||||
@@ -76,14 +60,236 @@ class AvailableModel:
|
||||
return result
|
||||
|
||||
|
||||
def __getattr__(name):
|
||||
"""Lazy re-export of ``ModelRegistry`` from ``model_registry.py``.
|
||||
class ModelRegistry:
|
||||
_instance = None
|
||||
_initialized = False
|
||||
|
||||
Done lazily to avoid an import cycle: ``model_registry`` imports
|
||||
``model_yaml`` which imports the dataclasses from this file.
|
||||
"""
|
||||
if name == "ModelRegistry":
|
||||
from application.core.model_registry import ModelRegistry as _MR
|
||||
def __new__(cls):
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
return _MR
|
||||
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||
def __init__(self):
|
||||
if not ModelRegistry._initialized:
|
||||
self.models: Dict[str, AvailableModel] = {}
|
||||
self.default_model_id: Optional[str] = None
|
||||
self._load_models()
|
||||
ModelRegistry._initialized = True
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls) -> "ModelRegistry":
|
||||
return cls()
|
||||
|
||||
def _load_models(self):
|
||||
from application.core.settings import settings
|
||||
|
||||
self.models.clear()
|
||||
|
||||
# Skip DocsGPT model if using custom OpenAI-compatible endpoint
|
||||
if not settings.OPENAI_BASE_URL:
|
||||
self._add_docsgpt_models(settings)
|
||||
if (
|
||||
settings.OPENAI_API_KEY
|
||||
or (settings.LLM_PROVIDER == "openai" and settings.API_KEY)
|
||||
or settings.OPENAI_BASE_URL
|
||||
):
|
||||
self._add_openai_models(settings)
|
||||
if settings.OPENAI_API_BASE or (
|
||||
settings.LLM_PROVIDER == "azure_openai" and settings.API_KEY
|
||||
):
|
||||
self._add_azure_openai_models(settings)
|
||||
if settings.ANTHROPIC_API_KEY or (
|
||||
settings.LLM_PROVIDER == "anthropic" and settings.API_KEY
|
||||
):
|
||||
self._add_anthropic_models(settings)
|
||||
if settings.GOOGLE_API_KEY or (
|
||||
settings.LLM_PROVIDER == "google" and settings.API_KEY
|
||||
):
|
||||
self._add_google_models(settings)
|
||||
if settings.GROQ_API_KEY or (
|
||||
settings.LLM_PROVIDER == "groq" and settings.API_KEY
|
||||
):
|
||||
self._add_groq_models(settings)
|
||||
if settings.OPEN_ROUTER_API_KEY or (
|
||||
settings.LLM_PROVIDER == "openrouter" and settings.API_KEY
|
||||
):
|
||||
self._add_openrouter_models(settings)
|
||||
if settings.HUGGINGFACE_API_KEY or (
|
||||
settings.LLM_PROVIDER == "huggingface" and settings.API_KEY
|
||||
):
|
||||
self._add_huggingface_models(settings)
|
||||
# Default model selection
|
||||
if settings.LLM_NAME:
|
||||
# Parse LLM_NAME (may be comma-separated)
|
||||
model_names = self._parse_model_names(settings.LLM_NAME)
|
||||
# First model in the list becomes default
|
||||
for model_name in model_names:
|
||||
if model_name in self.models:
|
||||
self.default_model_id = model_name
|
||||
break
|
||||
# Backward compat: try exact match if no parsed model found
|
||||
if not self.default_model_id and settings.LLM_NAME in self.models:
|
||||
self.default_model_id = settings.LLM_NAME
|
||||
|
||||
if not self.default_model_id:
|
||||
if settings.LLM_PROVIDER and settings.API_KEY:
|
||||
for model_id, model in self.models.items():
|
||||
if model.provider.value == settings.LLM_PROVIDER:
|
||||
self.default_model_id = model_id
|
||||
break
|
||||
|
||||
if not self.default_model_id and self.models:
|
||||
self.default_model_id = next(iter(self.models.keys()))
|
||||
logger.info(
|
||||
f"ModelRegistry loaded {len(self.models)} models, default: {self.default_model_id}"
|
||||
)
|
||||
|
||||
def _add_openai_models(self, settings):
|
||||
from application.core.model_configs import (
|
||||
OPENAI_MODELS,
|
||||
create_custom_openai_model,
|
||||
)
|
||||
|
||||
# Check if using local OpenAI-compatible endpoint (Ollama, LM Studio, etc.)
|
||||
using_local_endpoint = bool(
|
||||
settings.OPENAI_BASE_URL and settings.OPENAI_BASE_URL.strip()
|
||||
)
|
||||
|
||||
if using_local_endpoint:
|
||||
# When OPENAI_BASE_URL is set, ONLY register custom models from LLM_NAME
|
||||
# Do NOT add standard OpenAI models (gpt-5.1, etc.)
|
||||
if settings.LLM_NAME:
|
||||
model_names = self._parse_model_names(settings.LLM_NAME)
|
||||
for model_name in model_names:
|
||||
custom_model = create_custom_openai_model(
|
||||
model_name, settings.OPENAI_BASE_URL
|
||||
)
|
||||
self.models[model_name] = custom_model
|
||||
logger.info(
|
||||
f"Registered custom OpenAI model: {model_name} at {settings.OPENAI_BASE_URL}"
|
||||
)
|
||||
else:
|
||||
# Standard OpenAI API usage - add standard models if API key is valid
|
||||
if settings.OPENAI_API_KEY:
|
||||
for model in OPENAI_MODELS:
|
||||
self.models[model.id] = model
|
||||
|
||||
def _add_azure_openai_models(self, settings):
|
||||
from application.core.model_configs import AZURE_OPENAI_MODELS
|
||||
|
||||
if settings.LLM_PROVIDER == "azure_openai" and settings.LLM_NAME:
|
||||
for model in AZURE_OPENAI_MODELS:
|
||||
if model.id == settings.LLM_NAME:
|
||||
self.models[model.id] = model
|
||||
return
|
||||
for model in AZURE_OPENAI_MODELS:
|
||||
self.models[model.id] = model
|
||||
|
||||
def _add_anthropic_models(self, settings):
|
||||
from application.core.model_configs import ANTHROPIC_MODELS
|
||||
|
||||
if settings.ANTHROPIC_API_KEY:
|
||||
for model in ANTHROPIC_MODELS:
|
||||
self.models[model.id] = model
|
||||
return
|
||||
if settings.LLM_PROVIDER == "anthropic" and settings.LLM_NAME:
|
||||
for model in ANTHROPIC_MODELS:
|
||||
if model.id == settings.LLM_NAME:
|
||||
self.models[model.id] = model
|
||||
return
|
||||
for model in ANTHROPIC_MODELS:
|
||||
self.models[model.id] = model
|
||||
|
||||
def _add_google_models(self, settings):
|
||||
from application.core.model_configs import GOOGLE_MODELS
|
||||
|
||||
if settings.GOOGLE_API_KEY:
|
||||
for model in GOOGLE_MODELS:
|
||||
self.models[model.id] = model
|
||||
return
|
||||
if settings.LLM_PROVIDER == "google" and settings.LLM_NAME:
|
||||
for model in GOOGLE_MODELS:
|
||||
if model.id == settings.LLM_NAME:
|
||||
self.models[model.id] = model
|
||||
return
|
||||
for model in GOOGLE_MODELS:
|
||||
self.models[model.id] = model
|
||||
|
||||
def _add_groq_models(self, settings):
|
||||
from application.core.model_configs import GROQ_MODELS
|
||||
|
||||
if settings.GROQ_API_KEY:
|
||||
for model in GROQ_MODELS:
|
||||
self.models[model.id] = model
|
||||
return
|
||||
if settings.LLM_PROVIDER == "groq" and settings.LLM_NAME:
|
||||
for model in GROQ_MODELS:
|
||||
if model.id == settings.LLM_NAME:
|
||||
self.models[model.id] = model
|
||||
return
|
||||
for model in GROQ_MODELS:
|
||||
self.models[model.id] = model
|
||||
|
||||
def _add_openrouter_models(self, settings):
|
||||
from application.core.model_configs import OPENROUTER_MODELS
|
||||
|
||||
if settings.OPEN_ROUTER_API_KEY:
|
||||
for model in OPENROUTER_MODELS:
|
||||
self.models[model.id] = model
|
||||
return
|
||||
if settings.LLM_PROVIDER == "openrouter" and settings.LLM_NAME:
|
||||
for model in OPENROUTER_MODELS:
|
||||
if model.id == settings.LLM_NAME:
|
||||
self.models[model.id] = model
|
||||
return
|
||||
for model in OPENROUTER_MODELS:
|
||||
self.models[model.id] = model
|
||||
|
||||
def _add_docsgpt_models(self, settings):
|
||||
model_id = "docsgpt-local"
|
||||
model = AvailableModel(
|
||||
id=model_id,
|
||||
provider=ModelProvider.DOCSGPT,
|
||||
display_name="DocsGPT Model",
|
||||
description="Local model",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=False,
|
||||
supported_attachment_types=[],
|
||||
),
|
||||
)
|
||||
self.models[model_id] = model
|
||||
|
||||
def _add_huggingface_models(self, settings):
|
||||
model_id = "huggingface-local"
|
||||
model = AvailableModel(
|
||||
id=model_id,
|
||||
provider=ModelProvider.HUGGINGFACE,
|
||||
display_name="Hugging Face Model",
|
||||
description="Local Hugging Face model",
|
||||
capabilities=ModelCapabilities(
|
||||
supports_tools=False,
|
||||
supported_attachment_types=[],
|
||||
),
|
||||
)
|
||||
self.models[model_id] = model
|
||||
|
||||
def _parse_model_names(self, llm_name: str) -> List[str]:
|
||||
"""
|
||||
Parse LLM_NAME which may contain comma-separated model names.
|
||||
E.g., 'deepseek-r1:1.5b,gemma:2b' -> ['deepseek-r1:1.5b', 'gemma:2b']
|
||||
"""
|
||||
if not llm_name:
|
||||
return []
|
||||
return [name.strip() for name in llm_name.split(",") if name.strip()]
|
||||
|
||||
def get_model(self, model_id: str) -> Optional[AvailableModel]:
|
||||
return self.models.get(model_id)
|
||||
|
||||
def get_all_models(self) -> List[AvailableModel]:
|
||||
return list(self.models.values())
|
||||
|
||||
def get_enabled_models(self) -> List[AvailableModel]:
|
||||
return [m for m in self.models.values() if m.enabled]
|
||||
|
||||
def model_exists(self, model_id: str) -> bool:
|
||||
return model_id in self.models
|
||||
|
||||
@@ -1,22 +1,27 @@
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from application.core.model_registry import ModelRegistry
|
||||
from application.core.model_settings import ModelRegistry
|
||||
|
||||
|
||||
def get_api_key_for_provider(provider: str) -> Optional[str]:
|
||||
"""Get the appropriate API key for a provider.
|
||||
|
||||
Delegates to the provider plugin's ``get_api_key``. Falls back to the
|
||||
generic ``settings.API_KEY`` for unknown providers.
|
||||
"""
|
||||
"""Get the appropriate API key for a provider"""
|
||||
from application.core.settings import settings
|
||||
from application.llm.providers import PROVIDERS_BY_NAME
|
||||
|
||||
plugin = PROVIDERS_BY_NAME.get(provider)
|
||||
if plugin is not None:
|
||||
key = plugin.get_api_key(settings)
|
||||
if key:
|
||||
return key
|
||||
provider_key_map = {
|
||||
"openai": settings.OPENAI_API_KEY,
|
||||
"openrouter": settings.OPEN_ROUTER_API_KEY,
|
||||
"anthropic": settings.ANTHROPIC_API_KEY,
|
||||
"google": settings.GOOGLE_API_KEY,
|
||||
"groq": settings.GROQ_API_KEY,
|
||||
"huggingface": settings.HUGGINGFACE_API_KEY,
|
||||
"azure_openai": settings.API_KEY,
|
||||
"docsgpt": None,
|
||||
"llama.cpp": None,
|
||||
}
|
||||
|
||||
provider_key = provider_key_map.get(provider)
|
||||
if provider_key:
|
||||
return provider_key
|
||||
return settings.API_KEY
|
||||
|
||||
|
||||
@@ -85,21 +90,3 @@ def get_base_url_for_model(model_id: str) -> Optional[str]:
|
||||
if model:
|
||||
return model.base_url
|
||||
return None
|
||||
|
||||
|
||||
def get_api_key_for_model(model_id: str) -> Optional[str]:
|
||||
"""
|
||||
Resolve the API key to use when invoking ``model_id``.
|
||||
|
||||
Priority:
|
||||
1. The model record's own ``api_key`` (reserved for future end-user
|
||||
BYOM where credentials travel with the record).
|
||||
2. The provider plugin's settings-based key.
|
||||
"""
|
||||
registry = ModelRegistry.get_instance()
|
||||
model = registry.get_model(model_id)
|
||||
if model is not None and model.api_key:
|
||||
return model.api_key
|
||||
if model is not None:
|
||||
return get_api_key_for_provider(model.provider.value)
|
||||
return None
|
||||
|
||||
@@ -1,325 +0,0 @@
|
||||
"""YAML loader for model catalog files under ``application/core/models/``.
|
||||
|
||||
Each ``*.yaml`` file declares one provider's static model catalog. Files
|
||||
are validated with Pydantic at load time; any parse, schema, or alias
|
||||
error aborts startup with the offending file path in the message.
|
||||
|
||||
For most providers, one YAML maps to one catalog. The
|
||||
``openai_compatible`` provider is special: each YAML file represents a
|
||||
distinct logical endpoint (Mistral, Together, Ollama, ...) with its own
|
||||
``api_key_env`` and ``base_url``. The loader returns a flat list so the
|
||||
registry can distinguish multiple files with the same ``provider:`` value.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Sequence
|
||||
|
||||
import yaml
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
from application.core.model_settings import (
|
||||
AvailableModel,
|
||||
ModelCapabilities,
|
||||
ModelProvider,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
BUILTIN_MODELS_DIR = Path(__file__).parent / "models"
|
||||
DEFAULTS_FILENAME = "_defaults.yaml"
|
||||
|
||||
|
||||
class _DefaultsFile(BaseModel):
|
||||
"""Schema for ``_defaults.yaml``. Currently just attachment aliases."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
attachment_aliases: Dict[str, List[str]] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class _CapabilityFields(BaseModel):
|
||||
"""Capability fields shared between provider ``defaults:`` and per-model overrides.
|
||||
|
||||
All fields are optional so a per-model override can selectively replace
|
||||
a single field from the provider-level defaults.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
supports_tools: Optional[bool] = None
|
||||
supports_structured_output: Optional[bool] = None
|
||||
supports_streaming: Optional[bool] = None
|
||||
attachments: Optional[List[str]] = None
|
||||
context_window: Optional[int] = None
|
||||
input_cost_per_token: Optional[float] = None
|
||||
output_cost_per_token: Optional[float] = None
|
||||
|
||||
|
||||
class _ModelEntry(_CapabilityFields):
|
||||
"""Schema for one model row inside a YAML's ``models:`` list."""
|
||||
|
||||
id: str
|
||||
display_name: Optional[str] = None
|
||||
description: str = ""
|
||||
enabled: bool = True
|
||||
base_url: Optional[str] = None
|
||||
aliases: List[str] = Field(default_factory=list)
|
||||
|
||||
@field_validator("id")
|
||||
@classmethod
|
||||
def _id_nonempty(cls, v: str) -> str:
|
||||
if not v or not v.strip():
|
||||
raise ValueError("model id must be a non-empty string")
|
||||
return v
|
||||
|
||||
|
||||
class _ProviderFile(BaseModel):
|
||||
"""Schema for one ``<provider>.yaml`` catalog file."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
provider: str
|
||||
defaults: _CapabilityFields = Field(default_factory=_CapabilityFields)
|
||||
models: List[_ModelEntry] = Field(default_factory=list)
|
||||
# openai_compatible metadata. Optional for other providers.
|
||||
display_provider: Optional[str] = None
|
||||
api_key_env: Optional[str] = None
|
||||
base_url: Optional[str] = None
|
||||
|
||||
|
||||
class ProviderCatalog(BaseModel):
|
||||
"""One YAML file's parsed contents, ready for the registry.
|
||||
|
||||
For most providers, multiple catalogs with the same ``provider`` get
|
||||
merged later by the registry. The ``openai_compatible`` provider is
|
||||
the exception: each catalog is treated as a distinct endpoint, with
|
||||
its own ``api_key_env`` and ``base_url``.
|
||||
"""
|
||||
|
||||
provider: str
|
||||
models: List[AvailableModel]
|
||||
source_path: Optional[Path] = None
|
||||
display_provider: Optional[str] = None
|
||||
api_key_env: Optional[str] = None
|
||||
base_url: Optional[str] = None
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
|
||||
class ModelYAMLError(ValueError):
|
||||
"""Raised when a model YAML fails parsing, schema, or alias validation."""
|
||||
|
||||
|
||||
def _expand_attachments(
|
||||
attachments: Sequence[str], aliases: Dict[str, List[str]], source: str
|
||||
) -> List[str]:
|
||||
"""Resolve attachment shorthands (``image``, ``pdf``) to MIME types.
|
||||
|
||||
Raw MIME-typed entries (containing ``/``) pass through unchanged.
|
||||
Unknown aliases raise ``ModelYAMLError``.
|
||||
"""
|
||||
expanded: List[str] = []
|
||||
seen: set = set()
|
||||
for entry in attachments:
|
||||
if "/" in entry:
|
||||
if entry not in seen:
|
||||
expanded.append(entry)
|
||||
seen.add(entry)
|
||||
continue
|
||||
if entry not in aliases:
|
||||
valid = ", ".join(sorted(aliases.keys())) or "<none defined>"
|
||||
raise ModelYAMLError(
|
||||
f"{source}: unknown attachment alias '{entry}'. "
|
||||
f"Valid aliases: {valid}. "
|
||||
"(Or use a raw MIME type like 'image/png'.)"
|
||||
)
|
||||
for mime in aliases[entry]:
|
||||
if mime not in seen:
|
||||
expanded.append(mime)
|
||||
seen.add(mime)
|
||||
return expanded
|
||||
|
||||
|
||||
def _load_defaults(directory: Path) -> Dict[str, List[str]]:
|
||||
"""Load ``_defaults.yaml`` from ``directory`` if it exists."""
|
||||
path = directory / DEFAULTS_FILENAME
|
||||
if not path.exists():
|
||||
return {}
|
||||
try:
|
||||
raw = yaml.safe_load(path.read_text(encoding="utf-8")) or {}
|
||||
except yaml.YAMLError as e:
|
||||
raise ModelYAMLError(f"{path}: invalid YAML: {e}") from e
|
||||
try:
|
||||
parsed = _DefaultsFile.model_validate(raw)
|
||||
except Exception as e:
|
||||
raise ModelYAMLError(f"{path}: schema error: {e}") from e
|
||||
return parsed.attachment_aliases
|
||||
|
||||
|
||||
def _resolve_provider_enum(name: str, source: Path) -> ModelProvider:
|
||||
try:
|
||||
return ModelProvider(name)
|
||||
except ValueError as e:
|
||||
valid = ", ".join(p.value for p in ModelProvider)
|
||||
raise ModelYAMLError(
|
||||
f"{source}: unknown provider '{name}'. Valid: {valid}"
|
||||
) from e
|
||||
|
||||
|
||||
def _build_model(
|
||||
entry: _ModelEntry,
|
||||
defaults: _CapabilityFields,
|
||||
provider: ModelProvider,
|
||||
aliases: Dict[str, List[str]],
|
||||
source: Path,
|
||||
display_provider: Optional[str] = None,
|
||||
) -> AvailableModel:
|
||||
"""Merge defaults + per-model overrides into a final ``AvailableModel``."""
|
||||
|
||||
def pick(field_name: str, fallback):
|
||||
v = getattr(entry, field_name)
|
||||
if v is not None:
|
||||
return v
|
||||
d = getattr(defaults, field_name)
|
||||
if d is not None:
|
||||
return d
|
||||
return fallback
|
||||
|
||||
raw_attachments = entry.attachments
|
||||
if raw_attachments is None:
|
||||
raw_attachments = defaults.attachments
|
||||
if raw_attachments is None:
|
||||
raw_attachments = []
|
||||
expanded = _expand_attachments(
|
||||
raw_attachments, aliases, f"{source} [model={entry.id}]"
|
||||
)
|
||||
|
||||
caps = ModelCapabilities(
|
||||
supports_tools=pick("supports_tools", False),
|
||||
supports_structured_output=pick("supports_structured_output", False),
|
||||
supports_streaming=pick("supports_streaming", True),
|
||||
supported_attachment_types=expanded,
|
||||
context_window=pick("context_window", 128000),
|
||||
input_cost_per_token=pick("input_cost_per_token", None),
|
||||
output_cost_per_token=pick("output_cost_per_token", None),
|
||||
)
|
||||
|
||||
return AvailableModel(
|
||||
id=entry.id,
|
||||
provider=provider,
|
||||
display_name=entry.display_name or entry.id,
|
||||
description=entry.description,
|
||||
capabilities=caps,
|
||||
enabled=entry.enabled,
|
||||
base_url=entry.base_url,
|
||||
display_provider=display_provider,
|
||||
)
|
||||
|
||||
|
||||
def _load_one_yaml(
|
||||
path: Path, aliases: Dict[str, List[str]]
|
||||
) -> ProviderCatalog:
|
||||
try:
|
||||
raw = yaml.safe_load(path.read_text(encoding="utf-8")) or {}
|
||||
except yaml.YAMLError as e:
|
||||
raise ModelYAMLError(f"{path}: invalid YAML: {e}") from e
|
||||
try:
|
||||
parsed = _ProviderFile.model_validate(raw)
|
||||
except Exception as e:
|
||||
raise ModelYAMLError(f"{path}: schema error: {e}") from e
|
||||
|
||||
provider_enum = _resolve_provider_enum(parsed.provider, path)
|
||||
models = [
|
||||
_build_model(
|
||||
entry,
|
||||
parsed.defaults,
|
||||
provider_enum,
|
||||
aliases,
|
||||
path,
|
||||
display_provider=parsed.display_provider,
|
||||
)
|
||||
for entry in parsed.models
|
||||
]
|
||||
|
||||
return ProviderCatalog(
|
||||
provider=parsed.provider,
|
||||
models=models,
|
||||
source_path=path,
|
||||
display_provider=parsed.display_provider,
|
||||
api_key_env=parsed.api_key_env,
|
||||
base_url=parsed.base_url,
|
||||
)
|
||||
|
||||
|
||||
_BUILTIN_ALIASES_CACHE: Optional[Dict[str, List[str]]] = None
|
||||
|
||||
|
||||
def builtin_attachment_aliases() -> Dict[str, List[str]]:
|
||||
"""Return the built-in attachment alias map from ``_defaults.yaml``.
|
||||
|
||||
Cached after first read so repeat calls are cheap.
|
||||
"""
|
||||
global _BUILTIN_ALIASES_CACHE
|
||||
if _BUILTIN_ALIASES_CACHE is None:
|
||||
_BUILTIN_ALIASES_CACHE = _load_defaults(BUILTIN_MODELS_DIR)
|
||||
return _BUILTIN_ALIASES_CACHE
|
||||
|
||||
|
||||
def resolve_attachment_alias(alias: str) -> List[str]:
|
||||
"""Resolve a single attachment alias (e.g. ``"image"``) to its
|
||||
canonical MIME-type list. Raises ``ModelYAMLError`` if unknown.
|
||||
"""
|
||||
aliases = builtin_attachment_aliases()
|
||||
if alias not in aliases:
|
||||
valid = ", ".join(sorted(aliases.keys())) or "<none defined>"
|
||||
raise ModelYAMLError(
|
||||
f"Unknown attachment alias '{alias}'. Valid: {valid}"
|
||||
)
|
||||
return list(aliases[alias])
|
||||
|
||||
|
||||
def load_model_yamls(directories: Sequence[Path]) -> List[ProviderCatalog]:
|
||||
"""Load every ``*.yaml`` file (excluding ``_defaults.yaml``) under each
|
||||
directory in order and return a flat list of catalogs.
|
||||
|
||||
Caller is responsible for merging multiple catalogs that target the
|
||||
same provider plugin. The flat-list shape lets ``openai_compatible``
|
||||
keep each file separate (one logical endpoint per file).
|
||||
|
||||
When the same model ``id`` appears in more than one YAML across the
|
||||
directory list, a warning is logged. Order in the returned list
|
||||
preserves load order, so the registry's "later wins" merge gives the
|
||||
later directory's definition.
|
||||
"""
|
||||
catalogs: List[ProviderCatalog] = []
|
||||
seen_ids: Dict[str, Path] = {}
|
||||
|
||||
aliases: Dict[str, List[str]] = {}
|
||||
for d in directories:
|
||||
if not d or not d.exists():
|
||||
continue
|
||||
aliases.update(_load_defaults(d))
|
||||
|
||||
for d in directories:
|
||||
if not d or not d.exists():
|
||||
continue
|
||||
for path in sorted(d.glob("*.yaml")):
|
||||
if path.name == DEFAULTS_FILENAME:
|
||||
continue
|
||||
catalog = _load_one_yaml(path, aliases)
|
||||
catalogs.append(catalog)
|
||||
for m in catalog.models:
|
||||
prior = seen_ids.get(m.id)
|
||||
if prior is not None and prior != path:
|
||||
logger.warning(
|
||||
"Model id %r redefined: %s overrides %s (later wins)",
|
||||
m.id,
|
||||
path,
|
||||
prior,
|
||||
)
|
||||
seen_ids[m.id] = path
|
||||
|
||||
return catalogs
|
||||
@@ -1,213 +0,0 @@
|
||||
# Model catalogs
|
||||
|
||||
Each `*.yaml` file in this directory declares one provider's model
|
||||
catalog. The registry loads every YAML at boot and joins it to the
|
||||
matching provider plugin under `application/llm/providers/`.
|
||||
|
||||
To add or edit models, you almost always only touch a YAML here — no
|
||||
Python code required.
|
||||
|
||||
## Add a model to an existing provider
|
||||
|
||||
Open the provider's YAML (e.g. `anthropic.yaml`) and append two lines
|
||||
under `models:`:
|
||||
|
||||
```yaml
|
||||
models:
|
||||
- id: claude-3-7-sonnet
|
||||
display_name: Claude 3.7 Sonnet
|
||||
```
|
||||
|
||||
Capabilities default to the provider's `defaults:` block. Override
|
||||
per-model only when needed:
|
||||
|
||||
```yaml
|
||||
- id: claude-3-7-sonnet
|
||||
display_name: Claude 3.7 Sonnet
|
||||
context_window: 500000
|
||||
```
|
||||
|
||||
Restart the app. The new model appears in `/api/models`.
|
||||
|
||||
> The model `id` is what gets stored in agent / workflow records. Once
|
||||
> users start picking the model, **don't rename it** — agent and
|
||||
> workflow rows reference it as a free-form string and silently fall
|
||||
> back to the system default if the id disappears.
|
||||
|
||||
## Add an OpenAI-compatible provider (zero Python)
|
||||
|
||||
Drop a YAML in this directory (or in your `MODELS_CONFIG_DIR`) that uses
|
||||
the `openai_compatible` plugin. Set the env var named in `api_key_env`
|
||||
and you're done — no Python, no settings.py edit, no LLMCreator change:
|
||||
|
||||
```yaml
|
||||
# mistral.yaml
|
||||
provider: openai_compatible
|
||||
display_provider: mistral # shown in /api/models response
|
||||
api_key_env: MISTRAL_API_KEY # env var the plugin reads at boot
|
||||
base_url: https://api.mistral.ai/v1
|
||||
defaults:
|
||||
supports_tools: true
|
||||
context_window: 128000
|
||||
models:
|
||||
- id: mistral-large-latest
|
||||
display_name: Mistral Large
|
||||
- id: mistral-small-latest
|
||||
display_name: Mistral Small
|
||||
```
|
||||
|
||||
`MISTRAL_API_KEY=sk-... ; restart` — Mistral models appear in
|
||||
`/api/models` with `provider: "mistral"`. They route through the OpenAI
|
||||
wire format (it's `OpenAILLM` under the hood) but with Mistral's
|
||||
endpoint and key.
|
||||
|
||||
Multiple `openai_compatible` YAMLs coexist: each file is one logical
|
||||
endpoint with its own `api_key_env` and `base_url`. Drop in
|
||||
`together.yaml`, `fireworks.yaml`, etc. side by side. If an env var
|
||||
isn't set, that catalog is silently skipped at boot (logged at INFO) —
|
||||
no error.
|
||||
|
||||
Working example: `examples/mistral.yaml.example`. Files inside
|
||||
`examples/` aren't loaded by the registry; the glob only picks up
|
||||
`*.yaml` at the top level.
|
||||
|
||||
## Add a provider with its own SDK
|
||||
|
||||
For a provider that doesn't speak OpenAI's wire format, add one Python
|
||||
file to `application/llm/providers/<name>.py`:
|
||||
|
||||
```python
|
||||
from application.llm.providers.base import Provider
|
||||
from application.llm.my_provider import MyLLM
|
||||
|
||||
class MyProvider(Provider):
|
||||
name = "my_provider"
|
||||
llm_class = MyLLM
|
||||
|
||||
def get_api_key(self, settings):
|
||||
return settings.MY_PROVIDER_API_KEY
|
||||
```
|
||||
|
||||
Register it in `application/llm/providers/__init__.py` (one line in
|
||||
`ALL_PROVIDERS`), add `MY_PROVIDER_API_KEY` to `settings.py`, and create
|
||||
`my_provider.yaml` here with the model catalog.
|
||||
|
||||
## Schema reference
|
||||
|
||||
```yaml
|
||||
provider: <string, required> # matches the Provider plugin's `name`
|
||||
|
||||
# openai_compatible only — required for that provider, ignored for others
|
||||
display_provider: <string> # label shown in /api/models response
|
||||
api_key_env: <string> # name of the env var carrying the key
|
||||
base_url: <string> # endpoint URL
|
||||
|
||||
defaults: # optional, applied to every model below
|
||||
supports_tools: bool # default false
|
||||
supports_structured_output: bool # default false
|
||||
supports_streaming: bool # default true
|
||||
attachments: [<alias-or-mime>, ...] # default []
|
||||
context_window: int # default 128000
|
||||
input_cost_per_token: float # default null
|
||||
output_cost_per_token: float # default null
|
||||
|
||||
models: # required
|
||||
- id: <string, required> # the value persisted in agent records
|
||||
display_name: <string> # default: id
|
||||
description: <string> # default: ""
|
||||
enabled: bool # default true; false hides from /api/models
|
||||
base_url: <string> # optional custom endpoint for this model
|
||||
# All `defaults:` fields above can be overridden here per-model.
|
||||
```
|
||||
|
||||
### Attachment aliases
|
||||
|
||||
The `attachments:` list can mix human-readable aliases with raw MIME
|
||||
types. Aliases are defined in `_defaults.yaml`:
|
||||
|
||||
| Alias | Expands to |
|
||||
|---|---|
|
||||
| `image` | `image/png`, `image/jpeg`, `image/jpg`, `image/webp`, `image/gif` |
|
||||
| `pdf` | `application/pdf` |
|
||||
| `audio` | `audio/mpeg`, `audio/wav`, `audio/ogg` |
|
||||
|
||||
Use raw MIME types when you need surgical control:
|
||||
|
||||
```yaml
|
||||
attachments: [image/png, image/webp] # only these two
|
||||
```
|
||||
|
||||
## Operator-supplied YAMLs (`MODELS_CONFIG_DIR`)
|
||||
|
||||
Set the `MODELS_CONFIG_DIR` env var (or `.env` entry) to a directory
|
||||
path. Every `*.yaml` in that directory is loaded **after** the built-in
|
||||
catalog under `application/core/models/`. Operators use this to:
|
||||
|
||||
- Add new `openai_compatible` providers (Mistral, Together, Fireworks,
|
||||
Ollama, ...) without forking the repo.
|
||||
- Extend an existing provider's catalog with extra models — append
|
||||
models under `provider: anthropic` and they show up alongside the
|
||||
built-ins.
|
||||
- Override a built-in model's capabilities — declare the same `id`
|
||||
with different fields (e.g. a higher `context_window`). Later wins;
|
||||
the override is logged as a `WARNING` so you can audit it.
|
||||
|
||||
Things you cannot do via `MODELS_CONFIG_DIR`:
|
||||
|
||||
- Add a brand-new non-OpenAI provider — that needs a Python plugin
|
||||
under `application/llm/providers/` (see "Add a provider with its own
|
||||
SDK" above). Operator YAMLs may only target a `provider:` value that
|
||||
already has a registered plugin.
|
||||
|
||||
### Example: Docker
|
||||
|
||||
Mount your model YAMLs into the container and point the env var at the
|
||||
mount path:
|
||||
|
||||
```yaml
|
||||
# docker-compose.yml
|
||||
services:
|
||||
app:
|
||||
image: arc53/docsgpt
|
||||
environment:
|
||||
MODELS_CONFIG_DIR: /etc/docsgpt/models
|
||||
MISTRAL_API_KEY: ${MISTRAL_API_KEY}
|
||||
volumes:
|
||||
- ./my-models:/etc/docsgpt/models:ro
|
||||
```
|
||||
|
||||
Then `./my-models/mistral.yaml` (the file from
|
||||
`examples/mistral.yaml.example`) gets picked up at boot.
|
||||
|
||||
### Example: Kubernetes
|
||||
|
||||
Mount a `ConfigMap` containing your YAMLs at a known path and set
|
||||
`MODELS_CONFIG_DIR` on the deployment. The same `examples/mistral.yaml.example`
|
||||
becomes a key in the ConfigMap.
|
||||
|
||||
### Misconfiguration
|
||||
|
||||
If `MODELS_CONFIG_DIR` is set but the path doesn't exist (or isn't a
|
||||
directory), the app logs a `WARNING` at boot and continues with just
|
||||
the built-in catalog. The app does *not* fail to start — operators can
|
||||
ship config drift without taking down the service — but the warning is
|
||||
loud enough to surface in any reasonable log aggregator.
|
||||
|
||||
## Validation
|
||||
|
||||
YAMLs are parsed with Pydantic at boot. The app fails to start with a
|
||||
clear error message if:
|
||||
|
||||
- a top-level key is unknown
|
||||
- a model is missing `id`
|
||||
- an attachment alias isn't defined
|
||||
- the `provider:` value isn't registered as a plugin
|
||||
|
||||
This is intentional — silent fallbacks would mean users don't notice
|
||||
their model picks broke until they hit the API.
|
||||
|
||||
## Reserved fields (not yet implemented)
|
||||
|
||||
- `aliases:` on a model — old IDs that resolve to this model. Reserved
|
||||
for future renames; the schema accepts the field but it is not yet
|
||||
acted on.
|
||||
@@ -1,18 +0,0 @@
|
||||
# Global defaults applied across every model YAML in this directory.
|
||||
# Keep this file sparse — per-provider `defaults:` blocks are clearer
|
||||
# than a deep global default chain. This file is for things that
|
||||
# genuinely never vary, like the meaning of "image".
|
||||
|
||||
attachment_aliases:
|
||||
image:
|
||||
- image/png
|
||||
- image/jpeg
|
||||
- image/jpg
|
||||
- image/webp
|
||||
- image/gif
|
||||
pdf:
|
||||
- application/pdf
|
||||
audio:
|
||||
- audio/mpeg
|
||||
- audio/wav
|
||||
- audio/ogg
|
||||
@@ -1,23 +0,0 @@
|
||||
provider: anthropic
|
||||
defaults:
|
||||
supports_tools: true
|
||||
attachments: [image]
|
||||
context_window: 200000
|
||||
|
||||
models:
|
||||
- id: claude-opus-4-7
|
||||
display_name: Claude Opus 4.7
|
||||
description: Most capable Claude model for complex reasoning and agentic coding
|
||||
context_window: 1000000
|
||||
supports_structured_output: true
|
||||
|
||||
- id: claude-sonnet-4-6
|
||||
display_name: Claude Sonnet 4.6
|
||||
description: Best balance of speed and intelligence with extended thinking
|
||||
context_window: 1000000
|
||||
supports_structured_output: true
|
||||
|
||||
- id: claude-haiku-4-5
|
||||
display_name: Claude Haiku 4.5
|
||||
description: Fastest Claude model with near-frontier intelligence
|
||||
supports_structured_output: true
|
||||
@@ -1,31 +0,0 @@
|
||||
# Azure OpenAI catalog.
|
||||
#
|
||||
# IMPORTANT: For Azure OpenAI, the `id` field is the **deployment name**, not
|
||||
# a model name. Deployment names are arbitrary strings the operator chooses
|
||||
# in Azure portal (or via ARM/Bicep/Terraform) when they create a deployment
|
||||
# for a given underlying model + version.
|
||||
#
|
||||
# The IDs below are sensible defaults that mirror the underlying OpenAI
|
||||
# model name (prefixed with `azure-`). Operators almost always need to
|
||||
# override them via `MODELS_CONFIG_DIR` to match the deployment names that
|
||||
# actually exist in their Azure resource. The `display_name`, capability
|
||||
# flags, and `context_window` reflect the underlying OpenAI model.
|
||||
provider: azure_openai
|
||||
|
||||
defaults:
|
||||
supports_tools: true
|
||||
supports_structured_output: true
|
||||
attachments: [image]
|
||||
context_window: 400000
|
||||
|
||||
models:
|
||||
- id: azure-gpt-5.5
|
||||
display_name: Azure OpenAI GPT-5.5
|
||||
description: Azure-hosted flagship frontier model for complex reasoning, coding, and agentic work with a 1M-token context window
|
||||
context_window: 1050000
|
||||
- id: azure-gpt-5.4-mini
|
||||
display_name: Azure OpenAI GPT-5.4 Mini
|
||||
description: Azure-hosted cost-efficient GPT-5.4-class model for high-volume coding, computer use, and subagent workloads
|
||||
- id: azure-gpt-5.4-nano
|
||||
display_name: Azure OpenAI GPT-5.4 Nano
|
||||
description: Azure-hosted cheapest GPT-5.4-class model, optimized for simple high-volume tasks where speed and cost matter most
|
||||
@@ -1,7 +0,0 @@
|
||||
provider: docsgpt
|
||||
|
||||
models:
|
||||
- id: docsgpt-local
|
||||
display_name: DocsGPT Model
|
||||
description: Local model
|
||||
supports_tools: false
|
||||
@@ -1,31 +0,0 @@
|
||||
# EXAMPLE — copy this file to ../mistral.yaml (or to your
|
||||
# MODELS_CONFIG_DIR) and set MISTRAL_API_KEY in your environment.
|
||||
#
|
||||
# This is the entire integration. No Python required: the
|
||||
# `openai_compatible` plugin reads `api_key_env` and `base_url` from
|
||||
# the file and routes calls through the OpenAI wire format.
|
||||
#
|
||||
# Files in this `examples/` directory are NOT loaded by the registry
|
||||
# (the loader globs *.yaml at the top level only).
|
||||
|
||||
provider: openai_compatible
|
||||
display_provider: mistral # shown in /api/models response
|
||||
api_key_env: MISTRAL_API_KEY # env var the plugin reads
|
||||
base_url: https://api.mistral.ai/v1 # OpenAI-compatible endpoint
|
||||
|
||||
defaults:
|
||||
supports_tools: true
|
||||
context_window: 128000
|
||||
|
||||
models:
|
||||
- id: mistral-large-latest
|
||||
display_name: Mistral Large
|
||||
description: Top-tier reasoning model
|
||||
|
||||
- id: mistral-small-latest
|
||||
display_name: Mistral Small
|
||||
description: Fast, cost-efficient
|
||||
|
||||
- id: codestral-latest
|
||||
display_name: Codestral
|
||||
description: Code-specialized model
|
||||
@@ -1,17 +0,0 @@
|
||||
provider: google
|
||||
defaults:
|
||||
supports_tools: true
|
||||
supports_structured_output: true
|
||||
attachments: [pdf, image]
|
||||
context_window: 1048576
|
||||
|
||||
models:
|
||||
- id: gemini-3.1-pro-preview
|
||||
display_name: Gemini 3.1 Pro
|
||||
description: Most capable Gemini 3 model with advanced reasoning and agentic coding (preview)
|
||||
- id: gemini-3-flash-preview
|
||||
display_name: Gemini 3 Flash
|
||||
description: Frontier-class performance for low-latency, high-volume tasks (preview)
|
||||
- id: gemini-3.1-flash-lite-preview
|
||||
display_name: Gemini 3.1 Flash-Lite
|
||||
description: Cost-efficient frontier-class multimodal model for high-throughput workloads (preview)
|
||||
@@ -1,16 +0,0 @@
|
||||
provider: groq
|
||||
defaults:
|
||||
supports_tools: true
|
||||
context_window: 131072
|
||||
|
||||
models:
|
||||
- id: openai/gpt-oss-120b
|
||||
display_name: GPT-OSS 120B
|
||||
description: OpenAI's open-weight 120B flagship served on Groq's LPU hardware; strong general reasoning with strict structured output support
|
||||
supports_structured_output: true
|
||||
- id: llama-3.3-70b-versatile
|
||||
display_name: Llama 3.3 70B Versatile
|
||||
description: Meta's Llama 3.3 70B for general-purpose chat with parallel tool use
|
||||
- id: llama-3.1-8b-instant
|
||||
display_name: Llama 3.1 8B Instant
|
||||
description: Small, very low-latency Llama model (~560 tok/s) with parallel tool use
|
||||
@@ -1,7 +0,0 @@
|
||||
provider: huggingface
|
||||
|
||||
models:
|
||||
- id: huggingface-local
|
||||
display_name: Hugging Face Model
|
||||
description: Local Hugging Face model
|
||||
supports_tools: false
|
||||
@@ -1,21 +0,0 @@
|
||||
provider: novita
|
||||
defaults:
|
||||
supports_tools: true
|
||||
supports_structured_output: true
|
||||
|
||||
models:
|
||||
- id: deepseek/deepseek-v4-pro
|
||||
display_name: DeepSeek V4 Pro
|
||||
description: 1.6T MoE (49B active) with 1M context, hybrid CSA/HCA attention, top-tier reasoning and agentic coding
|
||||
context_window: 1048576
|
||||
|
||||
- id: moonshotai/kimi-k2.6
|
||||
display_name: Kimi K2.6
|
||||
description: 1T-parameter open-weight MoE with native vision/video, multi-step tool calling, and agentic long-horizon execution
|
||||
attachments: [image]
|
||||
context_window: 262144
|
||||
|
||||
- id: zai-org/glm-5
|
||||
display_name: GLM-5
|
||||
description: Z.AI 754B-parameter MoE with strong general reasoning, function calling, and structured output
|
||||
context_window: 202800
|
||||
@@ -1,18 +0,0 @@
|
||||
provider: openai
|
||||
defaults:
|
||||
supports_tools: true
|
||||
supports_structured_output: true
|
||||
attachments: [image]
|
||||
context_window: 400000
|
||||
|
||||
models:
|
||||
- id: gpt-5.5
|
||||
display_name: GPT-5.5
|
||||
description: Flagship frontier model for complex reasoning, coding, and agentic work with a 1M-token context window
|
||||
context_window: 1050000
|
||||
- id: gpt-5.4-mini
|
||||
display_name: GPT-5.4 Mini
|
||||
description: Cost-efficient GPT-5.4-class model for high-volume coding, computer use, and subagent workloads
|
||||
- id: gpt-5.4-nano
|
||||
display_name: GPT-5.4 Nano
|
||||
description: Cheapest GPT-5.4-class model, optimized for simple high-volume tasks where speed and cost matter most
|
||||
@@ -1,25 +0,0 @@
|
||||
provider: openrouter
|
||||
defaults:
|
||||
supports_tools: true
|
||||
attachments: [image]
|
||||
context_window: 128000
|
||||
|
||||
models:
|
||||
- id: qwen/qwen3-coder:free
|
||||
display_name: Qwen3 Coder (free)
|
||||
description: Free-tier 480B MoE coder model with strong agentic tool use; rate-limited
|
||||
context_window: 262000
|
||||
attachments: []
|
||||
|
||||
- id: deepseek/deepseek-v3.2
|
||||
display_name: DeepSeek V3.2
|
||||
description: Open-weights reasoning model, very low cost (~$0.25 in / $0.38 out per 1M)
|
||||
context_window: 131072
|
||||
attachments: []
|
||||
supports_structured_output: true
|
||||
|
||||
- id: anthropic/claude-sonnet-4.6
|
||||
display_name: Claude Sonnet 4.6 (via OpenRouter)
|
||||
description: Frontier Sonnet-class model with 1M context, vision, and extended thinking
|
||||
context_window: 1000000
|
||||
supports_structured_output: true
|
||||
24
application/core/mongo_db.py
Normal file
24
application/core/mongo_db.py
Normal file
@@ -0,0 +1,24 @@
|
||||
from application.core.settings import settings
|
||||
from pymongo import MongoClient
|
||||
|
||||
|
||||
class MongoDB:
|
||||
_client = None
|
||||
|
||||
@classmethod
|
||||
def get_client(cls):
|
||||
"""
|
||||
Get the MongoDB client instance, creating it if necessary.
|
||||
"""
|
||||
if cls._client is None:
|
||||
cls._client = MongoClient(settings.MONGO_URI)
|
||||
return cls._client
|
||||
|
||||
@classmethod
|
||||
def close_client(cls):
|
||||
"""
|
||||
Close the MongoDB client connection.
|
||||
"""
|
||||
if cls._client is not None:
|
||||
cls._client.close()
|
||||
cls._client = None
|
||||
@@ -5,12 +5,8 @@ from typing import Optional
|
||||
from pydantic import field_validator
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
current_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
|
||||
from application.core.db_uri import ( # noqa: E402
|
||||
normalize_pgvector_connection_string,
|
||||
normalize_postgres_uri,
|
||||
current_dir = os.path.dirname(
|
||||
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
)
|
||||
|
||||
|
||||
@@ -19,25 +15,19 @@ class Settings(BaseSettings):
|
||||
|
||||
AUTH_TYPE: Optional[str] = None # simple_jwt, session_jwt, or None
|
||||
LLM_PROVIDER: str = "docsgpt"
|
||||
LLM_NAME: Optional[str] = None # if LLM_PROVIDER is openai, LLM_NAME can be gpt-4 or gpt-3.5-turbo
|
||||
LLM_NAME: Optional[str] = (
|
||||
None # if LLM_PROVIDER is openai, LLM_NAME can be gpt-4 or gpt-3.5-turbo
|
||||
)
|
||||
EMBEDDINGS_NAME: str = "huggingface_sentence-transformers/all-mpnet-base-v2"
|
||||
EMBEDDINGS_BASE_URL: Optional[str] = None # Remote embeddings API URL (OpenAI-compatible)
|
||||
EMBEDDINGS_KEY: Optional[str] = None # api key for embeddings (if using openai, just copy API_KEY)
|
||||
# Optional directory of operator-supplied model YAMLs, loaded after the
|
||||
# built-in catalog under application/core/models/. Later wins on
|
||||
# duplicate model id. See application/core/models/README.md.
|
||||
MODELS_CONFIG_DIR: Optional[str] = None
|
||||
|
||||
EMBEDDINGS_KEY: Optional[str] = (
|
||||
None # api key for embeddings (if using openai, just copy API_KEY)
|
||||
)
|
||||
|
||||
CELERY_BROKER_URL: str = "redis://localhost:6379/0"
|
||||
CELERY_RESULT_BACKEND: str = "redis://localhost:6379/1"
|
||||
# Only consulted when VECTOR_STORE=mongodb or when running scripts/db/backfill.py; user data lives in Postgres.
|
||||
MONGO_URI: Optional[str] = None
|
||||
# User-data Postgres DB.
|
||||
POSTGRES_URI: Optional[str] = None
|
||||
# On app startup, apply pending Alembic migrations. Default ON for dev; disable in prod if you manage schema out-of-band.
|
||||
AUTO_MIGRATE: bool = True
|
||||
# On app startup, create the target Postgres database if it's missing (requires CREATEDB privilege). Dev-friendly default.
|
||||
AUTO_CREATE_DB: bool = True
|
||||
MONGO_URI: str = "mongodb://localhost:27017/docsgpt"
|
||||
MONGO_DB_NAME: str = "docsgpt"
|
||||
LLM_PATH: str = os.path.join(current_dir, "models/docsgpt-7b-f16.gguf")
|
||||
DEFAULT_MAX_HISTORY: int = 150
|
||||
DEFAULT_LLM_TOKEN_LIMIT: int = 128000 # Fallback when model not found in registry
|
||||
@@ -55,7 +45,9 @@ class Settings(BaseSettings):
|
||||
PARSE_IMAGE_REMOTE: bool = False
|
||||
DOCLING_OCR_ENABLED: bool = False # Enable OCR for docling parsers (PDF, images)
|
||||
DOCLING_OCR_ATTACHMENTS_ENABLED: bool = False # Enable OCR for docling when parsing attachments
|
||||
VECTOR_STORE: str = "faiss" # "faiss" or "elasticsearch" or "qdrant" or "milvus" or "lancedb" or "pgvector"
|
||||
VECTOR_STORE: str = (
|
||||
"faiss" # "faiss" or "elasticsearch" or "qdrant" or "milvus" or "lancedb" or "pgvector"
|
||||
)
|
||||
RETRIEVERS_ENABLED: list = ["classic_rag"]
|
||||
AGENT_NAME: str = "classic"
|
||||
FALLBACK_LLM_PROVIDER: Optional[str] = None # provider for fallback llm
|
||||
@@ -63,8 +55,12 @@ class Settings(BaseSettings):
|
||||
FALLBACK_LLM_API_KEY: Optional[str] = None # api key for fallback llm
|
||||
|
||||
# Google Drive integration
|
||||
GOOGLE_CLIENT_ID: Optional[str] = None # Replace with your actual Google OAuth client ID
|
||||
GOOGLE_CLIENT_SECRET: Optional[str] = None # Replace with your actual Google OAuth client secret
|
||||
GOOGLE_CLIENT_ID: Optional[str] = (
|
||||
None # Replace with your actual Google OAuth client ID
|
||||
)
|
||||
GOOGLE_CLIENT_SECRET: Optional[str] = (
|
||||
None # Replace with your actual Google OAuth client secret
|
||||
)
|
||||
CONNECTOR_REDIRECT_BASE_URI: Optional[str] = (
|
||||
"http://127.0.0.1:7091/api/connectors/callback" ##add redirect url as it is to your provider's console(gcp)
|
||||
)
|
||||
@@ -75,12 +71,8 @@ class Settings(BaseSettings):
|
||||
MICROSOFT_TENANT_ID: Optional[str] = "common" # Azure AD Tenant ID (or 'common' for multi-tenant)
|
||||
MICROSOFT_AUTHORITY: Optional[str] = None # e.g., "https://login.microsoftonline.com/{tenant_id}"
|
||||
|
||||
# Confluence Cloud integration
|
||||
CONFLUENCE_CLIENT_ID: Optional[str] = None
|
||||
CONFLUENCE_CLIENT_SECRET: Optional[str] = None
|
||||
|
||||
# GitHub source
|
||||
GITHUB_ACCESS_TOKEN: Optional[str] = None # PAT token with read repo access
|
||||
GITHUB_ACCESS_TOKEN: Optional[str] = None # PAT token with read repo access
|
||||
|
||||
# LLM Cache
|
||||
CACHE_REDIS_URL: str = "redis://localhost:6379/2"
|
||||
@@ -98,13 +90,16 @@ class Settings(BaseSettings):
|
||||
GROQ_API_KEY: Optional[str] = None
|
||||
HUGGINGFACE_API_KEY: Optional[str] = None
|
||||
OPEN_ROUTER_API_KEY: Optional[str] = None
|
||||
NOVITA_API_KEY: Optional[str] = None
|
||||
|
||||
OPENAI_API_BASE: Optional[str] = None # azure openai api base url
|
||||
OPENAI_API_VERSION: Optional[str] = None # azure openai api version
|
||||
AZURE_DEPLOYMENT_NAME: Optional[str] = None # azure deployment name for answering
|
||||
AZURE_EMBEDDINGS_DEPLOYMENT_NAME: Optional[str] = None # azure deployment name for embeddings
|
||||
OPENAI_BASE_URL: Optional[str] = None # openai base url for open ai compatable models
|
||||
AZURE_EMBEDDINGS_DEPLOYMENT_NAME: Optional[str] = (
|
||||
None # azure deployment name for embeddings
|
||||
)
|
||||
OPENAI_BASE_URL: Optional[str] = (
|
||||
None # openai base url for open ai compatable models
|
||||
)
|
||||
|
||||
# elasticsearch
|
||||
ELASTIC_CLOUD_ID: Optional[str] = None # cloud id for elasticsearch
|
||||
@@ -137,10 +132,7 @@ class Settings(BaseSettings):
|
||||
QDRANT_PATH: Optional[str] = None
|
||||
QDRANT_DISTANCE_FUNC: str = "Cosine"
|
||||
|
||||
# PGVector vectorstore config. Write the URI in whichever form you
|
||||
# prefer — ``postgres://``, ``postgresql://``, or even the SQLAlchemy
|
||||
# dialect form (``postgresql+psycopg://``) are all accepted and
|
||||
# normalized internally for ``psycopg.connect()``.
|
||||
# PGVector vectorstore config
|
||||
PGVECTOR_CONNECTION_STRING: Optional[str] = None
|
||||
# Milvus vectorstore config
|
||||
MILVUS_COLLECTION_NAME: Optional[str] = "docsgpt"
|
||||
@@ -149,13 +141,12 @@ class Settings(BaseSettings):
|
||||
|
||||
# LanceDB vectorstore config
|
||||
LANCEDB_PATH: str = "./data/lancedb" # Path where LanceDB stores its local data
|
||||
LANCEDB_TABLE_NAME: Optional[str] = "docsgpts" # Name of the table to use for storing vectors
|
||||
LANCEDB_TABLE_NAME: Optional[str] = (
|
||||
"docsgpts" # Name of the table to use for storing vectors
|
||||
)
|
||||
|
||||
FLASK_DEBUG_MODE: bool = False
|
||||
STORAGE_TYPE: str = "local" # local or s3
|
||||
|
||||
# Anonymous startup version check for security issues.
|
||||
VERSION_CHECK: bool = True
|
||||
URL_STRATEGY: str = "backend" # backend or s3
|
||||
|
||||
JWT_SECRET_KEY: str = ""
|
||||
@@ -182,16 +173,6 @@ class Settings(BaseSettings):
|
||||
COMPRESSION_PROMPT_VERSION: str = "v1.0" # Track prompt iterations
|
||||
COMPRESSION_MAX_HISTORY_POINTS: int = 3 # Keep only last N compression points to prevent DB bloat
|
||||
|
||||
@field_validator("POSTGRES_URI", mode="before")
|
||||
@classmethod
|
||||
def _normalize_postgres_uri_validator(cls, v):
|
||||
return normalize_postgres_uri(v)
|
||||
|
||||
@field_validator("PGVECTOR_CONNECTION_STRING", mode="before")
|
||||
@classmethod
|
||||
def _normalize_pgvector_connection_string_validator(cls, v):
|
||||
return normalize_pgvector_connection_string(v)
|
||||
|
||||
@field_validator(
|
||||
"API_KEY",
|
||||
"OPENAI_API_KEY",
|
||||
@@ -199,7 +180,6 @@ class Settings(BaseSettings):
|
||||
"GOOGLE_API_KEY",
|
||||
"GROQ_API_KEY",
|
||||
"HUGGINGFACE_API_KEY",
|
||||
"NOVITA_API_KEY",
|
||||
"EMBEDDINGS_KEY",
|
||||
"FALLBACK_LLM_API_KEY",
|
||||
"QDRANT_API_KEY",
|
||||
|
||||
@@ -1,72 +0,0 @@
|
||||
"""Gunicorn config — keeps uvicorn's access log in NCSA format."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import logging.config
|
||||
|
||||
# NCSA common log format:
|
||||
# %(h)s %(l)s %(u)s %(t)s "%(r)s" %(s)s %(b)s "%(f)s" "%(a)s"
|
||||
# Uvicorn's access formatter exposes a ``client_addr``/``request_line``/
|
||||
# ``status_code`` trio but not the full NCSA field set, so we re-derive
|
||||
# what we can.
|
||||
_NCSA_FMT = (
|
||||
'%(client_addr)s - - [%(asctime)s] "%(request_line)s" %(status_code)s'
|
||||
)
|
||||
|
||||
logconfig_dict = {
|
||||
"version": 1,
|
||||
"disable_existing_loggers": False,
|
||||
"formatters": {
|
||||
"ncsa_access": {
|
||||
"()": "uvicorn.logging.AccessFormatter",
|
||||
"fmt": _NCSA_FMT,
|
||||
"datefmt": "%d/%b/%Y:%H:%M:%S %z",
|
||||
"use_colors": False,
|
||||
},
|
||||
"default": {
|
||||
"format": "[%(asctime)s] [%(process)d] [%(levelname)s] %(name)s: %(message)s",
|
||||
},
|
||||
},
|
||||
"handlers": {
|
||||
"access": {
|
||||
"class": "logging.StreamHandler",
|
||||
"formatter": "ncsa_access",
|
||||
"stream": "ext://sys.stdout",
|
||||
},
|
||||
"default": {
|
||||
"class": "logging.StreamHandler",
|
||||
"formatter": "default",
|
||||
"stream": "ext://sys.stderr",
|
||||
},
|
||||
},
|
||||
"loggers": {
|
||||
"uvicorn": {"handlers": ["default"], "level": "INFO", "propagate": False},
|
||||
"uvicorn.error": {
|
||||
"handlers": ["default"],
|
||||
"level": "INFO",
|
||||
"propagate": False,
|
||||
},
|
||||
"uvicorn.access": {
|
||||
"handlers": ["access"],
|
||||
"level": "INFO",
|
||||
"propagate": False,
|
||||
},
|
||||
"gunicorn.error": {
|
||||
"handlers": ["default"],
|
||||
"level": "INFO",
|
||||
"propagate": False,
|
||||
},
|
||||
"gunicorn.access": {
|
||||
"handlers": ["access"],
|
||||
"level": "INFO",
|
||||
"propagate": False,
|
||||
},
|
||||
},
|
||||
"root": {"handlers": ["default"], "level": "INFO"},
|
||||
}
|
||||
|
||||
|
||||
def on_starting(server): # pragma: no cover — gunicorn hook
|
||||
"""Ensure gunicorn's own loggers use the configured handlers."""
|
||||
logging.config.dictConfig(logconfig_dict)
|
||||
@@ -127,33 +127,15 @@ class GoogleLLM(BaseLLM):
|
||||
).uri,
|
||||
)
|
||||
|
||||
# Cache the Google file URI on the attachment row so we don't
|
||||
# re-upload on the next LLM call. Accept either a PG UUID
|
||||
# (``id``) or a legacy Mongo ObjectId (``_id``). Opened per
|
||||
# write — this runs mid-LLM-call, so we don't wrap the
|
||||
# surrounding generator in a long-lived session.
|
||||
attachment_id = attachment.get("id") or attachment.get("_id")
|
||||
if attachment_id:
|
||||
user_id = None
|
||||
decoded = getattr(self, "decoded_token", None)
|
||||
if isinstance(decoded, dict):
|
||||
user_id = decoded.get("sub")
|
||||
from application.storage.db.repositories.attachments import (
|
||||
AttachmentsRepository,
|
||||
)
|
||||
from application.storage.db.session import db_session
|
||||
from application.core.mongo_db import MongoDB
|
||||
|
||||
try:
|
||||
with db_session() as conn:
|
||||
AttachmentsRepository(conn).update_any(
|
||||
str(attachment_id),
|
||||
user_id,
|
||||
{"google_file_uri": file_uri},
|
||||
)
|
||||
except Exception as cache_err:
|
||||
logging.warning(
|
||||
f"Failed to cache google_file_uri on attachment {attachment_id}: {cache_err}"
|
||||
)
|
||||
mongo = MongoDB.get_client()
|
||||
db = mongo[settings.MONGO_DB_NAME]
|
||||
attachments_collection = db["attachments"]
|
||||
if "_id" in attachment:
|
||||
attachments_collection.update_one(
|
||||
{"_id": attachment["_id"]}, {"$set": {"google_file_uri": file_uri}}
|
||||
)
|
||||
return file_uri
|
||||
except Exception as e:
|
||||
logging.error(f"Error uploading file to Google AI: {e}", exc_info=True)
|
||||
@@ -185,8 +167,6 @@ class GoogleLLM(BaseLLM):
|
||||
return "\n".join(parts)
|
||||
return ""
|
||||
|
||||
import json as _json
|
||||
|
||||
for message in messages:
|
||||
role = message.get("role")
|
||||
content = message.get("content")
|
||||
@@ -200,66 +180,9 @@ class GoogleLLM(BaseLLM):
|
||||
|
||||
if role == "assistant":
|
||||
role = "model"
|
||||
|
||||
parts = []
|
||||
|
||||
# Standard format: assistant message with tool_calls array
|
||||
msg_tool_calls = message.get("tool_calls")
|
||||
if msg_tool_calls and role == "model":
|
||||
for tc in msg_tool_calls:
|
||||
func = tc.get("function", {})
|
||||
args = func.get("arguments", "{}")
|
||||
if isinstance(args, str):
|
||||
try:
|
||||
args = _json.loads(args)
|
||||
except (_json.JSONDecodeError, TypeError):
|
||||
args = {}
|
||||
cleaned_args = self._remove_null_values(args)
|
||||
thought_sig = tc.get("thought_signature")
|
||||
if thought_sig:
|
||||
parts.append(
|
||||
types.Part(
|
||||
functionCall=types.FunctionCall(
|
||||
name=func.get("name", ""),
|
||||
args=cleaned_args,
|
||||
),
|
||||
thoughtSignature=thought_sig,
|
||||
)
|
||||
)
|
||||
else:
|
||||
parts.append(
|
||||
types.Part.from_function_call(
|
||||
name=func.get("name", ""),
|
||||
args=cleaned_args,
|
||||
)
|
||||
)
|
||||
if parts:
|
||||
cleaned_messages.append(types.Content(role=role, parts=parts))
|
||||
continue
|
||||
|
||||
# Standard format: tool message with tool_call_id
|
||||
tool_call_id = message.get("tool_call_id")
|
||||
if role == "tool" and tool_call_id is not None:
|
||||
result_content = content
|
||||
if isinstance(result_content, str):
|
||||
try:
|
||||
result_content = _json.loads(result_content)
|
||||
except (_json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
# Google expects function_response name — extract from tool_call_id context
|
||||
# We use a placeholder name since Google API doesn't require exact match
|
||||
parts.append(
|
||||
types.Part.from_function_response(
|
||||
name="tool_result",
|
||||
response={"result": result_content},
|
||||
)
|
||||
)
|
||||
cleaned_messages.append(types.Content(role="model", parts=parts))
|
||||
continue
|
||||
|
||||
if role == "tool":
|
||||
elif role == "tool":
|
||||
role = "model"
|
||||
|
||||
parts = []
|
||||
if role and content is not None:
|
||||
if isinstance(content, str):
|
||||
parts = [types.Part.from_text(text=content)]
|
||||
@@ -268,11 +191,15 @@ class GoogleLLM(BaseLLM):
|
||||
if "text" in item:
|
||||
parts.append(types.Part.from_text(text=item["text"]))
|
||||
elif "function_call" in item:
|
||||
# Legacy format support
|
||||
# Remove null values from args to avoid API errors
|
||||
|
||||
cleaned_args = self._remove_null_values(
|
||||
item["function_call"]["args"]
|
||||
)
|
||||
# Create function call part with thought_signature if present
|
||||
# For Gemini 3 models, we need to include thought_signature
|
||||
if "thought_signature" in item:
|
||||
# Use Part constructor with functionCall and thoughtSignature
|
||||
parts.append(
|
||||
types.Part(
|
||||
functionCall=types.FunctionCall(
|
||||
@@ -283,6 +210,7 @@ class GoogleLLM(BaseLLM):
|
||||
)
|
||||
)
|
||||
else:
|
||||
# Use helper method when no thought_signature
|
||||
parts.append(
|
||||
types.Part.from_function_call(
|
||||
name=item["function_call"]["name"],
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
@@ -316,34 +315,10 @@ class LLMHandler(ABC):
|
||||
current_prompt = self._extract_text_from_content(content)
|
||||
|
||||
elif role in {"assistant", "model"}:
|
||||
# Standard format: tool_calls array on assistant message
|
||||
msg_tool_calls = message.get("tool_calls")
|
||||
if msg_tool_calls:
|
||||
for tc in msg_tool_calls:
|
||||
call_id = tc.get("id") or str(uuid.uuid4())
|
||||
func = tc.get("function", {})
|
||||
args = func.get("arguments")
|
||||
if isinstance(args, str):
|
||||
try:
|
||||
args = json.loads(args)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
current_tool_calls[call_id] = {
|
||||
"tool_name": "unknown_tool",
|
||||
"action_name": func.get("name"),
|
||||
"arguments": args,
|
||||
"result": None,
|
||||
"status": "called",
|
||||
"call_id": call_id,
|
||||
}
|
||||
continue
|
||||
|
||||
# Legacy format: function_call/function_response in content list
|
||||
# If this assistant turn contains tool calls, collect them; otherwise commit a response.
|
||||
if isinstance(content, list):
|
||||
has_fc = False
|
||||
for item in content:
|
||||
if "function_call" in item:
|
||||
has_fc = True
|
||||
fc = item["function_call"]
|
||||
call_id = fc.get("call_id") or str(uuid.uuid4())
|
||||
current_tool_calls[call_id] = {
|
||||
@@ -354,30 +329,37 @@ class LLMHandler(ABC):
|
||||
"status": "called",
|
||||
"call_id": call_id,
|
||||
}
|
||||
if has_fc:
|
||||
continue
|
||||
elif "function_response" in item:
|
||||
fr = item["function_response"]
|
||||
call_id = fr.get("call_id") or str(uuid.uuid4())
|
||||
current_tool_calls[call_id] = {
|
||||
"tool_name": "unknown_tool",
|
||||
"action_name": fr.get("name"),
|
||||
"arguments": None,
|
||||
"result": fr.get("response", {}).get("result"),
|
||||
"status": "completed",
|
||||
"call_id": call_id,
|
||||
}
|
||||
# No direct assistant text here; continue to next message
|
||||
continue
|
||||
|
||||
response_text = self._extract_text_from_content(content)
|
||||
_commit_query(response_text)
|
||||
|
||||
elif role == "tool":
|
||||
# Standard format: tool_call_id on tool message
|
||||
call_id = message.get("tool_call_id")
|
||||
# Attach tool outputs to the latest pending tool call if possible
|
||||
tool_text = self._extract_text_from_content(content)
|
||||
|
||||
# Attempt to parse function_response style
|
||||
call_id = None
|
||||
if isinstance(content, list):
|
||||
for item in content:
|
||||
if "function_response" in item and item["function_response"].get("call_id"):
|
||||
call_id = item["function_response"]["call_id"]
|
||||
break
|
||||
if call_id and call_id in current_tool_calls:
|
||||
current_tool_calls[call_id]["result"] = tool_text
|
||||
current_tool_calls[call_id]["status"] = "completed"
|
||||
# Legacy: function_response in content list
|
||||
elif isinstance(content, list):
|
||||
for item in content:
|
||||
if "function_response" in item:
|
||||
legacy_id = item["function_response"].get("call_id")
|
||||
if legacy_id and legacy_id in current_tool_calls:
|
||||
current_tool_calls[legacy_id]["result"] = tool_text
|
||||
current_tool_calls[legacy_id]["status"] = "completed"
|
||||
break
|
||||
elif call_id is None and queries:
|
||||
elif queries:
|
||||
queries[-1].setdefault("tool_calls", []).append(
|
||||
{
|
||||
"tool_name": "unknown_tool",
|
||||
@@ -666,13 +648,6 @@ class LLMHandler(ABC):
|
||||
"""
|
||||
Execute tool calls and update conversation history.
|
||||
|
||||
When a tool requires approval or client-side execution, it is
|
||||
collected as a pending action instead of being executed. The
|
||||
generator returns ``(updated_messages, pending_actions)`` where
|
||||
*pending_actions* is ``None`` when every tool was executed
|
||||
normally, or a list of dicts describing actions the client must
|
||||
resolve before the LLM loop can continue.
|
||||
|
||||
Args:
|
||||
agent: The agent instance
|
||||
tool_calls: List of tool calls to execute
|
||||
@@ -680,11 +655,9 @@ class LLMHandler(ABC):
|
||||
messages: Current conversation history
|
||||
|
||||
Returns:
|
||||
Tuple of (updated_messages, pending_actions).
|
||||
pending_actions is None if all tools executed, otherwise a list.
|
||||
Updated messages list
|
||||
"""
|
||||
updated_messages = messages.copy()
|
||||
pending_actions: List[Dict] = []
|
||||
|
||||
for i, call in enumerate(tool_calls):
|
||||
# Check context limit before executing tool call
|
||||
@@ -790,29 +763,6 @@ class LLMHandler(ABC):
|
||||
# Set flag on agent
|
||||
agent.context_limit_reached = True
|
||||
break
|
||||
|
||||
# ---- Pause check: approval / client-side execution ----
|
||||
llm_class = agent.llm.__class__.__name__
|
||||
pause_info = agent.tool_executor.check_pause(
|
||||
tools_dict, call, llm_class
|
||||
)
|
||||
if pause_info:
|
||||
# Yield pause event so the client knows this tool is waiting
|
||||
yield {
|
||||
"type": "tool_call",
|
||||
"data": {
|
||||
"tool_name": pause_info["tool_name"],
|
||||
"call_id": pause_info["call_id"],
|
||||
"action_name": pause_info.get("llm_name", pause_info["name"]),
|
||||
"arguments": pause_info["arguments"],
|
||||
"status": pause_info["pause_type"],
|
||||
},
|
||||
}
|
||||
pending_actions.append(pause_info)
|
||||
# Do NOT add messages for pending tools here.
|
||||
# They will be added on resume to keep call/result pairs together.
|
||||
continue
|
||||
|
||||
try:
|
||||
self.tool_calls.append(call)
|
||||
tool_executor_gen = agent._execute_tool_action(tools_dict, call)
|
||||
@@ -822,30 +772,25 @@ class LLMHandler(ABC):
|
||||
except StopIteration as e:
|
||||
tool_response, call_id = e.value
|
||||
break
|
||||
|
||||
# Standard internal format: assistant message with tool_calls array
|
||||
args_str = (
|
||||
json.dumps(call.arguments)
|
||||
if isinstance(call.arguments, dict)
|
||||
else call.arguments
|
||||
)
|
||||
tool_call_obj = {
|
||||
"id": call_id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
|
||||
function_call_content = {
|
||||
"function_call": {
|
||||
"name": call.name,
|
||||
"arguments": args_str,
|
||||
},
|
||||
"args": call.arguments,
|
||||
"call_id": call_id,
|
||||
}
|
||||
}
|
||||
# Preserve thought_signature for Google Gemini 3 models
|
||||
# Include thought_signature for Google Gemini 3 models
|
||||
# It should be at the same level as function_call, not inside it
|
||||
if call.thought_signature:
|
||||
tool_call_obj["thought_signature"] = call.thought_signature
|
||||
function_call_content["thought_signature"] = call.thought_signature
|
||||
updated_messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [function_call_content],
|
||||
}
|
||||
)
|
||||
|
||||
updated_messages.append({
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [tool_call_obj],
|
||||
})
|
||||
|
||||
updated_messages.append(self.create_tool_message(call, tool_response))
|
||||
except Exception as e:
|
||||
@@ -857,15 +802,16 @@ class LLMHandler(ABC):
|
||||
error_message = self.create_tool_message(error_call, error_response)
|
||||
updated_messages.append(error_message)
|
||||
|
||||
mapping = agent.tool_executor._name_to_tool
|
||||
if call.name in mapping:
|
||||
resolved_tool_id, _ = mapping[call.name]
|
||||
tool_name = tools_dict.get(resolved_tool_id, {}).get(
|
||||
"name", "unknown_tool"
|
||||
)
|
||||
call_parts = call.name.split("_")
|
||||
if len(call_parts) >= 2:
|
||||
tool_id = call_parts[-1] # Last part is tool ID (e.g., "1")
|
||||
action_name = "_".join(call_parts[:-1])
|
||||
tool_name = tools_dict.get(tool_id, {}).get("name", "unknown_tool")
|
||||
full_action_name = f"{action_name}_{tool_id}"
|
||||
else:
|
||||
tool_name = "unknown_tool"
|
||||
full_action_name = call.name
|
||||
action_name = call.name
|
||||
full_action_name = call.name
|
||||
yield {
|
||||
"type": "tool_call",
|
||||
"data": {
|
||||
@@ -877,7 +823,7 @@ class LLMHandler(ABC):
|
||||
"status": "error",
|
||||
},
|
||||
}
|
||||
return updated_messages, pending_actions if pending_actions else None
|
||||
return updated_messages
|
||||
|
||||
def handle_non_streaming(
|
||||
self, agent, response: Any, tools_dict: Dict, messages: List[Dict]
|
||||
@@ -905,22 +851,8 @@ class LLMHandler(ABC):
|
||||
try:
|
||||
yield next(tool_handler_gen)
|
||||
except StopIteration as e:
|
||||
messages, pending_actions = e.value
|
||||
messages = e.value
|
||||
break
|
||||
|
||||
# If tools need approval or client execution, pause the loop
|
||||
if pending_actions:
|
||||
agent._pending_continuation = {
|
||||
"messages": messages,
|
||||
"pending_tool_calls": pending_actions,
|
||||
"tools_dict": tools_dict,
|
||||
}
|
||||
yield {
|
||||
"type": "tool_calls_pending",
|
||||
"data": {"pending_tool_calls": pending_actions},
|
||||
}
|
||||
return ""
|
||||
|
||||
response = agent.llm.gen(
|
||||
model=agent.model_id, messages=messages, tools=agent.tools
|
||||
)
|
||||
@@ -981,23 +913,10 @@ class LLMHandler(ABC):
|
||||
try:
|
||||
yield next(tool_handler_gen)
|
||||
except StopIteration as e:
|
||||
messages, pending_actions = e.value
|
||||
messages = e.value
|
||||
break
|
||||
tool_calls = {}
|
||||
|
||||
# If tools need approval or client execution, pause the loop
|
||||
if pending_actions:
|
||||
agent._pending_continuation = {
|
||||
"messages": messages,
|
||||
"pending_tool_calls": pending_actions,
|
||||
"tools_dict": tools_dict,
|
||||
}
|
||||
yield {
|
||||
"type": "tool_calls_pending",
|
||||
"data": {"pending_tool_calls": pending_actions},
|
||||
}
|
||||
return
|
||||
|
||||
# Check if context limit was reached during tool execution
|
||||
if hasattr(agent, 'context_limit_reached') and agent.context_limit_reached:
|
||||
# Add system message warning about context limit
|
||||
|
||||
@@ -67,18 +67,18 @@ class GoogleLLMHandler(LLMHandler):
|
||||
)
|
||||
|
||||
def create_tool_message(self, tool_call: ToolCall, result: Any) -> Dict:
|
||||
"""Create a tool result message in the standard internal format."""
|
||||
import json as _json
|
||||
"""Create Google-style tool message."""
|
||||
|
||||
content = (
|
||||
_json.dumps(result)
|
||||
if not isinstance(result, str)
|
||||
else result
|
||||
)
|
||||
return {
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call.id,
|
||||
"content": content,
|
||||
"role": "model",
|
||||
"content": [
|
||||
{
|
||||
"function_response": {
|
||||
"name": tool_call.name,
|
||||
"response": {"result": result},
|
||||
}
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
def _iterate_stream(self, response: Any) -> Generator:
|
||||
|
||||
@@ -7,7 +7,6 @@ class LLMHandlerCreator:
|
||||
handlers = {
|
||||
"openai": OpenAILLMHandler,
|
||||
"google": GoogleLLMHandler,
|
||||
"novita": OpenAILLMHandler, # Novita uses OpenAI-compatible API
|
||||
"default": OpenAILLMHandler,
|
||||
}
|
||||
|
||||
|
||||
@@ -37,18 +37,18 @@ class OpenAILLMHandler(LLMHandler):
|
||||
)
|
||||
|
||||
def create_tool_message(self, tool_call: ToolCall, result: Any) -> Dict:
|
||||
"""Create a tool result message in the standard internal format."""
|
||||
import json as _json
|
||||
|
||||
content = (
|
||||
_json.dumps(result)
|
||||
if not isinstance(result, str)
|
||||
else result
|
||||
)
|
||||
"""Create OpenAI-style tool message."""
|
||||
return {
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call.id,
|
||||
"content": content,
|
||||
"content": [
|
||||
{
|
||||
"function_response": {
|
||||
"name": tool_call.name,
|
||||
"response": {"result": result},
|
||||
"call_id": tool_call.id,
|
||||
}
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
def _iterate_stream(self, response: Any) -> Generator:
|
||||
|
||||
@@ -1,11 +1,34 @@
|
||||
import logging
|
||||
|
||||
from application.llm.providers import PROVIDERS_BY_NAME
|
||||
from application.llm.anthropic import AnthropicLLM
|
||||
from application.llm.docsgpt_provider import DocsGPTAPILLM
|
||||
from application.llm.google_ai import GoogleLLM
|
||||
from application.llm.groq import GroqLLM
|
||||
from application.llm.llama_cpp import LlamaCpp
|
||||
from application.llm.novita import NovitaLLM
|
||||
from application.llm.openai import AzureOpenAILLM, OpenAILLM
|
||||
from application.llm.premai import PremAILLM
|
||||
from application.llm.sagemaker import SagemakerAPILLM
|
||||
from application.llm.open_router import OpenRouterLLM
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LLMCreator:
|
||||
llms = {
|
||||
"openai": OpenAILLM,
|
||||
"azure_openai": AzureOpenAILLM,
|
||||
"sagemaker": SagemakerAPILLM,
|
||||
"llama.cpp": LlamaCpp,
|
||||
"anthropic": AnthropicLLM,
|
||||
"docsgpt": DocsGPTAPILLM,
|
||||
"premai": PremAILLM,
|
||||
"groq": GroqLLM,
|
||||
"google": GoogleLLM,
|
||||
"novita": NovitaLLM,
|
||||
"openrouter": OpenRouterLLM,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def create_llm(
|
||||
cls,
|
||||
@@ -19,27 +42,18 @@ class LLMCreator:
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
from application.core.model_registry import ModelRegistry
|
||||
from application.core.model_utils import get_base_url_for_model
|
||||
|
||||
plugin = PROVIDERS_BY_NAME.get(type.lower())
|
||||
if plugin is None or plugin.llm_class is None:
|
||||
llm_class = cls.llms.get(type.lower())
|
||||
if not llm_class:
|
||||
raise ValueError(f"No LLM class found for type {type}")
|
||||
|
||||
# Prefer per-model endpoint config from the registry. This is what
|
||||
# makes openai_compatible (and the future end-user BYOM phase)
|
||||
# work without changing every call site: if the registered
|
||||
# AvailableModel carries its own api_key / base_url, they win
|
||||
# over whatever the caller resolved via the provider plugin.
|
||||
# Extract base_url from model configuration if model_id is provided
|
||||
base_url = None
|
||||
if model_id:
|
||||
model = ModelRegistry.get_instance().get_model(model_id)
|
||||
if model is not None:
|
||||
if model.api_key:
|
||||
api_key = model.api_key
|
||||
if model.base_url:
|
||||
base_url = model.base_url
|
||||
base_url = get_base_url_for_model(model_id)
|
||||
|
||||
return plugin.llm_class(
|
||||
return llm_class(
|
||||
api_key,
|
||||
user_api_key,
|
||||
decoded_token=decoded_token,
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user