mirror of
https://github.com/arc53/DocsGPT.git
synced 2026-05-10 12:31:21 +00:00
Compare commits
129 Commits
tests-util
...
codex/draf
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c18f85a050 | ||
|
|
5ecb174567 | ||
|
|
ed7212d016 | ||
|
|
f82acdab5d | ||
|
|
361aebc34c | ||
|
|
bf194c1a0f | ||
|
|
54c396750b | ||
|
|
9adebfec69 | ||
|
|
92c321f163 | ||
|
|
e3d36b9e52 | ||
|
|
8950e11208 | ||
|
|
5de0132a65 | ||
|
|
b92ca91512 | ||
|
|
8e0b2844a2 | ||
|
|
0969db5e30 | ||
|
|
af335a27e8 | ||
|
|
0ae3139284 | ||
|
|
7529ca3dd6 | ||
|
|
1b813320f1 | ||
|
|
02012e9a0b | ||
|
|
c2f027265a | ||
|
|
0ae615c10e | ||
|
|
881d0da344 | ||
|
|
1376de6bae | ||
|
|
362ebfcc0a | ||
|
|
bc77eed3d8 | ||
|
|
1f346588e7 | ||
|
|
2fed5c882b | ||
|
|
aa938d76d7 | ||
|
|
2940628aa6 | ||
|
|
7f23928134 | ||
|
|
20e17c84c7 | ||
|
|
389ddf6068 | ||
|
|
1e2443fb90 | ||
|
|
6387bd1892 | ||
|
|
7d22724d1c | ||
|
|
f6f12f6895 | ||
|
|
934127f323 | ||
|
|
1780e3cc91 | ||
|
|
5e7fab2f34 | ||
|
|
92ae76f95e | ||
|
|
18755bdd9b | ||
|
|
0f20adcbf4 | ||
|
|
18e2a829c9 | ||
|
|
cd44501a71 | ||
|
|
f8ebdf3fd4 | ||
|
|
7c6fca18ad | ||
|
|
5fab798707 | ||
|
|
cb30a24e05 | ||
|
|
530761d08c | ||
|
|
73fbc28744 | ||
|
|
b5b6538762 | ||
|
|
a9761061fc | ||
|
|
9388996a15 | ||
|
|
875868b7e5 | ||
|
|
502819ae52 | ||
|
|
cada1a44fc | ||
|
|
6192767451 | ||
|
|
5c3e6eca54 | ||
|
|
59d9d4ac50 | ||
|
|
3931ccccee | ||
|
|
55717043f6 | ||
|
|
ececcb8b17 | ||
|
|
420e9d3dd5 | ||
|
|
749eed3d0b | ||
|
|
bd03a513e3 | ||
|
|
fcdb4fb5e8 | ||
|
|
e787c896eb | ||
|
|
23aeaff5db | ||
|
|
689dd79597 | ||
|
|
0c15af90b1 | ||
|
|
cdd6ff6557 | ||
|
|
72b3d94453 | ||
|
|
7e88d09e5d | ||
|
|
74a4a237dc | ||
|
|
c3f01c6619 | ||
|
|
6b408823d4 | ||
|
|
3fc81ac5d8 | ||
|
|
2652f8a5b0 | ||
|
|
d711eefe96 | ||
|
|
79206f3919 | ||
|
|
de971d9452 | ||
|
|
1b4d5ca0dd | ||
|
|
81989e8258 | ||
|
|
dc262d1698 | ||
|
|
69f9c93869 | ||
|
|
74bf80b25c | ||
|
|
d9a92a7208 | ||
|
|
02e93d993d | ||
|
|
6b6495f48c | ||
|
|
249dd9ce37 | ||
|
|
9134ab0478 | ||
|
|
10ef68c9d0 | ||
|
|
7d65cf1c2b | ||
|
|
13c6cc59c1 | ||
|
|
6381f7dd4e | ||
|
|
e6ac4008fe | ||
|
|
1af09f114d | ||
|
|
be7da983e7 | ||
|
|
8b9e595d85 | ||
|
|
398f3acc8d | ||
|
|
e04baa7ed8 | ||
|
|
e5586b6f20 | ||
|
|
addf57cab7 | ||
|
|
648b3f1d20 | ||
|
|
a75a9e23f9 | ||
|
|
73256389cf | ||
|
|
d609efca49 | ||
|
|
772860b667 | ||
|
|
ea2fd8b04a | ||
|
|
2c73deac20 | ||
|
|
47f3907e5e | ||
|
|
727495c553 | ||
|
|
a3b08a5b44 | ||
|
|
81532ada2a | ||
|
|
43f71374e5 | ||
|
|
d5c0322e2a | ||
|
|
3b66a3176c | ||
|
|
dc6db847ca | ||
|
|
ed0063aada | ||
|
|
9a6a55b6da | ||
|
|
12a8368216 | ||
|
|
3f6d6f15ea | ||
|
|
193ca6fd63 | ||
|
|
174dee0fe6 | ||
|
|
844167ba06 | ||
|
|
6fa3acb1ca | ||
|
|
9fd063266b | ||
|
|
324a8cd4cf |
@@ -34,3 +34,9 @@ MICROSOFT_TENANT_ID=your-azure-ad-tenant-id
|
|||||||
#or "https://login.microsoftonline.com/contoso.onmicrosoft.com".
|
#or "https://login.microsoftonline.com/contoso.onmicrosoft.com".
|
||||||
#Alternatively, use "https://login.microsoftonline.com/common" for multi-tenant app.
|
#Alternatively, use "https://login.microsoftonline.com/common" for multi-tenant app.
|
||||||
MICROSOFT_AUTHORITY=https://{tenantId}.ciamlogin.com/{tenantId}
|
MICROSOFT_AUTHORITY=https://{tenantId}.ciamlogin.com/{tenantId}
|
||||||
|
|
||||||
|
# User-data Postgres DB (Phase 0 of the MongoDB→Postgres migration).
|
||||||
|
# Standard Postgres URI — `postgres://` and `postgresql://` both work.
|
||||||
|
# Leave unset while the migration is still being rolled out; the app will
|
||||||
|
# fall back to MongoDB for user data until POSTGRES_URI is configured.
|
||||||
|
# POSTGRES_URI=postgresql://docsgpt:docsgpt@localhost:5432/docsgpt
|
||||||
|
|||||||
99
.github/INCIDENT_RESPONSE.md
vendored
Normal file
99
.github/INCIDENT_RESPONSE.md
vendored
Normal file
@@ -0,0 +1,99 @@
|
|||||||
|
# DocsGPT Incident Response Plan (IRP)
|
||||||
|
|
||||||
|
This playbook describes how maintainers respond to confirmed or suspected security incidents.
|
||||||
|
|
||||||
|
- Vulnerability reporting: [`SECURITY.md`](../SECURITY.md)
|
||||||
|
- Non-security bugs/features: [`CONTRIBUTING.md`](../CONTRIBUTING.md)
|
||||||
|
|
||||||
|
## Severity
|
||||||
|
|
||||||
|
| Severity | Definition | Typical examples |
|
||||||
|
|---|---|---|
|
||||||
|
| **Critical** | Active exploitation, supply-chain compromise, or confirmed data breach requiring immediate user action. | Compromised release artifact/image; remote execution. |
|
||||||
|
| **High** | Serious undisclosed vulnerability with no practical workaround, or CVSS >= 7.0. | key leakage; prompt injection enabling cross-tenant access. |
|
||||||
|
| **Medium** | Material impact but constrained by preconditions/scope, or a practical workaround exists. | Auth-required exploit; dependency CVE with limited reachability. |
|
||||||
|
| **Low** | Defense-in-depth or narrow availability impact with no confirmed data exposure. | Missing rate limiting; hardening gap without exploit evidence. |
|
||||||
|
|
||||||
|
|
||||||
|
## Response workflow
|
||||||
|
|
||||||
|
### 1) Triage (target: initial response within 48 hours)
|
||||||
|
|
||||||
|
1. Acknowledge report.
|
||||||
|
2. Validate on latest release and `main`.
|
||||||
|
3. Confirm in-scope security issue vs. hardening item (per `SECURITY.md`).
|
||||||
|
4. Assign severity and open a **draft GitHub Security Advisory (GHSA)** (no public issue).
|
||||||
|
5. Determine whether root cause is DocsGPT code or upstream dependency/provider.
|
||||||
|
|
||||||
|
### 2) Investigation
|
||||||
|
|
||||||
|
1. Identify affected components, versions, and deployment scope (self-hosted, cloud, or both).
|
||||||
|
2. For AI issues, explicitly evaluate prompt injection, document isolation, and output leakage.
|
||||||
|
3. Request a CVE through GHSA for **Medium+** issues.
|
||||||
|
|
||||||
|
### 3) Containment, fix, and disclosure
|
||||||
|
|
||||||
|
1. Implement and test fix in private security workflow (GHSA private fork/branch).
|
||||||
|
2. Merge fix to `main`, cut patched release, and verify published artifacts/images.
|
||||||
|
3. Patch managed cloud deployment (`app.docsgpt.cloud`) and other deployments as soon as validated.
|
||||||
|
4. Publish GHSA with CVE (if assigned), affected/fixed versions, CVSS, mitigations, and upgrade guidance.
|
||||||
|
5. **Critical/High:** coordinate disclosure timing with reporter (goal: <= 90 days) and publish a notice.
|
||||||
|
6. **Medium/Low:** include in next scheduled release unless risk requires immediate out-of-band patching.
|
||||||
|
|
||||||
|
### 4) Post-incident
|
||||||
|
|
||||||
|
1. Monitor support channels (GitHub/Discord) for regressions or exploitation reports.
|
||||||
|
2. Run a short retrospective (root cause, detection, response gaps, prevention work).
|
||||||
|
3. Track follow-up hardening actions with owners/dates.
|
||||||
|
4. Update this IRP and related runbooks as needed.
|
||||||
|
|
||||||
|
## Scenario playbooks
|
||||||
|
|
||||||
|
### Supply-chain compromise
|
||||||
|
|
||||||
|
1. Freeze releases and investigate blast radius.
|
||||||
|
2. Rotate credentials in order: Docker Hub -> GitHub tokens -> LLM provider keys -> DB credentials -> `JWT_SECRET_KEY` -> `ENCRYPTION_SECRET_KEY` -> `INTERNAL_KEY`.
|
||||||
|
3. Replace compromised artifacts/tags with clean releases and revoke/remove bad tags where possible.
|
||||||
|
4. Publish advisory with exact affected versions and required user actions.
|
||||||
|
|
||||||
|
### Data exposure
|
||||||
|
|
||||||
|
1. Determine scope (users, documents, keys, logs, time window).
|
||||||
|
2. Disable affected path or hotfix immediately for managed cloud.
|
||||||
|
3. Notify affected users with concrete remediation steps (for example, rotate keys).
|
||||||
|
4. Continue through standard fix/disclosure workflow.
|
||||||
|
|
||||||
|
### Critical regression with security impact
|
||||||
|
|
||||||
|
1. Identify introducing change (`git bisect` if needed).
|
||||||
|
2. Publish workaround within 24 hours (for example, pin to known-good version).
|
||||||
|
3. Ship patch release with regression test and close incident with public summary.
|
||||||
|
|
||||||
|
## AI-specific guidance
|
||||||
|
|
||||||
|
Treat confirmed AI-specific abuse as security incidents:
|
||||||
|
|
||||||
|
- Prompt injection causing sensitive data exfiltration (from tools that don't belong to the agent) -> **High**
|
||||||
|
- Cross-tenant retrieval/isolation failure -> **High**
|
||||||
|
- API key disclosure in output -> **High**
|
||||||
|
|
||||||
|
## Secret rotation quick reference
|
||||||
|
|
||||||
|
| Secret | Standard rotation action |
|
||||||
|
|---|---|
|
||||||
|
| Docker Hub credentials | Revoke/replace in Docker Hub; update CI/CD secrets |
|
||||||
|
| GitHub tokens/PATs | Revoke/replace in GitHub; update automation secrets |
|
||||||
|
| LLM provider API keys | Rotate in provider console; update runtime/deploy secrets |
|
||||||
|
| Database credentials | Rotate in DB platform; redeploy with new secrets |
|
||||||
|
| `JWT_SECRET_KEY` | Rotate and redeploy (invalidates all active user sessions/tokens) |
|
||||||
|
| `ENCRYPTION_SECRET_KEY` | Rotate and redeploy (re-encrypt stored data if possible; existing encrypted data may become inaccessible) |
|
||||||
|
| `INTERNAL_KEY` | Rotate and redeploy (invalidates worker-to-backend authentication) |
|
||||||
|
|
||||||
|
## Maintenance
|
||||||
|
|
||||||
|
Review this document:
|
||||||
|
|
||||||
|
- after every **Critical/High** incident, and
|
||||||
|
- at least annually.
|
||||||
|
|
||||||
|
Changes should be proposed via pull request to `main`.
|
||||||
144
.github/THREAT_MODEL.md
vendored
Normal file
144
.github/THREAT_MODEL.md
vendored
Normal file
@@ -0,0 +1,144 @@
|
|||||||
|
# DocsGPT Public Threat Model
|
||||||
|
|
||||||
|
**Classification:** Public
|
||||||
|
**Last updated:** 2026-04-15
|
||||||
|
**Applies to:** Open-source and self-hosted DocsGPT deployments
|
||||||
|
|
||||||
|
## 1) Overview
|
||||||
|
|
||||||
|
DocsGPT ingests content (files/URLs/connectors), indexes it, and answers queries via LLM-backed APIs and optional tools.
|
||||||
|
|
||||||
|
Core components:
|
||||||
|
- Backend API (`application/`)
|
||||||
|
- Workers/ingestion (`application/worker.py` and related modules)
|
||||||
|
- Datastores (MongoDB/Redis/vector stores)
|
||||||
|
- Frontend (`frontend/`)
|
||||||
|
- Optional extensions/integrations (`extensions/`)
|
||||||
|
|
||||||
|
## 2) Scope and assumptions
|
||||||
|
|
||||||
|
In scope:
|
||||||
|
- Application-level threats in this repository.
|
||||||
|
- Local and internet-exposed self-hosted deployments.
|
||||||
|
|
||||||
|
Assumptions:
|
||||||
|
- Internet-facing instances enable auth and use strong secrets.
|
||||||
|
- Datastores/internal services are not publicly exposed.
|
||||||
|
|
||||||
|
Out of scope:
|
||||||
|
- Cloud hardware/provider compromise.
|
||||||
|
- Security guarantees of external LLM vendors.
|
||||||
|
- Full security audits of third-party systems targeted by tools (external DBs/MCP servers/code-exec APIs).
|
||||||
|
|
||||||
|
## 3) Security objectives
|
||||||
|
|
||||||
|
- Protect document/conversation confidentiality.
|
||||||
|
- Preserve integrity of prompts, agents, tools, and indexed data.
|
||||||
|
- Maintain API/worker availability.
|
||||||
|
- Enforce tenant isolation in authenticated deployments.
|
||||||
|
|
||||||
|
## 4) Assets
|
||||||
|
|
||||||
|
- Documents, attachments, chunks/embeddings, summaries.
|
||||||
|
- Conversations, agents, workflows, prompt templates.
|
||||||
|
- Secrets (JWT secret, `INTERNAL_KEY`, provider/API/OAuth credentials).
|
||||||
|
- Operational capacity (worker throughput, queue depth, model quota/cost).
|
||||||
|
|
||||||
|
## 5) Trust boundaries and untrusted input
|
||||||
|
|
||||||
|
Trust boundaries:
|
||||||
|
- Internet ↔ Frontend
|
||||||
|
- Frontend ↔ Backend API
|
||||||
|
- Backend ↔ Workers/internal APIs
|
||||||
|
- Backend/workers ↔ Datastores
|
||||||
|
- Backend ↔ External LLM/connectors/remote URLs
|
||||||
|
|
||||||
|
Untrusted input includes API payloads, file uploads, remote URLs, OAuth/webhook data, retrieved content, and LLM/tool arguments.
|
||||||
|
|
||||||
|
## 6) Main attack surfaces
|
||||||
|
|
||||||
|
1. Auth/authz paths and sharing tokens.
|
||||||
|
2. File upload + parsing pipeline.
|
||||||
|
3. Remote URL fetching and connectors (SSRF risk).
|
||||||
|
4. Agent/tool execution from LLM output.
|
||||||
|
5. Template/workflow rendering.
|
||||||
|
6. Frontend rendering + token storage.
|
||||||
|
7. Internal service endpoints (`INTERNAL_KEY`).
|
||||||
|
8. High-impact integrations (SQL tool, generic API tool, remote MCP tools).
|
||||||
|
|
||||||
|
## 7) Key threats and expected mitigations
|
||||||
|
|
||||||
|
### A. Auth/authz misconfiguration
|
||||||
|
- Threat: weak/no auth or leaked tokens leads to broad data access.
|
||||||
|
- Mitigations: require auth for public deployments, short-lived tokens, rotation/revocation, least-privilege sharing.
|
||||||
|
|
||||||
|
### B. Untrusted file ingestion
|
||||||
|
- Threat: malicious files/archives trigger traversal, parser exploits, or resource exhaustion.
|
||||||
|
- Mitigations: strict path checks, archive safeguards, file limits, patched parser dependencies.
|
||||||
|
|
||||||
|
### C. SSRF/outbound abuse
|
||||||
|
- Threat: URL loaders/tools access private/internal/metadata endpoints.
|
||||||
|
- Mitigations: validate URLs + redirects, block private/link-local ranges, apply egress controls/allowlists.
|
||||||
|
|
||||||
|
### D. Prompt injection + tool abuse
|
||||||
|
- Threat: retrieved text manipulates model behavior and causes unsafe tool calls.
|
||||||
|
- Threat: never rely on the model to "choose correctly" under adversarial input.
|
||||||
|
- Mitigations: treat retrieved/model output as untrusted, enforce tool policies, only expose tools explicitly assigned by the user/admin to that agent, separate system instructions from retrieved content, audit tool calls.
|
||||||
|
|
||||||
|
### E. Dangerous tool capability chaining (SQL/API/MCP)
|
||||||
|
- Threat: write-capable SQL credentials allow destructive queries.
|
||||||
|
- Threat: API tool can trigger side effects (infra/payment/webhook/code-exec endpoints).
|
||||||
|
- Threat: remote MCP tools may expose privileged operations.
|
||||||
|
- Mitigations: read-only-by-default credentials, destination allowlists, explicit approval for write/exec actions, per-tool policy enforcement + logging.
|
||||||
|
|
||||||
|
### F. Frontend/XSS + token theft
|
||||||
|
- Threat: XSS can steal local tokens and call APIs.
|
||||||
|
- Mitigations: reduce unsafe rendering paths, strong CSP, scoped short-lived credentials.
|
||||||
|
|
||||||
|
### G. Internal endpoint exposure
|
||||||
|
- Threat: weak/unset `INTERNAL_KEY` enables internal API abuse.
|
||||||
|
- Mitigations: fail closed, require strong random keys, keep internal APIs private.
|
||||||
|
|
||||||
|
### H. DoS and cost abuse
|
||||||
|
- Threat: request floods, large ingestion jobs, expensive prompts/crawls.
|
||||||
|
- Mitigations: rate limits, quotas, timeouts, queue backpressure, usage budgets.
|
||||||
|
|
||||||
|
## 8) Example attacker stories
|
||||||
|
|
||||||
|
- Internet-exposed deployment runs with weak/no auth and receives unauthorized data access/abuse.
|
||||||
|
- Intranet deployment intentionally using weak/no auth is vulnerable to insider misuse and lateral-movement abuse.
|
||||||
|
- Crafted archive attempts path traversal during extraction.
|
||||||
|
- Malicious URL/redirect chain targets internal services.
|
||||||
|
- Poisoned document causes data exfiltration through tool calls.
|
||||||
|
- Over-privileged SQL/API/MCP tool performs destructive side effects.
|
||||||
|
|
||||||
|
## 9) Severity calibration
|
||||||
|
|
||||||
|
- **Critical:** unauthenticated public data access; prompt-injection-driven exfiltration; SSRF to sensitive internal endpoints.
|
||||||
|
- **High:** cross-tenant leakage, persistent token compromise, over-privileged destructive tools.
|
||||||
|
- **Medium:** DoS/cost amplification and non-critical information disclosure.
|
||||||
|
- **Low:** minor hardening gaps with limited impact.
|
||||||
|
|
||||||
|
## 10) Baseline controls for public deployments
|
||||||
|
|
||||||
|
1. Enforce authentication and secure defaults.
|
||||||
|
2. Set/rotate strong secrets (`JWT`, `INTERNAL_KEY`, encryption keys).
|
||||||
|
3. Restrict CORS and front API with a hardened proxy.
|
||||||
|
4. Add rate limiting/quotas for answer/upload/crawl/token endpoints.
|
||||||
|
5. Enforce URL+redirect SSRF protections and egress restrictions.
|
||||||
|
6. Apply upload/archive/parsing hardening.
|
||||||
|
7. Require least-privilege tool credentials and auditable tool execution.
|
||||||
|
8. Monitor auth failures, tool anomalies, ingestion spikes, and cost anomalies.
|
||||||
|
9. Keep dependencies/images patched and scanned.
|
||||||
|
10. Validate multi-tenant isolation with explicit tests.
|
||||||
|
|
||||||
|
## 11) Maintenance
|
||||||
|
|
||||||
|
Review this model after major auth, ingestion, connector, tool, or workflow changes.
|
||||||
|
|
||||||
|
## References
|
||||||
|
|
||||||
|
- [OWASP Top 10 for LLM Applications](https://owasp.org/www-project-top-10-for-large-language-model-applications/)
|
||||||
|
- [OWASP ASVS](https://owasp.org/www-project-application-security-verification-standard/)
|
||||||
|
- [STRIDE overview](https://learn.microsoft.com/azure/security/develop/threat-modeling-tool-threats)
|
||||||
|
- [DocsGPT SECURITY.md](../SECURITY.md)
|
||||||
@@ -1,46 +1,80 @@
|
|||||||
Ollama
|
Agentic
|
||||||
Qdrant
|
Anthropic's
|
||||||
Milvus
|
api
|
||||||
Chatwoot
|
|
||||||
Nextra
|
|
||||||
VSCode
|
|
||||||
npm
|
|
||||||
LLMs
|
|
||||||
APIs
|
APIs
|
||||||
Groq
|
Atlassian
|
||||||
SGLang
|
automations
|
||||||
LMDeploy
|
autoescaping
|
||||||
OAuth
|
Autoescaping
|
||||||
Vite
|
backfill
|
||||||
LLM
|
backfills
|
||||||
JSONPath
|
bool
|
||||||
UIs
|
boolean
|
||||||
|
brave_web_search
|
||||||
|
chatbot
|
||||||
|
Chatwoot
|
||||||
|
config
|
||||||
configs
|
configs
|
||||||
uncomment
|
CSVs
|
||||||
qdrant
|
dev
|
||||||
vectorstore
|
diarization
|
||||||
|
Docling
|
||||||
docsgpt
|
docsgpt
|
||||||
llm
|
docstrings
|
||||||
|
Entra
|
||||||
|
env
|
||||||
|
enqueues
|
||||||
|
EOL
|
||||||
|
ESLint
|
||||||
|
feedbacks
|
||||||
|
Figma
|
||||||
GPUs
|
GPUs
|
||||||
|
Groq
|
||||||
|
hardcode
|
||||||
|
hardcoding
|
||||||
|
Idempotency
|
||||||
|
JSONPath
|
||||||
kubectl
|
kubectl
|
||||||
Lightsail
|
Lightsail
|
||||||
enqueues
|
llama_cpp
|
||||||
chatbot
|
llm
|
||||||
VSCode's
|
LLM
|
||||||
Shareability
|
LLMs
|
||||||
feedbacks
|
LMDeploy
|
||||||
automations
|
Milvus
|
||||||
|
Mixtral
|
||||||
|
namespace
|
||||||
|
namespaces
|
||||||
|
needs_auth
|
||||||
|
Nextra
|
||||||
|
Novita
|
||||||
|
npm
|
||||||
|
OAuth
|
||||||
|
Ollama
|
||||||
|
opencode
|
||||||
|
parsable
|
||||||
|
passthrough
|
||||||
|
PDFs
|
||||||
|
pgvector
|
||||||
|
Postgres
|
||||||
Premade
|
Premade
|
||||||
Signup
|
Pydantic
|
||||||
|
pytest
|
||||||
|
Qdrant
|
||||||
|
qdrant
|
||||||
Repo
|
Repo
|
||||||
repo
|
repo
|
||||||
env
|
Sanitization
|
||||||
URl
|
|
||||||
agentic
|
|
||||||
llama_cpp
|
|
||||||
parsable
|
|
||||||
SDKs
|
SDKs
|
||||||
boolean
|
SGLang
|
||||||
bool
|
Shareability
|
||||||
hardcode
|
Signup
|
||||||
EOL
|
Supabase
|
||||||
|
UIs
|
||||||
|
uncomment
|
||||||
|
URl
|
||||||
|
vectorstore
|
||||||
|
Vite
|
||||||
|
VSCode
|
||||||
|
VSCode's
|
||||||
|
widget's
|
||||||
|
|||||||
25
.github/workflows/zizmor.yml
vendored
Normal file
25
.github/workflows/zizmor.yml
vendored
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
name: GitHub Actions Security Analysis
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches: ["master"]
|
||||||
|
pull_request:
|
||||||
|
branches: ["**"]
|
||||||
|
|
||||||
|
permissions: {}
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
zizmor:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
|
permissions:
|
||||||
|
security-events: write # Required for upload-sarif (used by zizmor-action) to upload SARIF files.
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Checkout repository
|
||||||
|
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||||
|
with:
|
||||||
|
persist-credentials: false
|
||||||
|
|
||||||
|
- name: Run zizmor 🌈
|
||||||
|
uses: zizmorcore/zizmor-action@71321a20a9ded102f6e9ce5718a2fcec2c4f70d8 # v0.5.2
|
||||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -108,6 +108,8 @@ celerybeat.pid
|
|||||||
# Environments
|
# Environments
|
||||||
.env
|
.env
|
||||||
.venv
|
.venv
|
||||||
|
# Machine-specific Claude Code guidance (see CLAUDE.md preamble)
|
||||||
|
CLAUDE.md
|
||||||
env/
|
env/
|
||||||
venv/
|
venv/
|
||||||
ENV/
|
ENV/
|
||||||
@@ -181,5 +183,6 @@ application/vectors/
|
|||||||
|
|
||||||
node_modules/
|
node_modules/
|
||||||
.vscode/settings.json
|
.vscode/settings.json
|
||||||
|
.vscode/sftp.json
|
||||||
/models/
|
/models/
|
||||||
model/
|
model/
|
||||||
|
|||||||
@@ -1,5 +1,7 @@
|
|||||||
MinAlertLevel = warning
|
MinAlertLevel = warning
|
||||||
StylesPath = .github/styles
|
StylesPath = .github/styles
|
||||||
|
Vocab = DocsGPT
|
||||||
|
|
||||||
[*.{md,mdx}]
|
[*.{md,mdx}]
|
||||||
BasedOnStyles = DocsGPT
|
BasedOnStyles = DocsGPT
|
||||||
|
|
||||||
|
|||||||
18
SECURITY.md
18
SECURITY.md
@@ -2,13 +2,21 @@
|
|||||||
|
|
||||||
## Supported Versions
|
## Supported Versions
|
||||||
|
|
||||||
Supported Versions:
|
Security patches target the latest release and the `main` branch. We recommend always running the most recent version.
|
||||||
|
|
||||||
Currently, we support security patches by committing changes and bumping the version published on Github.
|
|
||||||
|
|
||||||
## Reporting a Vulnerability
|
## Reporting a Vulnerability
|
||||||
|
|
||||||
Found a vulnerability? Please email us:
|
Preferred method: use GitHub's private vulnerability reporting flow:
|
||||||
|
https://github.com/arc53/DocsGPT/security
|
||||||
|
|
||||||
security@arc53.com
|
Then click **Report a vulnerability**.
|
||||||
|
|
||||||
|
|
||||||
|
Alternatively, email us at: security@arc53.com
|
||||||
|
|
||||||
|
We aim to acknowledge reports within 48 hours.
|
||||||
|
|
||||||
|
## Incident Handling
|
||||||
|
|
||||||
|
For the public incident response process, see [`INCIDENT_RESPONSE.md`](./.github/INCIDENT_RESPONSE.md). If you believe an active exploit is occurring, include **URGENT** in your report subject line.
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,8 @@
|
|||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Dict, Generator, List, Optional
|
from typing import Any, Dict, Generator, List, Optional
|
||||||
|
|
||||||
from application.agents.tool_executor import ToolExecutor
|
from application.agents.tool_executor import ToolExecutor
|
||||||
from application.core.json_schema_utils import (
|
from application.core.json_schema_utils import (
|
||||||
@@ -9,6 +10,7 @@ from application.core.json_schema_utils import (
|
|||||||
normalize_json_schema_payload,
|
normalize_json_schema_payload,
|
||||||
)
|
)
|
||||||
from application.core.settings import settings
|
from application.core.settings import settings
|
||||||
|
from application.llm.handlers.base import ToolCall
|
||||||
from application.llm.handlers.handler_creator import LLMHandlerCreator
|
from application.llm.handlers.handler_creator import LLMHandlerCreator
|
||||||
from application.llm.llm_creator import LLMCreator
|
from application.llm.llm_creator import LLMCreator
|
||||||
from application.logging import build_stack_data, log_activity, LogContext
|
from application.logging import build_stack_data, log_activity, LogContext
|
||||||
@@ -113,6 +115,153 @@ class BaseAgent(ABC):
|
|||||||
) -> Generator[Dict, None, None]:
|
) -> Generator[Dict, None, None]:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def gen_continuation(
|
||||||
|
self,
|
||||||
|
messages: List[Dict],
|
||||||
|
tools_dict: Dict,
|
||||||
|
pending_tool_calls: List[Dict],
|
||||||
|
tool_actions: List[Dict],
|
||||||
|
) -> Generator[Dict, None, None]:
|
||||||
|
"""Resume generation after tool actions are resolved.
|
||||||
|
|
||||||
|
Processes the client-provided *tool_actions* (approvals, denials,
|
||||||
|
or client-side results), appends the resulting messages, then
|
||||||
|
hands back to the LLM to continue the conversation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: The saved messages array from the pause point.
|
||||||
|
tools_dict: The saved tools dictionary.
|
||||||
|
pending_tool_calls: The pending tool call descriptors from the pause.
|
||||||
|
tool_actions: Client-provided actions resolving the pending calls.
|
||||||
|
"""
|
||||||
|
self._prepare_tools(tools_dict)
|
||||||
|
|
||||||
|
actions_by_id = {a["call_id"]: a for a in tool_actions}
|
||||||
|
|
||||||
|
# Build a single assistant message containing all tool calls so
|
||||||
|
# the message history matches the format LLM providers expect
|
||||||
|
# (one assistant message with N tool_calls, followed by N tool results).
|
||||||
|
tc_objects: List[Dict[str, Any]] = []
|
||||||
|
for pending in pending_tool_calls:
|
||||||
|
call_id = pending["call_id"]
|
||||||
|
args = pending["arguments"]
|
||||||
|
args_str = (
|
||||||
|
json.dumps(args) if isinstance(args, dict) else (args or "{}")
|
||||||
|
)
|
||||||
|
tc_obj: Dict[str, Any] = {
|
||||||
|
"id": call_id,
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": pending["name"],
|
||||||
|
"arguments": args_str,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if pending.get("thought_signature"):
|
||||||
|
tc_obj["thought_signature"] = pending["thought_signature"]
|
||||||
|
tc_objects.append(tc_obj)
|
||||||
|
|
||||||
|
messages.append({
|
||||||
|
"role": "assistant",
|
||||||
|
"content": None,
|
||||||
|
"tool_calls": tc_objects,
|
||||||
|
})
|
||||||
|
|
||||||
|
# Now process each pending call and append tool result messages
|
||||||
|
for pending in pending_tool_calls:
|
||||||
|
call_id = pending["call_id"]
|
||||||
|
args = pending["arguments"]
|
||||||
|
action = actions_by_id.get(call_id)
|
||||||
|
if not action:
|
||||||
|
action = {
|
||||||
|
"call_id": call_id,
|
||||||
|
"decision": "denied",
|
||||||
|
"comment": "No response provided",
|
||||||
|
}
|
||||||
|
|
||||||
|
if action.get("decision") == "approved":
|
||||||
|
# Execute the tool server-side
|
||||||
|
tc = ToolCall(
|
||||||
|
id=call_id,
|
||||||
|
name=pending["name"],
|
||||||
|
arguments=(
|
||||||
|
json.dumps(args) if isinstance(args, dict) else args
|
||||||
|
),
|
||||||
|
)
|
||||||
|
tool_gen = self._execute_tool_action(tools_dict, tc)
|
||||||
|
tool_response = None
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
event = next(tool_gen)
|
||||||
|
yield event
|
||||||
|
except StopIteration as e:
|
||||||
|
tool_response, _ = e.value
|
||||||
|
break
|
||||||
|
messages.append(
|
||||||
|
self.llm_handler.create_tool_message(tc, tool_response)
|
||||||
|
)
|
||||||
|
|
||||||
|
elif action.get("decision") == "denied":
|
||||||
|
comment = action.get("comment", "")
|
||||||
|
denial = (
|
||||||
|
f"Tool execution denied by user. Reason: {comment}"
|
||||||
|
if comment
|
||||||
|
else "Tool execution denied by user."
|
||||||
|
)
|
||||||
|
tc = ToolCall(
|
||||||
|
id=call_id, name=pending["name"], arguments=args
|
||||||
|
)
|
||||||
|
messages.append(
|
||||||
|
self.llm_handler.create_tool_message(tc, denial)
|
||||||
|
)
|
||||||
|
yield {
|
||||||
|
"type": "tool_call",
|
||||||
|
"data": {
|
||||||
|
"tool_name": pending.get("tool_name", "unknown"),
|
||||||
|
"call_id": call_id,
|
||||||
|
"action_name": pending.get("llm_name", pending["name"]),
|
||||||
|
"arguments": args,
|
||||||
|
"status": "denied",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
elif "result" in action:
|
||||||
|
result = action["result"]
|
||||||
|
result_str = (
|
||||||
|
json.dumps(result)
|
||||||
|
if not isinstance(result, str)
|
||||||
|
else result
|
||||||
|
)
|
||||||
|
tc = ToolCall(
|
||||||
|
id=call_id, name=pending["name"], arguments=args
|
||||||
|
)
|
||||||
|
messages.append(
|
||||||
|
self.llm_handler.create_tool_message(tc, result_str)
|
||||||
|
)
|
||||||
|
yield {
|
||||||
|
"type": "tool_call",
|
||||||
|
"data": {
|
||||||
|
"tool_name": pending.get("tool_name", "unknown"),
|
||||||
|
"call_id": call_id,
|
||||||
|
"action_name": pending.get("llm_name", pending["name"]),
|
||||||
|
"arguments": args,
|
||||||
|
"result": (
|
||||||
|
result_str[:50] + "..."
|
||||||
|
if len(result_str) > 50
|
||||||
|
else result_str
|
||||||
|
),
|
||||||
|
"status": "completed",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
# Resume the LLM loop with the updated messages
|
||||||
|
llm_response = self._llm_gen(messages)
|
||||||
|
yield from self._handle_response(
|
||||||
|
llm_response, tools_dict, messages, None
|
||||||
|
)
|
||||||
|
|
||||||
|
yield {"sources": self.retrieved_docs}
|
||||||
|
yield {"tool_calls": self._get_truncated_tool_calls()}
|
||||||
|
|
||||||
# ---- Tool delegation (thin wrappers around ToolExecutor) ----
|
# ---- Tool delegation (thin wrappers around ToolExecutor) ----
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -267,28 +416,35 @@ class BaseAgent(ABC):
|
|||||||
if "tool_calls" in i:
|
if "tool_calls" in i:
|
||||||
for tool_call in i["tool_calls"]:
|
for tool_call in i["tool_calls"]:
|
||||||
call_id = tool_call.get("call_id") or str(uuid.uuid4())
|
call_id = tool_call.get("call_id") or str(uuid.uuid4())
|
||||||
|
args = tool_call.get("arguments")
|
||||||
function_call_dict = {
|
args_str = (
|
||||||
"function_call": {
|
json.dumps(args)
|
||||||
"name": tool_call.get("action_name"),
|
if isinstance(args, dict)
|
||||||
"args": tool_call.get("arguments"),
|
else (args or "{}")
|
||||||
"call_id": call_id,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
function_response_dict = {
|
|
||||||
"function_response": {
|
|
||||||
"name": tool_call.get("action_name"),
|
|
||||||
"response": {"result": tool_call.get("result")},
|
|
||||||
"call_id": call_id,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
messages.append(
|
|
||||||
{"role": "assistant", "content": [function_call_dict]}
|
|
||||||
)
|
)
|
||||||
messages.append(
|
messages.append({
|
||||||
{"role": "tool", "content": [function_response_dict]}
|
"role": "assistant",
|
||||||
|
"content": None,
|
||||||
|
"tool_calls": [{
|
||||||
|
"id": call_id,
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": tool_call.get("action_name", ""),
|
||||||
|
"arguments": args_str,
|
||||||
|
},
|
||||||
|
}],
|
||||||
|
})
|
||||||
|
result = tool_call.get("result")
|
||||||
|
result_str = (
|
||||||
|
json.dumps(result)
|
||||||
|
if not isinstance(result, str)
|
||||||
|
else (result or "")
|
||||||
)
|
)
|
||||||
|
messages.append({
|
||||||
|
"role": "tool",
|
||||||
|
"tool_call_id": call_id,
|
||||||
|
"content": result_str,
|
||||||
|
})
|
||||||
messages.append({"role": "user", "content": query})
|
messages.append({"role": "user", "content": query})
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
|
|||||||
@@ -593,16 +593,22 @@ class ResearchAgent(BaseAgent):
|
|||||||
)
|
)
|
||||||
result = result_str
|
result = result_str
|
||||||
|
|
||||||
function_call_content = {
|
import json as _json
|
||||||
"function_call": {
|
|
||||||
"name": call.name,
|
args_str = (
|
||||||
"args": call.arguments,
|
_json.dumps(call.arguments)
|
||||||
"call_id": call_id,
|
if isinstance(call.arguments, dict)
|
||||||
}
|
else call.arguments
|
||||||
}
|
|
||||||
messages.append(
|
|
||||||
{"role": "assistant", "content": [function_call_content]}
|
|
||||||
)
|
)
|
||||||
|
messages.append({
|
||||||
|
"role": "assistant",
|
||||||
|
"content": None,
|
||||||
|
"tool_calls": [{
|
||||||
|
"id": call_id,
|
||||||
|
"type": "function",
|
||||||
|
"function": {"name": call.name, "arguments": args_str},
|
||||||
|
}],
|
||||||
|
})
|
||||||
tool_message = self.llm_handler.create_tool_message(call, result)
|
tool_message = self.llm_handler.create_tool_message(call, result)
|
||||||
messages.append(tool_message)
|
messages.append(tool_message)
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
import logging
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Dict, List, Optional
|
from collections import Counter
|
||||||
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
from bson.objectid import ObjectId
|
from bson.objectid import ObjectId
|
||||||
|
|
||||||
@@ -31,12 +32,23 @@ class ToolExecutor:
|
|||||||
self.tool_calls: List[Dict] = []
|
self.tool_calls: List[Dict] = []
|
||||||
self._loaded_tools: Dict[str, object] = {}
|
self._loaded_tools: Dict[str, object] = {}
|
||||||
self.conversation_id: Optional[str] = None
|
self.conversation_id: Optional[str] = None
|
||||||
|
self.client_tools: Optional[List[Dict]] = None
|
||||||
|
self._name_to_tool: Dict[str, Tuple[str, str]] = {}
|
||||||
|
self._tool_to_name: Dict[Tuple[str, str], str] = {}
|
||||||
|
|
||||||
def get_tools(self) -> Dict[str, Dict]:
|
def get_tools(self) -> Dict[str, Dict]:
|
||||||
"""Load tool configs from DB based on user context."""
|
"""Load tool configs from DB based on user context.
|
||||||
|
|
||||||
|
If *client_tools* have been set on this executor, they are
|
||||||
|
automatically merged into the returned dict.
|
||||||
|
"""
|
||||||
if self.user_api_key:
|
if self.user_api_key:
|
||||||
return self._get_tools_by_api_key(self.user_api_key)
|
tools = self._get_tools_by_api_key(self.user_api_key)
|
||||||
return self._get_user_tools(self.user or "local")
|
else:
|
||||||
|
tools = self._get_user_tools(self.user or "local")
|
||||||
|
if self.client_tools:
|
||||||
|
self.merge_client_tools(tools, self.client_tools)
|
||||||
|
return tools
|
||||||
|
|
||||||
def _get_tools_by_api_key(self, api_key: str) -> Dict[str, Dict]:
|
def _get_tools_by_api_key(self, api_key: str) -> Dict[str, Dict]:
|
||||||
mongo = MongoDB.get_client()
|
mongo = MongoDB.get_client()
|
||||||
@@ -65,29 +77,123 @@ class ToolExecutor:
|
|||||||
user_tools = list(user_tools)
|
user_tools = list(user_tools)
|
||||||
return {str(i): tool for i, tool in enumerate(user_tools)}
|
return {str(i): tool for i, tool in enumerate(user_tools)}
|
||||||
|
|
||||||
def prepare_tools_for_llm(self, tools_dict: Dict) -> List[Dict]:
|
def merge_client_tools(
|
||||||
"""Convert tool configs to LLM function schemas."""
|
self, tools_dict: Dict, client_tools: List[Dict]
|
||||||
return [
|
) -> Dict:
|
||||||
{
|
"""Merge client-provided tool definitions into tools_dict.
|
||||||
"type": "function",
|
|
||||||
"function": {
|
Client tools use the standard function-calling format::
|
||||||
"name": f"{action['name']}_{tool_id}",
|
|
||||||
"description": action["description"],
|
[{"type": "function", "function": {"name": "get_weather",
|
||||||
"parameters": self._build_tool_parameters(action),
|
"description": "...", "parameters": {...}}}]
|
||||||
},
|
|
||||||
|
They are stored in *tools_dict* with ``client_side: True`` so that
|
||||||
|
:meth:`check_pause` returns a pause signal instead of trying to
|
||||||
|
execute them server-side.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tools_dict: The mutable server tools dict (will be modified in place).
|
||||||
|
client_tools: List of tool definitions in function-calling format.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The updated *tools_dict* (same reference, for convenience).
|
||||||
|
"""
|
||||||
|
for i, ct in enumerate(client_tools):
|
||||||
|
func = ct.get("function", ct) # tolerate bare {"name":..} too
|
||||||
|
name = func.get("name", f"clienttool{i}")
|
||||||
|
tool_id = f"ct{i}"
|
||||||
|
|
||||||
|
tools_dict[tool_id] = {
|
||||||
|
"name": name,
|
||||||
|
"client_side": True,
|
||||||
|
"actions": [
|
||||||
|
{
|
||||||
|
"name": name,
|
||||||
|
"description": func.get("description", ""),
|
||||||
|
"active": True,
|
||||||
|
"parameters": func.get("parameters", {}),
|
||||||
|
}
|
||||||
|
],
|
||||||
}
|
}
|
||||||
for tool_id, tool in tools_dict.items()
|
return tools_dict
|
||||||
if (
|
|
||||||
(tool["name"] == "api_tool" and "actions" in tool.get("config", {}))
|
def prepare_tools_for_llm(self, tools_dict: Dict) -> List[Dict]:
|
||||||
or (tool["name"] != "api_tool" and "actions" in tool)
|
"""Convert tool configs to LLM function schemas.
|
||||||
)
|
|
||||||
for action in (
|
Action names are kept clean for the LLM:
|
||||||
|
- Unique action names appear as-is (e.g. ``get_weather``).
|
||||||
|
- Duplicate action names get numbered suffixes (e.g. ``search_1``,
|
||||||
|
``search_2``).
|
||||||
|
|
||||||
|
A reverse mapping is stored in ``_name_to_tool`` so that tool calls
|
||||||
|
can be routed back to the correct ``(tool_id, action_name)`` without
|
||||||
|
brittle string splitting.
|
||||||
|
"""
|
||||||
|
# Pass 1: collect entries and count action name occurrences
|
||||||
|
entries: List[Tuple[str, str, Dict, bool]] = [] # (tool_id, action_name, action, is_client)
|
||||||
|
name_counts: Counter = Counter()
|
||||||
|
|
||||||
|
for tool_id, tool in tools_dict.items():
|
||||||
|
is_api = tool["name"] == "api_tool"
|
||||||
|
is_client = tool.get("client_side", False)
|
||||||
|
|
||||||
|
if is_api and "actions" not in tool.get("config", {}):
|
||||||
|
continue
|
||||||
|
if not is_api and "actions" not in tool:
|
||||||
|
continue
|
||||||
|
|
||||||
|
actions = (
|
||||||
tool["config"]["actions"].values()
|
tool["config"]["actions"].values()
|
||||||
if tool["name"] == "api_tool"
|
if is_api
|
||||||
else tool["actions"]
|
else tool["actions"]
|
||||||
)
|
)
|
||||||
if action.get("active", True)
|
|
||||||
]
|
for action in actions:
|
||||||
|
if not action.get("active", True):
|
||||||
|
continue
|
||||||
|
entries.append((tool_id, action["name"], action, is_client))
|
||||||
|
name_counts[action["name"]] += 1
|
||||||
|
|
||||||
|
# Pass 2: assign LLM-visible names and build mappings
|
||||||
|
self._name_to_tool = {}
|
||||||
|
self._tool_to_name = {}
|
||||||
|
collision_counters: Dict[str, int] = {}
|
||||||
|
all_llm_names: set = set()
|
||||||
|
|
||||||
|
result = []
|
||||||
|
for tool_id, action_name, action, is_client in entries:
|
||||||
|
if name_counts[action_name] == 1:
|
||||||
|
llm_name = action_name
|
||||||
|
else:
|
||||||
|
counter = collision_counters.get(action_name, 1)
|
||||||
|
candidate = f"{action_name}_{counter}"
|
||||||
|
# Skip if candidate collides with a unique action name
|
||||||
|
while candidate in all_llm_names or (
|
||||||
|
candidate in name_counts and name_counts[candidate] == 1
|
||||||
|
):
|
||||||
|
counter += 1
|
||||||
|
candidate = f"{action_name}_{counter}"
|
||||||
|
collision_counters[action_name] = counter + 1
|
||||||
|
llm_name = candidate
|
||||||
|
|
||||||
|
all_llm_names.add(llm_name)
|
||||||
|
self._name_to_tool[llm_name] = (tool_id, action_name)
|
||||||
|
self._tool_to_name[(tool_id, action_name)] = llm_name
|
||||||
|
|
||||||
|
if is_client:
|
||||||
|
params = action.get("parameters", {})
|
||||||
|
else:
|
||||||
|
params = self._build_tool_parameters(action)
|
||||||
|
|
||||||
|
result.append({
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": llm_name,
|
||||||
|
"description": action.get("description", ""),
|
||||||
|
"parameters": params,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
return result
|
||||||
|
|
||||||
def _build_tool_parameters(self, action: Dict) -> Dict:
|
def _build_tool_parameters(self, action: Dict) -> Dict:
|
||||||
params = {"type": "object", "properties": {}, "required": []}
|
params = {"type": "object", "properties": {}, "required": []}
|
||||||
@@ -104,23 +210,81 @@ class ToolExecutor:
|
|||||||
params["required"].append(k)
|
params["required"].append(k)
|
||||||
return params
|
return params
|
||||||
|
|
||||||
|
def check_pause(
|
||||||
|
self, tools_dict: Dict, call, llm_class_name: str
|
||||||
|
) -> Optional[Dict]:
|
||||||
|
"""Check if a tool call requires pausing for approval or client execution.
|
||||||
|
|
||||||
|
Returns a dict describing the pending action if pause is needed, None otherwise.
|
||||||
|
"""
|
||||||
|
parser = ToolActionParser(llm_class_name, name_mapping=self._name_to_tool)
|
||||||
|
tool_id, action_name, call_args = parser.parse_args(call)
|
||||||
|
call_id = getattr(call, "id", None) or str(uuid.uuid4())
|
||||||
|
llm_name = getattr(call, "name", "")
|
||||||
|
|
||||||
|
if tool_id is None or action_name is None or tool_id not in tools_dict:
|
||||||
|
return None # Will be handled as error by execute()
|
||||||
|
|
||||||
|
tool_data = tools_dict[tool_id]
|
||||||
|
|
||||||
|
# Client-side tools
|
||||||
|
if tool_data.get("client_side"):
|
||||||
|
return {
|
||||||
|
"call_id": call_id,
|
||||||
|
"name": llm_name,
|
||||||
|
"tool_name": tool_data.get("name", "unknown"),
|
||||||
|
"tool_id": tool_id,
|
||||||
|
"action_name": action_name,
|
||||||
|
"llm_name": llm_name,
|
||||||
|
"arguments": call_args if isinstance(call_args, dict) else {},
|
||||||
|
"pause_type": "requires_client_execution",
|
||||||
|
"thought_signature": getattr(call, "thought_signature", None),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Approval required
|
||||||
|
if tool_data["name"] == "api_tool":
|
||||||
|
action_data = tool_data.get("config", {}).get("actions", {}).get(
|
||||||
|
action_name, {}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
action_data = next(
|
||||||
|
(a for a in tool_data.get("actions", []) if a["name"] == action_name),
|
||||||
|
{},
|
||||||
|
)
|
||||||
|
|
||||||
|
if action_data.get("require_approval"):
|
||||||
|
return {
|
||||||
|
"call_id": call_id,
|
||||||
|
"name": llm_name,
|
||||||
|
"tool_name": tool_data.get("name", "unknown"),
|
||||||
|
"tool_id": tool_id,
|
||||||
|
"action_name": action_name,
|
||||||
|
"llm_name": llm_name,
|
||||||
|
"arguments": call_args if isinstance(call_args, dict) else {},
|
||||||
|
"pause_type": "awaiting_approval",
|
||||||
|
"thought_signature": getattr(call, "thought_signature", None),
|
||||||
|
}
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
def execute(self, tools_dict: Dict, call, llm_class_name: str):
|
def execute(self, tools_dict: Dict, call, llm_class_name: str):
|
||||||
"""Execute a tool call. Yields status events, returns (result, call_id)."""
|
"""Execute a tool call. Yields status events, returns (result, call_id)."""
|
||||||
parser = ToolActionParser(llm_class_name)
|
parser = ToolActionParser(llm_class_name, name_mapping=self._name_to_tool)
|
||||||
tool_id, action_name, call_args = parser.parse_args(call)
|
tool_id, action_name, call_args = parser.parse_args(call)
|
||||||
|
llm_name = getattr(call, "name", "unknown")
|
||||||
|
|
||||||
call_id = getattr(call, "id", None) or str(uuid.uuid4())
|
call_id = getattr(call, "id", None) or str(uuid.uuid4())
|
||||||
|
|
||||||
if tool_id is None or action_name is None:
|
if tool_id is None or action_name is None:
|
||||||
error_message = f"Error: Failed to parse LLM tool call. Tool name: {getattr(call, 'name', 'unknown')}"
|
error_message = f"Error: Failed to parse LLM tool call. Tool name: {llm_name}"
|
||||||
logger.error(error_message)
|
logger.error(error_message)
|
||||||
|
|
||||||
tool_call_data = {
|
tool_call_data = {
|
||||||
"tool_name": "unknown",
|
"tool_name": "unknown",
|
||||||
"call_id": call_id,
|
"call_id": call_id,
|
||||||
"action_name": getattr(call, "name", "unknown"),
|
"action_name": llm_name,
|
||||||
"arguments": call_args or {},
|
"arguments": call_args or {},
|
||||||
"result": f"Failed to parse tool call. Invalid tool name format: {getattr(call, 'name', 'unknown')}",
|
"result": f"Failed to parse tool call. Invalid tool name format: {llm_name}",
|
||||||
}
|
}
|
||||||
yield {"type": "tool_call", "data": {**tool_call_data, "status": "error"}}
|
yield {"type": "tool_call", "data": {**tool_call_data, "status": "error"}}
|
||||||
self.tool_calls.append(tool_call_data)
|
self.tool_calls.append(tool_call_data)
|
||||||
@@ -133,7 +297,7 @@ class ToolExecutor:
|
|||||||
tool_call_data = {
|
tool_call_data = {
|
||||||
"tool_name": "unknown",
|
"tool_name": "unknown",
|
||||||
"call_id": call_id,
|
"call_id": call_id,
|
||||||
"action_name": f"{action_name}_{tool_id}",
|
"action_name": llm_name,
|
||||||
"arguments": call_args,
|
"arguments": call_args,
|
||||||
"result": f"Tool with ID {tool_id} not found. Available tools: {list(tools_dict.keys())}",
|
"result": f"Tool with ID {tool_id} not found. Available tools: {list(tools_dict.keys())}",
|
||||||
}
|
}
|
||||||
@@ -144,7 +308,7 @@ class ToolExecutor:
|
|||||||
tool_call_data = {
|
tool_call_data = {
|
||||||
"tool_name": tools_dict[tool_id]["name"],
|
"tool_name": tools_dict[tool_id]["name"],
|
||||||
"call_id": call_id,
|
"call_id": call_id,
|
||||||
"action_name": f"{action_name}_{tool_id}",
|
"action_name": llm_name,
|
||||||
"arguments": call_args,
|
"arguments": call_args,
|
||||||
}
|
}
|
||||||
yield {"type": "tool_call", "data": {**tool_call_data, "status": "pending"}}
|
yield {"type": "tool_call", "data": {**tool_call_data, "status": "pending"}}
|
||||||
|
|||||||
@@ -2,6 +2,8 @@ from abc import ABC, abstractmethod
|
|||||||
|
|
||||||
|
|
||||||
class Tool(ABC):
|
class Tool(ABC):
|
||||||
|
internal: bool = False
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def execute_action(self, action_name: str, **kwargs):
|
def execute_action(self, action_name: str, **kwargs):
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -73,7 +73,7 @@ class BraveSearchTool(Tool):
|
|||||||
"X-Subscription-Token": self.token,
|
"X-Subscription-Token": self.token,
|
||||||
}
|
}
|
||||||
|
|
||||||
response = requests.get(url, params=params, headers=headers)
|
response = requests.get(url, params=params, headers=headers, timeout=100)
|
||||||
|
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
return {
|
return {
|
||||||
@@ -118,7 +118,7 @@ class BraveSearchTool(Tool):
|
|||||||
"X-Subscription-Token": self.token,
|
"X-Subscription-Token": self.token,
|
||||||
}
|
}
|
||||||
|
|
||||||
response = requests.get(url, params=params, headers=headers)
|
response = requests.get(url, params=params, headers=headers, timeout=100)
|
||||||
|
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
return {
|
return {
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ class CryptoPriceTool(Tool):
|
|||||||
returns price in USD.
|
returns price in USD.
|
||||||
"""
|
"""
|
||||||
url = f"https://min-api.cryptocompare.com/data/price?fsym={symbol.upper()}&tsyms={currency.upper()}"
|
url = f"https://min-api.cryptocompare.com/data/price?fsym={symbol.upper()}&tsyms={currency.upper()}"
|
||||||
response = requests.get(url)
|
response = requests.get(url, timeout=100)
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
data = response.json()
|
data = response.json()
|
||||||
if currency.upper() in data:
|
if currency.upper() in data:
|
||||||
|
|||||||
@@ -20,6 +20,8 @@ class InternalSearchTool(Tool):
|
|||||||
- list_files action: browse the file/folder structure
|
- list_files action: browse the file/folder structure
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
internal = True
|
||||||
|
|
||||||
def __init__(self, config: Dict):
|
def __init__(self, config: Dict):
|
||||||
self.config = config
|
self.config = config
|
||||||
self.retrieved_docs: List[Dict] = []
|
self.retrieved_docs: List[Dict] = []
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ from application.api.user.tasks import mcp_oauth_status_task, mcp_oauth_task
|
|||||||
from application.cache import get_redis_instance
|
from application.cache import get_redis_instance
|
||||||
from application.core.mongo_db import MongoDB
|
from application.core.mongo_db import MongoDB
|
||||||
from application.core.settings import settings
|
from application.core.settings import settings
|
||||||
|
from application.core.url_validation import SSRFError, validate_url
|
||||||
from application.security.encryption import decrypt_credentials
|
from application.security.encryption import decrypt_credentials
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -61,7 +62,8 @@ class MCPTool(Tool):
|
|||||||
"""
|
"""
|
||||||
self.config = config
|
self.config = config
|
||||||
self.user_id = user_id
|
self.user_id = user_id
|
||||||
self.server_url = config.get("server_url", "")
|
raw_url = config.get("server_url", "")
|
||||||
|
self.server_url = self._validate_server_url(raw_url) if raw_url else ""
|
||||||
self.transport_type = config.get("transport_type", "auto")
|
self.transport_type = config.get("transport_type", "auto")
|
||||||
self.auth_type = config.get("auth_type", "none")
|
self.auth_type = config.get("auth_type", "none")
|
||||||
self.timeout = config.get("timeout", 30)
|
self.timeout = config.get("timeout", 30)
|
||||||
@@ -87,6 +89,18 @@ class MCPTool(Tool):
|
|||||||
if self.server_url and self.auth_type != "oauth":
|
if self.server_url and self.auth_type != "oauth":
|
||||||
self._setup_client()
|
self._setup_client()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _validate_server_url(server_url: str) -> str:
|
||||||
|
"""Validate server_url to prevent SSRF to internal networks.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the URL points to a private/internal address.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return validate_url(server_url)
|
||||||
|
except SSRFError as exc:
|
||||||
|
raise ValueError(f"Invalid MCP server URL: {exc}") from exc
|
||||||
|
|
||||||
def _resolve_redirect_uri(self, configured_redirect_uri: Optional[str]) -> str:
|
def _resolve_redirect_uri(self, configured_redirect_uri: Optional[str]) -> str:
|
||||||
if configured_redirect_uri:
|
if configured_redirect_uri:
|
||||||
return configured_redirect_uri.rstrip("/")
|
return configured_redirect_uri.rstrip("/")
|
||||||
@@ -108,8 +122,9 @@ class MCPTool(Tool):
|
|||||||
auth_key = ""
|
auth_key = ""
|
||||||
if self.auth_type == "oauth":
|
if self.auth_type == "oauth":
|
||||||
scopes_str = ",".join(self.oauth_scopes) if self.oauth_scopes else "none"
|
scopes_str = ",".join(self.oauth_scopes) if self.oauth_scopes else "none"
|
||||||
|
oauth_identity = self.user_id or self.oauth_task_id or "anonymous"
|
||||||
auth_key = (
|
auth_key = (
|
||||||
f"oauth:{self.oauth_client_name}:{scopes_str}:{self.redirect_uri}"
|
f"oauth:{oauth_identity}:{self.oauth_client_name}:{scopes_str}:{self.redirect_uri}"
|
||||||
)
|
)
|
||||||
elif self.auth_type in ["bearer"]:
|
elif self.auth_type in ["bearer"]:
|
||||||
token = self.auth_credentials.get(
|
token = self.auth_credentials.get(
|
||||||
|
|||||||
@@ -71,7 +71,7 @@ class NtfyTool(Tool):
|
|||||||
if self.token:
|
if self.token:
|
||||||
headers["Authorization"] = f"Basic {self.token}"
|
headers["Authorization"] = f"Basic {self.token}"
|
||||||
data = message.encode("utf-8")
|
data = message.encode("utf-8")
|
||||||
response = requests.post(url, headers=headers, data=data)
|
response = requests.post(url, headers=headers, data=data, timeout=100)
|
||||||
return {"status_code": response.status_code, "message": "Message sent"}
|
return {"status_code": response.status_code, "message": "Message sent"}
|
||||||
|
|
||||||
def get_actions_metadata(self):
|
def get_actions_metadata(self):
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
import logging
|
import logging
|
||||||
|
|
||||||
import psycopg2
|
import psycopg
|
||||||
|
|
||||||
from application.agents.tools.base import Tool
|
from application.agents.tools.base import Tool
|
||||||
|
|
||||||
@@ -33,7 +33,7 @@ class PostgresTool(Tool):
|
|||||||
"""
|
"""
|
||||||
conn = None
|
conn = None
|
||||||
try:
|
try:
|
||||||
conn = psycopg2.connect(self.connection_string)
|
conn = psycopg.connect(self.connection_string)
|
||||||
cur = conn.cursor()
|
cur = conn.cursor()
|
||||||
cur.execute(sql_query)
|
cur.execute(sql_query)
|
||||||
conn.commit()
|
conn.commit()
|
||||||
@@ -60,7 +60,7 @@ class PostgresTool(Tool):
|
|||||||
"response_data": response_data,
|
"response_data": response_data,
|
||||||
}
|
}
|
||||||
|
|
||||||
except psycopg2.Error as e:
|
except psycopg.Error as e:
|
||||||
error_message = f"Database error: {e}"
|
error_message = f"Database error: {e}"
|
||||||
logger.error("PostgreSQL execute_sql error: %s", e)
|
logger.error("PostgreSQL execute_sql error: %s", e)
|
||||||
return {
|
return {
|
||||||
@@ -78,7 +78,7 @@ class PostgresTool(Tool):
|
|||||||
"""
|
"""
|
||||||
conn = None
|
conn = None
|
||||||
try:
|
try:
|
||||||
conn = psycopg2.connect(self.connection_string)
|
conn = psycopg.connect(self.connection_string)
|
||||||
cur = conn.cursor()
|
cur = conn.cursor()
|
||||||
|
|
||||||
cur.execute(
|
cur.execute(
|
||||||
@@ -120,7 +120,7 @@ class PostgresTool(Tool):
|
|||||||
"schema": schema_data,
|
"schema": schema_data,
|
||||||
}
|
}
|
||||||
|
|
||||||
except psycopg2.Error as e:
|
except psycopg.Error as e:
|
||||||
error_message = f"Database error: {e}"
|
error_message = f"Database error: {e}"
|
||||||
logger.error("PostgreSQL get_schema error: %s", e)
|
logger.error("PostgreSQL get_schema error: %s", e)
|
||||||
return {
|
return {
|
||||||
|
|||||||
@@ -31,14 +31,14 @@ class TelegramTool(Tool):
|
|||||||
logger.debug("Sending Telegram message to chat_id=%s", chat_id)
|
logger.debug("Sending Telegram message to chat_id=%s", chat_id)
|
||||||
url = f"https://api.telegram.org/bot{self.token}/sendMessage"
|
url = f"https://api.telegram.org/bot{self.token}/sendMessage"
|
||||||
payload = {"chat_id": chat_id, "text": text}
|
payload = {"chat_id": chat_id, "text": text}
|
||||||
response = requests.post(url, data=payload)
|
response = requests.post(url, data=payload, timeout=100)
|
||||||
return {"status_code": response.status_code, "message": "Message sent"}
|
return {"status_code": response.status_code, "message": "Message sent"}
|
||||||
|
|
||||||
def _send_image(self, image_url, chat_id):
|
def _send_image(self, image_url, chat_id):
|
||||||
logger.debug("Sending Telegram image to chat_id=%s", chat_id)
|
logger.debug("Sending Telegram image to chat_id=%s", chat_id)
|
||||||
url = f"https://api.telegram.org/bot{self.token}/sendPhoto"
|
url = f"https://api.telegram.org/bot{self.token}/sendPhoto"
|
||||||
payload = {"chat_id": chat_id, "photo": image_url}
|
payload = {"chat_id": chat_id, "photo": image_url}
|
||||||
response = requests.post(url, data=payload)
|
response = requests.post(url, data=payload, timeout=100)
|
||||||
return {"status_code": response.status_code, "message": "Image sent"}
|
return {"status_code": response.status_code, "message": "Image sent"}
|
||||||
|
|
||||||
def get_actions_metadata(self):
|
def get_actions_metadata(self):
|
||||||
|
|||||||
@@ -36,6 +36,8 @@ class ThinkTool(Tool):
|
|||||||
The reasoning content is captured in tool_call data for transparency.
|
The reasoning content is captured in tool_call data for transparency.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
internal = True
|
||||||
|
|
||||||
def __init__(self, config=None):
|
def __init__(self, config=None):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|||||||
@@ -5,8 +5,9 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class ToolActionParser:
|
class ToolActionParser:
|
||||||
def __init__(self, llm_type):
|
def __init__(self, llm_type, name_mapping=None):
|
||||||
self.llm_type = llm_type
|
self.llm_type = llm_type
|
||||||
|
self.name_mapping = name_mapping
|
||||||
self.parsers = {
|
self.parsers = {
|
||||||
"OpenAILLM": self._parse_openai_llm,
|
"OpenAILLM": self._parse_openai_llm,
|
||||||
"GoogleLLM": self._parse_google_llm,
|
"GoogleLLM": self._parse_google_llm,
|
||||||
@@ -16,22 +17,33 @@ class ToolActionParser:
|
|||||||
parser = self.parsers.get(self.llm_type, self._parse_openai_llm)
|
parser = self.parsers.get(self.llm_type, self._parse_openai_llm)
|
||||||
return parser(call)
|
return parser(call)
|
||||||
|
|
||||||
|
def _resolve_via_mapping(self, call_name):
|
||||||
|
"""Look up (tool_id, action_name) from the name mapping if available."""
|
||||||
|
if self.name_mapping and call_name in self.name_mapping:
|
||||||
|
return self.name_mapping[call_name]
|
||||||
|
return None
|
||||||
|
|
||||||
def _parse_openai_llm(self, call):
|
def _parse_openai_llm(self, call):
|
||||||
try:
|
try:
|
||||||
call_args = json.loads(call.arguments)
|
call_args = json.loads(call.arguments)
|
||||||
|
|
||||||
|
resolved = self._resolve_via_mapping(call.name)
|
||||||
|
if resolved:
|
||||||
|
return resolved[0], resolved[1], call_args
|
||||||
|
|
||||||
|
# Fallback: legacy split on "_" for backward compatibility
|
||||||
tool_parts = call.name.split("_")
|
tool_parts = call.name.split("_")
|
||||||
|
|
||||||
# If the tool name doesn't contain an underscore, it's likely a hallucinated tool
|
|
||||||
if len(tool_parts) < 2:
|
if len(tool_parts) < 2:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Invalid tool name format: {call.name}. Expected format: action_name_tool_id"
|
f"Invalid tool name format: {call.name}. "
|
||||||
|
"Could not resolve via mapping or legacy parsing."
|
||||||
)
|
)
|
||||||
return None, None, None
|
return None, None, None
|
||||||
|
|
||||||
tool_id = tool_parts[-1]
|
tool_id = tool_parts[-1]
|
||||||
action_name = "_".join(tool_parts[:-1])
|
action_name = "_".join(tool_parts[:-1])
|
||||||
|
|
||||||
# Validate that tool_id looks like a numerical ID
|
|
||||||
if not tool_id.isdigit():
|
if not tool_id.isdigit():
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Tool ID '{tool_id}' is not numerical. This might be a hallucinated tool call."
|
f"Tool ID '{tool_id}' is not numerical. This might be a hallucinated tool call."
|
||||||
@@ -45,19 +57,24 @@ class ToolActionParser:
|
|||||||
def _parse_google_llm(self, call):
|
def _parse_google_llm(self, call):
|
||||||
try:
|
try:
|
||||||
call_args = call.arguments
|
call_args = call.arguments
|
||||||
|
|
||||||
|
resolved = self._resolve_via_mapping(call.name)
|
||||||
|
if resolved:
|
||||||
|
return resolved[0], resolved[1], call_args
|
||||||
|
|
||||||
|
# Fallback: legacy split on "_" for backward compatibility
|
||||||
tool_parts = call.name.split("_")
|
tool_parts = call.name.split("_")
|
||||||
|
|
||||||
# If the tool name doesn't contain an underscore, it's likely a hallucinated tool
|
|
||||||
if len(tool_parts) < 2:
|
if len(tool_parts) < 2:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Invalid tool name format: {call.name}. Expected format: action_name_tool_id"
|
f"Invalid tool name format: {call.name}. "
|
||||||
|
"Could not resolve via mapping or legacy parsing."
|
||||||
)
|
)
|
||||||
return None, None, None
|
return None, None, None
|
||||||
|
|
||||||
tool_id = tool_parts[-1]
|
tool_id = tool_parts[-1]
|
||||||
action_name = "_".join(tool_parts[:-1])
|
action_name = "_".join(tool_parts[:-1])
|
||||||
|
|
||||||
# Validate that tool_id looks like a numerical ID
|
|
||||||
if not tool_id.isdigit():
|
if not tool_id.isdigit():
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Tool ID '{tool_id}' is not numerical. This might be a hallucinated tool call."
|
f"Tool ID '{tool_id}' is not numerical. This might be a hallucinated tool call."
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ class ToolManager:
|
|||||||
continue
|
continue
|
||||||
module = importlib.import_module(f"application.agents.tools.{name}")
|
module = importlib.import_module(f"application.agents.tools.{name}")
|
||||||
for member_name, obj in inspect.getmembers(module, inspect.isclass):
|
for member_name, obj in inspect.getmembers(module, inspect.isclass):
|
||||||
if issubclass(obj, Tool) and obj is not Tool:
|
if issubclass(obj, Tool) and obj is not Tool and not obj.internal:
|
||||||
tool_config = self.config.get(name, {})
|
tool_config = self.config.get(name, {})
|
||||||
self.tools[name] = obj(tool_config)
|
self.tools[name] = obj(tool_config)
|
||||||
|
|
||||||
|
|||||||
@@ -15,6 +15,9 @@ from application.agents.workflows.workflow_engine import WorkflowEngine
|
|||||||
from application.core.mongo_db import MongoDB
|
from application.core.mongo_db import MongoDB
|
||||||
from application.core.settings import settings
|
from application.core.settings import settings
|
||||||
from application.logging import log_activity, LogContext
|
from application.logging import log_activity, LogContext
|
||||||
|
from application.storage.db.dual_write import dual_write
|
||||||
|
from application.storage.db.repositories.workflow_runs import WorkflowRunsRepository
|
||||||
|
from application.storage.db.repositories.workflows import WorkflowsRepository
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -181,6 +184,9 @@ class WorkflowAgent(BaseAgent):
|
|||||||
def _save_workflow_run(self, query: str) -> None:
|
def _save_workflow_run(self, query: str) -> None:
|
||||||
if not self._engine:
|
if not self._engine:
|
||||||
return
|
return
|
||||||
|
owner_id = self.workflow_owner
|
||||||
|
if not owner_id and isinstance(self.decoded_token, dict):
|
||||||
|
owner_id = self.decoded_token.get("sub")
|
||||||
try:
|
try:
|
||||||
mongo = MongoDB.get_client()
|
mongo = MongoDB.get_client()
|
||||||
db = mongo[settings.MONGO_DB_NAME]
|
db = mongo[settings.MONGO_DB_NAME]
|
||||||
@@ -188,6 +194,7 @@ class WorkflowAgent(BaseAgent):
|
|||||||
|
|
||||||
run = WorkflowRun(
|
run = WorkflowRun(
|
||||||
workflow_id=self.workflow_id or "unknown",
|
workflow_id=self.workflow_id or "unknown",
|
||||||
|
user=owner_id,
|
||||||
status=self._determine_run_status(),
|
status=self._determine_run_status(),
|
||||||
inputs={"query": query},
|
inputs={"query": query},
|
||||||
outputs=self._serialize_state(self._engine.state),
|
outputs=self._serialize_state(self._engine.state),
|
||||||
@@ -196,7 +203,34 @@ class WorkflowAgent(BaseAgent):
|
|||||||
completed_at=datetime.now(timezone.utc),
|
completed_at=datetime.now(timezone.utc),
|
||||||
)
|
)
|
||||||
|
|
||||||
workflow_runs_coll.insert_one(run.to_mongo_doc())
|
result = workflow_runs_coll.insert_one(run.to_mongo_doc())
|
||||||
|
legacy_mongo_id = (
|
||||||
|
str(result.inserted_id)
|
||||||
|
if getattr(result, "inserted_id", None) is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
def _pg_write(repo: WorkflowRunsRepository) -> None:
|
||||||
|
if not self.workflow_id or not owner_id or not legacy_mongo_id:
|
||||||
|
return
|
||||||
|
workflow = WorkflowsRepository(repo._conn).get_by_legacy_id(
|
||||||
|
self.workflow_id, owner_id,
|
||||||
|
)
|
||||||
|
if workflow is None:
|
||||||
|
return
|
||||||
|
repo.create(
|
||||||
|
workflow["id"],
|
||||||
|
owner_id,
|
||||||
|
run.status.value,
|
||||||
|
inputs=run.inputs,
|
||||||
|
result=run.outputs,
|
||||||
|
steps=[step.model_dump(mode="json") for step in run.steps],
|
||||||
|
started_at=run.created_at,
|
||||||
|
ended_at=run.completed_at,
|
||||||
|
legacy_mongo_id=legacy_mongo_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
dual_write(WorkflowRunsRepository, _pg_write)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to save workflow run: {e}")
|
logger.error(f"Failed to save workflow run: {e}")
|
||||||
|
|
||||||
|
|||||||
@@ -211,6 +211,7 @@ class WorkflowRun(BaseModel):
|
|||||||
model_config = ConfigDict(extra="allow")
|
model_config = ConfigDict(extra="allow")
|
||||||
id: Optional[str] = Field(None, alias="_id")
|
id: Optional[str] = Field(None, alias="_id")
|
||||||
workflow_id: str
|
workflow_id: str
|
||||||
|
user: Optional[str] = None
|
||||||
status: ExecutionStatus = ExecutionStatus.PENDING
|
status: ExecutionStatus = ExecutionStatus.PENDING
|
||||||
inputs: Dict[str, str] = Field(default_factory=dict)
|
inputs: Dict[str, str] = Field(default_factory=dict)
|
||||||
outputs: Dict[str, Any] = Field(default_factory=dict)
|
outputs: Dict[str, Any] = Field(default_factory=dict)
|
||||||
@@ -226,7 +227,7 @@ class WorkflowRun(BaseModel):
|
|||||||
return v
|
return v
|
||||||
|
|
||||||
def to_mongo_doc(self) -> Dict[str, Any]:
|
def to_mongo_doc(self) -> Dict[str, Any]:
|
||||||
return {
|
doc = {
|
||||||
"workflow_id": self.workflow_id,
|
"workflow_id": self.workflow_id,
|
||||||
"status": self.status.value,
|
"status": self.status.value,
|
||||||
"inputs": self.inputs,
|
"inputs": self.inputs,
|
||||||
@@ -235,3 +236,7 @@ class WorkflowRun(BaseModel):
|
|||||||
"created_at": self.created_at,
|
"created_at": self.created_at,
|
||||||
"completed_at": self.completed_at,
|
"completed_at": self.completed_at,
|
||||||
}
|
}
|
||||||
|
if self.user:
|
||||||
|
doc["user"] = self.user
|
||||||
|
doc["user_id"] = self.user
|
||||||
|
return doc
|
||||||
|
|||||||
52
application/alembic.ini
Normal file
52
application/alembic.ini
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
# Alembic configuration for the DocsGPT user-data Postgres database.
|
||||||
|
#
|
||||||
|
# The SQLAlchemy URL is deliberately NOT set here — env.py reads it from
|
||||||
|
# ``application.core.settings.settings.POSTGRES_URI`` so the same config
|
||||||
|
# source serves the running app and migrations. To run from the project
|
||||||
|
# root::
|
||||||
|
#
|
||||||
|
# alembic -c application/alembic.ini upgrade head
|
||||||
|
|
||||||
|
[alembic]
|
||||||
|
script_location = %(here)s/alembic
|
||||||
|
prepend_sys_path = ..
|
||||||
|
version_path_separator = os
|
||||||
|
|
||||||
|
# sqlalchemy.url is intentionally left blank — env.py supplies it.
|
||||||
|
sqlalchemy.url =
|
||||||
|
|
||||||
|
[post_write_hooks]
|
||||||
|
|
||||||
|
[loggers]
|
||||||
|
keys = root,sqlalchemy,alembic
|
||||||
|
|
||||||
|
[handlers]
|
||||||
|
keys = console
|
||||||
|
|
||||||
|
[formatters]
|
||||||
|
keys = generic
|
||||||
|
|
||||||
|
[logger_root]
|
||||||
|
level = WARNING
|
||||||
|
handlers = console
|
||||||
|
qualname =
|
||||||
|
|
||||||
|
[logger_sqlalchemy]
|
||||||
|
level = WARNING
|
||||||
|
handlers =
|
||||||
|
qualname = sqlalchemy.engine
|
||||||
|
|
||||||
|
[logger_alembic]
|
||||||
|
level = INFO
|
||||||
|
handlers =
|
||||||
|
qualname = alembic
|
||||||
|
|
||||||
|
[handler_console]
|
||||||
|
class = StreamHandler
|
||||||
|
args = (sys.stderr,)
|
||||||
|
level = NOTSET
|
||||||
|
formatter = generic
|
||||||
|
|
||||||
|
[formatter_generic]
|
||||||
|
format = %(levelname)-5.5s [%(name)s] %(message)s
|
||||||
|
datefmt = %H:%M:%S
|
||||||
82
application/alembic/env.py
Normal file
82
application/alembic/env.py
Normal file
@@ -0,0 +1,82 @@
|
|||||||
|
"""Alembic environment for the DocsGPT user-data Postgres database.
|
||||||
|
|
||||||
|
The URL is pulled from ``application.core.settings`` rather than
|
||||||
|
``alembic.ini`` so that a single ``POSTGRES_URI`` env var drives both the
|
||||||
|
running app and ``alembic`` CLI invocations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
from logging.config import fileConfig
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# Make the project root importable regardless of cwd. env.py lives at
|
||||||
|
# <repo>/application/alembic/env.py, so parents[2] is the repo root.
|
||||||
|
_PROJECT_ROOT = Path(__file__).resolve().parents[2]
|
||||||
|
if str(_PROJECT_ROOT) not in sys.path:
|
||||||
|
sys.path.insert(0, str(_PROJECT_ROOT))
|
||||||
|
|
||||||
|
from alembic import context # noqa: E402
|
||||||
|
from sqlalchemy import engine_from_config, pool # noqa: E402
|
||||||
|
|
||||||
|
from application.core.settings import settings # noqa: E402
|
||||||
|
from application.storage.db.models import metadata as target_metadata # noqa: E402
|
||||||
|
|
||||||
|
config = context.config
|
||||||
|
|
||||||
|
# Populate the runtime URL from settings.
|
||||||
|
if settings.POSTGRES_URI:
|
||||||
|
config.set_main_option("sqlalchemy.url", settings.POSTGRES_URI)
|
||||||
|
|
||||||
|
if config.config_file_name is not None:
|
||||||
|
fileConfig(config.config_file_name)
|
||||||
|
|
||||||
|
|
||||||
|
def run_migrations_offline() -> None:
|
||||||
|
"""Run migrations in 'offline' mode (emits SQL without a live DB)."""
|
||||||
|
url = config.get_main_option("sqlalchemy.url")
|
||||||
|
if not url:
|
||||||
|
raise RuntimeError(
|
||||||
|
"POSTGRES_URI is not configured. Set it in your .env to a "
|
||||||
|
"psycopg3 URI such as "
|
||||||
|
"'postgresql+psycopg://user:pass@host:5432/docsgpt'."
|
||||||
|
)
|
||||||
|
context.configure(
|
||||||
|
url=url,
|
||||||
|
target_metadata=target_metadata,
|
||||||
|
literal_binds=True,
|
||||||
|
dialect_opts={"paramstyle": "named"},
|
||||||
|
compare_type=True,
|
||||||
|
)
|
||||||
|
with context.begin_transaction():
|
||||||
|
context.run_migrations()
|
||||||
|
|
||||||
|
|
||||||
|
def run_migrations_online() -> None:
|
||||||
|
"""Run migrations in 'online' mode against a live connection."""
|
||||||
|
if not config.get_main_option("sqlalchemy.url"):
|
||||||
|
raise RuntimeError(
|
||||||
|
"POSTGRES_URI is not configured. Set it in your .env to a "
|
||||||
|
"psycopg3 URI such as "
|
||||||
|
"'postgresql+psycopg://user:pass@host:5432/docsgpt'."
|
||||||
|
)
|
||||||
|
connectable = engine_from_config(
|
||||||
|
config.get_section(config.config_ini_section, {}),
|
||||||
|
prefix="sqlalchemy.",
|
||||||
|
poolclass=pool.NullPool,
|
||||||
|
future=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
with connectable.connect() as connection:
|
||||||
|
context.configure(
|
||||||
|
connection=connection,
|
||||||
|
target_metadata=target_metadata,
|
||||||
|
compare_type=True,
|
||||||
|
)
|
||||||
|
with context.begin_transaction():
|
||||||
|
context.run_migrations()
|
||||||
|
|
||||||
|
|
||||||
|
if context.is_offline_mode():
|
||||||
|
run_migrations_offline()
|
||||||
|
else:
|
||||||
|
run_migrations_online()
|
||||||
26
application/alembic/script.py.mako
Normal file
26
application/alembic/script.py.mako
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
"""${message}
|
||||||
|
|
||||||
|
Revision ID: ${up_revision}
|
||||||
|
Revises: ${down_revision | comma,n}
|
||||||
|
Create Date: ${create_date}
|
||||||
|
|
||||||
|
"""
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
${imports if imports else ""}
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = ${repr(up_revision)}
|
||||||
|
down_revision: Union[str, None] = ${repr(down_revision)}
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
|
||||||
|
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
${upgrades if upgrades else "pass"}
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
${downgrades if downgrades else "pass"}
|
||||||
825
application/alembic/versions/0001_initial.py
Normal file
825
application/alembic/versions/0001_initial.py
Normal file
@@ -0,0 +1,825 @@
|
|||||||
|
"""0001 initial schema — consolidated Phase-1..3 baseline.
|
||||||
|
|
||||||
|
Revision ID: 0001_initial
|
||||||
|
Revises:
|
||||||
|
Create Date: 2026-04-13
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
|
||||||
|
revision: str = "0001_initial"
|
||||||
|
down_revision: Union[str, None] = None
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Extensions
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
op.execute('CREATE EXTENSION IF NOT EXISTS "pgcrypto";')
|
||||||
|
op.execute('CREATE EXTENSION IF NOT EXISTS "citext";')
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Trigger functions
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
op.execute(
|
||||||
|
"""
|
||||||
|
CREATE FUNCTION set_updated_at() RETURNS trigger
|
||||||
|
LANGUAGE plpgsql AS $$
|
||||||
|
BEGIN
|
||||||
|
NEW.updated_at = now();
|
||||||
|
RETURN NEW;
|
||||||
|
END;
|
||||||
|
$$;
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
op.execute(
|
||||||
|
"""
|
||||||
|
CREATE FUNCTION ensure_user_exists() RETURNS trigger
|
||||||
|
LANGUAGE plpgsql AS $$
|
||||||
|
BEGIN
|
||||||
|
IF NEW.user_id IS NOT NULL THEN
|
||||||
|
INSERT INTO users (user_id) VALUES (NEW.user_id)
|
||||||
|
ON CONFLICT (user_id) DO NOTHING;
|
||||||
|
END IF;
|
||||||
|
RETURN NEW;
|
||||||
|
END;
|
||||||
|
$$;
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
op.execute(
|
||||||
|
"""
|
||||||
|
CREATE FUNCTION cleanup_message_attachment_refs() RETURNS trigger
|
||||||
|
LANGUAGE plpgsql AS $$
|
||||||
|
BEGIN
|
||||||
|
UPDATE conversation_messages
|
||||||
|
SET attachments = array_remove(attachments, OLD.id)
|
||||||
|
WHERE OLD.id = ANY(attachments);
|
||||||
|
RETURN OLD;
|
||||||
|
END;
|
||||||
|
$$;
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
op.execute(
|
||||||
|
"""
|
||||||
|
CREATE FUNCTION cleanup_agent_extra_source_refs() RETURNS trigger
|
||||||
|
LANGUAGE plpgsql AS $$
|
||||||
|
BEGIN
|
||||||
|
UPDATE agents
|
||||||
|
SET extra_source_ids = array_remove(extra_source_ids, OLD.id)
|
||||||
|
WHERE OLD.id = ANY(extra_source_ids);
|
||||||
|
RETURN OLD;
|
||||||
|
END;
|
||||||
|
$$;
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
op.execute(
|
||||||
|
"""
|
||||||
|
CREATE FUNCTION cleanup_user_agent_prefs() RETURNS trigger
|
||||||
|
LANGUAGE plpgsql AS $$
|
||||||
|
DECLARE
|
||||||
|
agent_id_text text := OLD.id::text;
|
||||||
|
BEGIN
|
||||||
|
UPDATE users
|
||||||
|
SET agent_preferences = jsonb_set(
|
||||||
|
jsonb_set(
|
||||||
|
agent_preferences,
|
||||||
|
'{pinned}',
|
||||||
|
COALESCE((
|
||||||
|
SELECT jsonb_agg(e)
|
||||||
|
FROM jsonb_array_elements(
|
||||||
|
COALESCE(agent_preferences->'pinned', '[]'::jsonb)
|
||||||
|
) e
|
||||||
|
WHERE (e #>> '{}') <> agent_id_text
|
||||||
|
), '[]'::jsonb)
|
||||||
|
),
|
||||||
|
'{shared_with_me}',
|
||||||
|
COALESCE((
|
||||||
|
SELECT jsonb_agg(e)
|
||||||
|
FROM jsonb_array_elements(
|
||||||
|
COALESCE(agent_preferences->'shared_with_me', '[]'::jsonb)
|
||||||
|
) e
|
||||||
|
WHERE (e #>> '{}') <> agent_id_text
|
||||||
|
), '[]'::jsonb)
|
||||||
|
)
|
||||||
|
WHERE agent_preferences->'pinned' @> to_jsonb(agent_id_text)
|
||||||
|
OR agent_preferences->'shared_with_me' @> to_jsonb(agent_id_text);
|
||||||
|
RETURN OLD;
|
||||||
|
END;
|
||||||
|
$$;
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
op.execute(
|
||||||
|
"""
|
||||||
|
CREATE FUNCTION conversation_messages_fill_user_id() RETURNS trigger
|
||||||
|
LANGUAGE plpgsql AS $$
|
||||||
|
BEGIN
|
||||||
|
IF NEW.user_id IS NULL THEN
|
||||||
|
SELECT user_id INTO NEW.user_id
|
||||||
|
FROM conversations
|
||||||
|
WHERE id = NEW.conversation_id;
|
||||||
|
END IF;
|
||||||
|
RETURN NEW;
|
||||||
|
END;
|
||||||
|
$$;
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Tables
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
op.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE users (
|
||||||
|
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||||
|
user_id TEXT NOT NULL UNIQUE,
|
||||||
|
agent_preferences JSONB NOT NULL
|
||||||
|
DEFAULT '{"pinned": [], "shared_with_me": []}'::jsonb,
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||||
|
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
op.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE prompts (
|
||||||
|
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||||
|
user_id TEXT NOT NULL,
|
||||||
|
name TEXT NOT NULL,
|
||||||
|
content TEXT NOT NULL,
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||||
|
updated_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||||
|
legacy_mongo_id TEXT
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
op.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE user_tools (
|
||||||
|
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||||
|
user_id TEXT NOT NULL,
|
||||||
|
name TEXT NOT NULL,
|
||||||
|
custom_name TEXT,
|
||||||
|
display_name TEXT,
|
||||||
|
config JSONB NOT NULL DEFAULT '{}'::jsonb,
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||||
|
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
op.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE token_usage (
|
||||||
|
id BIGSERIAL PRIMARY KEY,
|
||||||
|
user_id TEXT,
|
||||||
|
api_key TEXT,
|
||||||
|
agent_id UUID,
|
||||||
|
prompt_tokens INTEGER NOT NULL DEFAULT 0,
|
||||||
|
generated_tokens INTEGER NOT NULL DEFAULT 0,
|
||||||
|
timestamp TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
op.execute(
|
||||||
|
"ALTER TABLE token_usage ADD CONSTRAINT token_usage_attribution_chk "
|
||||||
|
"CHECK (user_id IS NOT NULL OR api_key IS NOT NULL) NOT VALID;"
|
||||||
|
)
|
||||||
|
|
||||||
|
op.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE user_logs (
|
||||||
|
id BIGSERIAL PRIMARY KEY,
|
||||||
|
user_id TEXT,
|
||||||
|
endpoint TEXT,
|
||||||
|
timestamp TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||||
|
data JSONB
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
op.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE stack_logs (
|
||||||
|
id BIGSERIAL PRIMARY KEY,
|
||||||
|
activity_id TEXT NOT NULL,
|
||||||
|
endpoint TEXT,
|
||||||
|
level TEXT,
|
||||||
|
user_id TEXT,
|
||||||
|
api_key TEXT,
|
||||||
|
query TEXT,
|
||||||
|
stacks JSONB NOT NULL DEFAULT '[]'::jsonb,
|
||||||
|
timestamp TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
op.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE agent_folders (
|
||||||
|
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||||
|
user_id TEXT NOT NULL,
|
||||||
|
name TEXT NOT NULL,
|
||||||
|
description TEXT,
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||||
|
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
op.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE sources (
|
||||||
|
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||||
|
user_id TEXT NOT NULL,
|
||||||
|
name TEXT NOT NULL,
|
||||||
|
type TEXT,
|
||||||
|
metadata JSONB NOT NULL DEFAULT '{}'::jsonb,
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||||
|
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
op.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE agents (
|
||||||
|
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||||
|
user_id TEXT NOT NULL,
|
||||||
|
name TEXT NOT NULL,
|
||||||
|
description TEXT,
|
||||||
|
agent_type TEXT,
|
||||||
|
status TEXT NOT NULL,
|
||||||
|
key CITEXT UNIQUE,
|
||||||
|
source_id UUID REFERENCES sources(id) ON DELETE SET NULL,
|
||||||
|
extra_source_ids UUID[] NOT NULL DEFAULT '{}',
|
||||||
|
chunks INTEGER,
|
||||||
|
retriever TEXT,
|
||||||
|
prompt_id UUID REFERENCES prompts(id) ON DELETE SET NULL,
|
||||||
|
tools JSONB NOT NULL DEFAULT '[]'::jsonb,
|
||||||
|
json_schema JSONB,
|
||||||
|
models JSONB,
|
||||||
|
default_model_id TEXT,
|
||||||
|
folder_id UUID REFERENCES agent_folders(id) ON DELETE SET NULL,
|
||||||
|
limited_token_mode BOOLEAN NOT NULL DEFAULT false,
|
||||||
|
token_limit INTEGER,
|
||||||
|
limited_request_mode BOOLEAN NOT NULL DEFAULT false,
|
||||||
|
request_limit INTEGER,
|
||||||
|
shared BOOLEAN NOT NULL DEFAULT false,
|
||||||
|
incoming_webhook_token CITEXT UNIQUE,
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||||
|
updated_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||||
|
last_used_at TIMESTAMPTZ,
|
||||||
|
legacy_mongo_id TEXT
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
op.execute(
|
||||||
|
"ALTER TABLE token_usage ADD CONSTRAINT token_usage_agent_fk "
|
||||||
|
"FOREIGN KEY (agent_id) REFERENCES agents(id) ON DELETE SET NULL;"
|
||||||
|
)
|
||||||
|
|
||||||
|
op.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE attachments (
|
||||||
|
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||||
|
user_id TEXT NOT NULL,
|
||||||
|
filename TEXT NOT NULL,
|
||||||
|
upload_path TEXT NOT NULL,
|
||||||
|
mime_type TEXT,
|
||||||
|
size BIGINT,
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||||
|
legacy_mongo_id TEXT
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
op.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE memories (
|
||||||
|
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||||
|
user_id TEXT NOT NULL,
|
||||||
|
tool_id UUID REFERENCES user_tools(id) ON DELETE CASCADE,
|
||||||
|
path TEXT NOT NULL,
|
||||||
|
content TEXT NOT NULL,
|
||||||
|
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
op.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE todos (
|
||||||
|
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||||
|
user_id TEXT NOT NULL,
|
||||||
|
tool_id UUID REFERENCES user_tools(id) ON DELETE CASCADE,
|
||||||
|
title TEXT NOT NULL,
|
||||||
|
completed BOOLEAN NOT NULL DEFAULT false,
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||||
|
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
op.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE notes (
|
||||||
|
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||||
|
user_id TEXT NOT NULL,
|
||||||
|
tool_id UUID REFERENCES user_tools(id) ON DELETE CASCADE,
|
||||||
|
title TEXT NOT NULL,
|
||||||
|
content TEXT NOT NULL,
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||||
|
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
op.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE connector_sessions (
|
||||||
|
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||||
|
user_id TEXT NOT NULL,
|
||||||
|
provider TEXT NOT NULL,
|
||||||
|
session_data JSONB NOT NULL,
|
||||||
|
expires_at TIMESTAMPTZ,
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
op.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE conversations (
|
||||||
|
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||||
|
user_id TEXT NOT NULL,
|
||||||
|
agent_id UUID REFERENCES agents(id) ON DELETE SET NULL,
|
||||||
|
name TEXT,
|
||||||
|
api_key TEXT,
|
||||||
|
is_shared_usage BOOLEAN NOT NULL DEFAULT false,
|
||||||
|
shared_token TEXT,
|
||||||
|
date TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||||
|
updated_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||||
|
shared_with TEXT[] NOT NULL DEFAULT '{}'::text[],
|
||||||
|
compression_metadata JSONB,
|
||||||
|
legacy_mongo_id TEXT,
|
||||||
|
CONSTRAINT conversations_api_key_nonempty_chk
|
||||||
|
CHECK (api_key IS NULL OR api_key <> '')
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
op.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE conversation_messages (
|
||||||
|
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||||
|
conversation_id UUID NOT NULL REFERENCES conversations(id) ON DELETE CASCADE,
|
||||||
|
position INTEGER NOT NULL,
|
||||||
|
prompt TEXT,
|
||||||
|
response TEXT,
|
||||||
|
thought TEXT,
|
||||||
|
sources JSONB NOT NULL DEFAULT '[]'::jsonb,
|
||||||
|
tool_calls JSONB NOT NULL DEFAULT '[]'::jsonb,
|
||||||
|
attachments UUID[] NOT NULL DEFAULT '{}'::uuid[],
|
||||||
|
model_id TEXT,
|
||||||
|
message_metadata JSONB NOT NULL DEFAULT '{}'::jsonb,
|
||||||
|
feedback JSONB,
|
||||||
|
timestamp TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||||
|
user_id TEXT NOT NULL,
|
||||||
|
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
op.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE shared_conversations (
|
||||||
|
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||||
|
conversation_id UUID NOT NULL REFERENCES conversations(id) ON DELETE CASCADE,
|
||||||
|
user_id TEXT NOT NULL,
|
||||||
|
is_promptable BOOLEAN NOT NULL DEFAULT false,
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||||
|
uuid UUID NOT NULL,
|
||||||
|
first_n_queries INTEGER NOT NULL DEFAULT 0,
|
||||||
|
api_key TEXT,
|
||||||
|
prompt_id UUID REFERENCES prompts(id) ON DELETE SET NULL,
|
||||||
|
chunks INTEGER,
|
||||||
|
CONSTRAINT shared_conversations_api_key_nonempty_chk
|
||||||
|
CHECK (api_key IS NULL OR api_key <> '')
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
op.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE pending_tool_state (
|
||||||
|
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||||
|
conversation_id UUID NOT NULL REFERENCES conversations(id) ON DELETE CASCADE,
|
||||||
|
user_id TEXT NOT NULL,
|
||||||
|
messages JSONB NOT NULL,
|
||||||
|
pending_tool_calls JSONB NOT NULL,
|
||||||
|
tools_dict JSONB NOT NULL,
|
||||||
|
tool_schemas JSONB NOT NULL,
|
||||||
|
agent_config JSONB NOT NULL,
|
||||||
|
client_tools JSONB,
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||||
|
expires_at TIMESTAMPTZ NOT NULL
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
op.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE workflows (
|
||||||
|
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||||
|
user_id TEXT NOT NULL,
|
||||||
|
name TEXT NOT NULL,
|
||||||
|
description TEXT,
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||||
|
updated_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||||
|
current_graph_version INTEGER NOT NULL DEFAULT 1,
|
||||||
|
legacy_mongo_id TEXT
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
op.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE workflow_nodes (
|
||||||
|
id UUID DEFAULT gen_random_uuid() NOT NULL,
|
||||||
|
workflow_id UUID NOT NULL REFERENCES workflows(id) ON DELETE CASCADE,
|
||||||
|
graph_version INTEGER NOT NULL,
|
||||||
|
node_type TEXT NOT NULL,
|
||||||
|
config JSONB NOT NULL DEFAULT '{}'::jsonb,
|
||||||
|
node_id TEXT NOT NULL,
|
||||||
|
title TEXT,
|
||||||
|
description TEXT,
|
||||||
|
position JSONB NOT NULL DEFAULT '{"x": 0, "y": 0}'::jsonb,
|
||||||
|
legacy_mongo_id TEXT,
|
||||||
|
PRIMARY KEY (id),
|
||||||
|
CONSTRAINT workflow_nodes_id_wf_ver_key
|
||||||
|
UNIQUE (id, workflow_id, graph_version)
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
op.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE workflow_edges (
|
||||||
|
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||||
|
workflow_id UUID NOT NULL REFERENCES workflows(id) ON DELETE CASCADE,
|
||||||
|
graph_version INTEGER NOT NULL,
|
||||||
|
from_node_id UUID NOT NULL,
|
||||||
|
to_node_id UUID NOT NULL,
|
||||||
|
config JSONB NOT NULL DEFAULT '{}'::jsonb,
|
||||||
|
edge_id TEXT NOT NULL,
|
||||||
|
source_handle TEXT,
|
||||||
|
target_handle TEXT,
|
||||||
|
CONSTRAINT workflow_edges_from_node_fk
|
||||||
|
FOREIGN KEY (from_node_id, workflow_id, graph_version)
|
||||||
|
REFERENCES workflow_nodes(id, workflow_id, graph_version) ON DELETE CASCADE,
|
||||||
|
CONSTRAINT workflow_edges_to_node_fk
|
||||||
|
FOREIGN KEY (to_node_id, workflow_id, graph_version)
|
||||||
|
REFERENCES workflow_nodes(id, workflow_id, graph_version) ON DELETE CASCADE
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
op.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE workflow_runs (
|
||||||
|
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||||
|
workflow_id UUID NOT NULL REFERENCES workflows(id) ON DELETE CASCADE,
|
||||||
|
user_id TEXT NOT NULL,
|
||||||
|
status TEXT NOT NULL,
|
||||||
|
started_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||||
|
ended_at TIMESTAMPTZ,
|
||||||
|
result JSONB,
|
||||||
|
inputs JSONB,
|
||||||
|
steps JSONB NOT NULL DEFAULT '[]'::jsonb,
|
||||||
|
legacy_mongo_id TEXT,
|
||||||
|
CONSTRAINT workflow_runs_status_chk
|
||||||
|
CHECK (status IN ('pending', 'running', 'completed', 'failed'))
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Indexes
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
op.execute("CREATE INDEX agent_folders_user_idx ON agent_folders (user_id);")
|
||||||
|
|
||||||
|
op.execute("CREATE INDEX agents_user_idx ON agents (user_id);")
|
||||||
|
op.execute("CREATE INDEX agents_shared_idx ON agents (shared) WHERE shared = true;")
|
||||||
|
op.execute("CREATE INDEX agents_status_idx ON agents (status);")
|
||||||
|
op.execute("CREATE INDEX agents_source_id_idx ON agents (source_id);")
|
||||||
|
op.execute("CREATE INDEX agents_prompt_id_idx ON agents (prompt_id);")
|
||||||
|
op.execute("CREATE INDEX agents_folder_id_idx ON agents (folder_id);")
|
||||||
|
op.execute(
|
||||||
|
"CREATE UNIQUE INDEX agents_legacy_mongo_id_uidx "
|
||||||
|
"ON agents (legacy_mongo_id) WHERE legacy_mongo_id IS NOT NULL;"
|
||||||
|
)
|
||||||
|
|
||||||
|
op.execute("CREATE INDEX attachments_user_idx ON attachments (user_id);")
|
||||||
|
op.execute(
|
||||||
|
"CREATE UNIQUE INDEX attachments_legacy_mongo_id_uidx "
|
||||||
|
"ON attachments (legacy_mongo_id) WHERE legacy_mongo_id IS NOT NULL;"
|
||||||
|
)
|
||||||
|
|
||||||
|
op.execute(
|
||||||
|
"CREATE UNIQUE INDEX connector_sessions_user_provider_uidx "
|
||||||
|
"ON connector_sessions (user_id, provider);"
|
||||||
|
)
|
||||||
|
op.execute(
|
||||||
|
"CREATE INDEX connector_sessions_expiry_idx "
|
||||||
|
"ON connector_sessions (expires_at) WHERE expires_at IS NOT NULL;"
|
||||||
|
)
|
||||||
|
|
||||||
|
op.execute(
|
||||||
|
"CREATE UNIQUE INDEX conversation_messages_conv_pos_uidx "
|
||||||
|
"ON conversation_messages (conversation_id, position);"
|
||||||
|
)
|
||||||
|
op.execute(
|
||||||
|
"CREATE INDEX conversation_messages_user_ts_idx "
|
||||||
|
"ON conversation_messages (user_id, timestamp DESC);"
|
||||||
|
)
|
||||||
|
|
||||||
|
op.execute("CREATE INDEX conversations_user_date_idx ON conversations (user_id, date DESC);")
|
||||||
|
op.execute("CREATE INDEX conversations_agent_idx ON conversations (agent_id);")
|
||||||
|
op.execute(
|
||||||
|
"CREATE UNIQUE INDEX conversations_shared_token_uidx "
|
||||||
|
"ON conversations (shared_token) WHERE shared_token IS NOT NULL;"
|
||||||
|
)
|
||||||
|
op.execute(
|
||||||
|
"CREATE INDEX conversations_api_key_date_idx "
|
||||||
|
"ON conversations (api_key, date DESC) WHERE api_key IS NOT NULL;"
|
||||||
|
)
|
||||||
|
op.execute(
|
||||||
|
"CREATE UNIQUE INDEX conversations_legacy_mongo_id_uidx "
|
||||||
|
"ON conversations (legacy_mongo_id) WHERE legacy_mongo_id IS NOT NULL;"
|
||||||
|
)
|
||||||
|
|
||||||
|
op.execute(
|
||||||
|
"CREATE UNIQUE INDEX memories_user_tool_path_uidx "
|
||||||
|
"ON memories (user_id, tool_id, path);"
|
||||||
|
)
|
||||||
|
op.execute(
|
||||||
|
"CREATE UNIQUE INDEX memories_user_path_null_tool_uidx "
|
||||||
|
"ON memories (user_id, path) WHERE tool_id IS NULL;"
|
||||||
|
)
|
||||||
|
op.execute(
|
||||||
|
"CREATE INDEX memories_path_prefix_idx "
|
||||||
|
"ON memories (user_id, tool_id, path text_pattern_ops);"
|
||||||
|
)
|
||||||
|
op.execute("CREATE INDEX memories_tool_id_idx ON memories (tool_id);")
|
||||||
|
|
||||||
|
op.execute("CREATE UNIQUE INDEX notes_user_tool_uidx ON notes (user_id, tool_id);")
|
||||||
|
op.execute("CREATE INDEX notes_tool_id_idx ON notes (tool_id);")
|
||||||
|
|
||||||
|
op.execute(
|
||||||
|
"CREATE UNIQUE INDEX pending_tool_state_conv_user_uidx "
|
||||||
|
"ON pending_tool_state (conversation_id, user_id);"
|
||||||
|
)
|
||||||
|
op.execute(
|
||||||
|
"CREATE INDEX pending_tool_state_expires_idx ON pending_tool_state (expires_at);"
|
||||||
|
)
|
||||||
|
|
||||||
|
op.execute("CREATE INDEX prompts_user_id_idx ON prompts (user_id);")
|
||||||
|
op.execute(
|
||||||
|
"CREATE UNIQUE INDEX prompts_legacy_mongo_id_uidx "
|
||||||
|
"ON prompts (legacy_mongo_id) WHERE legacy_mongo_id IS NOT NULL;"
|
||||||
|
)
|
||||||
|
|
||||||
|
op.execute("CREATE INDEX shared_conversations_user_idx ON shared_conversations (user_id);")
|
||||||
|
op.execute("CREATE INDEX shared_conversations_conv_idx ON shared_conversations (conversation_id);")
|
||||||
|
op.execute(
|
||||||
|
"CREATE INDEX shared_conversations_prompt_id_idx ON shared_conversations (prompt_id);"
|
||||||
|
)
|
||||||
|
op.execute(
|
||||||
|
"CREATE UNIQUE INDEX shared_conversations_uuid_uidx ON shared_conversations (uuid);"
|
||||||
|
)
|
||||||
|
op.execute(
|
||||||
|
"CREATE UNIQUE INDEX shared_conversations_dedup_uidx "
|
||||||
|
"ON shared_conversations (conversation_id, user_id, is_promptable, first_n_queries, COALESCE(api_key, ''));"
|
||||||
|
)
|
||||||
|
|
||||||
|
op.execute("CREATE INDEX sources_user_idx ON sources (user_id);")
|
||||||
|
|
||||||
|
op.execute('CREATE INDEX stack_logs_timestamp_idx ON stack_logs ("timestamp" DESC);')
|
||||||
|
op.execute('CREATE INDEX stack_logs_user_ts_idx ON stack_logs (user_id, "timestamp" DESC);')
|
||||||
|
op.execute('CREATE INDEX stack_logs_level_ts_idx ON stack_logs (level, "timestamp" DESC);')
|
||||||
|
op.execute("CREATE INDEX stack_logs_activity_idx ON stack_logs (activity_id);")
|
||||||
|
|
||||||
|
op.execute("CREATE INDEX todos_user_tool_idx ON todos (user_id, tool_id);")
|
||||||
|
op.execute("CREATE INDEX todos_tool_id_idx ON todos (tool_id);")
|
||||||
|
|
||||||
|
op.execute('CREATE INDEX token_usage_user_ts_idx ON token_usage (user_id, "timestamp" DESC);')
|
||||||
|
op.execute('CREATE INDEX token_usage_key_ts_idx ON token_usage (api_key, "timestamp" DESC);')
|
||||||
|
op.execute('CREATE INDEX token_usage_agent_ts_idx ON token_usage (agent_id, "timestamp" DESC);')
|
||||||
|
|
||||||
|
op.execute('CREATE INDEX user_logs_user_ts_idx ON user_logs (user_id, "timestamp" DESC);')
|
||||||
|
|
||||||
|
op.execute("CREATE INDEX user_tools_user_id_idx ON user_tools (user_id);")
|
||||||
|
|
||||||
|
op.execute("CREATE INDEX workflow_edges_from_node_idx ON workflow_edges (from_node_id);")
|
||||||
|
op.execute("CREATE INDEX workflow_edges_to_node_idx ON workflow_edges (to_node_id);")
|
||||||
|
op.execute(
|
||||||
|
"CREATE UNIQUE INDEX workflow_edges_wf_ver_eid_uidx "
|
||||||
|
"ON workflow_edges (workflow_id, graph_version, edge_id);"
|
||||||
|
)
|
||||||
|
|
||||||
|
op.execute(
|
||||||
|
"CREATE UNIQUE INDEX workflow_nodes_wf_ver_nid_uidx "
|
||||||
|
"ON workflow_nodes (workflow_id, graph_version, node_id);"
|
||||||
|
)
|
||||||
|
op.execute(
|
||||||
|
"CREATE UNIQUE INDEX workflow_nodes_legacy_mongo_id_uidx "
|
||||||
|
"ON workflow_nodes (legacy_mongo_id) WHERE legacy_mongo_id IS NOT NULL;"
|
||||||
|
)
|
||||||
|
|
||||||
|
op.execute("CREATE INDEX workflow_runs_workflow_idx ON workflow_runs (workflow_id);")
|
||||||
|
op.execute("CREATE INDEX workflow_runs_user_idx ON workflow_runs (user_id);")
|
||||||
|
op.execute(
|
||||||
|
"CREATE INDEX workflow_runs_status_started_idx "
|
||||||
|
"ON workflow_runs (status, started_at DESC);"
|
||||||
|
)
|
||||||
|
op.execute(
|
||||||
|
"CREATE UNIQUE INDEX workflow_runs_legacy_mongo_id_uidx "
|
||||||
|
"ON workflow_runs (legacy_mongo_id) WHERE legacy_mongo_id IS NOT NULL;"
|
||||||
|
)
|
||||||
|
|
||||||
|
op.execute("CREATE INDEX workflows_user_idx ON workflows (user_id);")
|
||||||
|
op.execute(
|
||||||
|
"CREATE UNIQUE INDEX workflows_legacy_mongo_id_uidx "
|
||||||
|
"ON workflows (legacy_mongo_id) WHERE legacy_mongo_id IS NOT NULL;"
|
||||||
|
)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# user_id foreign keys (deferrable so backfills can stage rows)
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
user_fk_tables = (
|
||||||
|
"agent_folders",
|
||||||
|
"agents",
|
||||||
|
"attachments",
|
||||||
|
"connector_sessions",
|
||||||
|
"conversation_messages",
|
||||||
|
"conversations",
|
||||||
|
"memories",
|
||||||
|
"notes",
|
||||||
|
"pending_tool_state",
|
||||||
|
"prompts",
|
||||||
|
"shared_conversations",
|
||||||
|
"sources",
|
||||||
|
"stack_logs",
|
||||||
|
"todos",
|
||||||
|
"token_usage",
|
||||||
|
"user_logs",
|
||||||
|
"user_tools",
|
||||||
|
"workflow_runs",
|
||||||
|
"workflows",
|
||||||
|
)
|
||||||
|
for table in user_fk_tables:
|
||||||
|
op.execute(
|
||||||
|
f"ALTER TABLE {table} "
|
||||||
|
f"ADD CONSTRAINT {table}_user_id_fk "
|
||||||
|
f"FOREIGN KEY (user_id) REFERENCES users(user_id) "
|
||||||
|
f"ON DELETE RESTRICT DEFERRABLE INITIALLY IMMEDIATE;"
|
||||||
|
)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Triggers
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
updated_at_tables = (
|
||||||
|
"agent_folders",
|
||||||
|
"agents",
|
||||||
|
"conversation_messages",
|
||||||
|
"conversations",
|
||||||
|
"memories",
|
||||||
|
"notes",
|
||||||
|
"prompts",
|
||||||
|
"sources",
|
||||||
|
"todos",
|
||||||
|
"user_tools",
|
||||||
|
"users",
|
||||||
|
"workflows",
|
||||||
|
)
|
||||||
|
for table in updated_at_tables:
|
||||||
|
op.execute(
|
||||||
|
f"CREATE TRIGGER {table}_set_updated_at "
|
||||||
|
f"BEFORE UPDATE ON {table} "
|
||||||
|
f"FOR EACH ROW WHEN (OLD.* IS DISTINCT FROM NEW.*) "
|
||||||
|
f"EXECUTE FUNCTION set_updated_at();"
|
||||||
|
)
|
||||||
|
|
||||||
|
ensure_user_tables = (
|
||||||
|
"agent_folders",
|
||||||
|
"agents",
|
||||||
|
"attachments",
|
||||||
|
"connector_sessions",
|
||||||
|
"conversation_messages",
|
||||||
|
"conversations",
|
||||||
|
"memories",
|
||||||
|
"notes",
|
||||||
|
"pending_tool_state",
|
||||||
|
"prompts",
|
||||||
|
"shared_conversations",
|
||||||
|
"sources",
|
||||||
|
"stack_logs",
|
||||||
|
"todos",
|
||||||
|
"token_usage",
|
||||||
|
"user_logs",
|
||||||
|
"user_tools",
|
||||||
|
"workflow_runs",
|
||||||
|
"workflows",
|
||||||
|
)
|
||||||
|
for table in ensure_user_tables:
|
||||||
|
op.execute(
|
||||||
|
f"CREATE TRIGGER {table}_ensure_user "
|
||||||
|
f"BEFORE INSERT OR UPDATE OF user_id ON {table} "
|
||||||
|
f"FOR EACH ROW EXECUTE FUNCTION ensure_user_exists();"
|
||||||
|
)
|
||||||
|
|
||||||
|
op.execute(
|
||||||
|
"CREATE TRIGGER conversation_messages_fill_user "
|
||||||
|
"BEFORE INSERT ON conversation_messages "
|
||||||
|
"FOR EACH ROW EXECUTE FUNCTION conversation_messages_fill_user_id();"
|
||||||
|
)
|
||||||
|
|
||||||
|
op.execute(
|
||||||
|
"CREATE TRIGGER attachments_cleanup_message_refs "
|
||||||
|
"AFTER DELETE ON attachments "
|
||||||
|
"FOR EACH ROW EXECUTE FUNCTION cleanup_message_attachment_refs();"
|
||||||
|
)
|
||||||
|
op.execute(
|
||||||
|
"CREATE TRIGGER agents_cleanup_user_prefs "
|
||||||
|
"AFTER DELETE ON agents "
|
||||||
|
"FOR EACH ROW EXECUTE FUNCTION cleanup_user_agent_prefs();"
|
||||||
|
)
|
||||||
|
op.execute(
|
||||||
|
"CREATE TRIGGER sources_cleanup_agent_extra_refs "
|
||||||
|
"AFTER DELETE ON sources "
|
||||||
|
"FOR EACH ROW EXECUTE FUNCTION cleanup_agent_extra_source_refs();"
|
||||||
|
)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Seed sentinel __system__ user (system/template sources attribute here)
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
op.execute(
|
||||||
|
"INSERT INTO users (user_id) VALUES ('__system__') "
|
||||||
|
"ON CONFLICT (user_id) DO NOTHING;"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
# Nuclear downgrade: drop everything this migration created. The
|
||||||
|
# ordering drops FK-bearing children before parents; CASCADE would
|
||||||
|
# also work but explicit ordering is easier to reason about in code
|
||||||
|
# review.
|
||||||
|
tables_in_drop_order = (
|
||||||
|
"workflow_edges",
|
||||||
|
"workflow_runs",
|
||||||
|
"workflow_nodes",
|
||||||
|
"workflows",
|
||||||
|
"pending_tool_state",
|
||||||
|
"shared_conversations",
|
||||||
|
"conversation_messages",
|
||||||
|
"conversations",
|
||||||
|
"connector_sessions",
|
||||||
|
"notes",
|
||||||
|
"todos",
|
||||||
|
"memories",
|
||||||
|
"attachments",
|
||||||
|
"agents",
|
||||||
|
"sources",
|
||||||
|
"agent_folders",
|
||||||
|
"stack_logs",
|
||||||
|
"user_logs",
|
||||||
|
"token_usage",
|
||||||
|
"user_tools",
|
||||||
|
"prompts",
|
||||||
|
"users",
|
||||||
|
)
|
||||||
|
for table in tables_in_drop_order:
|
||||||
|
op.execute(f"DROP TABLE IF EXISTS {table} CASCADE;")
|
||||||
|
|
||||||
|
for fn in (
|
||||||
|
"conversation_messages_fill_user_id",
|
||||||
|
"cleanup_user_agent_prefs",
|
||||||
|
"cleanup_agent_extra_source_refs",
|
||||||
|
"cleanup_message_attachment_refs",
|
||||||
|
"ensure_user_exists",
|
||||||
|
"set_updated_at",
|
||||||
|
):
|
||||||
|
op.execute(f"DROP FUNCTION IF EXISTS {fn}();")
|
||||||
@@ -74,57 +74,76 @@ class AnswerResource(Resource, BaseAnswerResource):
|
|||||||
decoded_token = getattr(request, "decoded_token", None)
|
decoded_token = getattr(request, "decoded_token", None)
|
||||||
processor = StreamProcessor(data, decoded_token)
|
processor = StreamProcessor(data, decoded_token)
|
||||||
try:
|
try:
|
||||||
agent = processor.build_agent(data.get("question", ""))
|
# ---- Continuation mode ----
|
||||||
if not processor.decoded_token:
|
if data.get("tool_actions"):
|
||||||
return make_response({"error": "Unauthorized"}, 401)
|
(
|
||||||
|
agent,
|
||||||
|
messages,
|
||||||
|
tools_dict,
|
||||||
|
pending_tool_calls,
|
||||||
|
tool_actions,
|
||||||
|
) = processor.resume_from_tool_actions(
|
||||||
|
data["tool_actions"], data["conversation_id"]
|
||||||
|
)
|
||||||
|
if not processor.decoded_token:
|
||||||
|
return make_response({"error": "Unauthorized"}, 401)
|
||||||
|
if error := self.check_usage(processor.agent_config):
|
||||||
|
return error
|
||||||
|
stream = self.complete_stream(
|
||||||
|
question="",
|
||||||
|
agent=agent,
|
||||||
|
conversation_id=processor.conversation_id,
|
||||||
|
user_api_key=processor.agent_config.get("user_api_key"),
|
||||||
|
decoded_token=processor.decoded_token,
|
||||||
|
agent_id=processor.agent_id,
|
||||||
|
model_id=processor.model_id,
|
||||||
|
_continuation={
|
||||||
|
"messages": messages,
|
||||||
|
"tools_dict": tools_dict,
|
||||||
|
"pending_tool_calls": pending_tool_calls,
|
||||||
|
"tool_actions": tool_actions,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# ---- Normal mode ----
|
||||||
|
agent = processor.build_agent(data.get("question", ""))
|
||||||
|
if not processor.decoded_token:
|
||||||
|
return make_response({"error": "Unauthorized"}, 401)
|
||||||
|
|
||||||
if error := self.check_usage(processor.agent_config):
|
if error := self.check_usage(processor.agent_config):
|
||||||
return error
|
return error
|
||||||
|
|
||||||
|
stream = self.complete_stream(
|
||||||
|
question=data["question"],
|
||||||
|
agent=agent,
|
||||||
|
conversation_id=processor.conversation_id,
|
||||||
|
user_api_key=processor.agent_config.get("user_api_key"),
|
||||||
|
decoded_token=processor.decoded_token,
|
||||||
|
isNoneDoc=data.get("isNoneDoc"),
|
||||||
|
index=None,
|
||||||
|
should_save_conversation=data.get("save_conversation", True),
|
||||||
|
agent_id=processor.agent_id,
|
||||||
|
is_shared_usage=processor.is_shared_usage,
|
||||||
|
shared_token=processor.shared_token,
|
||||||
|
model_id=processor.model_id,
|
||||||
|
)
|
||||||
|
|
||||||
stream = self.complete_stream(
|
|
||||||
question=data["question"],
|
|
||||||
agent=agent,
|
|
||||||
conversation_id=processor.conversation_id,
|
|
||||||
user_api_key=processor.agent_config.get("user_api_key"),
|
|
||||||
decoded_token=processor.decoded_token,
|
|
||||||
isNoneDoc=data.get("isNoneDoc"),
|
|
||||||
index=None,
|
|
||||||
should_save_conversation=data.get("save_conversation", True),
|
|
||||||
agent_id=processor.agent_id,
|
|
||||||
is_shared_usage=processor.is_shared_usage,
|
|
||||||
shared_token=processor.shared_token,
|
|
||||||
model_id=processor.model_id,
|
|
||||||
)
|
|
||||||
stream_result = self.process_response_stream(stream)
|
stream_result = self.process_response_stream(stream)
|
||||||
|
|
||||||
if len(stream_result) == 7:
|
if stream_result["error"]:
|
||||||
(
|
return make_response({"error": stream_result["error"]}, 400)
|
||||||
conversation_id,
|
|
||||||
response,
|
|
||||||
sources,
|
|
||||||
tool_calls,
|
|
||||||
thought,
|
|
||||||
error,
|
|
||||||
structured_info,
|
|
||||||
) = stream_result
|
|
||||||
else:
|
|
||||||
conversation_id, response, sources, tool_calls, thought, error = (
|
|
||||||
stream_result
|
|
||||||
)
|
|
||||||
structured_info = None
|
|
||||||
|
|
||||||
if error:
|
|
||||||
return make_response({"error": error}, 400)
|
|
||||||
result = {
|
result = {
|
||||||
"conversation_id": conversation_id,
|
"conversation_id": stream_result["conversation_id"],
|
||||||
"answer": response,
|
"answer": stream_result["answer"],
|
||||||
"sources": sources,
|
"sources": stream_result["sources"],
|
||||||
"tool_calls": tool_calls,
|
"tool_calls": stream_result["tool_calls"],
|
||||||
"thought": thought,
|
"thought": stream_result["thought"],
|
||||||
}
|
}
|
||||||
|
|
||||||
if structured_info:
|
extra_info = stream_result.get("extra")
|
||||||
result.update(structured_info)
|
if extra_info:
|
||||||
|
result.update(extra_info)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
logger.error(
|
||||||
f"/api/answer - error: {str(e)} - traceback: {traceback.format_exc()}",
|
f"/api/answer - error: {str(e)} - traceback: {traceback.format_exc()}",
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ from typing import Any, Dict, Generator, List, Optional
|
|||||||
from flask import jsonify, make_response, Response
|
from flask import jsonify, make_response, Response
|
||||||
from flask_restx import Namespace
|
from flask_restx import Namespace
|
||||||
|
|
||||||
|
from application.api.answer.services.continuation_service import ContinuationService
|
||||||
from application.api.answer.services.conversation_service import ConversationService
|
from application.api.answer.services.conversation_service import ConversationService
|
||||||
from application.core.model_utils import (
|
from application.core.model_utils import (
|
||||||
get_api_key_for_provider,
|
get_api_key_for_provider,
|
||||||
@@ -39,7 +40,16 @@ class BaseAnswerResource:
|
|||||||
def validate_request(
|
def validate_request(
|
||||||
self, data: Dict[str, Any], require_conversation_id: bool = False
|
self, data: Dict[str, Any], require_conversation_id: bool = False
|
||||||
) -> Optional[Response]:
|
) -> Optional[Response]:
|
||||||
"""Common request validation"""
|
"""Common request validation.
|
||||||
|
|
||||||
|
Continuation requests (``tool_actions`` present) require
|
||||||
|
``conversation_id`` but not ``question``.
|
||||||
|
"""
|
||||||
|
if data.get("tool_actions"):
|
||||||
|
# Continuation mode — question is not required
|
||||||
|
if missing := check_required_fields(data, ["conversation_id"]):
|
||||||
|
return missing
|
||||||
|
return None
|
||||||
required_fields = ["question"]
|
required_fields = ["question"]
|
||||||
if require_conversation_id:
|
if require_conversation_id:
|
||||||
required_fields.append("conversation_id")
|
required_fields.append("conversation_id")
|
||||||
@@ -177,6 +187,7 @@ class BaseAnswerResource:
|
|||||||
is_shared_usage: bool = False,
|
is_shared_usage: bool = False,
|
||||||
shared_token: Optional[str] = None,
|
shared_token: Optional[str] = None,
|
||||||
model_id: Optional[str] = None,
|
model_id: Optional[str] = None,
|
||||||
|
_continuation: Optional[Dict] = None,
|
||||||
) -> Generator[str, None, None]:
|
) -> Generator[str, None, None]:
|
||||||
"""
|
"""
|
||||||
Generator function that streams the complete conversation response.
|
Generator function that streams the complete conversation response.
|
||||||
@@ -207,8 +218,19 @@ class BaseAnswerResource:
|
|||||||
schema_info = None
|
schema_info = None
|
||||||
structured_chunks = []
|
structured_chunks = []
|
||||||
query_metadata = {}
|
query_metadata = {}
|
||||||
|
paused = False
|
||||||
|
|
||||||
for line in agent.gen(query=question):
|
if _continuation:
|
||||||
|
gen_iter = agent.gen_continuation(
|
||||||
|
messages=_continuation["messages"],
|
||||||
|
tools_dict=_continuation["tools_dict"],
|
||||||
|
pending_tool_calls=_continuation["pending_tool_calls"],
|
||||||
|
tool_actions=_continuation["tool_actions"],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
gen_iter = agent.gen(query=question)
|
||||||
|
|
||||||
|
for line in gen_iter:
|
||||||
if "metadata" in line:
|
if "metadata" in line:
|
||||||
query_metadata.update(line["metadata"])
|
query_metadata.update(line["metadata"])
|
||||||
elif "answer" in line:
|
elif "answer" in line:
|
||||||
@@ -244,15 +266,21 @@ class BaseAnswerResource:
|
|||||||
data = json.dumps({"type": "thought", "thought": line["thought"]})
|
data = json.dumps({"type": "thought", "thought": line["thought"]})
|
||||||
yield f"data: {data}\n\n"
|
yield f"data: {data}\n\n"
|
||||||
elif "type" in line:
|
elif "type" in line:
|
||||||
if line.get("type") == "error":
|
if line.get("type") == "tool_calls_pending":
|
||||||
|
# Save continuation state and end the stream
|
||||||
|
paused = True
|
||||||
|
data = json.dumps(line)
|
||||||
|
yield f"data: {data}\n\n"
|
||||||
|
elif line.get("type") == "error":
|
||||||
sanitized_error = {
|
sanitized_error = {
|
||||||
"type": "error",
|
"type": "error",
|
||||||
"error": sanitize_api_error(line.get("error", "An error occurred"))
|
"error": sanitize_api_error(line.get("error", "An error occurred"))
|
||||||
}
|
}
|
||||||
data = json.dumps(sanitized_error)
|
data = json.dumps(sanitized_error)
|
||||||
|
yield f"data: {data}\n\n"
|
||||||
else:
|
else:
|
||||||
data = json.dumps(line)
|
data = json.dumps(line)
|
||||||
yield f"data: {data}\n\n"
|
yield f"data: {data}\n\n"
|
||||||
if is_structured and structured_chunks:
|
if is_structured and structured_chunks:
|
||||||
structured_data = {
|
structured_data = {
|
||||||
"type": "structured_answer",
|
"type": "structured_answer",
|
||||||
@@ -262,6 +290,93 @@ class BaseAnswerResource:
|
|||||||
}
|
}
|
||||||
data = json.dumps(structured_data)
|
data = json.dumps(structured_data)
|
||||||
yield f"data: {data}\n\n"
|
yield f"data: {data}\n\n"
|
||||||
|
|
||||||
|
# ---- Paused: save continuation state and end stream early ----
|
||||||
|
if paused:
|
||||||
|
continuation = getattr(agent, "_pending_continuation", None)
|
||||||
|
if continuation:
|
||||||
|
# Ensure we have a conversation_id — create a partial
|
||||||
|
# conversation if this is the first turn.
|
||||||
|
if not conversation_id and should_save_conversation:
|
||||||
|
try:
|
||||||
|
provider = (
|
||||||
|
get_provider_from_model_id(model_id)
|
||||||
|
if model_id
|
||||||
|
else settings.LLM_PROVIDER
|
||||||
|
)
|
||||||
|
sys_api_key = get_api_key_for_provider(
|
||||||
|
provider or settings.LLM_PROVIDER
|
||||||
|
)
|
||||||
|
llm = LLMCreator.create_llm(
|
||||||
|
provider or settings.LLM_PROVIDER,
|
||||||
|
api_key=sys_api_key,
|
||||||
|
user_api_key=user_api_key,
|
||||||
|
decoded_token=decoded_token,
|
||||||
|
model_id=model_id,
|
||||||
|
agent_id=agent_id,
|
||||||
|
)
|
||||||
|
conversation_id = (
|
||||||
|
self.conversation_service.save_conversation(
|
||||||
|
None,
|
||||||
|
question,
|
||||||
|
response_full,
|
||||||
|
thought,
|
||||||
|
source_log_docs,
|
||||||
|
tool_calls,
|
||||||
|
llm,
|
||||||
|
model_id or self.default_model_id,
|
||||||
|
decoded_token,
|
||||||
|
api_key=user_api_key,
|
||||||
|
agent_id=agent_id,
|
||||||
|
is_shared_usage=is_shared_usage,
|
||||||
|
shared_token=shared_token,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Failed to create conversation for continuation: {e}",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
if conversation_id:
|
||||||
|
try:
|
||||||
|
cont_service = ContinuationService()
|
||||||
|
cont_service.save_state(
|
||||||
|
conversation_id=str(conversation_id),
|
||||||
|
user=decoded_token.get("sub", "local"),
|
||||||
|
messages=continuation["messages"],
|
||||||
|
pending_tool_calls=continuation["pending_tool_calls"],
|
||||||
|
tools_dict=continuation["tools_dict"],
|
||||||
|
tool_schemas=getattr(agent, "tools", []),
|
||||||
|
agent_config={
|
||||||
|
"model_id": model_id or self.default_model_id,
|
||||||
|
"llm_name": getattr(agent, "llm_name", settings.LLM_PROVIDER),
|
||||||
|
"api_key": getattr(agent, "api_key", None),
|
||||||
|
"user_api_key": user_api_key,
|
||||||
|
"agent_id": agent_id,
|
||||||
|
"agent_type": agent.__class__.__name__,
|
||||||
|
"prompt": getattr(agent, "prompt", ""),
|
||||||
|
"json_schema": getattr(agent, "json_schema", None),
|
||||||
|
"retriever_config": getattr(agent, "retriever_config", None),
|
||||||
|
},
|
||||||
|
client_tools=getattr(
|
||||||
|
agent.tool_executor, "client_tools", None
|
||||||
|
),
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Failed to save continuation state: {str(e)}",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
id_data = {"type": "id", "id": str(conversation_id)}
|
||||||
|
data = json.dumps(id_data)
|
||||||
|
yield f"data: {data}\n\n"
|
||||||
|
|
||||||
|
data = json.dumps({"type": "end"})
|
||||||
|
yield f"data: {data}\n\n"
|
||||||
|
return
|
||||||
|
|
||||||
if isNoneDoc:
|
if isNoneDoc:
|
||||||
for doc in source_log_docs:
|
for doc in source_log_docs:
|
||||||
doc["source"] = "None"
|
doc["source"] = "None"
|
||||||
@@ -354,6 +469,18 @@ class BaseAnswerResource:
|
|||||||
log_data[key] = value[:10000]
|
log_data[key] = value[:10000]
|
||||||
self.user_logs_collection.insert_one(log_data)
|
self.user_logs_collection.insert_one(log_data)
|
||||||
|
|
||||||
|
from application.storage.db.dual_write import dual_write
|
||||||
|
from application.storage.db.repositories.user_logs import UserLogsRepository
|
||||||
|
|
||||||
|
dual_write(
|
||||||
|
UserLogsRepository,
|
||||||
|
lambda repo, d=log_data: repo.insert(
|
||||||
|
user_id=d.get("user"),
|
||||||
|
endpoint="stream_answer",
|
||||||
|
data=d,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
data = json.dumps({"type": "end"})
|
data = json.dumps({"type": "end"})
|
||||||
yield f"data: {data}\n\n"
|
yield f"data: {data}\n\n"
|
||||||
except GeneratorExit:
|
except GeneratorExit:
|
||||||
@@ -425,8 +552,13 @@ class BaseAnswerResource:
|
|||||||
yield f"data: {data}\n\n"
|
yield f"data: {data}\n\n"
|
||||||
return
|
return
|
||||||
|
|
||||||
def process_response_stream(self, stream):
|
def process_response_stream(self, stream) -> Dict[str, Any]:
|
||||||
"""Process the stream response for non-streaming endpoint"""
|
"""Process the stream response for non-streaming endpoint.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict with keys: conversation_id, answer, sources, tool_calls,
|
||||||
|
thought, error, and optional extra.
|
||||||
|
"""
|
||||||
conversation_id = ""
|
conversation_id = ""
|
||||||
response_full = ""
|
response_full = ""
|
||||||
source_log_docs = []
|
source_log_docs = []
|
||||||
@@ -435,6 +567,7 @@ class BaseAnswerResource:
|
|||||||
stream_ended = False
|
stream_ended = False
|
||||||
is_structured = False
|
is_structured = False
|
||||||
schema_info = None
|
schema_info = None
|
||||||
|
pending_tool_calls = None
|
||||||
|
|
||||||
for line in stream:
|
for line in stream:
|
||||||
try:
|
try:
|
||||||
@@ -453,11 +586,22 @@ class BaseAnswerResource:
|
|||||||
source_log_docs = event["source"]
|
source_log_docs = event["source"]
|
||||||
elif event["type"] == "tool_calls":
|
elif event["type"] == "tool_calls":
|
||||||
tool_calls = event["tool_calls"]
|
tool_calls = event["tool_calls"]
|
||||||
|
elif event["type"] == "tool_calls_pending":
|
||||||
|
pending_tool_calls = event.get("data", {}).get(
|
||||||
|
"pending_tool_calls", []
|
||||||
|
)
|
||||||
elif event["type"] == "thought":
|
elif event["type"] == "thought":
|
||||||
thought = event["thought"]
|
thought = event["thought"]
|
||||||
elif event["type"] == "error":
|
elif event["type"] == "error":
|
||||||
logger.error(f"Error from stream: {event['error']}")
|
logger.error(f"Error from stream: {event['error']}")
|
||||||
return None, None, None, None, event["error"], None
|
return {
|
||||||
|
"conversation_id": None,
|
||||||
|
"answer": None,
|
||||||
|
"sources": None,
|
||||||
|
"tool_calls": None,
|
||||||
|
"thought": None,
|
||||||
|
"error": event["error"],
|
||||||
|
}
|
||||||
elif event["type"] == "end":
|
elif event["type"] == "end":
|
||||||
stream_ended = True
|
stream_ended = True
|
||||||
except (json.JSONDecodeError, KeyError) as e:
|
except (json.JSONDecodeError, KeyError) as e:
|
||||||
@@ -465,18 +609,30 @@ class BaseAnswerResource:
|
|||||||
continue
|
continue
|
||||||
if not stream_ended:
|
if not stream_ended:
|
||||||
logger.error("Stream ended unexpectedly without an 'end' event.")
|
logger.error("Stream ended unexpectedly without an 'end' event.")
|
||||||
return None, None, None, None, "Stream ended unexpectedly", None
|
return {
|
||||||
result = (
|
"conversation_id": None,
|
||||||
conversation_id,
|
"answer": None,
|
||||||
response_full,
|
"sources": None,
|
||||||
source_log_docs,
|
"tool_calls": None,
|
||||||
tool_calls,
|
"thought": None,
|
||||||
thought,
|
"error": "Stream ended unexpectedly",
|
||||||
None,
|
}
|
||||||
)
|
|
||||||
|
result: Dict[str, Any] = {
|
||||||
|
"conversation_id": conversation_id,
|
||||||
|
"answer": response_full,
|
||||||
|
"sources": source_log_docs,
|
||||||
|
"tool_calls": tool_calls,
|
||||||
|
"thought": thought,
|
||||||
|
"error": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
if pending_tool_calls is not None:
|
||||||
|
result["extra"] = {"pending_tool_calls": pending_tool_calls}
|
||||||
|
|
||||||
if is_structured:
|
if is_structured:
|
||||||
result = result + ({"structured": True, "schema": schema_info},)
|
result["extra"] = {"structured": True, "schema": schema_info}
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def error_stream_generate(self, err_response):
|
def error_stream_generate(self, err_response):
|
||||||
|
|||||||
@@ -79,7 +79,47 @@ class StreamResource(Resource, BaseAnswerResource):
|
|||||||
return error
|
return error
|
||||||
decoded_token = getattr(request, "decoded_token", None)
|
decoded_token = getattr(request, "decoded_token", None)
|
||||||
processor = StreamProcessor(data, decoded_token)
|
processor = StreamProcessor(data, decoded_token)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
# ---- Continuation mode ----
|
||||||
|
if data.get("tool_actions"):
|
||||||
|
(
|
||||||
|
agent,
|
||||||
|
messages,
|
||||||
|
tools_dict,
|
||||||
|
pending_tool_calls,
|
||||||
|
tool_actions,
|
||||||
|
) = processor.resume_from_tool_actions(
|
||||||
|
data["tool_actions"], data["conversation_id"]
|
||||||
|
)
|
||||||
|
if not processor.decoded_token:
|
||||||
|
return Response(
|
||||||
|
self.error_stream_generate("Unauthorized"),
|
||||||
|
status=401,
|
||||||
|
mimetype="text/event-stream",
|
||||||
|
)
|
||||||
|
if error := self.check_usage(processor.agent_config):
|
||||||
|
return error
|
||||||
|
return Response(
|
||||||
|
self.complete_stream(
|
||||||
|
question="",
|
||||||
|
agent=agent,
|
||||||
|
conversation_id=processor.conversation_id,
|
||||||
|
user_api_key=processor.agent_config.get("user_api_key"),
|
||||||
|
decoded_token=processor.decoded_token,
|
||||||
|
agent_id=processor.agent_id,
|
||||||
|
model_id=processor.model_id,
|
||||||
|
_continuation={
|
||||||
|
"messages": messages,
|
||||||
|
"tools_dict": tools_dict,
|
||||||
|
"pending_tool_calls": pending_tool_calls,
|
||||||
|
"tool_actions": tool_actions,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
mimetype="text/event-stream",
|
||||||
|
)
|
||||||
|
|
||||||
|
# ---- Normal mode ----
|
||||||
agent = processor.build_agent(data["question"])
|
agent = processor.build_agent(data["question"])
|
||||||
if not processor.decoded_token:
|
if not processor.decoded_token:
|
||||||
return Response(
|
return Response(
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
"""Message reconstruction utilities for compression."""
|
"""Message reconstruction utilities for compression."""
|
||||||
|
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional
|
||||||
@@ -49,28 +50,35 @@ class MessageBuilder:
|
|||||||
if include_tool_calls and "tool_calls" in query:
|
if include_tool_calls and "tool_calls" in query:
|
||||||
for tool_call in query["tool_calls"]:
|
for tool_call in query["tool_calls"]:
|
||||||
call_id = tool_call.get("call_id") or str(uuid.uuid4())
|
call_id = tool_call.get("call_id") or str(uuid.uuid4())
|
||||||
|
args = tool_call.get("arguments")
|
||||||
function_call_dict = {
|
args_str = (
|
||||||
"function_call": {
|
json.dumps(args)
|
||||||
"name": tool_call.get("action_name"),
|
if isinstance(args, dict)
|
||||||
"args": tool_call.get("arguments"),
|
else (args or "{}")
|
||||||
"call_id": call_id,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
function_response_dict = {
|
|
||||||
"function_response": {
|
|
||||||
"name": tool_call.get("action_name"),
|
|
||||||
"response": {"result": tool_call.get("result")},
|
|
||||||
"call_id": call_id,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
messages.append(
|
|
||||||
{"role": "assistant", "content": [function_call_dict]}
|
|
||||||
)
|
)
|
||||||
messages.append(
|
messages.append({
|
||||||
{"role": "tool", "content": [function_response_dict]}
|
"role": "assistant",
|
||||||
|
"content": None,
|
||||||
|
"tool_calls": [{
|
||||||
|
"id": call_id,
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": tool_call.get("action_name", ""),
|
||||||
|
"arguments": args_str,
|
||||||
|
},
|
||||||
|
}],
|
||||||
|
})
|
||||||
|
result = tool_call.get("result")
|
||||||
|
result_str = (
|
||||||
|
json.dumps(result)
|
||||||
|
if not isinstance(result, str)
|
||||||
|
else (result or "")
|
||||||
)
|
)
|
||||||
|
messages.append({
|
||||||
|
"role": "tool",
|
||||||
|
"tool_call_id": call_id,
|
||||||
|
"content": result_str,
|
||||||
|
})
|
||||||
|
|
||||||
# If no recent queries (everything was compressed), add a continuation user message
|
# If no recent queries (everything was compressed), add a continuation user message
|
||||||
if len(recent_queries) == 0 and compressed_summary:
|
if len(recent_queries) == 0 and compressed_summary:
|
||||||
@@ -180,28 +188,35 @@ class MessageBuilder:
|
|||||||
if include_tool_calls and "tool_calls" in query:
|
if include_tool_calls and "tool_calls" in query:
|
||||||
for tool_call in query["tool_calls"]:
|
for tool_call in query["tool_calls"]:
|
||||||
call_id = tool_call.get("call_id") or str(uuid.uuid4())
|
call_id = tool_call.get("call_id") or str(uuid.uuid4())
|
||||||
|
args = tool_call.get("arguments")
|
||||||
function_call_dict = {
|
args_str = (
|
||||||
"function_call": {
|
json.dumps(args)
|
||||||
"name": tool_call.get("action_name"),
|
if isinstance(args, dict)
|
||||||
"args": tool_call.get("arguments"),
|
else (args or "{}")
|
||||||
"call_id": call_id,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
function_response_dict = {
|
|
||||||
"function_response": {
|
|
||||||
"name": tool_call.get("action_name"),
|
|
||||||
"response": {"result": tool_call.get("result")},
|
|
||||||
"call_id": call_id,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
rebuilt_messages.append(
|
|
||||||
{"role": "assistant", "content": [function_call_dict]}
|
|
||||||
)
|
)
|
||||||
rebuilt_messages.append(
|
rebuilt_messages.append({
|
||||||
{"role": "tool", "content": [function_response_dict]}
|
"role": "assistant",
|
||||||
|
"content": None,
|
||||||
|
"tool_calls": [{
|
||||||
|
"id": call_id,
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": tool_call.get("action_name", ""),
|
||||||
|
"arguments": args_str,
|
||||||
|
},
|
||||||
|
}],
|
||||||
|
})
|
||||||
|
result = tool_call.get("result")
|
||||||
|
result_str = (
|
||||||
|
json.dumps(result)
|
||||||
|
if not isinstance(result, str)
|
||||||
|
else (result or "")
|
||||||
)
|
)
|
||||||
|
rebuilt_messages.append({
|
||||||
|
"role": "tool",
|
||||||
|
"tool_call_id": call_id,
|
||||||
|
"content": result_str,
|
||||||
|
})
|
||||||
|
|
||||||
# If no recent queries (everything was compressed), add a continuation user message
|
# If no recent queries (everything was compressed), add a continuation user message
|
||||||
if len(recent_queries) == 0 and compressed_summary:
|
if len(recent_queries) == 0 and compressed_summary:
|
||||||
|
|||||||
175
application/api/answer/services/continuation_service.py
Normal file
175
application/api/answer/services/continuation_service.py
Normal file
@@ -0,0 +1,175 @@
|
|||||||
|
"""Service for saving and restoring tool-call continuation state.
|
||||||
|
|
||||||
|
When a stream pauses (tool needs approval or client-side execution),
|
||||||
|
the full execution state is persisted to MongoDB so the client can
|
||||||
|
resume later by sending tool_actions.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import datetime
|
||||||
|
import logging
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
from bson import ObjectId
|
||||||
|
|
||||||
|
from application.core.mongo_db import MongoDB
|
||||||
|
from application.core.settings import settings
|
||||||
|
from application.storage.db.dual_write import dual_write
|
||||||
|
from application.storage.db.repositories.conversations import ConversationsRepository
|
||||||
|
from application.storage.db.repositories.pending_tool_state import (
|
||||||
|
PendingToolStateRepository,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# TTL for pending states — auto-cleaned after this period
|
||||||
|
PENDING_STATE_TTL_SECONDS = 30 * 60 # 30 minutes
|
||||||
|
|
||||||
|
|
||||||
|
def _make_serializable(obj: Any) -> Any:
|
||||||
|
"""Recursively convert MongoDB ObjectIds and other non-JSON types."""
|
||||||
|
if isinstance(obj, ObjectId):
|
||||||
|
return str(obj)
|
||||||
|
if isinstance(obj, dict):
|
||||||
|
return {str(k): _make_serializable(v) for k, v in obj.items()}
|
||||||
|
if isinstance(obj, list):
|
||||||
|
return [_make_serializable(v) for v in obj]
|
||||||
|
if isinstance(obj, bytes):
|
||||||
|
return obj.decode("utf-8", errors="replace")
|
||||||
|
return obj
|
||||||
|
|
||||||
|
|
||||||
|
class ContinuationService:
|
||||||
|
"""Manages pending tool-call state in MongoDB."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
mongo = MongoDB.get_client()
|
||||||
|
db = mongo[settings.MONGO_DB_NAME]
|
||||||
|
self.collection = db["pending_tool_state"]
|
||||||
|
self._ensure_indexes()
|
||||||
|
|
||||||
|
def _ensure_indexes(self):
|
||||||
|
try:
|
||||||
|
self.collection.create_index(
|
||||||
|
"expires_at", expireAfterSeconds=0
|
||||||
|
)
|
||||||
|
self.collection.create_index(
|
||||||
|
[("conversation_id", 1), ("user", 1)], unique=True
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
# Indexes may already exist or mongomock doesn't support TTL
|
||||||
|
pass
|
||||||
|
|
||||||
|
def save_state(
|
||||||
|
self,
|
||||||
|
conversation_id: str,
|
||||||
|
user: str,
|
||||||
|
messages: List[Dict],
|
||||||
|
pending_tool_calls: List[Dict],
|
||||||
|
tools_dict: Dict,
|
||||||
|
tool_schemas: List[Dict],
|
||||||
|
agent_config: Dict,
|
||||||
|
client_tools: Optional[List[Dict]] = None,
|
||||||
|
) -> str:
|
||||||
|
"""Save execution state for later continuation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
conversation_id: The conversation this state belongs to.
|
||||||
|
user: Owner user ID.
|
||||||
|
messages: Full messages array at the pause point.
|
||||||
|
pending_tool_calls: Tool calls awaiting client action.
|
||||||
|
tools_dict: Serializable tools configuration dict.
|
||||||
|
tool_schemas: LLM-formatted tool schemas (agent.tools).
|
||||||
|
agent_config: Config needed to recreate the agent on resume.
|
||||||
|
client_tools: Client-provided tool schemas for client-side execution.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The string ID of the saved state document.
|
||||||
|
"""
|
||||||
|
now = datetime.datetime.now(datetime.timezone.utc)
|
||||||
|
expires_at = now + datetime.timedelta(seconds=PENDING_STATE_TTL_SECONDS)
|
||||||
|
|
||||||
|
doc = {
|
||||||
|
"conversation_id": conversation_id,
|
||||||
|
"user": user,
|
||||||
|
"messages": _make_serializable(messages),
|
||||||
|
"pending_tool_calls": _make_serializable(pending_tool_calls),
|
||||||
|
"tools_dict": _make_serializable(tools_dict),
|
||||||
|
"tool_schemas": _make_serializable(tool_schemas),
|
||||||
|
"agent_config": _make_serializable(agent_config),
|
||||||
|
"client_tools": _make_serializable(client_tools) if client_tools else None,
|
||||||
|
"created_at": now,
|
||||||
|
"expires_at": expires_at,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Upsert — only one pending state per conversation per user
|
||||||
|
result = self.collection.replace_one(
|
||||||
|
{"conversation_id": conversation_id, "user": user},
|
||||||
|
doc,
|
||||||
|
upsert=True,
|
||||||
|
)
|
||||||
|
state_id = str(result.upserted_id) if result.upserted_id else conversation_id
|
||||||
|
logger.info(
|
||||||
|
f"Saved continuation state for conversation {conversation_id} "
|
||||||
|
f"with {len(pending_tool_calls)} pending tool call(s)"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Dual-write to Postgres — upsert against the same Mongo conversation
|
||||||
|
# by resolving its UUID via conversations.legacy_mongo_id.
|
||||||
|
def _pg_save(_: PendingToolStateRepository) -> None:
|
||||||
|
conn = _._conn # reuse the existing transaction
|
||||||
|
conv = ConversationsRepository(conn).get_by_legacy_id(conversation_id)
|
||||||
|
if conv is None:
|
||||||
|
return
|
||||||
|
_.save_state(
|
||||||
|
conv["id"],
|
||||||
|
user,
|
||||||
|
messages=_make_serializable(messages),
|
||||||
|
pending_tool_calls=_make_serializable(pending_tool_calls),
|
||||||
|
tools_dict=_make_serializable(tools_dict),
|
||||||
|
tool_schemas=_make_serializable(tool_schemas),
|
||||||
|
agent_config=_make_serializable(agent_config),
|
||||||
|
client_tools=_make_serializable(client_tools) if client_tools else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
dual_write(PendingToolStateRepository, _pg_save)
|
||||||
|
return state_id
|
||||||
|
|
||||||
|
def load_state(
|
||||||
|
self, conversation_id: str, user: str
|
||||||
|
) -> Optional[Dict[str, Any]]:
|
||||||
|
"""Load pending continuation state.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The state dict, or None if no pending state exists.
|
||||||
|
"""
|
||||||
|
doc = self.collection.find_one(
|
||||||
|
{"conversation_id": conversation_id, "user": user}
|
||||||
|
)
|
||||||
|
if not doc:
|
||||||
|
return None
|
||||||
|
doc["_id"] = str(doc["_id"])
|
||||||
|
return doc
|
||||||
|
|
||||||
|
def delete_state(self, conversation_id: str, user: str) -> bool:
|
||||||
|
"""Delete pending state after successful resumption.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if a document was deleted.
|
||||||
|
"""
|
||||||
|
result = self.collection.delete_one(
|
||||||
|
{"conversation_id": conversation_id, "user": user}
|
||||||
|
)
|
||||||
|
if result.deleted_count:
|
||||||
|
logger.info(
|
||||||
|
f"Deleted continuation state for conversation {conversation_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Dual-write to Postgres — delete the same row.
|
||||||
|
def _pg_delete(repo: PendingToolStateRepository) -> None:
|
||||||
|
conv = ConversationsRepository(repo._conn).get_by_legacy_id(conversation_id)
|
||||||
|
if conv is None:
|
||||||
|
return
|
||||||
|
repo.delete_state(conv["id"], user)
|
||||||
|
|
||||||
|
dual_write(PendingToolStateRepository, _pg_delete)
|
||||||
|
return result.deleted_count > 0
|
||||||
@@ -5,6 +5,8 @@ from typing import Any, Dict, List, Optional
|
|||||||
from application.core.mongo_db import MongoDB
|
from application.core.mongo_db import MongoDB
|
||||||
|
|
||||||
from application.core.settings import settings
|
from application.core.settings import settings
|
||||||
|
from application.storage.db.dual_write import dual_write
|
||||||
|
from application.storage.db.repositories.conversations import ConversationsRepository
|
||||||
from bson import ObjectId
|
from bson import ObjectId
|
||||||
|
|
||||||
|
|
||||||
@@ -113,6 +115,26 @@ class ConversationService:
|
|||||||
},
|
},
|
||||||
{"$push": {"queries": {"$each": [], "$slice": index + 1}}},
|
{"$push": {"queries": {"$each": [], "$slice": index + 1}}},
|
||||||
)
|
)
|
||||||
|
# Dual-write to Postgres: update the message at :index and
|
||||||
|
# truncate anything after it, mirroring Mongo's $set+$slice.
|
||||||
|
def _pg_update_at_index(repo: ConversationsRepository) -> None:
|
||||||
|
conv = repo.get_by_legacy_id(conversation_id)
|
||||||
|
if conv is None:
|
||||||
|
return
|
||||||
|
repo.update_message_at(conv["id"], index, {
|
||||||
|
"prompt": question,
|
||||||
|
"response": response,
|
||||||
|
"thought": thought,
|
||||||
|
"sources": sources,
|
||||||
|
"tool_calls": tool_calls,
|
||||||
|
"attachments": attachment_ids,
|
||||||
|
"model_id": model_id,
|
||||||
|
"timestamp": current_time,
|
||||||
|
**({"metadata": metadata} if metadata else {}),
|
||||||
|
})
|
||||||
|
repo.truncate_after(conv["id"], index)
|
||||||
|
|
||||||
|
dual_write(ConversationsRepository, _pg_update_at_index)
|
||||||
return conversation_id
|
return conversation_id
|
||||||
elif conversation_id:
|
elif conversation_id:
|
||||||
# Append new message to existing conversation
|
# Append new message to existing conversation
|
||||||
@@ -138,6 +160,25 @@ class ConversationService:
|
|||||||
|
|
||||||
if result.matched_count == 0:
|
if result.matched_count == 0:
|
||||||
raise ValueError("Conversation not found or unauthorized")
|
raise ValueError("Conversation not found or unauthorized")
|
||||||
|
|
||||||
|
# Dual-write to Postgres: append the same message.
|
||||||
|
def _pg_append(repo: ConversationsRepository) -> None:
|
||||||
|
conv = repo.get_by_legacy_id(conversation_id)
|
||||||
|
if conv is None:
|
||||||
|
return
|
||||||
|
repo.append_message(conv["id"], {
|
||||||
|
"prompt": question,
|
||||||
|
"response": response,
|
||||||
|
"thought": thought,
|
||||||
|
"sources": sources,
|
||||||
|
"tool_calls": tool_calls,
|
||||||
|
"attachments": attachment_ids,
|
||||||
|
"model_id": model_id,
|
||||||
|
"timestamp": current_time,
|
||||||
|
"metadata": metadata or {},
|
||||||
|
})
|
||||||
|
|
||||||
|
dual_write(ConversationsRepository, _pg_append)
|
||||||
return conversation_id
|
return conversation_id
|
||||||
else:
|
else:
|
||||||
# Create new conversation
|
# Create new conversation
|
||||||
@@ -193,7 +234,34 @@ class ConversationService:
|
|||||||
if agent:
|
if agent:
|
||||||
conversation_data["api_key"] = agent["key"]
|
conversation_data["api_key"] = agent["key"]
|
||||||
result = self.conversations_collection.insert_one(conversation_data)
|
result = self.conversations_collection.insert_one(conversation_data)
|
||||||
return str(result.inserted_id)
|
inserted_id = str(result.inserted_id)
|
||||||
|
|
||||||
|
# Dual-write to Postgres: create the conversation row with
|
||||||
|
# legacy_mongo_id and append the first message.
|
||||||
|
def _pg_create(repo: ConversationsRepository) -> None:
|
||||||
|
conv = repo.create(
|
||||||
|
user_id,
|
||||||
|
completion,
|
||||||
|
agent_id=conversation_data.get("agent_id"),
|
||||||
|
api_key=conversation_data.get("api_key"),
|
||||||
|
is_shared_usage=conversation_data.get("is_shared_usage", False),
|
||||||
|
shared_token=conversation_data.get("shared_token"),
|
||||||
|
legacy_mongo_id=inserted_id,
|
||||||
|
)
|
||||||
|
repo.append_message(conv["id"], {
|
||||||
|
"prompt": question,
|
||||||
|
"response": response,
|
||||||
|
"thought": thought,
|
||||||
|
"sources": sources,
|
||||||
|
"tool_calls": tool_calls,
|
||||||
|
"attachments": attachment_ids,
|
||||||
|
"model_id": model_id,
|
||||||
|
"timestamp": current_time,
|
||||||
|
"metadata": metadata or {},
|
||||||
|
})
|
||||||
|
|
||||||
|
dual_write(ConversationsRepository, _pg_create)
|
||||||
|
return inserted_id
|
||||||
|
|
||||||
def update_compression_metadata(
|
def update_compression_metadata(
|
||||||
self, conversation_id: str, compression_metadata: Dict[str, Any]
|
self, conversation_id: str, compression_metadata: Dict[str, Any]
|
||||||
@@ -230,6 +298,24 @@ class ConversationService:
|
|||||||
logger.info(
|
logger.info(
|
||||||
f"Updated compression metadata for conversation {conversation_id}"
|
f"Updated compression metadata for conversation {conversation_id}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Dual-write to Postgres: mirror $set + $push $slice.
|
||||||
|
def _pg_compression(repo: ConversationsRepository) -> None:
|
||||||
|
conv = repo.get_by_legacy_id(conversation_id)
|
||||||
|
if conv is None:
|
||||||
|
return
|
||||||
|
repo.set_compression_flags(
|
||||||
|
conv["id"],
|
||||||
|
is_compressed=True,
|
||||||
|
last_compression_at=compression_metadata.get("timestamp"),
|
||||||
|
)
|
||||||
|
repo.append_compression_point(
|
||||||
|
conv["id"],
|
||||||
|
compression_metadata,
|
||||||
|
max_points=settings.COMPRESSION_MAX_HISTORY_POINTS,
|
||||||
|
)
|
||||||
|
|
||||||
|
dual_write(ConversationsRepository, _pg_compression)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
logger.error(
|
||||||
f"Error updating compression metadata: {str(e)}", exc_info=True
|
f"Error updating compression metadata: {str(e)}", exc_info=True
|
||||||
@@ -266,6 +352,23 @@ class ConversationService:
|
|||||||
}
|
}
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _pg_append_summary(repo: ConversationsRepository) -> None:
|
||||||
|
conv = repo.get_by_legacy_id(conversation_id)
|
||||||
|
if conv is None:
|
||||||
|
return
|
||||||
|
repo.append_message(conv["id"], {
|
||||||
|
"prompt": "[Context Compression Summary]",
|
||||||
|
"response": summary,
|
||||||
|
"thought": "",
|
||||||
|
"sources": [],
|
||||||
|
"tool_calls": [],
|
||||||
|
"attachments": [],
|
||||||
|
"model_id": compression_metadata.get("model_used"),
|
||||||
|
"timestamp": timestamp,
|
||||||
|
})
|
||||||
|
|
||||||
|
dual_write(ConversationsRepository, _pg_append_summary)
|
||||||
logger.info(f"Appended compression summary to conversation {conversation_id}")
|
logger.info(f"Appended compression summary to conversation {conversation_id}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
logger.error(
|
||||||
|
|||||||
@@ -112,6 +112,7 @@ class StreamProcessor:
|
|||||||
self._required_tool_actions: Optional[Dict[str, Set[Optional[str]]]] = None
|
self._required_tool_actions: Optional[Dict[str, Set[Optional[str]]]] = None
|
||||||
self.compressed_summary: Optional[str] = None
|
self.compressed_summary: Optional[str] = None
|
||||||
self.compressed_summary_tokens: int = 0
|
self.compressed_summary_tokens: int = 0
|
||||||
|
self._agent_data: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
def initialize(self):
|
def initialize(self):
|
||||||
"""Initialize all required components for processing"""
|
"""Initialize all required components for processing"""
|
||||||
@@ -359,22 +360,29 @@ class StreamProcessor:
|
|||||||
return data
|
return data
|
||||||
|
|
||||||
def _configure_source(self):
|
def _configure_source(self):
|
||||||
"""Configure the source based on agent data"""
|
"""Configure the source based on agent data.
|
||||||
api_key = self.data.get("api_key") or self.agent_key
|
|
||||||
|
|
||||||
if api_key:
|
The literal string ``"default"`` is a placeholder meaning "no
|
||||||
agent_data = self._get_data_from_api_key(api_key)
|
ingested source" and is normalized to an empty source so that no
|
||||||
|
retrieval is attempted.
|
||||||
|
"""
|
||||||
|
if self._agent_data:
|
||||||
|
agent_data = self._agent_data
|
||||||
|
|
||||||
if agent_data.get("sources") and len(agent_data["sources"]) > 0:
|
if agent_data.get("sources") and len(agent_data["sources"]) > 0:
|
||||||
source_ids = [
|
source_ids = [
|
||||||
source["id"] for source in agent_data["sources"] if source.get("id")
|
source["id"]
|
||||||
|
for source in agent_data["sources"]
|
||||||
|
if source.get("id") and source["id"] != "default"
|
||||||
]
|
]
|
||||||
if source_ids:
|
if source_ids:
|
||||||
self.source = {"active_docs": source_ids}
|
self.source = {"active_docs": source_ids}
|
||||||
else:
|
else:
|
||||||
self.source = {}
|
self.source = {}
|
||||||
self.all_sources = agent_data["sources"]
|
self.all_sources = [
|
||||||
elif agent_data.get("source"):
|
s for s in agent_data["sources"] if s.get("id") != "default"
|
||||||
|
]
|
||||||
|
elif agent_data.get("source") and agent_data["source"] != "default":
|
||||||
self.source = {"active_docs": agent_data["source"]}
|
self.source = {"active_docs": agent_data["source"]}
|
||||||
self.all_sources = [
|
self.all_sources = [
|
||||||
{
|
{
|
||||||
@@ -387,11 +395,24 @@ class StreamProcessor:
|
|||||||
self.all_sources = []
|
self.all_sources = []
|
||||||
return
|
return
|
||||||
if "active_docs" in self.data:
|
if "active_docs" in self.data:
|
||||||
self.source = {"active_docs": self.data["active_docs"]}
|
active_docs = self.data["active_docs"]
|
||||||
|
if active_docs and active_docs != "default":
|
||||||
|
self.source = {"active_docs": active_docs}
|
||||||
|
else:
|
||||||
|
self.source = {}
|
||||||
return
|
return
|
||||||
self.source = {}
|
self.source = {}
|
||||||
self.all_sources = []
|
self.all_sources = []
|
||||||
|
|
||||||
|
def _has_active_docs(self) -> bool:
|
||||||
|
"""Return True if a real document source is configured for retrieval."""
|
||||||
|
active_docs = self.source.get("active_docs") if self.source else None
|
||||||
|
if not active_docs:
|
||||||
|
return False
|
||||||
|
if active_docs == "default":
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
def _resolve_agent_id(self) -> Optional[str]:
|
def _resolve_agent_id(self) -> Optional[str]:
|
||||||
"""Resolve agent_id from request, then fall back to conversation context."""
|
"""Resolve agent_id from request, then fall back to conversation context."""
|
||||||
request_agent_id = self.data.get("agent_id")
|
request_agent_id = self.data.get("agent_id")
|
||||||
@@ -433,48 +454,39 @@ class StreamProcessor:
|
|||||||
effective_key = self.data.get("api_key") or self.agent_key
|
effective_key = self.data.get("api_key") or self.agent_key
|
||||||
|
|
||||||
if effective_key:
|
if effective_key:
|
||||||
data_key = self._get_data_from_api_key(effective_key)
|
self._agent_data = self._get_data_from_api_key(effective_key)
|
||||||
if data_key.get("_id"):
|
if self._agent_data.get("_id"):
|
||||||
self.agent_id = str(data_key.get("_id"))
|
self.agent_id = str(self._agent_data.get("_id"))
|
||||||
|
|
||||||
self.agent_config.update(
|
self.agent_config.update(
|
||||||
{
|
{
|
||||||
"prompt_id": data_key.get("prompt_id", "default"),
|
"prompt_id": self._agent_data.get("prompt_id", "default"),
|
||||||
"agent_type": data_key.get("agent_type", settings.AGENT_NAME),
|
"agent_type": self._agent_data.get("agent_type", settings.AGENT_NAME),
|
||||||
"user_api_key": effective_key,
|
"user_api_key": effective_key,
|
||||||
"json_schema": data_key.get("json_schema"),
|
"json_schema": self._agent_data.get("json_schema"),
|
||||||
"default_model_id": data_key.get("default_model_id", ""),
|
"default_model_id": self._agent_data.get("default_model_id", ""),
|
||||||
"models": data_key.get("models", []),
|
"models": self._agent_data.get("models", []),
|
||||||
|
"allow_system_prompt_override": self._agent_data.get(
|
||||||
|
"allow_system_prompt_override", False
|
||||||
|
),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
# Set identity context
|
# Set identity context
|
||||||
if self.data.get("api_key"):
|
if self.data.get("api_key"):
|
||||||
# External API key: use the key owner's identity
|
# External API key: use the key owner's identity
|
||||||
self.initial_user_id = data_key.get("user")
|
self.initial_user_id = self._agent_data.get("user")
|
||||||
self.decoded_token = {"sub": data_key.get("user")}
|
self.decoded_token = {"sub": self._agent_data.get("user")}
|
||||||
elif self.is_shared_usage:
|
elif self.is_shared_usage:
|
||||||
# Shared agent: keep the caller's identity
|
# Shared agent: keep the caller's identity
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
# Owner using their own agent
|
# Owner using their own agent
|
||||||
self.decoded_token = {"sub": data_key.get("user")}
|
self.decoded_token = {"sub": self._agent_data.get("user")}
|
||||||
|
|
||||||
if data_key.get("source"):
|
if self._agent_data.get("workflow"):
|
||||||
self.source = {"active_docs": data_key["source"]}
|
self.agent_config["workflow"] = self._agent_data["workflow"]
|
||||||
if data_key.get("workflow"):
|
self.agent_config["workflow_owner"] = self._agent_data.get("user")
|
||||||
self.agent_config["workflow"] = data_key["workflow"]
|
|
||||||
self.agent_config["workflow_owner"] = data_key.get("user")
|
|
||||||
if data_key.get("retriever"):
|
|
||||||
self.retriever_config["retriever_name"] = data_key["retriever"]
|
|
||||||
if data_key.get("chunks") is not None:
|
|
||||||
try:
|
|
||||||
self.retriever_config["chunks"] = int(data_key["chunks"])
|
|
||||||
except (ValueError, TypeError):
|
|
||||||
logger.warning(
|
|
||||||
f"Invalid chunks value: {data_key['chunks']}, using default value 2"
|
|
||||||
)
|
|
||||||
self.retriever_config["chunks"] = 2
|
|
||||||
else:
|
else:
|
||||||
# No API key — default/workflow configuration
|
# No API key — default/workflow configuration
|
||||||
agent_type = settings.AGENT_NAME
|
agent_type = settings.AGENT_NAME
|
||||||
@@ -497,14 +509,45 @@ class StreamProcessor:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _configure_retriever(self):
|
def _configure_retriever(self):
|
||||||
|
"""Assemble retriever config with precedence: request > agent > default."""
|
||||||
doc_token_limit = calculate_doc_token_budget(model_id=self.model_id)
|
doc_token_limit = calculate_doc_token_budget(model_id=self.model_id)
|
||||||
|
|
||||||
|
# Start with defaults
|
||||||
|
retriever_name = "classic"
|
||||||
|
chunks = 2
|
||||||
|
|
||||||
|
# Layer agent-level config (if present)
|
||||||
|
if self._agent_data:
|
||||||
|
if self._agent_data.get("retriever"):
|
||||||
|
retriever_name = self._agent_data["retriever"]
|
||||||
|
if self._agent_data.get("chunks") is not None:
|
||||||
|
try:
|
||||||
|
chunks = int(self._agent_data["chunks"])
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
logger.warning(
|
||||||
|
f"Invalid agent chunks value: {self._agent_data['chunks']}, "
|
||||||
|
"using default value 2"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Explicit request values win over agent config
|
||||||
|
if "retriever" in self.data:
|
||||||
|
retriever_name = self.data["retriever"]
|
||||||
|
if "chunks" in self.data:
|
||||||
|
try:
|
||||||
|
chunks = int(self.data["chunks"])
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
logger.warning(
|
||||||
|
f"Invalid request chunks value: {self.data['chunks']}, "
|
||||||
|
"using default value 2"
|
||||||
|
)
|
||||||
|
|
||||||
self.retriever_config = {
|
self.retriever_config = {
|
||||||
"retriever_name": self.data.get("retriever", "classic"),
|
"retriever_name": retriever_name,
|
||||||
"chunks": int(self.data.get("chunks", 2)),
|
"chunks": chunks,
|
||||||
"doc_token_limit": doc_token_limit,
|
"doc_token_limit": doc_token_limit,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# isNoneDoc without an API key forces no retrieval
|
||||||
api_key = self.data.get("api_key") or self.agent_key
|
api_key = self.data.get("api_key") or self.agent_key
|
||||||
if not api_key and "isNoneDoc" in self.data and self.data["isNoneDoc"]:
|
if not api_key and "isNoneDoc" in self.data and self.data["isNoneDoc"]:
|
||||||
self.retriever_config["chunks"] = 0
|
self.retriever_config["chunks"] = 0
|
||||||
@@ -528,6 +571,9 @@ class StreamProcessor:
|
|||||||
if self.data.get("isNoneDoc", False) and not self.agent_id:
|
if self.data.get("isNoneDoc", False) and not self.agent_id:
|
||||||
logger.info("Pre-fetch skipped: isNoneDoc=True")
|
logger.info("Pre-fetch skipped: isNoneDoc=True")
|
||||||
return None, None
|
return None, None
|
||||||
|
if not self._has_active_docs():
|
||||||
|
logger.info("Pre-fetch skipped: no active docs configured")
|
||||||
|
return None, None
|
||||||
try:
|
try:
|
||||||
retriever = self.create_retriever()
|
retriever = self.create_retriever()
|
||||||
logger.info(
|
logger.info(
|
||||||
@@ -771,6 +817,121 @@ class StreamProcessor:
|
|||||||
logger.warning(f"Failed to fetch memory tool data: {str(e)}")
|
logger.warning(f"Failed to fetch memory tool data: {str(e)}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def resume_from_tool_actions(
|
||||||
|
self,
|
||||||
|
tool_actions: list,
|
||||||
|
conversation_id: str,
|
||||||
|
):
|
||||||
|
"""Resume a paused agent from saved continuation state.
|
||||||
|
|
||||||
|
Loads the pending state from MongoDB, recreates the agent with
|
||||||
|
the saved configuration, and returns an agent ready to call
|
||||||
|
``gen_continuation()``.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tool_actions: Client-provided actions (approvals / results).
|
||||||
|
conversation_id: The conversation being resumed.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (agent, messages, tools_dict, pending_tool_calls, tool_actions).
|
||||||
|
"""
|
||||||
|
from application.api.answer.services.continuation_service import (
|
||||||
|
ContinuationService,
|
||||||
|
)
|
||||||
|
from application.agents.agent_creator import AgentCreator
|
||||||
|
from application.agents.tool_executor import ToolExecutor
|
||||||
|
from application.llm.handlers.handler_creator import LLMHandlerCreator
|
||||||
|
from application.llm.llm_creator import LLMCreator
|
||||||
|
|
||||||
|
cont_service = ContinuationService()
|
||||||
|
state = cont_service.load_state(conversation_id, self.initial_user_id)
|
||||||
|
if not state:
|
||||||
|
raise ValueError("No pending tool state found for this conversation")
|
||||||
|
|
||||||
|
messages = state["messages"]
|
||||||
|
pending_tool_calls = state["pending_tool_calls"]
|
||||||
|
tools_dict = state["tools_dict"]
|
||||||
|
tool_schemas = state.get("tool_schemas", [])
|
||||||
|
agent_config = state["agent_config"]
|
||||||
|
|
||||||
|
model_id = agent_config.get("model_id")
|
||||||
|
llm_name = agent_config.get("llm_name", settings.LLM_PROVIDER)
|
||||||
|
api_key = agent_config.get("api_key")
|
||||||
|
user_api_key = agent_config.get("user_api_key")
|
||||||
|
agent_id = agent_config.get("agent_id")
|
||||||
|
prompt = agent_config.get("prompt", "")
|
||||||
|
json_schema = agent_config.get("json_schema")
|
||||||
|
retriever_config = agent_config.get("retriever_config")
|
||||||
|
|
||||||
|
# Recreate dependencies
|
||||||
|
system_api_key = api_key or get_api_key_for_provider(llm_name)
|
||||||
|
llm = LLMCreator.create_llm(
|
||||||
|
llm_name,
|
||||||
|
api_key=system_api_key,
|
||||||
|
user_api_key=user_api_key,
|
||||||
|
decoded_token=self.decoded_token,
|
||||||
|
model_id=model_id,
|
||||||
|
agent_id=agent_id,
|
||||||
|
)
|
||||||
|
llm_handler = LLMHandlerCreator.create_handler(llm_name or "default")
|
||||||
|
tool_executor = ToolExecutor(
|
||||||
|
user_api_key=user_api_key,
|
||||||
|
user=self.initial_user_id,
|
||||||
|
decoded_token=self.decoded_token,
|
||||||
|
)
|
||||||
|
tool_executor.conversation_id = conversation_id
|
||||||
|
# Restore client tools so they stay available for subsequent LLM calls
|
||||||
|
saved_client_tools = state.get("client_tools")
|
||||||
|
if saved_client_tools:
|
||||||
|
tool_executor.client_tools = saved_client_tools
|
||||||
|
# Re-merge into tools_dict (they may have been stripped during serialization)
|
||||||
|
tool_executor.merge_client_tools(tools_dict, saved_client_tools)
|
||||||
|
|
||||||
|
agent_type = agent_config.get("agent_type", "ClassicAgent")
|
||||||
|
# Map class names back to agent creator keys
|
||||||
|
type_map = {
|
||||||
|
"ClassicAgent": "classic",
|
||||||
|
"AgenticAgent": "agentic",
|
||||||
|
"ResearchAgent": "research",
|
||||||
|
"WorkflowAgent": "workflow",
|
||||||
|
}
|
||||||
|
agent_key = type_map.get(agent_type, "classic")
|
||||||
|
|
||||||
|
agent_kwargs = {
|
||||||
|
"endpoint": "stream",
|
||||||
|
"llm_name": llm_name,
|
||||||
|
"model_id": model_id,
|
||||||
|
"api_key": system_api_key,
|
||||||
|
"agent_id": agent_id,
|
||||||
|
"user_api_key": user_api_key,
|
||||||
|
"prompt": prompt,
|
||||||
|
"chat_history": [],
|
||||||
|
"decoded_token": self.decoded_token,
|
||||||
|
"json_schema": json_schema,
|
||||||
|
"llm": llm,
|
||||||
|
"llm_handler": llm_handler,
|
||||||
|
"tool_executor": tool_executor,
|
||||||
|
}
|
||||||
|
|
||||||
|
if agent_key in ("agentic", "research") and retriever_config:
|
||||||
|
agent_kwargs["retriever_config"] = retriever_config
|
||||||
|
|
||||||
|
agent = AgentCreator.create_agent(agent_key, **agent_kwargs)
|
||||||
|
agent.conversation_id = conversation_id
|
||||||
|
agent.initial_user_id = self.initial_user_id
|
||||||
|
agent.tools = tool_schemas
|
||||||
|
|
||||||
|
# Store config for the route layer
|
||||||
|
self.model_id = model_id
|
||||||
|
self.agent_id = agent_id
|
||||||
|
self.agent_config["user_api_key"] = user_api_key
|
||||||
|
self.conversation_id = conversation_id
|
||||||
|
|
||||||
|
# Delete state so it can't be replayed
|
||||||
|
cont_service.delete_state(conversation_id, self.initial_user_id)
|
||||||
|
|
||||||
|
return agent, messages, tools_dict, pending_tool_calls, tool_actions
|
||||||
|
|
||||||
def create_agent(
|
def create_agent(
|
||||||
self,
|
self,
|
||||||
docs_together: Optional[str] = None,
|
docs_together: Optional[str] = None,
|
||||||
@@ -795,15 +956,23 @@ class StreamProcessor:
|
|||||||
raw_prompt = get_prompt(prompt_id, self.prompts_collection)
|
raw_prompt = get_prompt(prompt_id, self.prompts_collection)
|
||||||
self._prompt_content = raw_prompt
|
self._prompt_content = raw_prompt
|
||||||
|
|
||||||
rendered_prompt = self.prompt_renderer.render_prompt(
|
# Allow API callers to override the system prompt when the agent
|
||||||
prompt_content=raw_prompt,
|
# has opted in via allow_system_prompt_override.
|
||||||
user_id=self.initial_user_id,
|
if (
|
||||||
request_id=self.data.get("request_id"),
|
self.agent_config.get("allow_system_prompt_override", False)
|
||||||
passthrough_data=self.data.get("passthrough"),
|
and self.data.get("system_prompt_override")
|
||||||
docs=docs,
|
):
|
||||||
docs_together=docs_together,
|
rendered_prompt = self.data["system_prompt_override"]
|
||||||
tools_data=tools_data,
|
else:
|
||||||
)
|
rendered_prompt = self.prompt_renderer.render_prompt(
|
||||||
|
prompt_content=raw_prompt,
|
||||||
|
user_id=self.initial_user_id,
|
||||||
|
request_id=self.data.get("request_id"),
|
||||||
|
passthrough_data=self.data.get("passthrough"),
|
||||||
|
docs=docs,
|
||||||
|
docs_together=docs_together,
|
||||||
|
tools_data=tools_data,
|
||||||
|
)
|
||||||
|
|
||||||
provider = (
|
provider = (
|
||||||
get_provider_from_model_id(self.model_id)
|
get_provider_from_model_id(self.model_id)
|
||||||
@@ -841,6 +1010,10 @@ class StreamProcessor:
|
|||||||
decoded_token=self.decoded_token,
|
decoded_token=self.decoded_token,
|
||||||
)
|
)
|
||||||
tool_executor.conversation_id = self.conversation_id
|
tool_executor.conversation_id = self.conversation_id
|
||||||
|
# Pass client-side tools so they get merged in get_tools()
|
||||||
|
client_tools = self.data.get("client_tools")
|
||||||
|
if client_tools:
|
||||||
|
tool_executor.client_tools = client_tools
|
||||||
|
|
||||||
# Base agent kwargs
|
# Base agent kwargs
|
||||||
agent_kwargs = {
|
agent_kwargs = {
|
||||||
|
|||||||
@@ -26,12 +26,20 @@ internal = Blueprint("internal", __name__)
|
|||||||
|
|
||||||
@internal.before_request
|
@internal.before_request
|
||||||
def verify_internal_key():
|
def verify_internal_key():
|
||||||
"""Verify INTERNAL_KEY for all internal endpoint requests."""
|
"""Verify INTERNAL_KEY for all internal endpoint requests.
|
||||||
if settings.INTERNAL_KEY:
|
|
||||||
internal_key = request.headers.get("X-Internal-Key")
|
Deny by default: if INTERNAL_KEY is not configured, reject all requests.
|
||||||
if not internal_key or internal_key != settings.INTERNAL_KEY:
|
"""
|
||||||
logger.warning(f"Unauthorized internal API access attempt from {request.remote_addr}")
|
if not settings.INTERNAL_KEY:
|
||||||
return jsonify({"error": "Unauthorized", "message": "Invalid or missing internal key"}), 401
|
logger.warning(
|
||||||
|
f"Internal API request rejected from {request.remote_addr}: "
|
||||||
|
"INTERNAL_KEY is not configured"
|
||||||
|
)
|
||||||
|
return jsonify({"error": "Unauthorized", "message": "Internal API is not configured"}), 401
|
||||||
|
internal_key = request.headers.get("X-Internal-Key")
|
||||||
|
if not internal_key or internal_key != settings.INTERNAL_KEY:
|
||||||
|
logger.warning(f"Unauthorized internal API access attempt from {request.remote_addr}")
|
||||||
|
return jsonify({"error": "Unauthorized", "message": "Invalid or missing internal key"}), 401
|
||||||
|
|
||||||
|
|
||||||
@internal.route("/api/download", methods=["get"])
|
@internal.route("/api/download", methods=["get"])
|
||||||
|
|||||||
@@ -13,6 +13,8 @@ from application.api.user.base import (
|
|||||||
agent_folders_collection,
|
agent_folders_collection,
|
||||||
agents_collection,
|
agents_collection,
|
||||||
)
|
)
|
||||||
|
from application.storage.db.dual_write import dual_write
|
||||||
|
from application.storage.db.repositories.agent_folders import AgentFoldersRepository
|
||||||
|
|
||||||
agents_folders_ns = Namespace(
|
agents_folders_ns = Namespace(
|
||||||
"agents_folders", description="Agent folder management", path="/api/agents/folders"
|
"agents_folders", description="Agent folder management", path="/api/agents/folders"
|
||||||
@@ -83,6 +85,10 @@ class AgentFolders(Resource):
|
|||||||
"updated_at": now,
|
"updated_at": now,
|
||||||
}
|
}
|
||||||
result = agent_folders_collection.insert_one(folder)
|
result = agent_folders_collection.insert_one(folder)
|
||||||
|
dual_write(
|
||||||
|
AgentFoldersRepository,
|
||||||
|
lambda repo, u=user, n=data["name"]: repo.create(u, n),
|
||||||
|
)
|
||||||
return make_response(
|
return make_response(
|
||||||
jsonify({"id": str(result.inserted_id), "name": data["name"], "parent_id": parent_id}),
|
jsonify({"id": str(result.inserted_id), "name": data["name"], "parent_id": parent_id}),
|
||||||
201,
|
201,
|
||||||
@@ -167,6 +173,10 @@ class AgentFolder(Resource):
|
|||||||
{"user": user, "parent_id": folder_id}, {"$unset": {"parent_id": ""}}
|
{"user": user, "parent_id": folder_id}, {"$unset": {"parent_id": ""}}
|
||||||
)
|
)
|
||||||
result = agent_folders_collection.delete_one({"_id": ObjectId(folder_id), "user": user})
|
result = agent_folders_collection.delete_one({"_id": ObjectId(folder_id), "user": user})
|
||||||
|
dual_write(
|
||||||
|
AgentFoldersRepository,
|
||||||
|
lambda repo, fid=folder_id, u=user: repo.delete(fid, u),
|
||||||
|
)
|
||||||
if result.deleted_count == 0:
|
if result.deleted_count == 0:
|
||||||
return make_response(jsonify({"success": False, "message": "Folder not found"}), 404)
|
return make_response(jsonify({"success": False, "message": "Folder not found"}), 404)
|
||||||
return make_response(jsonify({"success": True}), 200)
|
return make_response(jsonify({"success": True}), 200)
|
||||||
|
|||||||
@@ -23,6 +23,9 @@ from application.api.user.base import (
|
|||||||
workflow_nodes_collection,
|
workflow_nodes_collection,
|
||||||
workflows_collection,
|
workflows_collection,
|
||||||
)
|
)
|
||||||
|
from application.storage.db.dual_write import dual_write
|
||||||
|
from application.storage.db.repositories.agents import AgentsRepository
|
||||||
|
from application.storage.db.repositories.users import UsersRepository
|
||||||
from application.core.json_schema_utils import (
|
from application.core.json_schema_utils import (
|
||||||
JsonSchemaValidationError,
|
JsonSchemaValidationError,
|
||||||
normalize_json_schema_payload,
|
normalize_json_schema_payload,
|
||||||
@@ -73,6 +76,7 @@ AGENT_TYPE_SCHEMAS = {
|
|||||||
"token_limit",
|
"token_limit",
|
||||||
"limited_request_mode",
|
"limited_request_mode",
|
||||||
"request_limit",
|
"request_limit",
|
||||||
|
"allow_system_prompt_override",
|
||||||
"createdAt",
|
"createdAt",
|
||||||
"updatedAt",
|
"updatedAt",
|
||||||
"lastUsedAt",
|
"lastUsedAt",
|
||||||
@@ -96,6 +100,7 @@ AGENT_TYPE_SCHEMAS = {
|
|||||||
"token_limit",
|
"token_limit",
|
||||||
"limited_request_mode",
|
"limited_request_mode",
|
||||||
"request_limit",
|
"request_limit",
|
||||||
|
"allow_system_prompt_override",
|
||||||
"createdAt",
|
"createdAt",
|
||||||
"updatedAt",
|
"updatedAt",
|
||||||
"lastUsedAt",
|
"lastUsedAt",
|
||||||
@@ -109,6 +114,35 @@ AGENT_TYPE_SCHEMAS["research"] = AGENT_TYPE_SCHEMAS["classic"]
|
|||||||
AGENT_TYPE_SCHEMAS["openai"] = AGENT_TYPE_SCHEMAS["classic"]
|
AGENT_TYPE_SCHEMAS["openai"] = AGENT_TYPE_SCHEMAS["classic"]
|
||||||
|
|
||||||
|
|
||||||
|
def _build_pg_agent_fields(fields: dict) -> dict:
|
||||||
|
"""Translate Mongo-shaped agent fields into the Postgres mirror subset."""
|
||||||
|
allowed = {
|
||||||
|
"name",
|
||||||
|
"description",
|
||||||
|
"agent_type",
|
||||||
|
"status",
|
||||||
|
"key",
|
||||||
|
"chunks",
|
||||||
|
"retriever",
|
||||||
|
"tools",
|
||||||
|
"json_schema",
|
||||||
|
"models",
|
||||||
|
"default_model_id",
|
||||||
|
"limited_token_mode",
|
||||||
|
"token_limit",
|
||||||
|
"limited_request_mode",
|
||||||
|
"request_limit",
|
||||||
|
"incoming_webhook_token",
|
||||||
|
"lastUsedAt",
|
||||||
|
}
|
||||||
|
translated: dict = {}
|
||||||
|
for key, value in fields.items():
|
||||||
|
if key not in allowed:
|
||||||
|
continue
|
||||||
|
translated["last_used_at" if key == "lastUsedAt" else key] = value
|
||||||
|
return translated
|
||||||
|
|
||||||
|
|
||||||
def normalize_workflow_reference(workflow_value):
|
def normalize_workflow_reference(workflow_value):
|
||||||
"""Normalize workflow references from form/json payloads."""
|
"""Normalize workflow references from form/json payloads."""
|
||||||
if workflow_value is None:
|
if workflow_value is None:
|
||||||
@@ -220,6 +254,12 @@ def build_agent_document(
|
|||||||
base_doc["request_limit"] = int(
|
base_doc["request_limit"] = int(
|
||||||
data.get("request_limit", settings.DEFAULT_AGENT_LIMITS["request_limit"])
|
data.get("request_limit", settings.DEFAULT_AGENT_LIMITS["request_limit"])
|
||||||
)
|
)
|
||||||
|
if "allow_system_prompt_override" in allowed_fields:
|
||||||
|
base_doc["allow_system_prompt_override"] = (
|
||||||
|
data.get("allow_system_prompt_override") == "True"
|
||||||
|
if isinstance(data.get("allow_system_prompt_override"), str)
|
||||||
|
else bool(data.get("allow_system_prompt_override", False))
|
||||||
|
)
|
||||||
return {k: v for k, v in base_doc.items() if k in allowed_fields}
|
return {k: v for k, v in base_doc.items() if k in allowed_fields}
|
||||||
|
|
||||||
|
|
||||||
@@ -292,6 +332,9 @@ class GetAgent(Resource):
|
|||||||
"default_model_id": agent.get("default_model_id", ""),
|
"default_model_id": agent.get("default_model_id", ""),
|
||||||
"folder_id": agent.get("folder_id"),
|
"folder_id": agent.get("folder_id"),
|
||||||
"workflow": agent.get("workflow"),
|
"workflow": agent.get("workflow"),
|
||||||
|
"allow_system_prompt_override": agent.get(
|
||||||
|
"allow_system_prompt_override", False
|
||||||
|
),
|
||||||
}
|
}
|
||||||
return make_response(jsonify(data), 200)
|
return make_response(jsonify(data), 200)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -373,6 +416,9 @@ class GetAgents(Resource):
|
|||||||
"default_model_id": agent.get("default_model_id", ""),
|
"default_model_id": agent.get("default_model_id", ""),
|
||||||
"folder_id": agent.get("folder_id"),
|
"folder_id": agent.get("folder_id"),
|
||||||
"workflow": agent.get("workflow"),
|
"workflow": agent.get("workflow"),
|
||||||
|
"allow_system_prompt_override": agent.get(
|
||||||
|
"allow_system_prompt_override", False
|
||||||
|
),
|
||||||
}
|
}
|
||||||
for agent in agents
|
for agent in agents
|
||||||
if "source" in agent
|
if "source" in agent
|
||||||
@@ -450,6 +496,10 @@ class CreateAgent(Resource):
|
|||||||
"folder_id": fields.String(
|
"folder_id": fields.String(
|
||||||
required=False, description="Folder ID to organize the agent"
|
required=False, description="Folder ID to organize the agent"
|
||||||
),
|
),
|
||||||
|
"allow_system_prompt_override": fields.Boolean(
|
||||||
|
required=False,
|
||||||
|
description="Allow API callers to override the system prompt via the v1 endpoint",
|
||||||
|
),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -491,9 +541,9 @@ class CreateAgent(Resource):
|
|||||||
data["json_schema"] = normalize_json_schema_payload(
|
data["json_schema"] = normalize_json_schema_payload(
|
||||||
data.get("json_schema")
|
data.get("json_schema")
|
||||||
)
|
)
|
||||||
except JsonSchemaValidationError as exc:
|
except JsonSchemaValidationError:
|
||||||
return make_response(
|
return make_response(
|
||||||
jsonify({"success": False, "message": f"JSON schema {exc}"}),
|
jsonify({"success": False, "message": "Invalid JSON schema"}),
|
||||||
400,
|
400,
|
||||||
)
|
)
|
||||||
if data.get("status") not in ["draft", "published"]:
|
if data.get("status") not in ["draft", "published"]:
|
||||||
@@ -603,6 +653,18 @@ class CreateAgent(Resource):
|
|||||||
new_agent["retriever"] = "classic"
|
new_agent["retriever"] = "classic"
|
||||||
resp = agents_collection.insert_one(new_agent)
|
resp = agents_collection.insert_one(new_agent)
|
||||||
new_id = str(resp.inserted_id)
|
new_id = str(resp.inserted_id)
|
||||||
|
dual_write(
|
||||||
|
AgentsRepository,
|
||||||
|
lambda repo, u=user, a=new_agent, mid=new_id: repo.create(
|
||||||
|
u, a.get("name", ""), a.get("status", "draft"),
|
||||||
|
key=a.get("key"), description=a.get("description"),
|
||||||
|
retriever=a.get("retriever"), chunks=a.get("chunks"),
|
||||||
|
tools=a.get("tools"), models=a.get("models"),
|
||||||
|
shared=a.get("shared", False),
|
||||||
|
incoming_webhook_token=a.get("incoming_webhook_token"),
|
||||||
|
legacy_mongo_id=mid,
|
||||||
|
),
|
||||||
|
)
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
current_app.logger.error(f"Error creating agent: {err}", exc_info=True)
|
current_app.logger.error(f"Error creating agent: {err}", exc_info=True)
|
||||||
return make_response(jsonify({"success": False}), 400)
|
return make_response(jsonify({"success": False}), 400)
|
||||||
@@ -674,6 +736,10 @@ class UpdateAgent(Resource):
|
|||||||
"folder_id": fields.String(
|
"folder_id": fields.String(
|
||||||
required=False, description="Folder ID to organize the agent"
|
required=False, description="Folder ID to organize the agent"
|
||||||
),
|
),
|
||||||
|
"allow_system_prompt_override": fields.Boolean(
|
||||||
|
required=False,
|
||||||
|
description="Allow API callers to override the system prompt via the v1 endpoint",
|
||||||
|
),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -765,6 +831,7 @@ class UpdateAgent(Resource):
|
|||||||
"default_model_id",
|
"default_model_id",
|
||||||
"folder_id",
|
"folder_id",
|
||||||
"workflow",
|
"workflow",
|
||||||
|
"allow_system_prompt_override",
|
||||||
]
|
]
|
||||||
|
|
||||||
for field in allowed_fields:
|
for field in allowed_fields:
|
||||||
@@ -872,9 +939,9 @@ class UpdateAgent(Resource):
|
|||||||
update_fields[field] = normalize_json_schema_payload(
|
update_fields[field] = normalize_json_schema_payload(
|
||||||
json_schema
|
json_schema
|
||||||
)
|
)
|
||||||
except JsonSchemaValidationError as exc:
|
except JsonSchemaValidationError:
|
||||||
return make_response(
|
return make_response(
|
||||||
jsonify({"success": False, "message": f"JSON schema {exc}"}),
|
jsonify({"success": False, "message": "Invalid JSON schema"}),
|
||||||
400,
|
400,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@@ -983,6 +1050,13 @@ class UpdateAgent(Resource):
|
|||||||
if workflow_error:
|
if workflow_error:
|
||||||
return workflow_error
|
return workflow_error
|
||||||
update_fields[field] = workflow_id
|
update_fields[field] = workflow_id
|
||||||
|
elif field == "allow_system_prompt_override":
|
||||||
|
raw_value = data.get("allow_system_prompt_override", False)
|
||||||
|
update_fields[field] = (
|
||||||
|
raw_value == "True"
|
||||||
|
if isinstance(raw_value, str)
|
||||||
|
else bool(raw_value)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
value = data[field]
|
value = data[field]
|
||||||
if field in ["name", "description", "prompt_id", "agent_type"]:
|
if field in ["name", "description", "prompt_id", "agent_type"]:
|
||||||
@@ -1126,6 +1200,14 @@ class UpdateAgent(Resource):
|
|||||||
jsonify({"success": False, "message": "Database error during update"}),
|
jsonify({"success": False, "message": "Database error during update"}),
|
||||||
500,
|
500,
|
||||||
)
|
)
|
||||||
|
pg_update_fields = _build_pg_agent_fields(update_fields)
|
||||||
|
if pg_update_fields:
|
||||||
|
dual_write(
|
||||||
|
AgentsRepository,
|
||||||
|
lambda repo, aid=agent_id, u=user, fields=pg_update_fields: repo.update_by_legacy_id(
|
||||||
|
aid, u, fields,
|
||||||
|
),
|
||||||
|
)
|
||||||
response_data = {
|
response_data = {
|
||||||
"success": True,
|
"success": True,
|
||||||
"id": agent_id,
|
"id": agent_id,
|
||||||
@@ -1153,6 +1235,10 @@ class DeleteAgent(Resource):
|
|||||||
deleted_agent = agents_collection.find_one_and_delete(
|
deleted_agent = agents_collection.find_one_and_delete(
|
||||||
{"_id": ObjectId(agent_id), "user": user}
|
{"_id": ObjectId(agent_id), "user": user}
|
||||||
)
|
)
|
||||||
|
dual_write(
|
||||||
|
AgentsRepository,
|
||||||
|
lambda repo, aid=agent_id, u=user: repo.delete_by_legacy_id(aid, u),
|
||||||
|
)
|
||||||
if not deleted_agent:
|
if not deleted_agent:
|
||||||
return make_response(
|
return make_response(
|
||||||
jsonify({"success": False, "message": "Agent not found"}), 404
|
jsonify({"success": False, "message": "Agent not found"}), 404
|
||||||
@@ -1220,6 +1306,9 @@ class PinnedAgents(Resource):
|
|||||||
{"user_id": user_id},
|
{"user_id": user_id},
|
||||||
{"$pullAll": {"agent_preferences.pinned": stale_ids}},
|
{"$pullAll": {"agent_preferences.pinned": stale_ids}},
|
||||||
)
|
)
|
||||||
|
dual_write(UsersRepository,
|
||||||
|
lambda repo, uid=user_id, ids=stale_ids: repo.remove_pinned_bulk(uid, ids)
|
||||||
|
)
|
||||||
list_pinned_agents = [
|
list_pinned_agents = [
|
||||||
{
|
{
|
||||||
"id": str(agent["_id"]),
|
"id": str(agent["_id"]),
|
||||||
@@ -1351,12 +1440,18 @@ class PinAgent(Resource):
|
|||||||
{"user_id": user_id},
|
{"user_id": user_id},
|
||||||
{"$pull": {"agent_preferences.pinned": agent_id}},
|
{"$pull": {"agent_preferences.pinned": agent_id}},
|
||||||
)
|
)
|
||||||
|
dual_write(UsersRepository,
|
||||||
|
lambda repo, uid=user_id, aid=agent_id: repo.remove_pinned(uid, aid)
|
||||||
|
)
|
||||||
action = "unpinned"
|
action = "unpinned"
|
||||||
else:
|
else:
|
||||||
users_collection.update_one(
|
users_collection.update_one(
|
||||||
{"user_id": user_id},
|
{"user_id": user_id},
|
||||||
{"$addToSet": {"agent_preferences.pinned": agent_id}},
|
{"$addToSet": {"agent_preferences.pinned": agent_id}},
|
||||||
)
|
)
|
||||||
|
dual_write(UsersRepository,
|
||||||
|
lambda repo, uid=user_id, aid=agent_id: repo.add_pinned(uid, aid)
|
||||||
|
)
|
||||||
action = "pinned"
|
action = "pinned"
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
current_app.logger.error(f"Error pinning/unpinning agent: {err}")
|
current_app.logger.error(f"Error pinning/unpinning agent: {err}")
|
||||||
@@ -1402,6 +1497,9 @@ class RemoveSharedAgent(Resource):
|
|||||||
}
|
}
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
dual_write(UsersRepository,
|
||||||
|
lambda repo, uid=user_id, aid=agent_id: repo.remove_agent_from_all(uid, aid)
|
||||||
|
)
|
||||||
|
|
||||||
return make_response(jsonify({"success": True, "action": "removed"}), 200)
|
return make_response(jsonify({"success": True, "action": "removed"}), 200)
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
|
|||||||
@@ -18,6 +18,8 @@ from application.api.user.base import (
|
|||||||
user_tools_collection,
|
user_tools_collection,
|
||||||
users_collection,
|
users_collection,
|
||||||
)
|
)
|
||||||
|
from application.storage.db.dual_write import dual_write
|
||||||
|
from application.storage.db.repositories.users import UsersRepository
|
||||||
from application.utils import generate_image_url
|
from application.utils import generate_image_url
|
||||||
|
|
||||||
agents_sharing_ns = Namespace(
|
agents_sharing_ns = Namespace(
|
||||||
@@ -105,6 +107,9 @@ class SharedAgent(Resource):
|
|||||||
{"user_id": user_id},
|
{"user_id": user_id},
|
||||||
{"$addToSet": {"agent_preferences.shared_with_me": agent_id}},
|
{"$addToSet": {"agent_preferences.shared_with_me": agent_id}},
|
||||||
)
|
)
|
||||||
|
dual_write(UsersRepository,
|
||||||
|
lambda repo, uid=user_id, aid=agent_id: repo.add_shared(uid, aid)
|
||||||
|
)
|
||||||
return make_response(jsonify(data), 200)
|
return make_response(jsonify(data), 200)
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
current_app.logger.error(f"Error retrieving shared agent: {err}")
|
current_app.logger.error(f"Error retrieving shared agent: {err}")
|
||||||
@@ -139,6 +144,9 @@ class SharedAgents(Resource):
|
|||||||
{"user_id": user_id},
|
{"user_id": user_id},
|
||||||
{"$pullAll": {"agent_preferences.shared_with_me": stale_ids}},
|
{"$pullAll": {"agent_preferences.shared_with_me": stale_ids}},
|
||||||
)
|
)
|
||||||
|
dual_write(UsersRepository,
|
||||||
|
lambda repo, uid=user_id, ids=stale_ids: repo.remove_shared_bulk(uid, ids)
|
||||||
|
)
|
||||||
pinned_ids = set(user_doc.get("agent_preferences", {}).get("pinned", []))
|
pinned_ids = set(user_doc.get("agent_preferences", {}).get("pinned", []))
|
||||||
|
|
||||||
list_shared_agents = [
|
list_shared_agents = [
|
||||||
|
|||||||
@@ -612,6 +612,10 @@ class LiveSpeechToTextFinish(Resource):
|
|||||||
class ServeImage(Resource):
|
class ServeImage(Resource):
|
||||||
@api.doc(description="Serve an image from storage")
|
@api.doc(description="Serve an image from storage")
|
||||||
def get(self, image_path):
|
def get(self, image_path):
|
||||||
|
if ".." in image_path or image_path.startswith("/") or "\x00" in image_path:
|
||||||
|
return make_response(
|
||||||
|
jsonify({"success": False, "message": "Invalid image path"}), 400
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
from application.api.user.base import storage
|
from application.api.user.base import storage
|
||||||
|
|
||||||
@@ -629,6 +633,10 @@ class ServeImage(Resource):
|
|||||||
return make_response(
|
return make_response(
|
||||||
jsonify({"success": False, "message": "Image not found"}), 404
|
jsonify({"success": False, "message": "Image not found"}), 404
|
||||||
)
|
)
|
||||||
|
except ValueError:
|
||||||
|
return make_response(
|
||||||
|
jsonify({"success": False, "message": "Invalid image path"}), 400
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
current_app.logger.error(f"Error serving image: {e}")
|
current_app.logger.error(f"Error serving image: {e}")
|
||||||
return make_response(
|
return make_response(
|
||||||
|
|||||||
@@ -15,6 +15,8 @@ from werkzeug.utils import secure_filename
|
|||||||
|
|
||||||
from application.core.mongo_db import MongoDB
|
from application.core.mongo_db import MongoDB
|
||||||
from application.core.settings import settings
|
from application.core.settings import settings
|
||||||
|
from application.storage.db.dual_write import dual_write
|
||||||
|
from application.storage.db.repositories.users import UsersRepository
|
||||||
from application.storage.storage_creator import StorageCreator
|
from application.storage.storage_creator import StorageCreator
|
||||||
from application.vectorstore.vector_creator import VectorCreator
|
from application.vectorstore.vector_creator import VectorCreator
|
||||||
|
|
||||||
@@ -132,6 +134,9 @@ def ensure_user_doc(user_id):
|
|||||||
if updates:
|
if updates:
|
||||||
users_collection.update_one({"user_id": user_id}, {"$set": updates})
|
users_collection.update_one({"user_id": user_id}, {"$set": updates})
|
||||||
user_doc = users_collection.find_one({"user_id": user_id})
|
user_doc = users_collection.find_one({"user_id": user_id})
|
||||||
|
|
||||||
|
dual_write(UsersRepository, lambda repo: repo.upsert(user_id))
|
||||||
|
|
||||||
return user_doc
|
return user_doc
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -8,6 +8,8 @@ from flask_restx import fields, Namespace, Resource
|
|||||||
|
|
||||||
from application.api import api
|
from application.api import api
|
||||||
from application.api.user.base import attachments_collection, conversations_collection
|
from application.api.user.base import attachments_collection, conversations_collection
|
||||||
|
from application.storage.db.dual_write import dual_write
|
||||||
|
from application.storage.db.repositories.conversations import ConversationsRepository
|
||||||
from application.utils import check_required_fields
|
from application.utils import check_required_fields
|
||||||
|
|
||||||
conversations_ns = Namespace(
|
conversations_ns = Namespace(
|
||||||
@@ -30,15 +32,23 @@ class DeleteConversation(Resource):
|
|||||||
return make_response(
|
return make_response(
|
||||||
jsonify({"success": False, "message": "ID is required"}), 400
|
jsonify({"success": False, "message": "ID is required"}), 400
|
||||||
)
|
)
|
||||||
|
user_id = decoded_token["sub"]
|
||||||
try:
|
try:
|
||||||
conversations_collection.delete_one(
|
conversations_collection.delete_one(
|
||||||
{"_id": ObjectId(conversation_id), "user": decoded_token["sub"]}
|
{"_id": ObjectId(conversation_id), "user": user_id}
|
||||||
)
|
)
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
current_app.logger.error(
|
current_app.logger.error(
|
||||||
f"Error deleting conversation: {err}", exc_info=True
|
f"Error deleting conversation: {err}", exc_info=True
|
||||||
)
|
)
|
||||||
return make_response(jsonify({"success": False}), 400)
|
return make_response(jsonify({"success": False}), 400)
|
||||||
|
|
||||||
|
def _pg_delete(repo: ConversationsRepository) -> None:
|
||||||
|
conv = repo.get_by_legacy_id(conversation_id)
|
||||||
|
if conv is not None:
|
||||||
|
repo.delete(conv["id"], user_id)
|
||||||
|
|
||||||
|
dual_write(ConversationsRepository, _pg_delete)
|
||||||
return make_response(jsonify({"success": True}), 200)
|
return make_response(jsonify({"success": True}), 200)
|
||||||
|
|
||||||
|
|
||||||
@@ -59,6 +69,11 @@ class DeleteAllConversations(Resource):
|
|||||||
f"Error deleting all conversations: {err}", exc_info=True
|
f"Error deleting all conversations: {err}", exc_info=True
|
||||||
)
|
)
|
||||||
return make_response(jsonify({"success": False}), 400)
|
return make_response(jsonify({"success": False}), 400)
|
||||||
|
|
||||||
|
dual_write(
|
||||||
|
ConversationsRepository,
|
||||||
|
lambda r, uid=user_id: r.delete_all_for_user(uid),
|
||||||
|
)
|
||||||
return make_response(jsonify({"success": True}), 200)
|
return make_response(jsonify({"success": True}), 200)
|
||||||
|
|
||||||
|
|
||||||
@@ -190,9 +205,10 @@ class UpdateConversationName(Resource):
|
|||||||
missing_fields = check_required_fields(data, required_fields)
|
missing_fields = check_required_fields(data, required_fields)
|
||||||
if missing_fields:
|
if missing_fields:
|
||||||
return missing_fields
|
return missing_fields
|
||||||
|
user_id = decoded_token.get("sub")
|
||||||
try:
|
try:
|
||||||
conversations_collection.update_one(
|
conversations_collection.update_one(
|
||||||
{"_id": ObjectId(data["id"]), "user": decoded_token.get("sub")},
|
{"_id": ObjectId(data["id"]), "user": user_id},
|
||||||
{"$set": {"name": data["name"]}},
|
{"$set": {"name": data["name"]}},
|
||||||
)
|
)
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
@@ -200,6 +216,13 @@ class UpdateConversationName(Resource):
|
|||||||
f"Error updating conversation name: {err}", exc_info=True
|
f"Error updating conversation name: {err}", exc_info=True
|
||||||
)
|
)
|
||||||
return make_response(jsonify({"success": False}), 400)
|
return make_response(jsonify({"success": False}), 400)
|
||||||
|
|
||||||
|
def _pg_rename(repo: ConversationsRepository) -> None:
|
||||||
|
conv = repo.get_by_legacy_id(data["id"])
|
||||||
|
if conv is not None:
|
||||||
|
repo.rename(conv["id"], user_id, data["name"])
|
||||||
|
|
||||||
|
dual_write(ConversationsRepository, _pg_rename)
|
||||||
return make_response(jsonify({"success": True}), 200)
|
return make_response(jsonify({"success": True}), 200)
|
||||||
|
|
||||||
|
|
||||||
@@ -277,4 +300,21 @@ class SubmitFeedback(Resource):
|
|||||||
except Exception as err:
|
except Exception as err:
|
||||||
current_app.logger.error(f"Error submitting feedback: {err}", exc_info=True)
|
current_app.logger.error(f"Error submitting feedback: {err}", exc_info=True)
|
||||||
return make_response(jsonify({"success": False}), 400)
|
return make_response(jsonify({"success": False}), 400)
|
||||||
|
|
||||||
|
# Dual-write to Postgres: mirror the per-message feedback set/unset.
|
||||||
|
feedback_value = data["feedback"]
|
||||||
|
question_index = int(data["question_index"])
|
||||||
|
feedback_payload = (
|
||||||
|
None if feedback_value is None
|
||||||
|
else {"text": feedback_value, "timestamp": datetime.datetime.now(
|
||||||
|
datetime.timezone.utc
|
||||||
|
).isoformat()}
|
||||||
|
)
|
||||||
|
|
||||||
|
def _pg_feedback(repo: ConversationsRepository) -> None:
|
||||||
|
conv = repo.get_by_legacy_id(data["conversation_id"])
|
||||||
|
if conv is not None:
|
||||||
|
repo.set_feedback(conv["id"], question_index, feedback_payload)
|
||||||
|
|
||||||
|
dual_write(ConversationsRepository, _pg_feedback)
|
||||||
return make_response(jsonify({"success": True}), 200)
|
return make_response(jsonify({"success": True}), 200)
|
||||||
|
|||||||
@@ -8,6 +8,8 @@ from flask_restx import fields, Namespace, Resource
|
|||||||
|
|
||||||
from application.api import api
|
from application.api import api
|
||||||
from application.api.user.base import current_dir, prompts_collection
|
from application.api.user.base import current_dir, prompts_collection
|
||||||
|
from application.storage.db.dual_write import dual_write
|
||||||
|
from application.storage.db.repositories.prompts import PromptsRepository
|
||||||
from application.utils import check_required_fields
|
from application.utils import check_required_fields
|
||||||
|
|
||||||
prompts_ns = Namespace(
|
prompts_ns = Namespace(
|
||||||
@@ -49,6 +51,12 @@ class CreatePrompt(Resource):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
new_id = str(resp.inserted_id)
|
new_id = str(resp.inserted_id)
|
||||||
|
dual_write(
|
||||||
|
PromptsRepository,
|
||||||
|
lambda repo, u=user, n=data["name"], c=data["content"], mid=new_id: repo.create(
|
||||||
|
u, n, c, legacy_mongo_id=mid,
|
||||||
|
),
|
||||||
|
)
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
current_app.logger.error(f"Error creating prompt: {err}", exc_info=True)
|
current_app.logger.error(f"Error creating prompt: {err}", exc_info=True)
|
||||||
return make_response(jsonify({"success": False}), 400)
|
return make_response(jsonify({"success": False}), 400)
|
||||||
@@ -149,6 +157,10 @@ class DeletePrompt(Resource):
|
|||||||
return missing_fields
|
return missing_fields
|
||||||
try:
|
try:
|
||||||
prompts_collection.delete_one({"_id": ObjectId(data["id"]), "user": user})
|
prompts_collection.delete_one({"_id": ObjectId(data["id"]), "user": user})
|
||||||
|
dual_write(
|
||||||
|
PromptsRepository,
|
||||||
|
lambda repo, pid=data["id"], u=user: repo.delete_by_legacy_id(pid, u),
|
||||||
|
)
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
current_app.logger.error(f"Error deleting prompt: {err}", exc_info=True)
|
current_app.logger.error(f"Error deleting prompt: {err}", exc_info=True)
|
||||||
return make_response(jsonify({"success": False}), 400)
|
return make_response(jsonify({"success": False}), 400)
|
||||||
@@ -185,6 +197,12 @@ class UpdatePrompt(Resource):
|
|||||||
{"_id": ObjectId(data["id"]), "user": user},
|
{"_id": ObjectId(data["id"]), "user": user},
|
||||||
{"$set": {"name": data["name"], "content": data["content"]}},
|
{"$set": {"name": data["name"], "content": data["content"]}},
|
||||||
)
|
)
|
||||||
|
dual_write(
|
||||||
|
PromptsRepository,
|
||||||
|
lambda repo, pid=data["id"], u=user, n=data["name"], c=data["content"]: repo.update_by_legacy_id(
|
||||||
|
pid, u, n, c,
|
||||||
|
),
|
||||||
|
)
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
current_app.logger.error(f"Error updating prompt: {err}", exc_info=True)
|
current_app.logger.error(f"Error updating prompt: {err}", exc_info=True)
|
||||||
return make_response(jsonify({"success": False}), 400)
|
return make_response(jsonify({"success": False}), 400)
|
||||||
|
|||||||
@@ -15,8 +15,71 @@ from application.api.user.base import (
|
|||||||
conversations_collection,
|
conversations_collection,
|
||||||
shared_conversations_collections,
|
shared_conversations_collections,
|
||||||
)
|
)
|
||||||
|
from application.storage.db.dual_write import dual_write
|
||||||
|
from application.storage.db.repositories.conversations import ConversationsRepository
|
||||||
|
from application.storage.db.repositories.shared_conversations import (
|
||||||
|
SharedConversationsRepository,
|
||||||
|
)
|
||||||
from application.utils import check_required_fields
|
from application.utils import check_required_fields
|
||||||
|
|
||||||
|
|
||||||
|
def _dual_write_share(
|
||||||
|
mongo_conv_id: str,
|
||||||
|
share_uuid: str,
|
||||||
|
user: str,
|
||||||
|
*,
|
||||||
|
is_promptable: bool,
|
||||||
|
first_n_queries: int,
|
||||||
|
api_key: str | None,
|
||||||
|
prompt_id: str | None = None,
|
||||||
|
chunks: int | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Mirror a Mongo share-record insert into Postgres.
|
||||||
|
|
||||||
|
Preserves the Mongo-generated UUID so public ``/shared/{uuid}`` URLs
|
||||||
|
resolve from both stores during cutover.
|
||||||
|
"""
|
||||||
|
def _write(repo: SharedConversationsRepository) -> None:
|
||||||
|
conv = ConversationsRepository(repo._conn).get_by_legacy_id(
|
||||||
|
mongo_conv_id, user_id=user,
|
||||||
|
)
|
||||||
|
if conv is None:
|
||||||
|
return
|
||||||
|
# prompt_id / chunks are only meaningful for promptable shares;
|
||||||
|
# prompt_id is often the string "default" or an ObjectId that
|
||||||
|
# hasn't been migrated — pass as-is and let the repo drop
|
||||||
|
# non-UUID values. Scope the prompt lookup by user_id so an
|
||||||
|
# authenticated caller can't link another user's prompt into
|
||||||
|
# their share record.
|
||||||
|
resolved_prompt_id = None
|
||||||
|
if prompt_id and len(str(prompt_id)) == 24:
|
||||||
|
from sqlalchemy import text as _text
|
||||||
|
row = repo._conn.execute(
|
||||||
|
_text(
|
||||||
|
"SELECT id FROM prompts "
|
||||||
|
"WHERE legacy_mongo_id = :legacy_id AND user_id = :user_id"
|
||||||
|
),
|
||||||
|
{"legacy_id": str(prompt_id), "user_id": user},
|
||||||
|
).fetchone()
|
||||||
|
if row:
|
||||||
|
resolved_prompt_id = str(row[0])
|
||||||
|
# get_or_create is race-free on the PG side thanks to the
|
||||||
|
# composite partial unique index on the dedup tuple
|
||||||
|
# (migration 0008). It converges concurrent share requests to
|
||||||
|
# a single row.
|
||||||
|
repo.get_or_create(
|
||||||
|
conv["id"],
|
||||||
|
user,
|
||||||
|
is_promptable=is_promptable,
|
||||||
|
first_n_queries=first_n_queries,
|
||||||
|
api_key=api_key,
|
||||||
|
prompt_id=resolved_prompt_id,
|
||||||
|
chunks=chunks,
|
||||||
|
share_uuid=share_uuid,
|
||||||
|
)
|
||||||
|
|
||||||
|
dual_write(SharedConversationsRepository, _write)
|
||||||
|
|
||||||
sharing_ns = Namespace(
|
sharing_ns = Namespace(
|
||||||
"sharing", description="Conversation sharing operations", path="/api"
|
"sharing", description="Conversation sharing operations", path="/api"
|
||||||
)
|
)
|
||||||
@@ -57,7 +120,7 @@ class ShareConversation(Resource):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
conversation = conversations_collection.find_one(
|
conversation = conversations_collection.find_one(
|
||||||
{"_id": ObjectId(conversation_id)}
|
{"_id": ObjectId(conversation_id), "user": user}
|
||||||
)
|
)
|
||||||
if conversation is None:
|
if conversation is None:
|
||||||
return make_response(
|
return make_response(
|
||||||
@@ -124,6 +187,16 @@ class ShareConversation(Resource):
|
|||||||
"api_key": api_uuid,
|
"api_key": api_uuid,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
_dual_write_share(
|
||||||
|
conversation_id,
|
||||||
|
str(explicit_binary.as_uuid()),
|
||||||
|
user,
|
||||||
|
is_promptable=is_promptable,
|
||||||
|
first_n_queries=current_n_queries,
|
||||||
|
api_key=api_uuid,
|
||||||
|
prompt_id=prompt_id,
|
||||||
|
chunks=int(chunks) if chunks else None,
|
||||||
|
)
|
||||||
return make_response(
|
return make_response(
|
||||||
jsonify(
|
jsonify(
|
||||||
{
|
{
|
||||||
@@ -155,6 +228,16 @@ class ShareConversation(Resource):
|
|||||||
"api_key": api_uuid,
|
"api_key": api_uuid,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
_dual_write_share(
|
||||||
|
conversation_id,
|
||||||
|
str(explicit_binary.as_uuid()),
|
||||||
|
user,
|
||||||
|
is_promptable=is_promptable,
|
||||||
|
first_n_queries=current_n_queries,
|
||||||
|
api_key=api_uuid,
|
||||||
|
prompt_id=prompt_id,
|
||||||
|
chunks=int(chunks) if chunks else None,
|
||||||
|
)
|
||||||
return make_response(
|
return make_response(
|
||||||
jsonify(
|
jsonify(
|
||||||
{
|
{
|
||||||
@@ -192,6 +275,14 @@ class ShareConversation(Resource):
|
|||||||
"user": user,
|
"user": user,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
_dual_write_share(
|
||||||
|
conversation_id,
|
||||||
|
str(explicit_binary.as_uuid()),
|
||||||
|
user,
|
||||||
|
is_promptable=is_promptable,
|
||||||
|
first_n_queries=current_n_queries,
|
||||||
|
api_key=None,
|
||||||
|
)
|
||||||
return make_response(
|
return make_response(
|
||||||
jsonify(
|
jsonify(
|
||||||
{"success": True, "identifier": str(explicit_binary.as_uuid())}
|
{"success": True, "identifier": str(explicit_binary.as_uuid())}
|
||||||
|
|||||||
@@ -463,6 +463,16 @@ class ManageSourceFiles(Resource):
|
|||||||
removed_files = []
|
removed_files = []
|
||||||
map_updated = False
|
map_updated = False
|
||||||
for file_path in file_paths:
|
for file_path in file_paths:
|
||||||
|
if ".." in str(file_path) or str(file_path).startswith("/"):
|
||||||
|
return make_response(
|
||||||
|
jsonify(
|
||||||
|
{
|
||||||
|
"success": False,
|
||||||
|
"message": "Invalid file path",
|
||||||
|
}
|
||||||
|
),
|
||||||
|
400,
|
||||||
|
)
|
||||||
full_path = f"{source_file_path}/{file_path}"
|
full_path = f"{source_file_path}/{file_path}"
|
||||||
|
|
||||||
# Remove from storage
|
# Remove from storage
|
||||||
|
|||||||
@@ -134,6 +134,12 @@ def setup_periodic_tasks(sender, **kwargs):
|
|||||||
timedelta(days=30),
|
timedelta(days=30),
|
||||||
schedule_syncs.s("monthly"),
|
schedule_syncs.s("monthly"),
|
||||||
)
|
)
|
||||||
|
# Replaces Mongo's TTL index on pending_tool_state.expires_at.
|
||||||
|
sender.add_periodic_task(
|
||||||
|
timedelta(seconds=60),
|
||||||
|
cleanup_pending_tool_state.s(),
|
||||||
|
name="cleanup-pending-tool-state",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@celery.task(bind=True)
|
@celery.task(bind=True)
|
||||||
@@ -146,3 +152,27 @@ def mcp_oauth_task(self, config, user):
|
|||||||
def mcp_oauth_status_task(self, task_id):
|
def mcp_oauth_status_task(self, task_id):
|
||||||
resp = mcp_oauth_status(self, task_id)
|
resp = mcp_oauth_status(self, task_id)
|
||||||
return resp
|
return resp
|
||||||
|
|
||||||
|
|
||||||
|
@celery.task(bind=True)
|
||||||
|
def cleanup_pending_tool_state(self):
|
||||||
|
"""Delete pending_tool_state rows past their TTL.
|
||||||
|
|
||||||
|
Replaces Mongo's ``expireAfterSeconds=0`` TTL index — Postgres has
|
||||||
|
no native TTL, so this task runs every 60 seconds to keep
|
||||||
|
``pending_tool_state`` bounded. No-ops if ``POSTGRES_URI`` isn't
|
||||||
|
configured (keeps the task runnable in Mongo-only environments).
|
||||||
|
"""
|
||||||
|
from application.core.settings import settings
|
||||||
|
if not settings.POSTGRES_URI:
|
||||||
|
return {"deleted": 0, "skipped": "POSTGRES_URI not set"}
|
||||||
|
|
||||||
|
from application.storage.db.engine import get_engine
|
||||||
|
from application.storage.db.repositories.pending_tool_state import (
|
||||||
|
PendingToolStateRepository,
|
||||||
|
)
|
||||||
|
|
||||||
|
engine = get_engine()
|
||||||
|
with engine.begin() as conn:
|
||||||
|
deleted = PendingToolStateRepository(conn).cleanup_expired()
|
||||||
|
return {"deleted": deleted}
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ from application.api.user.tools.routes import transform_actions
|
|||||||
from application.cache import get_redis_instance
|
from application.cache import get_redis_instance
|
||||||
from application.core.mongo_db import MongoDB
|
from application.core.mongo_db import MongoDB
|
||||||
from application.core.settings import settings
|
from application.core.settings import settings
|
||||||
|
from application.core.url_validation import SSRFError, validate_url
|
||||||
from application.security.encryption import decrypt_credentials, encrypt_credentials
|
from application.security.encryption import decrypt_credentials, encrypt_credentials
|
||||||
from application.utils import check_required_fields
|
from application.utils import check_required_fields
|
||||||
|
|
||||||
@@ -63,6 +64,21 @@ def _extract_auth_credentials(config):
|
|||||||
return auth_credentials
|
return auth_credentials
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_mcp_server_url(config: dict) -> None:
|
||||||
|
"""Validate the server_url in an MCP config to prevent SSRF.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the URL is missing or points to a blocked address.
|
||||||
|
"""
|
||||||
|
server_url = (config.get("server_url") or "").strip()
|
||||||
|
if not server_url:
|
||||||
|
raise ValueError("server_url is required")
|
||||||
|
try:
|
||||||
|
validate_url(server_url)
|
||||||
|
except SSRFError as exc:
|
||||||
|
raise ValueError(f"Invalid server URL: {exc}") from exc
|
||||||
|
|
||||||
|
|
||||||
@tools_mcp_ns.route("/mcp_server/test")
|
@tools_mcp_ns.route("/mcp_server/test")
|
||||||
class TestMCPServerConfig(Resource):
|
class TestMCPServerConfig(Resource):
|
||||||
@api.expect(
|
@api.expect(
|
||||||
@@ -97,6 +113,8 @@ class TestMCPServerConfig(Resource):
|
|||||||
400,
|
400,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
_validate_mcp_server_url(config)
|
||||||
|
|
||||||
auth_credentials = _extract_auth_credentials(config)
|
auth_credentials = _extract_auth_credentials(config)
|
||||||
test_config = config.copy()
|
test_config = config.copy()
|
||||||
test_config["auth_credentials"] = auth_credentials
|
test_config["auth_credentials"] = auth_credentials
|
||||||
@@ -105,15 +123,41 @@ class TestMCPServerConfig(Resource):
|
|||||||
result = mcp_tool.test_connection()
|
result = mcp_tool.test_connection()
|
||||||
|
|
||||||
if result.get("requires_oauth"):
|
if result.get("requires_oauth"):
|
||||||
return make_response(jsonify(result), 200)
|
safe_result = {
|
||||||
|
k: v
|
||||||
|
for k, v in result.items()
|
||||||
|
if k in ("success", "requires_oauth", "auth_url")
|
||||||
|
}
|
||||||
|
return make_response(jsonify(safe_result), 200)
|
||||||
|
|
||||||
if not result.get("success") and "message" in result:
|
if not result.get("success"):
|
||||||
current_app.logger.error(
|
current_app.logger.error(
|
||||||
f"MCP connection test failed: {result.get('message')}"
|
f"MCP connection test failed: {result.get('message')}"
|
||||||
)
|
)
|
||||||
result["message"] = "Connection test failed"
|
return make_response(
|
||||||
|
jsonify(
|
||||||
|
{
|
||||||
|
"success": False,
|
||||||
|
"message": "Connection test failed",
|
||||||
|
"tools_count": 0,
|
||||||
|
}
|
||||||
|
),
|
||||||
|
200,
|
||||||
|
)
|
||||||
|
|
||||||
return make_response(jsonify(result), 200)
|
safe_result = {
|
||||||
|
"success": True,
|
||||||
|
"message": result.get("message", "Connection successful"),
|
||||||
|
"tools_count": result.get("tools_count", 0),
|
||||||
|
"tools": result.get("tools", []),
|
||||||
|
}
|
||||||
|
return make_response(jsonify(safe_result), 200)
|
||||||
|
except ValueError as e:
|
||||||
|
current_app.logger.warning(f"Invalid MCP server test request: {e}")
|
||||||
|
return make_response(
|
||||||
|
jsonify({"success": False, "error": "Invalid MCP server configuration"}),
|
||||||
|
400,
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
current_app.logger.error(f"Error testing MCP server: {e}", exc_info=True)
|
current_app.logger.error(f"Error testing MCP server: {e}", exc_info=True)
|
||||||
return make_response(
|
return make_response(
|
||||||
@@ -165,6 +209,8 @@ class MCPServerSave(Resource):
|
|||||||
400,
|
400,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
_validate_mcp_server_url(config)
|
||||||
|
|
||||||
auth_credentials = _extract_auth_credentials(config)
|
auth_credentials = _extract_auth_credentials(config)
|
||||||
auth_type = config.get("auth_type", "none")
|
auth_type = config.get("auth_type", "none")
|
||||||
mcp_config = config.copy()
|
mcp_config = config.copy()
|
||||||
@@ -279,6 +325,12 @@ class MCPServerSave(Resource):
|
|||||||
"tools_count": len(transformed_actions),
|
"tools_count": len(transformed_actions),
|
||||||
}
|
}
|
||||||
return make_response(jsonify(response_data), 200)
|
return make_response(jsonify(response_data), 200)
|
||||||
|
except ValueError as e:
|
||||||
|
current_app.logger.warning(f"Invalid MCP server save request: {e}")
|
||||||
|
return make_response(
|
||||||
|
jsonify({"success": False, "error": "Invalid MCP server configuration"}),
|
||||||
|
400,
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
current_app.logger.error(f"Error saving MCP server: {e}", exc_info=True)
|
current_app.logger.error(f"Error saving MCP server: {e}", exc_info=True)
|
||||||
return make_response(
|
return make_response(
|
||||||
|
|||||||
@@ -8,6 +8,9 @@ from application.agents.tools.spec_parser import parse_spec
|
|||||||
from application.agents.tools.tool_manager import ToolManager
|
from application.agents.tools.tool_manager import ToolManager
|
||||||
from application.api import api
|
from application.api import api
|
||||||
from application.api.user.base import user_tools_collection
|
from application.api.user.base import user_tools_collection
|
||||||
|
from application.core.url_validation import SSRFError, validate_url
|
||||||
|
from application.storage.db.dual_write import dual_write
|
||||||
|
from application.storage.db.repositories.user_tools import UserToolsRepository
|
||||||
from application.security.encryption import decrypt_credentials, encrypt_credentials
|
from application.security.encryption import decrypt_credentials, encrypt_credentials
|
||||||
from application.utils import check_required_fields, validate_function_name
|
from application.utils import check_required_fields, validate_function_name
|
||||||
|
|
||||||
@@ -130,6 +133,8 @@ tools_ns = Namespace("tools", description="Tool management operations", path="/a
|
|||||||
class AvailableTools(Resource):
|
class AvailableTools(Resource):
|
||||||
@api.doc(description="Get available tools for a user")
|
@api.doc(description="Get available tools for a user")
|
||||||
def get(self):
|
def get(self):
|
||||||
|
if not request.decoded_token:
|
||||||
|
return make_response(jsonify({"success": False}), 401)
|
||||||
try:
|
try:
|
||||||
tools_metadata = []
|
tools_metadata = []
|
||||||
for tool_name, tool_instance in tool_manager.tools.items():
|
for tool_name, tool_instance in tool_manager.tools.items():
|
||||||
@@ -236,6 +241,16 @@ class CreateTool(Resource):
|
|||||||
if missing_fields:
|
if missing_fields:
|
||||||
return missing_fields
|
return missing_fields
|
||||||
try:
|
try:
|
||||||
|
if data["name"] == "mcp_tool":
|
||||||
|
server_url = (data.get("config", {}).get("server_url") or "").strip()
|
||||||
|
if server_url:
|
||||||
|
try:
|
||||||
|
validate_url(server_url)
|
||||||
|
except SSRFError:
|
||||||
|
return make_response(
|
||||||
|
jsonify({"success": False, "message": "Invalid server URL"}),
|
||||||
|
400,
|
||||||
|
)
|
||||||
tool_instance = tool_manager.tools.get(data["name"])
|
tool_instance = tool_manager.tools.get(data["name"])
|
||||||
if not tool_instance:
|
if not tool_instance:
|
||||||
return make_response(
|
return make_response(
|
||||||
@@ -281,6 +296,13 @@ class CreateTool(Resource):
|
|||||||
}
|
}
|
||||||
resp = user_tools_collection.insert_one(new_tool)
|
resp = user_tools_collection.insert_one(new_tool)
|
||||||
new_id = str(resp.inserted_id)
|
new_id = str(resp.inserted_id)
|
||||||
|
dual_write(
|
||||||
|
UserToolsRepository,
|
||||||
|
lambda repo, u=user, t=new_tool: repo.create(
|
||||||
|
u, t["name"], config=t.get("config"),
|
||||||
|
custom_name=t.get("customName"), display_name=t.get("displayName"),
|
||||||
|
),
|
||||||
|
)
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
current_app.logger.error(f"Error creating tool: {err}", exc_info=True)
|
current_app.logger.error(f"Error creating tool: {err}", exc_info=True)
|
||||||
return make_response(jsonify({"success": False}), 400)
|
return make_response(jsonify({"success": False}), 400)
|
||||||
@@ -421,6 +443,16 @@ class UpdateToolConfig(Resource):
|
|||||||
return make_response(jsonify({"success": False}), 404)
|
return make_response(jsonify({"success": False}), 404)
|
||||||
|
|
||||||
tool_name = tool_doc.get("name")
|
tool_name = tool_doc.get("name")
|
||||||
|
if tool_name == "mcp_tool":
|
||||||
|
server_url = (data["config"].get("server_url") or "").strip()
|
||||||
|
if server_url:
|
||||||
|
try:
|
||||||
|
validate_url(server_url)
|
||||||
|
except SSRFError:
|
||||||
|
return make_response(
|
||||||
|
jsonify({"success": False, "message": "Invalid server URL"}),
|
||||||
|
400,
|
||||||
|
)
|
||||||
tool_instance = tool_manager.tools.get(tool_name)
|
tool_instance = tool_manager.tools.get(tool_name)
|
||||||
config_requirements = (
|
config_requirements = (
|
||||||
tool_instance.get_config_requirements() if tool_instance else {}
|
tool_instance.get_config_requirements() if tool_instance else {}
|
||||||
@@ -558,6 +590,10 @@ class DeleteTool(Resource):
|
|||||||
result = user_tools_collection.delete_one(
|
result = user_tools_collection.delete_one(
|
||||||
{"_id": ObjectId(data["id"]), "user": user}
|
{"_id": ObjectId(data["id"]), "user": user}
|
||||||
)
|
)
|
||||||
|
dual_write(
|
||||||
|
UserToolsRepository,
|
||||||
|
lambda repo, tid=data["id"], u=user: repo.delete(tid, u),
|
||||||
|
)
|
||||||
if result.deleted_count == 0:
|
if result.deleted_count == 0:
|
||||||
return make_response(
|
return make_response(
|
||||||
jsonify({"success": False, "message": "Tool not found"}), 404
|
jsonify({"success": False, "message": "Tool not found"}), 404
|
||||||
|
|||||||
@@ -11,6 +11,10 @@ from application.api.user.base import (
|
|||||||
workflow_nodes_collection,
|
workflow_nodes_collection,
|
||||||
workflows_collection,
|
workflows_collection,
|
||||||
)
|
)
|
||||||
|
from application.storage.db.dual_write import dual_write
|
||||||
|
from application.storage.db.repositories.workflow_edges import WorkflowEdgesRepository
|
||||||
|
from application.storage.db.repositories.workflow_nodes import WorkflowNodesRepository
|
||||||
|
from application.storage.db.repositories.workflows import WorkflowsRepository
|
||||||
from application.core.json_schema_utils import (
|
from application.core.json_schema_utils import (
|
||||||
JsonSchemaValidationError,
|
JsonSchemaValidationError,
|
||||||
normalize_json_schema_payload,
|
normalize_json_schema_payload,
|
||||||
@@ -35,6 +39,174 @@ def _workflow_error_response(message: str, err: Exception):
|
|||||||
return error_response(message)
|
return error_response(message)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Postgres dual-write helpers
|
||||||
|
#
|
||||||
|
# Workflows are unusual relative to other Phase 3 tables: a single user
|
||||||
|
# action (create / update) writes to three collections in concert
|
||||||
|
# (workflows + workflow_nodes + workflow_edges) and the edges reference
|
||||||
|
# nodes by user-provided string ids. The Postgres mirror needs to:
|
||||||
|
#
|
||||||
|
# 1. Run all three writes inside one PG transaction (so the just-created
|
||||||
|
# nodes are visible when we resolve their UUIDs for the edge insert).
|
||||||
|
# 2. Translate edge source_id/target_id strings → workflow_nodes.id UUIDs
|
||||||
|
# after the bulk_create returns them.
|
||||||
|
#
|
||||||
|
# Each helper opens exactly one ``dual_write`` call (one PG txn) and uses
|
||||||
|
# the connection from whichever repo it was instantiated with to spin up
|
||||||
|
# any sibling repos it needs.
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _dual_write_workflow_create(
|
||||||
|
mongo_workflow_id: str,
|
||||||
|
user_id: str,
|
||||||
|
name: str,
|
||||||
|
description: str,
|
||||||
|
nodes_data: List[Dict],
|
||||||
|
edges_data: List[Dict],
|
||||||
|
graph_version: int = 1,
|
||||||
|
) -> None:
|
||||||
|
"""Mirror a Mongo workflow create into Postgres."""
|
||||||
|
|
||||||
|
def _do(repo: WorkflowsRepository) -> None:
|
||||||
|
conn = repo._conn
|
||||||
|
wf = repo.create(
|
||||||
|
user_id,
|
||||||
|
name,
|
||||||
|
description=description,
|
||||||
|
legacy_mongo_id=mongo_workflow_id,
|
||||||
|
)
|
||||||
|
_write_graph(conn, wf["id"], graph_version, nodes_data, edges_data)
|
||||||
|
|
||||||
|
dual_write(WorkflowsRepository, _do)
|
||||||
|
|
||||||
|
|
||||||
|
def _dual_write_workflow_update(
|
||||||
|
mongo_workflow_id: str,
|
||||||
|
user_id: str,
|
||||||
|
name: str,
|
||||||
|
description: str,
|
||||||
|
nodes_data: List[Dict],
|
||||||
|
edges_data: List[Dict],
|
||||||
|
next_graph_version: int,
|
||||||
|
) -> None:
|
||||||
|
"""Mirror a Mongo workflow update into Postgres.
|
||||||
|
|
||||||
|
Mirrors the Mongo route: insert the new graph_version's nodes/edges,
|
||||||
|
bump the workflow's name/description/current_graph_version, then drop
|
||||||
|
every other graph_version's nodes/edges.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _do(repo: WorkflowsRepository) -> None:
|
||||||
|
conn = repo._conn
|
||||||
|
wf = _resolve_pg_workflow(conn, mongo_workflow_id)
|
||||||
|
if wf is None:
|
||||||
|
return
|
||||||
|
_write_graph(conn, wf["id"], next_graph_version, nodes_data, edges_data)
|
||||||
|
repo.update(wf["id"], user_id, {
|
||||||
|
"name": name,
|
||||||
|
"description": description,
|
||||||
|
"current_graph_version": next_graph_version,
|
||||||
|
})
|
||||||
|
WorkflowNodesRepository(conn).delete_other_versions(
|
||||||
|
wf["id"], next_graph_version,
|
||||||
|
)
|
||||||
|
WorkflowEdgesRepository(conn).delete_other_versions(
|
||||||
|
wf["id"], next_graph_version,
|
||||||
|
)
|
||||||
|
|
||||||
|
dual_write(WorkflowsRepository, _do)
|
||||||
|
|
||||||
|
|
||||||
|
def _dual_write_workflow_delete(mongo_workflow_id: str, user_id: str) -> None:
|
||||||
|
"""Mirror a Mongo workflow delete into Postgres.
|
||||||
|
|
||||||
|
The CASCADE on workflows.id → workflow_nodes/workflow_edges takes
|
||||||
|
care of the children automatically.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _do(repo: WorkflowsRepository) -> None:
|
||||||
|
wf = _resolve_pg_workflow(repo._conn, mongo_workflow_id)
|
||||||
|
if wf is not None:
|
||||||
|
repo.delete(wf["id"], user_id)
|
||||||
|
|
||||||
|
dual_write(WorkflowsRepository, _do)
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_pg_workflow(conn, mongo_workflow_id: str) -> Optional[Dict]:
|
||||||
|
"""Look up a Postgres workflow by its Mongo ObjectId string."""
|
||||||
|
from sqlalchemy import text as _text
|
||||||
|
row = conn.execute(
|
||||||
|
_text("SELECT id FROM workflows WHERE legacy_mongo_id = :legacy_id"),
|
||||||
|
{"legacy_id": mongo_workflow_id},
|
||||||
|
).fetchone()
|
||||||
|
return {"id": str(row[0])} if row else None
|
||||||
|
|
||||||
|
|
||||||
|
def _write_graph(
|
||||||
|
conn,
|
||||||
|
pg_workflow_id: str,
|
||||||
|
graph_version: int,
|
||||||
|
nodes_data: List[Dict],
|
||||||
|
edges_data: List[Dict],
|
||||||
|
) -> None:
|
||||||
|
"""Bulk-create nodes + edges for one graph version inside one txn.
|
||||||
|
|
||||||
|
Edges arrive with source/target as user-provided node-id strings
|
||||||
|
(the same shape the Mongo route stores). We bulk-insert nodes first,
|
||||||
|
capture their ``node_id → UUID`` map from the returned rows, then
|
||||||
|
translate edge source/target strings to those UUIDs before the edge
|
||||||
|
bulk insert. Edges referencing missing nodes are dropped (logged).
|
||||||
|
"""
|
||||||
|
nodes_repo = WorkflowNodesRepository(conn)
|
||||||
|
edges_repo = WorkflowEdgesRepository(conn)
|
||||||
|
|
||||||
|
if nodes_data:
|
||||||
|
created_nodes = nodes_repo.bulk_create(
|
||||||
|
pg_workflow_id, graph_version,
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"node_id": n["id"],
|
||||||
|
"node_type": n["type"],
|
||||||
|
"title": n.get("title", ""),
|
||||||
|
"description": n.get("description", ""),
|
||||||
|
"position": n.get("position", {"x": 0, "y": 0}),
|
||||||
|
"config": n.get("data", {}),
|
||||||
|
"legacy_mongo_id": n.get("legacy_mongo_id"),
|
||||||
|
}
|
||||||
|
for n in nodes_data
|
||||||
|
],
|
||||||
|
)
|
||||||
|
node_uuid_by_str = {n["node_id"]: n["id"] for n in created_nodes}
|
||||||
|
else:
|
||||||
|
node_uuid_by_str = {}
|
||||||
|
|
||||||
|
if edges_data:
|
||||||
|
translated_edges: List[Dict] = []
|
||||||
|
for e in edges_data:
|
||||||
|
src = e.get("source")
|
||||||
|
tgt = e.get("target")
|
||||||
|
from_uuid = node_uuid_by_str.get(src)
|
||||||
|
to_uuid = node_uuid_by_str.get(tgt)
|
||||||
|
if not from_uuid or not to_uuid:
|
||||||
|
current_app.logger.warning(
|
||||||
|
"PG dual-write: dropping edge %s; node refs unresolved "
|
||||||
|
"(source=%s, target=%s)",
|
||||||
|
e.get("id"), src, tgt,
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
translated_edges.append({
|
||||||
|
"edge_id": e["id"],
|
||||||
|
"from_node_id": from_uuid,
|
||||||
|
"to_node_id": to_uuid,
|
||||||
|
"source_handle": e.get("sourceHandle"),
|
||||||
|
"target_handle": e.get("targetHandle"),
|
||||||
|
})
|
||||||
|
if translated_edges:
|
||||||
|
edges_repo.bulk_create(pg_workflow_id, graph_version, translated_edges)
|
||||||
|
|
||||||
|
|
||||||
def serialize_workflow(w: Dict) -> Dict:
|
def serialize_workflow(w: Dict) -> Dict:
|
||||||
"""Serialize workflow document to API response format."""
|
"""Serialize workflow document to API response format."""
|
||||||
return {
|
return {
|
||||||
@@ -317,24 +489,28 @@ def _can_reach_end(
|
|||||||
|
|
||||||
def create_workflow_nodes(
|
def create_workflow_nodes(
|
||||||
workflow_id: str, nodes_data: List[Dict], graph_version: int
|
workflow_id: str, nodes_data: List[Dict], graph_version: int
|
||||||
) -> None:
|
) -> List[Dict]:
|
||||||
"""Insert workflow nodes into database."""
|
"""Insert workflow nodes into Mongo and return rows with Mongo ids."""
|
||||||
if nodes_data:
|
if nodes_data:
|
||||||
workflow_nodes_collection.insert_many(
|
mongo_nodes = [
|
||||||
[
|
{
|
||||||
{
|
"id": n["id"],
|
||||||
"id": n["id"],
|
"workflow_id": workflow_id,
|
||||||
"workflow_id": workflow_id,
|
"graph_version": graph_version,
|
||||||
"graph_version": graph_version,
|
"type": n["type"],
|
||||||
"type": n["type"],
|
"title": n.get("title", ""),
|
||||||
"title": n.get("title", ""),
|
"description": n.get("description", ""),
|
||||||
"description": n.get("description", ""),
|
"position": n.get("position", {"x": 0, "y": 0}),
|
||||||
"position": n.get("position", {"x": 0, "y": 0}),
|
"config": n.get("data", {}),
|
||||||
"config": n.get("data", {}),
|
}
|
||||||
}
|
for n in nodes_data
|
||||||
for n in nodes_data
|
]
|
||||||
]
|
result = workflow_nodes_collection.insert_many(mongo_nodes)
|
||||||
)
|
return [
|
||||||
|
{**node, "legacy_mongo_id": str(inserted_id)}
|
||||||
|
for node, inserted_id in zip(nodes_data, result.inserted_ids)
|
||||||
|
]
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
def create_workflow_edges(
|
def create_workflow_edges(
|
||||||
@@ -399,7 +575,7 @@ class WorkflowList(Resource):
|
|||||||
workflow_id = str(result.inserted_id)
|
workflow_id = str(result.inserted_id)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
create_workflow_nodes(workflow_id, nodes_data, 1)
|
created_nodes = create_workflow_nodes(workflow_id, nodes_data, 1)
|
||||||
create_workflow_edges(workflow_id, edges_data, 1)
|
create_workflow_edges(workflow_id, edges_data, 1)
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
workflow_nodes_collection.delete_many({"workflow_id": workflow_id})
|
workflow_nodes_collection.delete_many({"workflow_id": workflow_id})
|
||||||
@@ -407,6 +583,15 @@ class WorkflowList(Resource):
|
|||||||
workflows_collection.delete_one({"_id": result.inserted_id})
|
workflows_collection.delete_one({"_id": result.inserted_id})
|
||||||
return _workflow_error_response("Failed to create workflow structure", err)
|
return _workflow_error_response("Failed to create workflow structure", err)
|
||||||
|
|
||||||
|
_dual_write_workflow_create(
|
||||||
|
workflow_id,
|
||||||
|
user_id,
|
||||||
|
name,
|
||||||
|
data.get("description", ""),
|
||||||
|
created_nodes,
|
||||||
|
edges_data,
|
||||||
|
)
|
||||||
|
|
||||||
return success_response({"id": workflow_id}, 201)
|
return success_response({"id": workflow_id}, 201)
|
||||||
|
|
||||||
|
|
||||||
@@ -473,7 +658,9 @@ class WorkflowDetail(Resource):
|
|||||||
current_graph_version = get_workflow_graph_version(workflow)
|
current_graph_version = get_workflow_graph_version(workflow)
|
||||||
next_graph_version = current_graph_version + 1
|
next_graph_version = current_graph_version + 1
|
||||||
try:
|
try:
|
||||||
create_workflow_nodes(workflow_id, nodes_data, next_graph_version)
|
created_nodes = create_workflow_nodes(
|
||||||
|
workflow_id, nodes_data, next_graph_version,
|
||||||
|
)
|
||||||
create_workflow_edges(workflow_id, edges_data, next_graph_version)
|
create_workflow_edges(workflow_id, edges_data, next_graph_version)
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
workflow_nodes_collection.delete_many(
|
workflow_nodes_collection.delete_many(
|
||||||
@@ -520,6 +707,16 @@ class WorkflowDetail(Resource):
|
|||||||
f"Failed to clean old workflow graph versions for {workflow_id}: {cleanup_err}"
|
f"Failed to clean old workflow graph versions for {workflow_id}: {cleanup_err}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
_dual_write_workflow_update(
|
||||||
|
workflow_id,
|
||||||
|
user_id,
|
||||||
|
name,
|
||||||
|
data.get("description", ""),
|
||||||
|
created_nodes,
|
||||||
|
edges_data,
|
||||||
|
next_graph_version,
|
||||||
|
)
|
||||||
|
|
||||||
return success_response()
|
return success_response()
|
||||||
|
|
||||||
@require_auth
|
@require_auth
|
||||||
@@ -543,4 +740,6 @@ class WorkflowDetail(Resource):
|
|||||||
except Exception as err:
|
except Exception as err:
|
||||||
return _workflow_error_response("Failed to delete workflow", err)
|
return _workflow_error_response("Failed to delete workflow", err)
|
||||||
|
|
||||||
|
_dual_write_workflow_delete(workflow_id, user_id)
|
||||||
|
|
||||||
return success_response()
|
return success_response()
|
||||||
|
|||||||
3
application/api/v1/__init__.py
Normal file
3
application/api/v1/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
from application.api.v1.routes import v1_bp
|
||||||
|
|
||||||
|
__all__ = ["v1_bp"]
|
||||||
333
application/api/v1/routes.py
Normal file
333
application/api/v1/routes.py
Normal file
@@ -0,0 +1,333 @@
|
|||||||
|
"""Standard chat completions API routes.
|
||||||
|
|
||||||
|
Exposes ``/v1/chat/completions`` and ``/v1/models`` endpoints that
|
||||||
|
follow the widely-adopted chat completions protocol so external tools
|
||||||
|
(opencode, continue, etc.) can connect to DocsGPT agents.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
import traceback
|
||||||
|
from typing import Any, Dict, Generator, Optional
|
||||||
|
|
||||||
|
from flask import Blueprint, jsonify, make_response, request, Response
|
||||||
|
|
||||||
|
from application.api.answer.routes.base import BaseAnswerResource
|
||||||
|
from application.api.answer.services.stream_processor import StreamProcessor
|
||||||
|
from application.api.v1.translator import (
|
||||||
|
translate_request,
|
||||||
|
translate_response,
|
||||||
|
translate_stream_event,
|
||||||
|
)
|
||||||
|
from application.core.mongo_db import MongoDB
|
||||||
|
from application.core.settings import settings
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
v1_bp = Blueprint("v1", __name__, url_prefix="/v1")
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_bearer_token() -> Optional[str]:
|
||||||
|
"""Extract API key from Authorization: Bearer header."""
|
||||||
|
auth = request.headers.get("Authorization", "")
|
||||||
|
if auth.startswith("Bearer "):
|
||||||
|
return auth[7:].strip()
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _lookup_agent(api_key: str) -> Optional[Dict]:
|
||||||
|
"""Look up the agent document for this API key."""
|
||||||
|
try:
|
||||||
|
mongo = MongoDB.get_client()
|
||||||
|
db = mongo[settings.MONGO_DB_NAME]
|
||||||
|
return db["agents"].find_one({"key": api_key})
|
||||||
|
except Exception:
|
||||||
|
logger.warning("Failed to look up agent for API key", exc_info=True)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _get_model_name(agent: Optional[Dict], api_key: str) -> str:
|
||||||
|
"""Return agent name for display as model name."""
|
||||||
|
if agent:
|
||||||
|
return agent.get("name", api_key)
|
||||||
|
return api_key
|
||||||
|
|
||||||
|
|
||||||
|
class _V1AnswerHelper(BaseAnswerResource):
|
||||||
|
"""Thin wrapper to access complete_stream / process_response_stream."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@v1_bp.route("/chat/completions", methods=["POST"])
|
||||||
|
def chat_completions():
|
||||||
|
"""Handle POST /v1/chat/completions."""
|
||||||
|
api_key = _extract_bearer_token()
|
||||||
|
if not api_key:
|
||||||
|
return make_response(
|
||||||
|
jsonify({"error": {"message": "Missing Authorization header", "type": "auth_error"}}),
|
||||||
|
401,
|
||||||
|
)
|
||||||
|
|
||||||
|
data = request.get_json()
|
||||||
|
if not data or not data.get("messages"):
|
||||||
|
return make_response(
|
||||||
|
jsonify({"error": {"message": "messages field is required", "type": "invalid_request"}}),
|
||||||
|
400,
|
||||||
|
)
|
||||||
|
|
||||||
|
is_stream = data.get("stream", False)
|
||||||
|
agent_doc = _lookup_agent(api_key)
|
||||||
|
model_name = _get_model_name(agent_doc, api_key)
|
||||||
|
|
||||||
|
try:
|
||||||
|
internal_data = translate_request(data, api_key)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"/v1/chat/completions translate error: {e}", exc_info=True)
|
||||||
|
return make_response(
|
||||||
|
jsonify({"error": {"message": "Failed to process request", "type": "invalid_request"}}),
|
||||||
|
400,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Link decoded_token to the agent's owner so continuation state,
|
||||||
|
# logs, and tool execution use the correct user identity.
|
||||||
|
agent_user = agent_doc.get("user") if agent_doc else None
|
||||||
|
decoded_token = {"sub": agent_user or "api_key_user"}
|
||||||
|
|
||||||
|
try:
|
||||||
|
processor = StreamProcessor(internal_data, decoded_token)
|
||||||
|
|
||||||
|
if internal_data.get("tool_actions"):
|
||||||
|
# Continuation mode
|
||||||
|
conversation_id = internal_data.get("conversation_id")
|
||||||
|
if not conversation_id:
|
||||||
|
return make_response(
|
||||||
|
jsonify({"error": {"message": "conversation_id required for tool continuation", "type": "invalid_request"}}),
|
||||||
|
400,
|
||||||
|
)
|
||||||
|
(
|
||||||
|
agent,
|
||||||
|
messages,
|
||||||
|
tools_dict,
|
||||||
|
pending_tool_calls,
|
||||||
|
tool_actions,
|
||||||
|
) = processor.resume_from_tool_actions(
|
||||||
|
internal_data["tool_actions"], conversation_id
|
||||||
|
)
|
||||||
|
continuation = {
|
||||||
|
"messages": messages,
|
||||||
|
"tools_dict": tools_dict,
|
||||||
|
"pending_tool_calls": pending_tool_calls,
|
||||||
|
"tool_actions": tool_actions,
|
||||||
|
}
|
||||||
|
question = ""
|
||||||
|
else:
|
||||||
|
# Normal mode
|
||||||
|
question = internal_data.get("question", "")
|
||||||
|
agent = processor.build_agent(question)
|
||||||
|
continuation = None
|
||||||
|
|
||||||
|
if not processor.decoded_token:
|
||||||
|
return make_response(
|
||||||
|
jsonify({"error": {"message": "Unauthorized", "type": "auth_error"}}),
|
||||||
|
401,
|
||||||
|
)
|
||||||
|
|
||||||
|
helper = _V1AnswerHelper()
|
||||||
|
usage_error = helper.check_usage(processor.agent_config)
|
||||||
|
if usage_error:
|
||||||
|
return usage_error
|
||||||
|
|
||||||
|
should_save_conversation = bool(internal_data.get("save_conversation", False))
|
||||||
|
|
||||||
|
if is_stream:
|
||||||
|
return Response(
|
||||||
|
_stream_response(
|
||||||
|
helper,
|
||||||
|
question,
|
||||||
|
agent,
|
||||||
|
processor,
|
||||||
|
model_name,
|
||||||
|
continuation,
|
||||||
|
should_save_conversation,
|
||||||
|
),
|
||||||
|
mimetype="text/event-stream",
|
||||||
|
headers={
|
||||||
|
"Cache-Control": "no-cache",
|
||||||
|
"X-Accel-Buffering": "no",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return _non_stream_response(
|
||||||
|
helper,
|
||||||
|
question,
|
||||||
|
agent,
|
||||||
|
processor,
|
||||||
|
model_name,
|
||||||
|
continuation,
|
||||||
|
should_save_conversation,
|
||||||
|
)
|
||||||
|
|
||||||
|
except ValueError as e:
|
||||||
|
logger.error(
|
||||||
|
f"/v1/chat/completions error: {e} - {traceback.format_exc()}",
|
||||||
|
extra={"error": str(e)},
|
||||||
|
)
|
||||||
|
return make_response(
|
||||||
|
jsonify({"error": {"message": "Failed to process request", "type": "invalid_request"}}),
|
||||||
|
400,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"/v1/chat/completions error: {e} - {traceback.format_exc()}",
|
||||||
|
extra={"error": str(e)},
|
||||||
|
)
|
||||||
|
return make_response(
|
||||||
|
jsonify({"error": {"message": "Internal server error", "type": "server_error"}}),
|
||||||
|
500,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _stream_response(
|
||||||
|
helper: _V1AnswerHelper,
|
||||||
|
question: str,
|
||||||
|
agent: Any,
|
||||||
|
processor: StreamProcessor,
|
||||||
|
model_name: str,
|
||||||
|
continuation: Optional[Dict],
|
||||||
|
should_save_conversation: bool,
|
||||||
|
) -> Generator[str, None, None]:
|
||||||
|
"""Generate translated SSE chunks for streaming response."""
|
||||||
|
completion_id = f"chatcmpl-{int(time.time())}"
|
||||||
|
|
||||||
|
internal_stream = helper.complete_stream(
|
||||||
|
question=question,
|
||||||
|
agent=agent,
|
||||||
|
conversation_id=processor.conversation_id,
|
||||||
|
user_api_key=processor.agent_config.get("user_api_key"),
|
||||||
|
decoded_token=processor.decoded_token,
|
||||||
|
agent_id=processor.agent_id,
|
||||||
|
model_id=processor.model_id,
|
||||||
|
should_save_conversation=should_save_conversation,
|
||||||
|
_continuation=continuation,
|
||||||
|
)
|
||||||
|
|
||||||
|
for line in internal_stream:
|
||||||
|
if not line.strip():
|
||||||
|
continue
|
||||||
|
# Parse the internal SSE event
|
||||||
|
event_str = line.replace("data: ", "").strip()
|
||||||
|
try:
|
||||||
|
event_data = json.loads(event_str)
|
||||||
|
except (json.JSONDecodeError, TypeError):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Update completion_id when we get the conversation id
|
||||||
|
if event_data.get("type") == "id":
|
||||||
|
conv_id = event_data.get("id", "")
|
||||||
|
if conv_id:
|
||||||
|
completion_id = f"chatcmpl-{conv_id}"
|
||||||
|
|
||||||
|
# Translate to standard format
|
||||||
|
translated = translate_stream_event(event_data, completion_id, model_name)
|
||||||
|
for chunk in translated:
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
|
||||||
|
def _non_stream_response(
|
||||||
|
helper: _V1AnswerHelper,
|
||||||
|
question: str,
|
||||||
|
agent: Any,
|
||||||
|
processor: StreamProcessor,
|
||||||
|
model_name: str,
|
||||||
|
continuation: Optional[Dict],
|
||||||
|
should_save_conversation: bool,
|
||||||
|
) -> Response:
|
||||||
|
"""Collect full response and return as single JSON."""
|
||||||
|
stream = helper.complete_stream(
|
||||||
|
question=question,
|
||||||
|
agent=agent,
|
||||||
|
conversation_id=processor.conversation_id,
|
||||||
|
user_api_key=processor.agent_config.get("user_api_key"),
|
||||||
|
decoded_token=processor.decoded_token,
|
||||||
|
agent_id=processor.agent_id,
|
||||||
|
model_id=processor.model_id,
|
||||||
|
should_save_conversation=should_save_conversation,
|
||||||
|
_continuation=continuation,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = helper.process_response_stream(stream)
|
||||||
|
|
||||||
|
if result["error"]:
|
||||||
|
return make_response(
|
||||||
|
jsonify({"error": {"message": result["error"], "type": "server_error"}}),
|
||||||
|
500,
|
||||||
|
)
|
||||||
|
|
||||||
|
extra = result.get("extra")
|
||||||
|
pending = extra.get("pending_tool_calls") if isinstance(extra, dict) else None
|
||||||
|
|
||||||
|
response = translate_response(
|
||||||
|
conversation_id=result["conversation_id"],
|
||||||
|
answer=result["answer"] or "",
|
||||||
|
sources=result["sources"],
|
||||||
|
tool_calls=result["tool_calls"],
|
||||||
|
thought=result["thought"] or "",
|
||||||
|
model_name=model_name,
|
||||||
|
pending_tool_calls=pending,
|
||||||
|
)
|
||||||
|
return make_response(jsonify(response), 200)
|
||||||
|
|
||||||
|
|
||||||
|
@v1_bp.route("/models", methods=["GET"])
|
||||||
|
def list_models():
|
||||||
|
"""Handle GET /v1/models — return agents as models."""
|
||||||
|
api_key = _extract_bearer_token()
|
||||||
|
if not api_key:
|
||||||
|
return make_response(
|
||||||
|
jsonify({"error": {"message": "Missing Authorization header", "type": "auth_error"}}),
|
||||||
|
401,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
mongo = MongoDB.get_client()
|
||||||
|
db = mongo[settings.MONGO_DB_NAME]
|
||||||
|
agents_collection = db["agents"]
|
||||||
|
|
||||||
|
# Find the agent for this api_key
|
||||||
|
agent = agents_collection.find_one({"key": api_key})
|
||||||
|
if not agent:
|
||||||
|
return make_response(
|
||||||
|
jsonify({"error": {"message": "Invalid API key", "type": "auth_error"}}),
|
||||||
|
401,
|
||||||
|
)
|
||||||
|
|
||||||
|
user = agent.get("user")
|
||||||
|
|
||||||
|
# Return all agents belonging to this user
|
||||||
|
user_agents = list(agents_collection.find({"user": user}))
|
||||||
|
|
||||||
|
models = []
|
||||||
|
for ag in user_agents:
|
||||||
|
created = ag.get("createdAt")
|
||||||
|
created_ts = int(created.timestamp()) if created else int(time.time())
|
||||||
|
model_id = str(ag.get("_id") or ag.get("id") or "")
|
||||||
|
models.append({
|
||||||
|
"id": model_id,
|
||||||
|
"object": "model",
|
||||||
|
"created": created_ts,
|
||||||
|
"owned_by": "docsgpt",
|
||||||
|
"name": ag.get("name", ""),
|
||||||
|
"description": ag.get("description", ""),
|
||||||
|
})
|
||||||
|
|
||||||
|
return make_response(
|
||||||
|
jsonify({"object": "list", "data": models}),
|
||||||
|
200,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"/v1/models error: {e}", exc_info=True)
|
||||||
|
return make_response(
|
||||||
|
jsonify({"error": {"message": "Internal server error", "type": "server_error"}}),
|
||||||
|
500,
|
||||||
|
)
|
||||||
433
application/api/v1/translator.py
Normal file
433
application/api/v1/translator.py
Normal file
@@ -0,0 +1,433 @@
|
|||||||
|
"""Translate between standard chat completions format and DocsGPT internals.
|
||||||
|
|
||||||
|
This module handles:
|
||||||
|
- Request translation (chat completions -> DocsGPT internal format)
|
||||||
|
- Response translation (DocsGPT response -> chat completions format)
|
||||||
|
- Streaming event translation (DocsGPT SSE -> standard SSE chunks)
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
def _get_client_tool_name(tc: Dict) -> str:
|
||||||
|
"""Return the original tool name for client-facing responses.
|
||||||
|
|
||||||
|
For client-side tools the ``tool_name`` field carries the name the
|
||||||
|
client originally registered. Fall back to ``action_name`` (which
|
||||||
|
is now the clean LLM-visible name) or ``name``.
|
||||||
|
"""
|
||||||
|
return tc.get("tool_name", tc.get("action_name", tc.get("name", "")))
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Request translation
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def is_continuation(messages: List[Dict]) -> bool:
|
||||||
|
"""Check if messages represent a tool-call continuation.
|
||||||
|
|
||||||
|
A continuation is detected when the last message(s) have ``role: "tool"``
|
||||||
|
immediately after an assistant message with ``tool_calls``.
|
||||||
|
"""
|
||||||
|
if not messages:
|
||||||
|
return False
|
||||||
|
# Walk backwards: if we see tool messages before hitting a non-tool, non-assistant message
|
||||||
|
# and there's an assistant message with tool_calls, it's a continuation.
|
||||||
|
i = len(messages) - 1
|
||||||
|
while i >= 0 and messages[i].get("role") == "tool":
|
||||||
|
i -= 1
|
||||||
|
if i < 0:
|
||||||
|
return False
|
||||||
|
return (
|
||||||
|
messages[i].get("role") == "assistant"
|
||||||
|
and bool(messages[i].get("tool_calls"))
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def extract_tool_results(messages: List[Dict]) -> List[Dict]:
|
||||||
|
"""Extract tool results from trailing tool messages for continuation.
|
||||||
|
|
||||||
|
Returns a list of ``tool_actions`` dicts with ``call_id`` and ``result``.
|
||||||
|
"""
|
||||||
|
results = []
|
||||||
|
for msg in reversed(messages):
|
||||||
|
if msg.get("role") != "tool":
|
||||||
|
break
|
||||||
|
call_id = msg.get("tool_call_id", "")
|
||||||
|
content = msg.get("content", "")
|
||||||
|
if isinstance(content, str):
|
||||||
|
try:
|
||||||
|
content = json.loads(content)
|
||||||
|
except (json.JSONDecodeError, TypeError):
|
||||||
|
pass
|
||||||
|
results.append({"call_id": call_id, "result": content})
|
||||||
|
results.reverse()
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def extract_conversation_id(messages: List[Dict]) -> Optional[str]:
|
||||||
|
"""Try to extract conversation_id from the assistant message before tool results.
|
||||||
|
|
||||||
|
The conversation_id may be stored in a custom field on the assistant message
|
||||||
|
from a previous response cycle.
|
||||||
|
"""
|
||||||
|
for msg in reversed(messages):
|
||||||
|
if msg.get("role") == "assistant":
|
||||||
|
# Check docsgpt extension
|
||||||
|
return msg.get("docsgpt", {}).get("conversation_id")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def extract_system_prompt(messages: List[Dict]) -> Optional[str]:
|
||||||
|
"""Extract the first system message content from the messages array.
|
||||||
|
|
||||||
|
Returns None if no system message is present.
|
||||||
|
"""
|
||||||
|
for msg in messages:
|
||||||
|
if msg.get("role") == "system":
|
||||||
|
return msg.get("content", "")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def convert_history(messages: List[Dict]) -> List[Dict]:
|
||||||
|
"""Convert chat completions messages array to DocsGPT history format.
|
||||||
|
|
||||||
|
DocsGPT history is a list of ``{prompt, response}`` dicts.
|
||||||
|
Excludes the last user message (that becomes the ``question``).
|
||||||
|
"""
|
||||||
|
history = []
|
||||||
|
i = 0
|
||||||
|
while i < len(messages):
|
||||||
|
msg = messages[i]
|
||||||
|
if msg.get("role") == "system":
|
||||||
|
i += 1
|
||||||
|
continue
|
||||||
|
if msg.get("role") == "user":
|
||||||
|
# Look ahead for assistant response
|
||||||
|
if i + 1 < len(messages) and messages[i + 1].get("role") == "assistant":
|
||||||
|
content = messages[i + 1].get("content") or ""
|
||||||
|
history.append({
|
||||||
|
"prompt": msg.get("content", ""),
|
||||||
|
"response": content,
|
||||||
|
})
|
||||||
|
i += 2
|
||||||
|
continue
|
||||||
|
# Last user message without response — skip (it's the question)
|
||||||
|
i += 1
|
||||||
|
continue
|
||||||
|
i += 1
|
||||||
|
return history
|
||||||
|
|
||||||
|
|
||||||
|
def translate_request(
|
||||||
|
data: Dict[str, Any], api_key: str
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Translate a chat completions request to DocsGPT internal format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: The incoming request body.
|
||||||
|
api_key: Agent API key from the Authorization header.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict suitable for passing to ``StreamProcessor``.
|
||||||
|
"""
|
||||||
|
messages = data.get("messages", [])
|
||||||
|
|
||||||
|
# Check for continuation (tool results after assistant tool_calls)
|
||||||
|
if is_continuation(messages):
|
||||||
|
tool_actions = extract_tool_results(messages)
|
||||||
|
conversation_id = extract_conversation_id(messages)
|
||||||
|
if not conversation_id:
|
||||||
|
conversation_id = data.get("conversation_id")
|
||||||
|
result = {
|
||||||
|
"conversation_id": conversation_id,
|
||||||
|
"tool_actions": tool_actions,
|
||||||
|
"api_key": api_key,
|
||||||
|
}
|
||||||
|
# Carry tools forward for next iteration
|
||||||
|
if data.get("tools"):
|
||||||
|
result["client_tools"] = data["tools"]
|
||||||
|
return result
|
||||||
|
|
||||||
|
# Normal request — extract question from last user message
|
||||||
|
question = ""
|
||||||
|
for msg in reversed(messages):
|
||||||
|
if msg.get("role") == "user":
|
||||||
|
question = msg.get("content", "")
|
||||||
|
break
|
||||||
|
|
||||||
|
history = convert_history(messages)
|
||||||
|
system_prompt_override = extract_system_prompt(messages)
|
||||||
|
|
||||||
|
docsgpt = data.get("docsgpt", {})
|
||||||
|
|
||||||
|
result = {
|
||||||
|
"question": question,
|
||||||
|
"api_key": api_key,
|
||||||
|
"history": json.dumps(history),
|
||||||
|
# Conversations are NOT persisted by default on the v1 endpoint.
|
||||||
|
# Callers opt in via ``docsgpt.save_conversation: true``.
|
||||||
|
"save_conversation": bool(docsgpt.get("save_conversation", False)),
|
||||||
|
}
|
||||||
|
|
||||||
|
if system_prompt_override is not None:
|
||||||
|
result["system_prompt_override"] = system_prompt_override
|
||||||
|
|
||||||
|
# Client tools
|
||||||
|
if data.get("tools"):
|
||||||
|
result["client_tools"] = data["tools"]
|
||||||
|
|
||||||
|
# DocsGPT extensions
|
||||||
|
if docsgpt.get("attachments"):
|
||||||
|
result["attachments"] = docsgpt["attachments"]
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Response translation (non-streaming)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def translate_response(
|
||||||
|
conversation_id: str,
|
||||||
|
answer: str,
|
||||||
|
sources: Optional[List[Dict]],
|
||||||
|
tool_calls: Optional[List[Dict]],
|
||||||
|
thought: str,
|
||||||
|
model_name: str,
|
||||||
|
pending_tool_calls: Optional[List[Dict]] = None,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Translate DocsGPT response to chat completions format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
conversation_id: The DocsGPT conversation ID.
|
||||||
|
answer: The assistant's text response.
|
||||||
|
sources: RAG retrieval sources.
|
||||||
|
tool_calls: Completed tool call results.
|
||||||
|
thought: Reasoning/thinking tokens.
|
||||||
|
model_name: Model/agent identifier.
|
||||||
|
pending_tool_calls: Pending client-side tool calls (if paused).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict in the standard chat completions response format.
|
||||||
|
"""
|
||||||
|
created = int(time.time())
|
||||||
|
completion_id = f"chatcmpl-{conversation_id}" if conversation_id else f"chatcmpl-{created}"
|
||||||
|
|
||||||
|
# Build message
|
||||||
|
message: Dict[str, Any] = {"role": "assistant"}
|
||||||
|
|
||||||
|
if pending_tool_calls:
|
||||||
|
# Tool calls pending — return them for client execution
|
||||||
|
message["content"] = None
|
||||||
|
message["tool_calls"] = [
|
||||||
|
{
|
||||||
|
"id": tc.get("call_id", ""),
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": _get_client_tool_name(tc),
|
||||||
|
"arguments": (
|
||||||
|
json.dumps(tc["arguments"])
|
||||||
|
if isinstance(tc.get("arguments"), dict)
|
||||||
|
else tc.get("arguments", "{}")
|
||||||
|
),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for tc in pending_tool_calls
|
||||||
|
]
|
||||||
|
finish_reason = "tool_calls"
|
||||||
|
else:
|
||||||
|
message["content"] = answer
|
||||||
|
if thought:
|
||||||
|
message["reasoning_content"] = thought
|
||||||
|
finish_reason = "stop"
|
||||||
|
|
||||||
|
result: Dict[str, Any] = {
|
||||||
|
"id": completion_id,
|
||||||
|
"object": "chat.completion",
|
||||||
|
"created": created,
|
||||||
|
"model": model_name,
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"message": message,
|
||||||
|
"finish_reason": finish_reason,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": 0,
|
||||||
|
"completion_tokens": 0,
|
||||||
|
"total_tokens": 0,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
# DocsGPT extensions
|
||||||
|
docsgpt: Dict[str, Any] = {}
|
||||||
|
if conversation_id:
|
||||||
|
docsgpt["conversation_id"] = conversation_id
|
||||||
|
if sources:
|
||||||
|
docsgpt["sources"] = sources
|
||||||
|
if tool_calls:
|
||||||
|
docsgpt["tool_calls"] = tool_calls
|
||||||
|
if docsgpt:
|
||||||
|
result["docsgpt"] = docsgpt
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Streaming event translation
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _make_chunk(
|
||||||
|
completion_id: str,
|
||||||
|
model_name: str,
|
||||||
|
delta: Dict[str, Any],
|
||||||
|
finish_reason: Optional[str] = None,
|
||||||
|
) -> str:
|
||||||
|
"""Build a single SSE chunk in the standard streaming format."""
|
||||||
|
chunk = {
|
||||||
|
"id": completion_id,
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"created": int(time.time()),
|
||||||
|
"model": model_name,
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"delta": delta,
|
||||||
|
"finish_reason": finish_reason,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
return f"data: {json.dumps(chunk)}\n\n"
|
||||||
|
|
||||||
|
|
||||||
|
def _make_docsgpt_chunk(data: Dict[str, Any]) -> str:
|
||||||
|
"""Build a DocsGPT extension SSE chunk."""
|
||||||
|
return f"data: {json.dumps({'docsgpt': data})}\n\n"
|
||||||
|
|
||||||
|
|
||||||
|
def translate_stream_event(
|
||||||
|
event_data: Dict[str, Any],
|
||||||
|
completion_id: str,
|
||||||
|
model_name: str,
|
||||||
|
) -> List[str]:
|
||||||
|
"""Translate a DocsGPT SSE event dict to standard streaming chunks.
|
||||||
|
|
||||||
|
May return 0, 1, or 2 chunks per input event. For example, a completed
|
||||||
|
tool call produces both a docsgpt extension chunk and nothing on the
|
||||||
|
standard side (since server-side tool calls aren't surfaced in standard
|
||||||
|
format).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
event_data: Parsed DocsGPT event dict.
|
||||||
|
completion_id: The completion ID for this response.
|
||||||
|
model_name: Model/agent identifier.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of SSE-formatted strings to send to the client.
|
||||||
|
"""
|
||||||
|
event_type = event_data.get("type")
|
||||||
|
chunks: List[str] = []
|
||||||
|
|
||||||
|
if event_type == "answer":
|
||||||
|
chunks.append(
|
||||||
|
_make_chunk(completion_id, model_name, {"content": event_data.get("answer", "")})
|
||||||
|
)
|
||||||
|
|
||||||
|
elif event_type == "thought":
|
||||||
|
chunks.append(
|
||||||
|
_make_chunk(
|
||||||
|
completion_id, model_name,
|
||||||
|
{"reasoning_content": event_data.get("thought", "")},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
elif event_type == "source":
|
||||||
|
chunks.append(
|
||||||
|
_make_docsgpt_chunk({
|
||||||
|
"type": "source",
|
||||||
|
"sources": event_data.get("source", []),
|
||||||
|
})
|
||||||
|
)
|
||||||
|
|
||||||
|
elif event_type == "tool_call":
|
||||||
|
tc_data = event_data.get("data", {})
|
||||||
|
status = tc_data.get("status")
|
||||||
|
|
||||||
|
if status == "requires_client_execution":
|
||||||
|
# Standard: stream as tool_calls delta
|
||||||
|
args = tc_data.get("arguments", {})
|
||||||
|
args_str = json.dumps(args) if isinstance(args, dict) else str(args)
|
||||||
|
chunks.append(
|
||||||
|
_make_chunk(completion_id, model_name, {
|
||||||
|
"tool_calls": [{
|
||||||
|
"index": 0,
|
||||||
|
"id": tc_data.get("call_id", ""),
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": _get_client_tool_name(tc_data),
|
||||||
|
"arguments": args_str,
|
||||||
|
},
|
||||||
|
}],
|
||||||
|
})
|
||||||
|
)
|
||||||
|
elif status == "awaiting_approval":
|
||||||
|
# Extension: approval needed
|
||||||
|
chunks.append(_make_docsgpt_chunk({"type": "tool_call", "data": tc_data}))
|
||||||
|
elif status in ("completed", "pending", "error", "denied", "skipped"):
|
||||||
|
# Extension: tool call progress
|
||||||
|
chunks.append(_make_docsgpt_chunk({"type": "tool_call", "data": tc_data}))
|
||||||
|
|
||||||
|
elif event_type == "tool_calls_pending":
|
||||||
|
# Standard: finish_reason = tool_calls
|
||||||
|
chunks.append(
|
||||||
|
_make_chunk(completion_id, model_name, {}, finish_reason="tool_calls")
|
||||||
|
)
|
||||||
|
# Also emit as docsgpt extension
|
||||||
|
chunks.append(
|
||||||
|
_make_docsgpt_chunk({
|
||||||
|
"type": "tool_calls_pending",
|
||||||
|
"pending_tool_calls": event_data.get("data", {}).get("pending_tool_calls", []),
|
||||||
|
})
|
||||||
|
)
|
||||||
|
|
||||||
|
elif event_type == "end":
|
||||||
|
chunks.append(
|
||||||
|
_make_chunk(completion_id, model_name, {}, finish_reason="stop")
|
||||||
|
)
|
||||||
|
chunks.append("data: [DONE]\n\n")
|
||||||
|
|
||||||
|
elif event_type == "id":
|
||||||
|
chunks.append(
|
||||||
|
_make_docsgpt_chunk({
|
||||||
|
"type": "id",
|
||||||
|
"conversation_id": event_data.get("id", ""),
|
||||||
|
})
|
||||||
|
)
|
||||||
|
|
||||||
|
elif event_type == "error":
|
||||||
|
# Emit as standard error (non-standard but widely supported)
|
||||||
|
error_data = {
|
||||||
|
"error": {
|
||||||
|
"message": event_data.get("error", "An error occurred"),
|
||||||
|
"type": "server_error",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
chunks.append(f"data: {json.dumps(error_data)}\n\n")
|
||||||
|
|
||||||
|
elif event_type == "structured_answer":
|
||||||
|
chunks.append(
|
||||||
|
_make_chunk(
|
||||||
|
completion_id, model_name,
|
||||||
|
{"content": event_data.get("answer", "")},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Skip: tool_calls (redundant), research_plan, research_progress
|
||||||
|
|
||||||
|
return chunks
|
||||||
@@ -17,6 +17,7 @@ from application.api.answer import answer # noqa: E402
|
|||||||
from application.api.internal.routes import internal # noqa: E402
|
from application.api.internal.routes import internal # noqa: E402
|
||||||
from application.api.user.routes import user # noqa: E402
|
from application.api.user.routes import user # noqa: E402
|
||||||
from application.api.connector.routes import connector # noqa: E402
|
from application.api.connector.routes import connector # noqa: E402
|
||||||
|
from application.api.v1 import v1_bp # noqa: E402
|
||||||
from application.celery_init import celery # noqa: E402
|
from application.celery_init import celery # noqa: E402
|
||||||
from application.core.settings import settings # noqa: E402
|
from application.core.settings import settings # noqa: E402
|
||||||
from application.stt.upload_limits import ( # noqa: E402
|
from application.stt.upload_limits import ( # noqa: E402
|
||||||
@@ -36,6 +37,7 @@ app.register_blueprint(user)
|
|||||||
app.register_blueprint(answer)
|
app.register_blueprint(answer)
|
||||||
app.register_blueprint(internal)
|
app.register_blueprint(internal)
|
||||||
app.register_blueprint(connector)
|
app.register_blueprint(connector)
|
||||||
|
app.register_blueprint(v1_bp)
|
||||||
app.config.update(
|
app.config.update(
|
||||||
UPLOAD_FOLDER="inputs",
|
UPLOAD_FOLDER="inputs",
|
||||||
CELERY_BROKER_URL=settings.CELERY_BROKER_URL,
|
CELERY_BROKER_URL=settings.CELERY_BROKER_URL,
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
from celery import Celery
|
from celery import Celery
|
||||||
from application.core.settings import settings
|
from application.core.settings import settings
|
||||||
from celery.signals import setup_logging
|
from celery.signals import setup_logging, worker_process_init
|
||||||
|
|
||||||
|
|
||||||
def make_celery(app_name=__name__):
|
def make_celery(app_name=__name__):
|
||||||
@@ -20,5 +20,24 @@ def config_loggers(*args, **kwargs):
|
|||||||
setup_logging()
|
setup_logging()
|
||||||
|
|
||||||
|
|
||||||
|
@worker_process_init.connect
|
||||||
|
def _dispose_db_engine_on_fork(*args, **kwargs):
|
||||||
|
"""Dispose the SQLAlchemy engine pool in each forked Celery worker.
|
||||||
|
|
||||||
|
SQLAlchemy connection pools are not fork-safe: file descriptors shared
|
||||||
|
between the parent and a forked worker will corrupt the pool. Disposing
|
||||||
|
on ``worker_process_init`` gives every worker its own fresh pool on
|
||||||
|
first use.
|
||||||
|
|
||||||
|
Imported lazily so Celery workers that don't touch Postgres (or where
|
||||||
|
``POSTGRES_URI`` is unset) don't fail at startup.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from application.storage.db.engine import dispose_engine
|
||||||
|
except Exception:
|
||||||
|
return
|
||||||
|
dispose_engine()
|
||||||
|
|
||||||
|
|
||||||
celery = make_celery()
|
celery = make_celery()
|
||||||
celery.config_from_object("application.celeryconfig")
|
celery.config_from_object("application.celeryconfig")
|
||||||
|
|||||||
89
application/core/db_uri.py
Normal file
89
application/core/db_uri.py
Normal file
@@ -0,0 +1,89 @@
|
|||||||
|
"""Normalize user-supplied Postgres URIs for different drivers.
|
||||||
|
|
||||||
|
DocsGPT has two Postgres connection strings pointing at potentially
|
||||||
|
different databases:
|
||||||
|
|
||||||
|
* ``POSTGRES_URI`` feeds SQLAlchemy, which needs the
|
||||||
|
``postgresql+psycopg://`` dialect prefix to pick the psycopg v3 driver.
|
||||||
|
* ``PGVECTOR_CONNECTION_STRING`` feeds ``psycopg.connect()`` directly
|
||||||
|
(via libpq) in ``application/vectorstore/pgvector.py``. libpq only
|
||||||
|
understands ``postgres://`` and ``postgresql://`` — the SQLAlchemy
|
||||||
|
dialect prefix is an invalid URI from its point of view.
|
||||||
|
|
||||||
|
The two fields therefore need opposite normalization so operators don't
|
||||||
|
have to know which driver a given field feeds. Each normalizer also
|
||||||
|
silently upgrades the legacy ``postgresql+psycopg2://`` prefix since
|
||||||
|
psycopg2 is no longer in the project.
|
||||||
|
|
||||||
|
This module is deliberately separate from ``application/core/settings.py``
|
||||||
|
so the Settings class stays focused on field declarations, and the
|
||||||
|
URI-rewriting logic can be unit-tested without triggering ``.env``
|
||||||
|
file loading from importing Settings.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
|
||||||
|
def _rewrite_uri_prefixes(v, rewrites):
|
||||||
|
"""Shared URI prefix rewriter used by both normalizers below.
|
||||||
|
|
||||||
|
Strips whitespace, returns ``None`` for empty / ``"none"`` values,
|
||||||
|
applies the first matching rewrite, and passes unrecognised input
|
||||||
|
through so downstream consumers (SQLAlchemy, libpq) can produce
|
||||||
|
their own error messages rather than us silently eating a
|
||||||
|
misconfiguration.
|
||||||
|
"""
|
||||||
|
if v is None:
|
||||||
|
return None
|
||||||
|
if not isinstance(v, str):
|
||||||
|
return v
|
||||||
|
v = v.strip()
|
||||||
|
if not v or v.lower() == "none":
|
||||||
|
return None
|
||||||
|
for prefix, target in rewrites:
|
||||||
|
if v.startswith(prefix):
|
||||||
|
return target + v[len(prefix):]
|
||||||
|
return v
|
||||||
|
|
||||||
|
|
||||||
|
# POSTGRES_URI feeds SQLAlchemy, which needs a ``postgresql+psycopg://``
|
||||||
|
# dialect prefix to select the psycopg v3 driver. Normalize the
|
||||||
|
# operator-friendly forms TOWARD that dialect.
|
||||||
|
_POSTGRES_URI_REWRITES = (
|
||||||
|
("postgresql+psycopg2://", "postgresql+psycopg://"),
|
||||||
|
("postgresql://", "postgresql+psycopg://"),
|
||||||
|
("postgres://", "postgresql+psycopg://"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# PGVECTOR_CONNECTION_STRING feeds ``psycopg.connect()`` directly in
|
||||||
|
# application/vectorstore/pgvector.py — NOT SQLAlchemy. libpq only
|
||||||
|
# understands ``postgres://`` and ``postgresql://``; the SQLAlchemy
|
||||||
|
# dialect prefix is an invalid URI from libpq's point of view. Strip it
|
||||||
|
# if the operator accidentally copied their POSTGRES_URI value here.
|
||||||
|
_PGVECTOR_CONNECTION_STRING_REWRITES = (
|
||||||
|
("postgresql+psycopg2://", "postgresql://"),
|
||||||
|
("postgresql+psycopg://", "postgresql://"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_postgres_uri(v):
|
||||||
|
"""Normalize a user-supplied POSTGRES_URI to the SQLAlchemy psycopg3 form.
|
||||||
|
|
||||||
|
Accepts the forms operators naturally write (``postgres://``,
|
||||||
|
``postgresql://``) and rewrites them to ``postgresql+psycopg://``.
|
||||||
|
Unknown schemes pass through unchanged so SQLAlchemy can produce its
|
||||||
|
own dialect-not-found error.
|
||||||
|
"""
|
||||||
|
return _rewrite_uri_prefixes(v, _POSTGRES_URI_REWRITES)
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_pgvector_connection_string(v):
|
||||||
|
"""Normalize a user-supplied PGVECTOR_CONNECTION_STRING for libpq.
|
||||||
|
|
||||||
|
Strips the SQLAlchemy dialect prefix if the operator accidentally
|
||||||
|
copied their POSTGRES_URI value here — libpq can't parse it.
|
||||||
|
User-friendly forms (``postgres://``, ``postgresql://``) pass
|
||||||
|
through unchanged since libpq accepts them natively.
|
||||||
|
"""
|
||||||
|
return _rewrite_uri_prefixes(v, _PGVECTOR_CONNECTION_STRING_REWRITES)
|
||||||
@@ -8,6 +8,12 @@ from pydantic_settings import BaseSettings, SettingsConfigDict
|
|||||||
current_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
current_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
|
|
||||||
|
|
||||||
|
from application.core.db_uri import ( # noqa: E402
|
||||||
|
normalize_pgvector_connection_string,
|
||||||
|
normalize_postgres_uri,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class Settings(BaseSettings):
|
class Settings(BaseSettings):
|
||||||
model_config = SettingsConfigDict(extra="ignore")
|
model_config = SettingsConfigDict(extra="ignore")
|
||||||
|
|
||||||
@@ -22,6 +28,11 @@ class Settings(BaseSettings):
|
|||||||
CELERY_RESULT_BACKEND: str = "redis://localhost:6379/1"
|
CELERY_RESULT_BACKEND: str = "redis://localhost:6379/1"
|
||||||
MONGO_URI: str = "mongodb://localhost:27017/docsgpt"
|
MONGO_URI: str = "mongodb://localhost:27017/docsgpt"
|
||||||
MONGO_DB_NAME: str = "docsgpt"
|
MONGO_DB_NAME: str = "docsgpt"
|
||||||
|
# User-data Postgres DB.
|
||||||
|
POSTGRES_URI: Optional[str] = None
|
||||||
|
|
||||||
|
# MongoDB→Postgres migration: dual-write to Postgres (Mongo stays source of truth)
|
||||||
|
USE_POSTGRES: bool = False
|
||||||
LLM_PATH: str = os.path.join(current_dir, "models/docsgpt-7b-f16.gguf")
|
LLM_PATH: str = os.path.join(current_dir, "models/docsgpt-7b-f16.gguf")
|
||||||
DEFAULT_MAX_HISTORY: int = 150
|
DEFAULT_MAX_HISTORY: int = 150
|
||||||
DEFAULT_LLM_TOKEN_LIMIT: int = 128000 # Fallback when model not found in registry
|
DEFAULT_LLM_TOKEN_LIMIT: int = 128000 # Fallback when model not found in registry
|
||||||
@@ -59,6 +70,10 @@ class Settings(BaseSettings):
|
|||||||
MICROSOFT_TENANT_ID: Optional[str] = "common" # Azure AD Tenant ID (or 'common' for multi-tenant)
|
MICROSOFT_TENANT_ID: Optional[str] = "common" # Azure AD Tenant ID (or 'common' for multi-tenant)
|
||||||
MICROSOFT_AUTHORITY: Optional[str] = None # e.g., "https://login.microsoftonline.com/{tenant_id}"
|
MICROSOFT_AUTHORITY: Optional[str] = None # e.g., "https://login.microsoftonline.com/{tenant_id}"
|
||||||
|
|
||||||
|
# Confluence Cloud integration
|
||||||
|
CONFLUENCE_CLIENT_ID: Optional[str] = None
|
||||||
|
CONFLUENCE_CLIENT_SECRET: Optional[str] = None
|
||||||
|
|
||||||
# GitHub source
|
# GitHub source
|
||||||
GITHUB_ACCESS_TOKEN: Optional[str] = None # PAT token with read repo access
|
GITHUB_ACCESS_TOKEN: Optional[str] = None # PAT token with read repo access
|
||||||
|
|
||||||
@@ -117,7 +132,10 @@ class Settings(BaseSettings):
|
|||||||
QDRANT_PATH: Optional[str] = None
|
QDRANT_PATH: Optional[str] = None
|
||||||
QDRANT_DISTANCE_FUNC: str = "Cosine"
|
QDRANT_DISTANCE_FUNC: str = "Cosine"
|
||||||
|
|
||||||
# PGVector vectorstore config
|
# PGVector vectorstore config. Write the URI in whichever form you
|
||||||
|
# prefer — ``postgres://``, ``postgresql://``, or even the SQLAlchemy
|
||||||
|
# dialect form (``postgresql+psycopg://``) are all accepted and
|
||||||
|
# normalized internally for ``psycopg.connect()``.
|
||||||
PGVECTOR_CONNECTION_STRING: Optional[str] = None
|
PGVECTOR_CONNECTION_STRING: Optional[str] = None
|
||||||
# Milvus vectorstore config
|
# Milvus vectorstore config
|
||||||
MILVUS_COLLECTION_NAME: Optional[str] = "docsgpt"
|
MILVUS_COLLECTION_NAME: Optional[str] = "docsgpt"
|
||||||
@@ -156,6 +174,16 @@ class Settings(BaseSettings):
|
|||||||
COMPRESSION_PROMPT_VERSION: str = "v1.0" # Track prompt iterations
|
COMPRESSION_PROMPT_VERSION: str = "v1.0" # Track prompt iterations
|
||||||
COMPRESSION_MAX_HISTORY_POINTS: int = 3 # Keep only last N compression points to prevent DB bloat
|
COMPRESSION_MAX_HISTORY_POINTS: int = 3 # Keep only last N compression points to prevent DB bloat
|
||||||
|
|
||||||
|
@field_validator("POSTGRES_URI", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def _normalize_postgres_uri_validator(cls, v):
|
||||||
|
return normalize_postgres_uri(v)
|
||||||
|
|
||||||
|
@field_validator("PGVECTOR_CONNECTION_STRING", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def _normalize_pgvector_connection_string_validator(cls, v):
|
||||||
|
return normalize_pgvector_connection_string(v)
|
||||||
|
|
||||||
@field_validator(
|
@field_validator(
|
||||||
"API_KEY",
|
"API_KEY",
|
||||||
"OPENAI_API_KEY",
|
"OPENAI_API_KEY",
|
||||||
|
|||||||
@@ -167,6 +167,8 @@ class GoogleLLM(BaseLLM):
|
|||||||
return "\n".join(parts)
|
return "\n".join(parts)
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
import json as _json
|
||||||
|
|
||||||
for message in messages:
|
for message in messages:
|
||||||
role = message.get("role")
|
role = message.get("role")
|
||||||
content = message.get("content")
|
content = message.get("content")
|
||||||
@@ -180,9 +182,66 @@ class GoogleLLM(BaseLLM):
|
|||||||
|
|
||||||
if role == "assistant":
|
if role == "assistant":
|
||||||
role = "model"
|
role = "model"
|
||||||
elif role == "tool":
|
|
||||||
role = "model"
|
|
||||||
parts = []
|
parts = []
|
||||||
|
|
||||||
|
# Standard format: assistant message with tool_calls array
|
||||||
|
msg_tool_calls = message.get("tool_calls")
|
||||||
|
if msg_tool_calls and role == "model":
|
||||||
|
for tc in msg_tool_calls:
|
||||||
|
func = tc.get("function", {})
|
||||||
|
args = func.get("arguments", "{}")
|
||||||
|
if isinstance(args, str):
|
||||||
|
try:
|
||||||
|
args = _json.loads(args)
|
||||||
|
except (_json.JSONDecodeError, TypeError):
|
||||||
|
args = {}
|
||||||
|
cleaned_args = self._remove_null_values(args)
|
||||||
|
thought_sig = tc.get("thought_signature")
|
||||||
|
if thought_sig:
|
||||||
|
parts.append(
|
||||||
|
types.Part(
|
||||||
|
functionCall=types.FunctionCall(
|
||||||
|
name=func.get("name", ""),
|
||||||
|
args=cleaned_args,
|
||||||
|
),
|
||||||
|
thoughtSignature=thought_sig,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
parts.append(
|
||||||
|
types.Part.from_function_call(
|
||||||
|
name=func.get("name", ""),
|
||||||
|
args=cleaned_args,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if parts:
|
||||||
|
cleaned_messages.append(types.Content(role=role, parts=parts))
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Standard format: tool message with tool_call_id
|
||||||
|
tool_call_id = message.get("tool_call_id")
|
||||||
|
if role == "tool" and tool_call_id is not None:
|
||||||
|
result_content = content
|
||||||
|
if isinstance(result_content, str):
|
||||||
|
try:
|
||||||
|
result_content = _json.loads(result_content)
|
||||||
|
except (_json.JSONDecodeError, TypeError):
|
||||||
|
pass
|
||||||
|
# Google expects function_response name — extract from tool_call_id context
|
||||||
|
# We use a placeholder name since Google API doesn't require exact match
|
||||||
|
parts.append(
|
||||||
|
types.Part.from_function_response(
|
||||||
|
name="tool_result",
|
||||||
|
response={"result": result_content},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
cleaned_messages.append(types.Content(role="model", parts=parts))
|
||||||
|
continue
|
||||||
|
|
||||||
|
if role == "tool":
|
||||||
|
role = "model"
|
||||||
|
|
||||||
if role and content is not None:
|
if role and content is not None:
|
||||||
if isinstance(content, str):
|
if isinstance(content, str):
|
||||||
parts = [types.Part.from_text(text=content)]
|
parts = [types.Part.from_text(text=content)]
|
||||||
@@ -191,15 +250,11 @@ class GoogleLLM(BaseLLM):
|
|||||||
if "text" in item:
|
if "text" in item:
|
||||||
parts.append(types.Part.from_text(text=item["text"]))
|
parts.append(types.Part.from_text(text=item["text"]))
|
||||||
elif "function_call" in item:
|
elif "function_call" in item:
|
||||||
# Remove null values from args to avoid API errors
|
# Legacy format support
|
||||||
|
|
||||||
cleaned_args = self._remove_null_values(
|
cleaned_args = self._remove_null_values(
|
||||||
item["function_call"]["args"]
|
item["function_call"]["args"]
|
||||||
)
|
)
|
||||||
# Create function call part with thought_signature if present
|
|
||||||
# For Gemini 3 models, we need to include thought_signature
|
|
||||||
if "thought_signature" in item:
|
if "thought_signature" in item:
|
||||||
# Use Part constructor with functionCall and thoughtSignature
|
|
||||||
parts.append(
|
parts.append(
|
||||||
types.Part(
|
types.Part(
|
||||||
functionCall=types.FunctionCall(
|
functionCall=types.FunctionCall(
|
||||||
@@ -210,7 +265,6 @@ class GoogleLLM(BaseLLM):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Use helper method when no thought_signature
|
|
||||||
parts.append(
|
parts.append(
|
||||||
types.Part.from_function_call(
|
types.Part.from_function_call(
|
||||||
name=item["function_call"]["name"],
|
name=item["function_call"]["name"],
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
@@ -315,10 +316,34 @@ class LLMHandler(ABC):
|
|||||||
current_prompt = self._extract_text_from_content(content)
|
current_prompt = self._extract_text_from_content(content)
|
||||||
|
|
||||||
elif role in {"assistant", "model"}:
|
elif role in {"assistant", "model"}:
|
||||||
# If this assistant turn contains tool calls, collect them; otherwise commit a response.
|
# Standard format: tool_calls array on assistant message
|
||||||
|
msg_tool_calls = message.get("tool_calls")
|
||||||
|
if msg_tool_calls:
|
||||||
|
for tc in msg_tool_calls:
|
||||||
|
call_id = tc.get("id") or str(uuid.uuid4())
|
||||||
|
func = tc.get("function", {})
|
||||||
|
args = func.get("arguments")
|
||||||
|
if isinstance(args, str):
|
||||||
|
try:
|
||||||
|
args = json.loads(args)
|
||||||
|
except (json.JSONDecodeError, TypeError):
|
||||||
|
pass
|
||||||
|
current_tool_calls[call_id] = {
|
||||||
|
"tool_name": "unknown_tool",
|
||||||
|
"action_name": func.get("name"),
|
||||||
|
"arguments": args,
|
||||||
|
"result": None,
|
||||||
|
"status": "called",
|
||||||
|
"call_id": call_id,
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Legacy format: function_call/function_response in content list
|
||||||
if isinstance(content, list):
|
if isinstance(content, list):
|
||||||
|
has_fc = False
|
||||||
for item in content:
|
for item in content:
|
||||||
if "function_call" in item:
|
if "function_call" in item:
|
||||||
|
has_fc = True
|
||||||
fc = item["function_call"]
|
fc = item["function_call"]
|
||||||
call_id = fc.get("call_id") or str(uuid.uuid4())
|
call_id = fc.get("call_id") or str(uuid.uuid4())
|
||||||
current_tool_calls[call_id] = {
|
current_tool_calls[call_id] = {
|
||||||
@@ -329,37 +354,30 @@ class LLMHandler(ABC):
|
|||||||
"status": "called",
|
"status": "called",
|
||||||
"call_id": call_id,
|
"call_id": call_id,
|
||||||
}
|
}
|
||||||
elif "function_response" in item:
|
if has_fc:
|
||||||
fr = item["function_response"]
|
continue
|
||||||
call_id = fr.get("call_id") or str(uuid.uuid4())
|
|
||||||
current_tool_calls[call_id] = {
|
|
||||||
"tool_name": "unknown_tool",
|
|
||||||
"action_name": fr.get("name"),
|
|
||||||
"arguments": None,
|
|
||||||
"result": fr.get("response", {}).get("result"),
|
|
||||||
"status": "completed",
|
|
||||||
"call_id": call_id,
|
|
||||||
}
|
|
||||||
# No direct assistant text here; continue to next message
|
|
||||||
continue
|
|
||||||
|
|
||||||
response_text = self._extract_text_from_content(content)
|
response_text = self._extract_text_from_content(content)
|
||||||
_commit_query(response_text)
|
_commit_query(response_text)
|
||||||
|
|
||||||
elif role == "tool":
|
elif role == "tool":
|
||||||
# Attach tool outputs to the latest pending tool call if possible
|
# Standard format: tool_call_id on tool message
|
||||||
|
call_id = message.get("tool_call_id")
|
||||||
tool_text = self._extract_text_from_content(content)
|
tool_text = self._extract_text_from_content(content)
|
||||||
# Attempt to parse function_response style
|
|
||||||
call_id = None
|
|
||||||
if isinstance(content, list):
|
|
||||||
for item in content:
|
|
||||||
if "function_response" in item and item["function_response"].get("call_id"):
|
|
||||||
call_id = item["function_response"]["call_id"]
|
|
||||||
break
|
|
||||||
if call_id and call_id in current_tool_calls:
|
if call_id and call_id in current_tool_calls:
|
||||||
current_tool_calls[call_id]["result"] = tool_text
|
current_tool_calls[call_id]["result"] = tool_text
|
||||||
current_tool_calls[call_id]["status"] = "completed"
|
current_tool_calls[call_id]["status"] = "completed"
|
||||||
elif queries:
|
# Legacy: function_response in content list
|
||||||
|
elif isinstance(content, list):
|
||||||
|
for item in content:
|
||||||
|
if "function_response" in item:
|
||||||
|
legacy_id = item["function_response"].get("call_id")
|
||||||
|
if legacy_id and legacy_id in current_tool_calls:
|
||||||
|
current_tool_calls[legacy_id]["result"] = tool_text
|
||||||
|
current_tool_calls[legacy_id]["status"] = "completed"
|
||||||
|
break
|
||||||
|
elif call_id is None and queries:
|
||||||
queries[-1].setdefault("tool_calls", []).append(
|
queries[-1].setdefault("tool_calls", []).append(
|
||||||
{
|
{
|
||||||
"tool_name": "unknown_tool",
|
"tool_name": "unknown_tool",
|
||||||
@@ -648,6 +666,13 @@ class LLMHandler(ABC):
|
|||||||
"""
|
"""
|
||||||
Execute tool calls and update conversation history.
|
Execute tool calls and update conversation history.
|
||||||
|
|
||||||
|
When a tool requires approval or client-side execution, it is
|
||||||
|
collected as a pending action instead of being executed. The
|
||||||
|
generator returns ``(updated_messages, pending_actions)`` where
|
||||||
|
*pending_actions* is ``None`` when every tool was executed
|
||||||
|
normally, or a list of dicts describing actions the client must
|
||||||
|
resolve before the LLM loop can continue.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
agent: The agent instance
|
agent: The agent instance
|
||||||
tool_calls: List of tool calls to execute
|
tool_calls: List of tool calls to execute
|
||||||
@@ -655,9 +680,11 @@ class LLMHandler(ABC):
|
|||||||
messages: Current conversation history
|
messages: Current conversation history
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Updated messages list
|
Tuple of (updated_messages, pending_actions).
|
||||||
|
pending_actions is None if all tools executed, otherwise a list.
|
||||||
"""
|
"""
|
||||||
updated_messages = messages.copy()
|
updated_messages = messages.copy()
|
||||||
|
pending_actions: List[Dict] = []
|
||||||
|
|
||||||
for i, call in enumerate(tool_calls):
|
for i, call in enumerate(tool_calls):
|
||||||
# Check context limit before executing tool call
|
# Check context limit before executing tool call
|
||||||
@@ -763,6 +790,29 @@ class LLMHandler(ABC):
|
|||||||
# Set flag on agent
|
# Set flag on agent
|
||||||
agent.context_limit_reached = True
|
agent.context_limit_reached = True
|
||||||
break
|
break
|
||||||
|
|
||||||
|
# ---- Pause check: approval / client-side execution ----
|
||||||
|
llm_class = agent.llm.__class__.__name__
|
||||||
|
pause_info = agent.tool_executor.check_pause(
|
||||||
|
tools_dict, call, llm_class
|
||||||
|
)
|
||||||
|
if pause_info:
|
||||||
|
# Yield pause event so the client knows this tool is waiting
|
||||||
|
yield {
|
||||||
|
"type": "tool_call",
|
||||||
|
"data": {
|
||||||
|
"tool_name": pause_info["tool_name"],
|
||||||
|
"call_id": pause_info["call_id"],
|
||||||
|
"action_name": pause_info.get("llm_name", pause_info["name"]),
|
||||||
|
"arguments": pause_info["arguments"],
|
||||||
|
"status": pause_info["pause_type"],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
pending_actions.append(pause_info)
|
||||||
|
# Do NOT add messages for pending tools here.
|
||||||
|
# They will be added on resume to keep call/result pairs together.
|
||||||
|
continue
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.tool_calls.append(call)
|
self.tool_calls.append(call)
|
||||||
tool_executor_gen = agent._execute_tool_action(tools_dict, call)
|
tool_executor_gen = agent._execute_tool_action(tools_dict, call)
|
||||||
@@ -773,24 +823,29 @@ class LLMHandler(ABC):
|
|||||||
tool_response, call_id = e.value
|
tool_response, call_id = e.value
|
||||||
break
|
break
|
||||||
|
|
||||||
function_call_content = {
|
# Standard internal format: assistant message with tool_calls array
|
||||||
"function_call": {
|
args_str = (
|
||||||
"name": call.name,
|
json.dumps(call.arguments)
|
||||||
"args": call.arguments,
|
if isinstance(call.arguments, dict)
|
||||||
"call_id": call_id,
|
else call.arguments
|
||||||
}
|
|
||||||
}
|
|
||||||
# Include thought_signature for Google Gemini 3 models
|
|
||||||
# It should be at the same level as function_call, not inside it
|
|
||||||
if call.thought_signature:
|
|
||||||
function_call_content["thought_signature"] = call.thought_signature
|
|
||||||
updated_messages.append(
|
|
||||||
{
|
|
||||||
"role": "assistant",
|
|
||||||
"content": [function_call_content],
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
tool_call_obj = {
|
||||||
|
"id": call_id,
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": call.name,
|
||||||
|
"arguments": args_str,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
# Preserve thought_signature for Google Gemini 3 models
|
||||||
|
if call.thought_signature:
|
||||||
|
tool_call_obj["thought_signature"] = call.thought_signature
|
||||||
|
|
||||||
|
updated_messages.append({
|
||||||
|
"role": "assistant",
|
||||||
|
"content": None,
|
||||||
|
"tool_calls": [tool_call_obj],
|
||||||
|
})
|
||||||
|
|
||||||
updated_messages.append(self.create_tool_message(call, tool_response))
|
updated_messages.append(self.create_tool_message(call, tool_response))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -802,16 +857,15 @@ class LLMHandler(ABC):
|
|||||||
error_message = self.create_tool_message(error_call, error_response)
|
error_message = self.create_tool_message(error_call, error_response)
|
||||||
updated_messages.append(error_message)
|
updated_messages.append(error_message)
|
||||||
|
|
||||||
call_parts = call.name.split("_")
|
mapping = agent.tool_executor._name_to_tool
|
||||||
if len(call_parts) >= 2:
|
if call.name in mapping:
|
||||||
tool_id = call_parts[-1] # Last part is tool ID (e.g., "1")
|
resolved_tool_id, _ = mapping[call.name]
|
||||||
action_name = "_".join(call_parts[:-1])
|
tool_name = tools_dict.get(resolved_tool_id, {}).get(
|
||||||
tool_name = tools_dict.get(tool_id, {}).get("name", "unknown_tool")
|
"name", "unknown_tool"
|
||||||
full_action_name = f"{action_name}_{tool_id}"
|
)
|
||||||
else:
|
else:
|
||||||
tool_name = "unknown_tool"
|
tool_name = "unknown_tool"
|
||||||
action_name = call.name
|
full_action_name = call.name
|
||||||
full_action_name = call.name
|
|
||||||
yield {
|
yield {
|
||||||
"type": "tool_call",
|
"type": "tool_call",
|
||||||
"data": {
|
"data": {
|
||||||
@@ -823,7 +877,7 @@ class LLMHandler(ABC):
|
|||||||
"status": "error",
|
"status": "error",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
return updated_messages
|
return updated_messages, pending_actions if pending_actions else None
|
||||||
|
|
||||||
def handle_non_streaming(
|
def handle_non_streaming(
|
||||||
self, agent, response: Any, tools_dict: Dict, messages: List[Dict]
|
self, agent, response: Any, tools_dict: Dict, messages: List[Dict]
|
||||||
@@ -851,8 +905,22 @@ class LLMHandler(ABC):
|
|||||||
try:
|
try:
|
||||||
yield next(tool_handler_gen)
|
yield next(tool_handler_gen)
|
||||||
except StopIteration as e:
|
except StopIteration as e:
|
||||||
messages = e.value
|
messages, pending_actions = e.value
|
||||||
break
|
break
|
||||||
|
|
||||||
|
# If tools need approval or client execution, pause the loop
|
||||||
|
if pending_actions:
|
||||||
|
agent._pending_continuation = {
|
||||||
|
"messages": messages,
|
||||||
|
"pending_tool_calls": pending_actions,
|
||||||
|
"tools_dict": tools_dict,
|
||||||
|
}
|
||||||
|
yield {
|
||||||
|
"type": "tool_calls_pending",
|
||||||
|
"data": {"pending_tool_calls": pending_actions},
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
|
||||||
response = agent.llm.gen(
|
response = agent.llm.gen(
|
||||||
model=agent.model_id, messages=messages, tools=agent.tools
|
model=agent.model_id, messages=messages, tools=agent.tools
|
||||||
)
|
)
|
||||||
@@ -913,10 +981,23 @@ class LLMHandler(ABC):
|
|||||||
try:
|
try:
|
||||||
yield next(tool_handler_gen)
|
yield next(tool_handler_gen)
|
||||||
except StopIteration as e:
|
except StopIteration as e:
|
||||||
messages = e.value
|
messages, pending_actions = e.value
|
||||||
break
|
break
|
||||||
tool_calls = {}
|
tool_calls = {}
|
||||||
|
|
||||||
|
# If tools need approval or client execution, pause the loop
|
||||||
|
if pending_actions:
|
||||||
|
agent._pending_continuation = {
|
||||||
|
"messages": messages,
|
||||||
|
"pending_tool_calls": pending_actions,
|
||||||
|
"tools_dict": tools_dict,
|
||||||
|
}
|
||||||
|
yield {
|
||||||
|
"type": "tool_calls_pending",
|
||||||
|
"data": {"pending_tool_calls": pending_actions},
|
||||||
|
}
|
||||||
|
return
|
||||||
|
|
||||||
# Check if context limit was reached during tool execution
|
# Check if context limit was reached during tool execution
|
||||||
if hasattr(agent, 'context_limit_reached') and agent.context_limit_reached:
|
if hasattr(agent, 'context_limit_reached') and agent.context_limit_reached:
|
||||||
# Add system message warning about context limit
|
# Add system message warning about context limit
|
||||||
|
|||||||
@@ -67,18 +67,18 @@ class GoogleLLMHandler(LLMHandler):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def create_tool_message(self, tool_call: ToolCall, result: Any) -> Dict:
|
def create_tool_message(self, tool_call: ToolCall, result: Any) -> Dict:
|
||||||
"""Create Google-style tool message."""
|
"""Create a tool result message in the standard internal format."""
|
||||||
|
import json as _json
|
||||||
|
|
||||||
|
content = (
|
||||||
|
_json.dumps(result)
|
||||||
|
if not isinstance(result, str)
|
||||||
|
else result
|
||||||
|
)
|
||||||
return {
|
return {
|
||||||
"role": "model",
|
"role": "tool",
|
||||||
"content": [
|
"tool_call_id": tool_call.id,
|
||||||
{
|
"content": content,
|
||||||
"function_response": {
|
|
||||||
"name": tool_call.name,
|
|
||||||
"response": {"result": result},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
],
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def _iterate_stream(self, response: Any) -> Generator:
|
def _iterate_stream(self, response: Any) -> Generator:
|
||||||
|
|||||||
@@ -37,18 +37,18 @@ class OpenAILLMHandler(LLMHandler):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def create_tool_message(self, tool_call: ToolCall, result: Any) -> Dict:
|
def create_tool_message(self, tool_call: ToolCall, result: Any) -> Dict:
|
||||||
"""Create OpenAI-style tool message."""
|
"""Create a tool result message in the standard internal format."""
|
||||||
|
import json as _json
|
||||||
|
|
||||||
|
content = (
|
||||||
|
_json.dumps(result)
|
||||||
|
if not isinstance(result, str)
|
||||||
|
else result
|
||||||
|
)
|
||||||
return {
|
return {
|
||||||
"role": "tool",
|
"role": "tool",
|
||||||
"content": [
|
"tool_call_id": tool_call.id,
|
||||||
{
|
"content": content,
|
||||||
"function_response": {
|
|
||||||
"name": tool_call.name,
|
|
||||||
"response": {"result": result},
|
|
||||||
"call_id": tool_call.id,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
],
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def _iterate_stream(self, response: Any) -> Generator:
|
def _iterate_stream(self, response: Any) -> Generator:
|
||||||
|
|||||||
@@ -91,16 +91,52 @@ class OpenAILLM(BaseLLM):
|
|||||||
|
|
||||||
if role == "model":
|
if role == "model":
|
||||||
role = "assistant"
|
role = "assistant"
|
||||||
|
|
||||||
|
# Standard format: assistant message with tool_calls (passthrough)
|
||||||
|
tool_calls = message.get("tool_calls")
|
||||||
|
if tool_calls and role == "assistant":
|
||||||
|
cleaned_tcs = []
|
||||||
|
for tc in tool_calls:
|
||||||
|
func = tc.get("function", {})
|
||||||
|
args = func.get("arguments", "{}")
|
||||||
|
if isinstance(args, dict):
|
||||||
|
args = json.dumps(self._remove_null_values(args))
|
||||||
|
elif isinstance(args, str):
|
||||||
|
try:
|
||||||
|
parsed = json.loads(args)
|
||||||
|
args = json.dumps(self._remove_null_values(parsed))
|
||||||
|
except (json.JSONDecodeError, TypeError):
|
||||||
|
pass
|
||||||
|
cleaned_tcs.append({
|
||||||
|
"id": tc.get("id", ""),
|
||||||
|
"type": "function",
|
||||||
|
"function": {"name": func.get("name", ""), "arguments": args},
|
||||||
|
})
|
||||||
|
cleaned_messages.append({
|
||||||
|
"role": "assistant",
|
||||||
|
"content": None,
|
||||||
|
"tool_calls": cleaned_tcs,
|
||||||
|
})
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Standard format: tool message with tool_call_id (passthrough)
|
||||||
|
tool_call_id = message.get("tool_call_id")
|
||||||
|
if role == "tool" and tool_call_id is not None:
|
||||||
|
cleaned_messages.append({
|
||||||
|
"role": "tool",
|
||||||
|
"tool_call_id": tool_call_id,
|
||||||
|
"content": content if isinstance(content, str) else json.dumps(content),
|
||||||
|
})
|
||||||
|
continue
|
||||||
|
|
||||||
if role and content is not None:
|
if role and content is not None:
|
||||||
if isinstance(content, str):
|
if isinstance(content, str):
|
||||||
cleaned_messages.append({"role": role, "content": content})
|
cleaned_messages.append({"role": role, "content": content})
|
||||||
elif isinstance(content, list):
|
elif isinstance(content, list):
|
||||||
# Collect all content parts into a single message
|
|
||||||
content_parts = []
|
content_parts = []
|
||||||
|
|
||||||
for item in content:
|
for item in content:
|
||||||
|
# Legacy format support: function_call / function_response
|
||||||
if "function_call" in item:
|
if "function_call" in item:
|
||||||
# Function calls need their own message
|
|
||||||
args = item["function_call"]["args"]
|
args = item["function_call"]["args"]
|
||||||
if isinstance(args, str):
|
if isinstance(args, str):
|
||||||
try:
|
try:
|
||||||
@@ -116,28 +152,20 @@ class OpenAILLM(BaseLLM):
|
|||||||
"arguments": json.dumps(cleaned_args),
|
"arguments": json.dumps(cleaned_args),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
cleaned_messages.append(
|
cleaned_messages.append({
|
||||||
{
|
"role": "assistant",
|
||||||
"role": "assistant",
|
"content": None,
|
||||||
"content": None,
|
"tool_calls": [tool_call],
|
||||||
"tool_calls": [tool_call],
|
})
|
||||||
}
|
|
||||||
)
|
|
||||||
elif "function_response" in item:
|
elif "function_response" in item:
|
||||||
# Function responses need their own message
|
cleaned_messages.append({
|
||||||
cleaned_messages.append(
|
"role": "tool",
|
||||||
{
|
"tool_call_id": item["function_response"]["call_id"],
|
||||||
"role": "tool",
|
"content": json.dumps(
|
||||||
"tool_call_id": item["function_response"][
|
item["function_response"]["response"]["result"]
|
||||||
"call_id"
|
),
|
||||||
],
|
})
|
||||||
"content": json.dumps(
|
|
||||||
item["function_response"]["response"]["result"]
|
|
||||||
),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
elif isinstance(item, dict):
|
elif isinstance(item, dict):
|
||||||
# Collect content parts (text, images, files) into a single message
|
|
||||||
if "type" in item and item["type"] == "text" and "text" in item:
|
if "type" in item and item["type"] == "text" and "text" in item:
|
||||||
content_parts.append(item)
|
content_parts.append(item)
|
||||||
elif "type" in item and item["type"] == "file" and "file" in item:
|
elif "type" in item and item["type"] == "file" and "file" in item:
|
||||||
@@ -145,10 +173,7 @@ class OpenAILLM(BaseLLM):
|
|||||||
elif "type" in item and item["type"] == "image_url" and "image_url" in item:
|
elif "type" in item and item["type"] == "image_url" and "image_url" in item:
|
||||||
content_parts.append(item)
|
content_parts.append(item)
|
||||||
elif "text" in item and "type" not in item:
|
elif "text" in item and "type" not in item:
|
||||||
# Legacy format: {"text": "..."} without type
|
|
||||||
content_parts.append({"type": "text", "text": item["text"]})
|
content_parts.append({"type": "text", "text": item["text"]})
|
||||||
|
|
||||||
# Add the collected content parts as a single message
|
|
||||||
if content_parts:
|
if content_parts:
|
||||||
cleaned_messages.append({"role": role, "content": content_parts})
|
cleaned_messages.append({"role": role, "content": content_parts})
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -157,5 +157,21 @@ def _log_to_mongodb(
|
|||||||
user_logs_collection.insert_one(log_entry)
|
user_logs_collection.insert_one(log_entry)
|
||||||
logging.debug(f"Logged activity to MongoDB: {activity_id}")
|
logging.debug(f"Logged activity to MongoDB: {activity_id}")
|
||||||
|
|
||||||
|
from application.storage.db.dual_write import dual_write
|
||||||
|
from application.storage.db.repositories.stack_logs import StackLogsRepository
|
||||||
|
|
||||||
|
dual_write(
|
||||||
|
StackLogsRepository,
|
||||||
|
lambda repo, e=log_entry: repo.insert(
|
||||||
|
activity_id=e["id"],
|
||||||
|
endpoint=e.get("endpoint"),
|
||||||
|
level=e.get("level"),
|
||||||
|
user_id=e.get("user"),
|
||||||
|
api_key=e.get("api_key"),
|
||||||
|
query=e.get("query"),
|
||||||
|
stacks=e.get("stacks"),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Failed to log to MongoDB: {e}", exc_info=True)
|
logging.error(f"Failed to log to MongoDB: {e}", exc_info=True)
|
||||||
|
|||||||
4
application/parser/connectors/confluence/__init__.py
Normal file
4
application/parser/connectors/confluence/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
from .auth import ConfluenceAuth
|
||||||
|
from .loader import ConfluenceLoader
|
||||||
|
|
||||||
|
__all__ = ["ConfluenceAuth", "ConfluenceLoader"]
|
||||||
216
application/parser/connectors/confluence/auth.py
Normal file
216
application/parser/connectors/confluence/auth.py
Normal file
@@ -0,0 +1,216 @@
|
|||||||
|
import datetime
|
||||||
|
import logging
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
from urllib.parse import urlencode
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from application.core.settings import settings
|
||||||
|
from application.parser.connectors.base import BaseConnectorAuth
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ConfluenceAuth(BaseConnectorAuth):
|
||||||
|
|
||||||
|
SCOPES = [
|
||||||
|
"read:page:confluence",
|
||||||
|
"read:space:confluence",
|
||||||
|
"read:attachment:confluence",
|
||||||
|
"read:me",
|
||||||
|
"offline_access",
|
||||||
|
]
|
||||||
|
|
||||||
|
AUTH_URL = "https://auth.atlassian.com/authorize"
|
||||||
|
TOKEN_URL = "https://auth.atlassian.com/oauth/token"
|
||||||
|
RESOURCES_URL = "https://api.atlassian.com/oauth/token/accessible-resources"
|
||||||
|
ME_URL = "https://api.atlassian.com/me"
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.client_id = settings.CONFLUENCE_CLIENT_ID
|
||||||
|
self.client_secret = settings.CONFLUENCE_CLIENT_SECRET
|
||||||
|
self.redirect_uri = settings.CONNECTOR_REDIRECT_BASE_URI
|
||||||
|
|
||||||
|
if not self.client_id or not self.client_secret:
|
||||||
|
raise ValueError(
|
||||||
|
"Confluence OAuth credentials not configured. "
|
||||||
|
"Please set CONFLUENCE_CLIENT_ID and CONFLUENCE_CLIENT_SECRET in settings."
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_authorization_url(self, state: Optional[str] = None) -> str:
|
||||||
|
params = {
|
||||||
|
"audience": "api.atlassian.com",
|
||||||
|
"client_id": self.client_id,
|
||||||
|
"scope": " ".join(self.SCOPES),
|
||||||
|
"redirect_uri": self.redirect_uri,
|
||||||
|
"state": state,
|
||||||
|
"response_type": "code",
|
||||||
|
"prompt": "consent",
|
||||||
|
}
|
||||||
|
return f"{self.AUTH_URL}?{urlencode(params)}"
|
||||||
|
|
||||||
|
def exchange_code_for_tokens(self, authorization_code: str) -> Dict[str, Any]:
|
||||||
|
if not authorization_code:
|
||||||
|
raise ValueError("Authorization code is required")
|
||||||
|
|
||||||
|
response = requests.post(
|
||||||
|
self.TOKEN_URL,
|
||||||
|
json={
|
||||||
|
"grant_type": "authorization_code",
|
||||||
|
"client_id": self.client_id,
|
||||||
|
"client_secret": self.client_secret,
|
||||||
|
"code": authorization_code,
|
||||||
|
"redirect_uri": self.redirect_uri,
|
||||||
|
},
|
||||||
|
headers={"Content-Type": "application/json"},
|
||||||
|
timeout=30,
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
token_data = response.json()
|
||||||
|
|
||||||
|
access_token = token_data.get("access_token")
|
||||||
|
if not access_token:
|
||||||
|
raise ValueError("OAuth flow did not return an access token")
|
||||||
|
|
||||||
|
refresh_token = token_data.get("refresh_token")
|
||||||
|
if not refresh_token:
|
||||||
|
raise ValueError("OAuth flow did not return a refresh token")
|
||||||
|
|
||||||
|
expires_in = token_data.get("expires_in", 3600)
|
||||||
|
expiry = (
|
||||||
|
datetime.datetime.now(datetime.timezone.utc)
|
||||||
|
+ datetime.timedelta(seconds=expires_in)
|
||||||
|
).isoformat()
|
||||||
|
|
||||||
|
cloud_id = self._fetch_cloud_id(access_token)
|
||||||
|
user_info = self._fetch_user_info(access_token)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"access_token": access_token,
|
||||||
|
"refresh_token": refresh_token,
|
||||||
|
"token_uri": self.TOKEN_URL,
|
||||||
|
"scopes": self.SCOPES,
|
||||||
|
"expiry": expiry,
|
||||||
|
"cloud_id": cloud_id,
|
||||||
|
"user_info": {
|
||||||
|
"name": user_info.get("display_name", ""),
|
||||||
|
"email": user_info.get("email", ""),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
def refresh_access_token(self, refresh_token: str) -> Dict[str, Any]:
|
||||||
|
if not refresh_token:
|
||||||
|
raise ValueError("Refresh token is required")
|
||||||
|
|
||||||
|
response = requests.post(
|
||||||
|
self.TOKEN_URL,
|
||||||
|
json={
|
||||||
|
"grant_type": "refresh_token",
|
||||||
|
"client_id": self.client_id,
|
||||||
|
"client_secret": self.client_secret,
|
||||||
|
"refresh_token": refresh_token,
|
||||||
|
},
|
||||||
|
headers={"Content-Type": "application/json"},
|
||||||
|
timeout=30,
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
token_data = response.json()
|
||||||
|
|
||||||
|
access_token = token_data.get("access_token")
|
||||||
|
new_refresh_token = token_data.get("refresh_token", refresh_token)
|
||||||
|
|
||||||
|
expires_in = token_data.get("expires_in", 3600)
|
||||||
|
expiry = (
|
||||||
|
datetime.datetime.now(datetime.timezone.utc)
|
||||||
|
+ datetime.timedelta(seconds=expires_in)
|
||||||
|
).isoformat()
|
||||||
|
|
||||||
|
cloud_id = self._fetch_cloud_id(access_token)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"access_token": access_token,
|
||||||
|
"refresh_token": new_refresh_token,
|
||||||
|
"token_uri": self.TOKEN_URL,
|
||||||
|
"scopes": self.SCOPES,
|
||||||
|
"expiry": expiry,
|
||||||
|
"cloud_id": cloud_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
def is_token_expired(self, token_info: Dict[str, Any]) -> bool:
|
||||||
|
if not token_info:
|
||||||
|
return True
|
||||||
|
|
||||||
|
expiry = token_info.get("expiry")
|
||||||
|
if not expiry:
|
||||||
|
return bool(token_info.get("access_token"))
|
||||||
|
|
||||||
|
try:
|
||||||
|
expiry_dt = datetime.datetime.fromisoformat(expiry)
|
||||||
|
now = datetime.datetime.now(datetime.timezone.utc)
|
||||||
|
return now >= expiry_dt - datetime.timedelta(seconds=60)
|
||||||
|
except Exception:
|
||||||
|
return True
|
||||||
|
|
||||||
|
def get_token_info_from_session(self, session_token: str) -> Dict[str, Any]:
|
||||||
|
from application.core.mongo_db import MongoDB
|
||||||
|
from application.core.settings import settings as app_settings
|
||||||
|
|
||||||
|
mongo = MongoDB.get_client()
|
||||||
|
db = mongo[app_settings.MONGO_DB_NAME]
|
||||||
|
|
||||||
|
session = db["connector_sessions"].find_one({"session_token": session_token})
|
||||||
|
if not session:
|
||||||
|
raise ValueError(f"Invalid session token: {session_token}")
|
||||||
|
|
||||||
|
token_info = session.get("token_info")
|
||||||
|
if not token_info:
|
||||||
|
raise ValueError("Session missing token information")
|
||||||
|
|
||||||
|
required = ["access_token", "refresh_token", "cloud_id"]
|
||||||
|
missing = [f for f in required if not token_info.get(f)]
|
||||||
|
if missing:
|
||||||
|
raise ValueError(f"Missing required token fields: {missing}")
|
||||||
|
|
||||||
|
return token_info
|
||||||
|
|
||||||
|
def sanitize_token_info(
|
||||||
|
self, token_info: Dict[str, Any], **extra_fields
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
return super().sanitize_token_info(
|
||||||
|
token_info,
|
||||||
|
cloud_id=token_info.get("cloud_id"),
|
||||||
|
**extra_fields,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _fetch_cloud_id(self, access_token: str) -> str:
|
||||||
|
response = requests.get(
|
||||||
|
self.RESOURCES_URL,
|
||||||
|
headers={
|
||||||
|
"Authorization": f"Bearer {access_token}",
|
||||||
|
"Accept": "application/json",
|
||||||
|
},
|
||||||
|
timeout=30,
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
resources = response.json()
|
||||||
|
|
||||||
|
if not resources:
|
||||||
|
raise ValueError("No accessible Confluence sites found for this account")
|
||||||
|
|
||||||
|
return resources[0]["id"]
|
||||||
|
|
||||||
|
def _fetch_user_info(self, access_token: str) -> Dict[str, Any]:
|
||||||
|
try:
|
||||||
|
response = requests.get(
|
||||||
|
self.ME_URL,
|
||||||
|
headers={
|
||||||
|
"Authorization": f"Bearer {access_token}",
|
||||||
|
"Accept": "application/json",
|
||||||
|
},
|
||||||
|
timeout=30,
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
return response.json()
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("Could not fetch user info: %s", e)
|
||||||
|
return {}
|
||||||
416
application/parser/connectors/confluence/loader.py
Normal file
416
application/parser/connectors/confluence/loader.py
Normal file
@@ -0,0 +1,416 @@
|
|||||||
|
import functools
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from application.parser.connectors.base import BaseConnectorLoader
|
||||||
|
from application.parser.connectors.confluence.auth import ConfluenceAuth
|
||||||
|
from application.parser.schema.base import Document
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
API_V2 = "https://api.atlassian.com/ex/confluence/{cloud_id}/wiki/api/v2"
|
||||||
|
DOWNLOAD_BASE = "https://api.atlassian.com/ex/confluence/{cloud_id}/wiki"
|
||||||
|
|
||||||
|
SUPPORTED_ATTACHMENT_TYPES = {
|
||||||
|
"application/pdf": ".pdf",
|
||||||
|
"application/vnd.openxmlformats-officedocument.wordprocessingml.document": ".docx",
|
||||||
|
"application/vnd.openxmlformats-officedocument.presentationml.presentation": ".pptx",
|
||||||
|
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet": ".xlsx",
|
||||||
|
"application/msword": ".doc",
|
||||||
|
"application/vnd.ms-powerpoint": ".ppt",
|
||||||
|
"application/vnd.ms-excel": ".xls",
|
||||||
|
"text/plain": ".txt",
|
||||||
|
"text/csv": ".csv",
|
||||||
|
"text/html": ".html",
|
||||||
|
"text/markdown": ".md",
|
||||||
|
"application/json": ".json",
|
||||||
|
"application/epub+zip": ".epub",
|
||||||
|
"image/jpeg": ".jpg",
|
||||||
|
"image/png": ".png",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _retry_on_auth_failure(func):
|
||||||
|
@functools.wraps(func)
|
||||||
|
def wrapper(self, *args, **kwargs):
|
||||||
|
try:
|
||||||
|
return func(self, *args, **kwargs)
|
||||||
|
except requests.exceptions.HTTPError as e:
|
||||||
|
if e.response is not None and e.response.status_code in (401, 403):
|
||||||
|
logger.info(
|
||||||
|
"Auth failure in %s, refreshing token and retrying", func.__name__
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
new_token_info = self.auth.refresh_access_token(self.refresh_token)
|
||||||
|
self.access_token = new_token_info["access_token"]
|
||||||
|
self.refresh_token = new_token_info.get(
|
||||||
|
"refresh_token", self.refresh_token
|
||||||
|
)
|
||||||
|
self._persist_refreshed_tokens(new_token_info)
|
||||||
|
except Exception as refresh_err:
|
||||||
|
raise ValueError(
|
||||||
|
f"Authentication failed and could not be refreshed: {refresh_err}"
|
||||||
|
) from e
|
||||||
|
return func(self, *args, **kwargs)
|
||||||
|
raise
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
class ConfluenceLoader(BaseConnectorLoader):
|
||||||
|
|
||||||
|
def __init__(self, session_token: str):
|
||||||
|
self.auth = ConfluenceAuth()
|
||||||
|
self.session_token = session_token
|
||||||
|
|
||||||
|
token_info = self.auth.get_token_info_from_session(session_token)
|
||||||
|
self.access_token = token_info["access_token"]
|
||||||
|
self.refresh_token = token_info["refresh_token"]
|
||||||
|
self.cloud_id = token_info["cloud_id"]
|
||||||
|
|
||||||
|
self.base_url = API_V2.format(cloud_id=self.cloud_id)
|
||||||
|
self.download_base = DOWNLOAD_BASE.format(cloud_id=self.cloud_id)
|
||||||
|
self.next_page_token = None
|
||||||
|
|
||||||
|
def _headers(self) -> Dict[str, str]:
|
||||||
|
return {
|
||||||
|
"Authorization": f"Bearer {self.access_token}",
|
||||||
|
"Accept": "application/json",
|
||||||
|
}
|
||||||
|
|
||||||
|
def _persist_refreshed_tokens(self, token_info: Dict[str, Any]) -> None:
|
||||||
|
try:
|
||||||
|
from application.core.mongo_db import MongoDB
|
||||||
|
from application.core.settings import settings as app_settings
|
||||||
|
|
||||||
|
sanitized = self.auth.sanitize_token_info(token_info)
|
||||||
|
mongo = MongoDB.get_client()
|
||||||
|
db = mongo[app_settings.MONGO_DB_NAME]
|
||||||
|
db["connector_sessions"].update_one(
|
||||||
|
{"session_token": self.session_token},
|
||||||
|
{"$set": {"token_info": sanitized}},
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("Failed to persist refreshed tokens: %s", e)
|
||||||
|
|
||||||
|
@_retry_on_auth_failure
|
||||||
|
def load_data(self, inputs: Dict[str, Any]) -> List[Document]:
|
||||||
|
folder_id = inputs.get("folder_id")
|
||||||
|
file_ids = inputs.get("file_ids", [])
|
||||||
|
limit = inputs.get("limit", 100)
|
||||||
|
list_only = inputs.get("list_only", False)
|
||||||
|
page_token = inputs.get("page_token")
|
||||||
|
search_query = inputs.get("search_query")
|
||||||
|
self.next_page_token = None
|
||||||
|
|
||||||
|
if file_ids:
|
||||||
|
return self._load_pages_by_ids(file_ids, list_only, search_query)
|
||||||
|
|
||||||
|
if folder_id:
|
||||||
|
return self._list_pages_in_space(
|
||||||
|
folder_id, limit, list_only, page_token, search_query
|
||||||
|
)
|
||||||
|
|
||||||
|
return self._list_spaces(limit, page_token, search_query)
|
||||||
|
|
||||||
|
@_retry_on_auth_failure
|
||||||
|
def download_to_directory(self, local_dir: str, source_config: dict = None) -> dict:
|
||||||
|
config = source_config or getattr(self, "config", {})
|
||||||
|
file_ids = config.get("file_ids", [])
|
||||||
|
folder_ids = config.get("folder_ids", [])
|
||||||
|
files_downloaded = 0
|
||||||
|
|
||||||
|
os.makedirs(local_dir, exist_ok=True)
|
||||||
|
|
||||||
|
if isinstance(file_ids, str):
|
||||||
|
file_ids = [file_ids]
|
||||||
|
if isinstance(folder_ids, str):
|
||||||
|
folder_ids = [folder_ids]
|
||||||
|
|
||||||
|
for page_id in file_ids:
|
||||||
|
if self._download_page(page_id, local_dir):
|
||||||
|
files_downloaded += 1
|
||||||
|
files_downloaded += self._download_page_attachments(page_id, local_dir)
|
||||||
|
|
||||||
|
for space_id in folder_ids:
|
||||||
|
files_downloaded += self._download_space(space_id, local_dir)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"files_downloaded": files_downloaded,
|
||||||
|
"directory_path": local_dir,
|
||||||
|
"empty_result": files_downloaded == 0,
|
||||||
|
"source_type": "confluence",
|
||||||
|
"config_used": config,
|
||||||
|
}
|
||||||
|
|
||||||
|
def _list_spaces(
|
||||||
|
self, limit: int, cursor: Optional[str], search_query: Optional[str]
|
||||||
|
) -> List[Document]:
|
||||||
|
documents: List[Document] = []
|
||||||
|
params: Dict[str, Any] = {"limit": min(limit, 250)}
|
||||||
|
if cursor:
|
||||||
|
params["cursor"] = cursor
|
||||||
|
|
||||||
|
response = requests.get(
|
||||||
|
f"{self.base_url}/spaces",
|
||||||
|
headers=self._headers(),
|
||||||
|
params=params,
|
||||||
|
timeout=30,
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
for space in data.get("results", []):
|
||||||
|
name = space.get("name", "")
|
||||||
|
if search_query and search_query.lower() not in name.lower():
|
||||||
|
continue
|
||||||
|
|
||||||
|
documents.append(
|
||||||
|
Document(
|
||||||
|
text="",
|
||||||
|
doc_id=space["id"],
|
||||||
|
extra_info={
|
||||||
|
"file_name": name,
|
||||||
|
"mime_type": "folder",
|
||||||
|
"size": None,
|
||||||
|
"created_time": space.get("createdAt"),
|
||||||
|
"modified_time": None,
|
||||||
|
"source": "confluence",
|
||||||
|
"is_folder": True,
|
||||||
|
"space_key": space.get("key"),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
next_link = data.get("_links", {}).get("next")
|
||||||
|
self.next_page_token = self._extract_cursor(next_link)
|
||||||
|
return documents
|
||||||
|
|
||||||
|
def _list_pages_in_space(
|
||||||
|
self,
|
||||||
|
space_id: str,
|
||||||
|
limit: int,
|
||||||
|
list_only: bool,
|
||||||
|
cursor: Optional[str],
|
||||||
|
search_query: Optional[str],
|
||||||
|
) -> List[Document]:
|
||||||
|
documents: List[Document] = []
|
||||||
|
params: Dict[str, Any] = {"limit": min(limit, 250)}
|
||||||
|
if cursor:
|
||||||
|
params["cursor"] = cursor
|
||||||
|
|
||||||
|
response = requests.get(
|
||||||
|
f"{self.base_url}/spaces/{space_id}/pages",
|
||||||
|
headers=self._headers(),
|
||||||
|
params=params,
|
||||||
|
timeout=30,
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
for page in data.get("results", []):
|
||||||
|
title = page.get("title", "")
|
||||||
|
if search_query and search_query.lower() not in title.lower():
|
||||||
|
continue
|
||||||
|
|
||||||
|
doc = self._page_to_document(
|
||||||
|
page, load_content=not list_only, space_id=space_id
|
||||||
|
)
|
||||||
|
if doc:
|
||||||
|
documents.append(doc)
|
||||||
|
|
||||||
|
next_link = data.get("_links", {}).get("next")
|
||||||
|
self.next_page_token = self._extract_cursor(next_link)
|
||||||
|
return documents
|
||||||
|
|
||||||
|
def _load_pages_by_ids(
|
||||||
|
self, page_ids: List[str], list_only: bool, search_query: Optional[str]
|
||||||
|
) -> List[Document]:
|
||||||
|
documents: List[Document] = []
|
||||||
|
for page_id in page_ids:
|
||||||
|
try:
|
||||||
|
params: Dict[str, str] = {}
|
||||||
|
if not list_only:
|
||||||
|
params["body-format"] = "storage"
|
||||||
|
|
||||||
|
response = requests.get(
|
||||||
|
f"{self.base_url}/pages/{page_id}",
|
||||||
|
headers=self._headers(),
|
||||||
|
params=params,
|
||||||
|
timeout=30,
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
page = response.json()
|
||||||
|
|
||||||
|
title = page.get("title", "")
|
||||||
|
if search_query and search_query.lower() not in title.lower():
|
||||||
|
continue
|
||||||
|
|
||||||
|
doc = self._page_to_document(page, load_content=not list_only)
|
||||||
|
if doc:
|
||||||
|
documents.append(doc)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Error loading page %s: %s", page_id, e)
|
||||||
|
return documents
|
||||||
|
|
||||||
|
def _page_to_document(
|
||||||
|
self,
|
||||||
|
page: Dict[str, Any],
|
||||||
|
load_content: bool = False,
|
||||||
|
space_id: Optional[str] = None,
|
||||||
|
) -> Optional[Document]:
|
||||||
|
page_id = page.get("id")
|
||||||
|
title = page.get("title", "Unknown")
|
||||||
|
version = page.get("version", {})
|
||||||
|
modified_time = version.get("createdAt") if isinstance(version, dict) else None
|
||||||
|
created_time = page.get("createdAt")
|
||||||
|
resolved_space_id = space_id or page.get("spaceId")
|
||||||
|
|
||||||
|
text = ""
|
||||||
|
if load_content:
|
||||||
|
body = page.get("body", {})
|
||||||
|
storage = body.get("storage", {}) if isinstance(body, dict) else {}
|
||||||
|
text = storage.get("value", "") if isinstance(storage, dict) else ""
|
||||||
|
|
||||||
|
return Document(
|
||||||
|
text=text,
|
||||||
|
doc_id=str(page_id),
|
||||||
|
extra_info={
|
||||||
|
"file_name": title,
|
||||||
|
"mime_type": "text/html",
|
||||||
|
"size": len(text) if text else None,
|
||||||
|
"created_time": created_time,
|
||||||
|
"modified_time": modified_time,
|
||||||
|
"source": "confluence",
|
||||||
|
"is_folder": False,
|
||||||
|
"page_id": str(page_id),
|
||||||
|
"space_id": resolved_space_id,
|
||||||
|
"cloud_id": self.cloud_id,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
def _download_page(self, page_id: str, local_dir: str) -> bool:
|
||||||
|
try:
|
||||||
|
response = requests.get(
|
||||||
|
f"{self.base_url}/pages/{page_id}",
|
||||||
|
headers=self._headers(),
|
||||||
|
params={"body-format": "storage"},
|
||||||
|
timeout=30,
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
page = response.json()
|
||||||
|
|
||||||
|
title = page.get("title", page_id)
|
||||||
|
safe_name = "".join(c if c.isalnum() or c in " -_" else "_" for c in title)
|
||||||
|
body = page.get("body", {}).get("storage", {}).get("value", "")
|
||||||
|
|
||||||
|
file_path = os.path.join(local_dir, f"{safe_name}.html")
|
||||||
|
with open(file_path, "w", encoding="utf-8") as f:
|
||||||
|
f.write(body)
|
||||||
|
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Error downloading page %s: %s", page_id, e)
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _download_page_attachments(self, page_id: str, local_dir: str) -> int:
|
||||||
|
downloaded = 0
|
||||||
|
try:
|
||||||
|
cursor = None
|
||||||
|
while True:
|
||||||
|
params: Dict[str, Any] = {"limit": 100}
|
||||||
|
if cursor:
|
||||||
|
params["cursor"] = cursor
|
||||||
|
|
||||||
|
response = requests.get(
|
||||||
|
f"{self.base_url}/pages/{page_id}/attachments",
|
||||||
|
headers=self._headers(),
|
||||||
|
params=params,
|
||||||
|
timeout=30,
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
for att in data.get("results", []):
|
||||||
|
media_type = att.get("mediaType", "")
|
||||||
|
if media_type not in SUPPORTED_ATTACHMENT_TYPES:
|
||||||
|
continue
|
||||||
|
|
||||||
|
download_link = att.get("_links", {}).get("download")
|
||||||
|
if not download_link:
|
||||||
|
continue
|
||||||
|
|
||||||
|
raw_name = att.get("title", att.get("id", "attachment"))
|
||||||
|
file_name = "".join(
|
||||||
|
c if c.isalnum() or c in " -_." else "_"
|
||||||
|
for c in os.path.basename(raw_name)
|
||||||
|
) or "attachment"
|
||||||
|
file_path = os.path.join(local_dir, file_name)
|
||||||
|
|
||||||
|
url = f"{self.download_base}{download_link}"
|
||||||
|
file_resp = requests.get(
|
||||||
|
url, headers=self._headers(), timeout=60, stream=True
|
||||||
|
)
|
||||||
|
file_resp.raise_for_status()
|
||||||
|
|
||||||
|
with open(file_path, "wb") as f:
|
||||||
|
for chunk in file_resp.iter_content(chunk_size=8192):
|
||||||
|
f.write(chunk)
|
||||||
|
|
||||||
|
downloaded += 1
|
||||||
|
|
||||||
|
next_link = data.get("_links", {}).get("next")
|
||||||
|
cursor = self._extract_cursor(next_link)
|
||||||
|
if not cursor:
|
||||||
|
break
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Error downloading attachments for page %s: %s", page_id, e)
|
||||||
|
return downloaded
|
||||||
|
|
||||||
|
def _download_space(self, space_id: str, local_dir: str) -> int:
|
||||||
|
downloaded = 0
|
||||||
|
cursor = None
|
||||||
|
while True:
|
||||||
|
params: Dict[str, Any] = {"limit": 250}
|
||||||
|
if cursor:
|
||||||
|
params["cursor"] = cursor
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = requests.get(
|
||||||
|
f"{self.base_url}/spaces/{space_id}/pages",
|
||||||
|
headers=self._headers(),
|
||||||
|
params=params,
|
||||||
|
timeout=30,
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
data = response.json()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Error listing pages in space %s: %s", space_id, e)
|
||||||
|
break
|
||||||
|
|
||||||
|
for page in data.get("results", []):
|
||||||
|
page_id = page.get("id")
|
||||||
|
if self._download_page(str(page_id), local_dir):
|
||||||
|
downloaded += 1
|
||||||
|
downloaded += self._download_page_attachments(str(page_id), local_dir)
|
||||||
|
|
||||||
|
next_link = data.get("_links", {}).get("next")
|
||||||
|
cursor = self._extract_cursor(next_link)
|
||||||
|
if not cursor:
|
||||||
|
break
|
||||||
|
|
||||||
|
return downloaded
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _extract_cursor(next_link: Optional[str]) -> Optional[str]:
|
||||||
|
if not next_link:
|
||||||
|
return None
|
||||||
|
from urllib.parse import parse_qs, urlparse
|
||||||
|
|
||||||
|
parsed = urlparse(next_link)
|
||||||
|
cursors = parse_qs(parsed.query).get("cursor")
|
||||||
|
return cursors[0] if cursors else None
|
||||||
@@ -1,5 +1,7 @@
|
|||||||
from application.parser.connectors.google_drive.loader import GoogleDriveLoader
|
from application.parser.connectors.confluence.auth import ConfluenceAuth
|
||||||
|
from application.parser.connectors.confluence.loader import ConfluenceLoader
|
||||||
from application.parser.connectors.google_drive.auth import GoogleDriveAuth
|
from application.parser.connectors.google_drive.auth import GoogleDriveAuth
|
||||||
|
from application.parser.connectors.google_drive.loader import GoogleDriveLoader
|
||||||
from application.parser.connectors.share_point.auth import SharePointAuth
|
from application.parser.connectors.share_point.auth import SharePointAuth
|
||||||
from application.parser.connectors.share_point.loader import SharePointLoader
|
from application.parser.connectors.share_point.loader import SharePointLoader
|
||||||
|
|
||||||
@@ -13,11 +15,13 @@ class ConnectorCreator:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
connectors = {
|
connectors = {
|
||||||
|
"confluence": ConfluenceLoader,
|
||||||
"google_drive": GoogleDriveLoader,
|
"google_drive": GoogleDriveLoader,
|
||||||
"share_point": SharePointLoader,
|
"share_point": SharePointLoader,
|
||||||
}
|
}
|
||||||
|
|
||||||
auth_providers = {
|
auth_providers = {
|
||||||
|
"confluence": ConfluenceAuth,
|
||||||
"google_drive": GoogleDriveAuth,
|
"google_drive": GoogleDriveAuth,
|
||||||
"share_point": SharePointAuth,
|
"share_point": SharePointAuth,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -205,7 +205,7 @@ class SharePointLoader(BaseConnectorLoader):
|
|||||||
try:
|
try:
|
||||||
url = self._get_item_url(file_id)
|
url = self._get_item_url(file_id)
|
||||||
params = {'$select': 'id,name,file,createdDateTime,lastModifiedDateTime,size'}
|
params = {'$select': 'id,name,file,createdDateTime,lastModifiedDateTime,size'}
|
||||||
response = requests.get(url, headers=self._get_headers(), params=params)
|
response = requests.get(url, headers=self._get_headers(), params=params, timeout=100)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
|
|
||||||
file_metadata = response.json()
|
file_metadata = response.json()
|
||||||
@@ -236,9 +236,9 @@ class SharePointLoader(BaseConnectorLoader):
|
|||||||
search_url = f"{self.GRAPH_API_BASE}/drives/{drive_id}/root/search(q='{encoded_query}')"
|
search_url = f"{self.GRAPH_API_BASE}/drives/{drive_id}/root/search(q='{encoded_query}')"
|
||||||
else:
|
else:
|
||||||
search_url = f"{self.GRAPH_API_BASE}/me/drive/search(q='{encoded_query}')"
|
search_url = f"{self.GRAPH_API_BASE}/me/drive/search(q='{encoded_query}')"
|
||||||
response = requests.get(search_url, headers=self._get_headers(), params=params)
|
response = requests.get(search_url, headers=self._get_headers(), params=params, timeout=100)
|
||||||
else:
|
else:
|
||||||
response = requests.get(url, headers=self._get_headers(), params=params)
|
response = requests.get(url, headers=self._get_headers(), params=params, timeout=100)
|
||||||
|
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
|
|
||||||
@@ -307,7 +307,8 @@ class SharePointLoader(BaseConnectorLoader):
|
|||||||
response = requests.get(
|
response = requests.get(
|
||||||
f"{self.GRAPH_API_BASE}/me/drive",
|
f"{self.GRAPH_API_BASE}/me/drive",
|
||||||
headers=self._get_headers(),
|
headers=self._get_headers(),
|
||||||
params={'$select': 'webUrl'}
|
params={'$select': 'webUrl'},
|
||||||
|
timeout=100,
|
||||||
)
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
return response.json().get('webUrl')
|
return response.json().get('webUrl')
|
||||||
@@ -352,7 +353,7 @@ class SharePointLoader(BaseConnectorLoader):
|
|||||||
|
|
||||||
headers = self._get_headers()
|
headers = self._get_headers()
|
||||||
headers["Content-Type"] = "application/json"
|
headers["Content-Type"] = "application/json"
|
||||||
response = requests.post(url, headers=headers, json=body)
|
response = requests.post(url, headers=headers, json=body, timeout=100)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
results = response.json()
|
results = response.json()
|
||||||
|
|
||||||
@@ -472,7 +473,7 @@ class SharePointLoader(BaseConnectorLoader):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
url = f"{self._get_item_url(file_id)}/content"
|
url = f"{self._get_item_url(file_id)}/content"
|
||||||
response = requests.get(url, headers=self._get_headers())
|
response = requests.get(url, headers=self._get_headers(), timeout=100)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -491,7 +492,7 @@ class SharePointLoader(BaseConnectorLoader):
|
|||||||
try:
|
try:
|
||||||
url = self._get_item_url(file_id)
|
url = self._get_item_url(file_id)
|
||||||
params = {'$select': 'id,name,file'}
|
params = {'$select': 'id,name,file'}
|
||||||
response = requests.get(url, headers=self._get_headers(), params=params)
|
response = requests.get(url, headers=self._get_headers(), params=params, timeout=100)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
|
|
||||||
metadata = response.json()
|
metadata = response.json()
|
||||||
@@ -507,7 +508,7 @@ class SharePointLoader(BaseConnectorLoader):
|
|||||||
full_path = os.path.join(local_dir, file_name)
|
full_path = os.path.join(local_dir, file_name)
|
||||||
|
|
||||||
download_url = f"{self._get_item_url(file_id)}/content"
|
download_url = f"{self._get_item_url(file_id)}/content"
|
||||||
download_response = requests.get(download_url, headers=self._get_headers())
|
download_response = requests.get(download_url, headers=self._get_headers(), timeout=100)
|
||||||
download_response.raise_for_status()
|
download_response.raise_for_status()
|
||||||
|
|
||||||
with open(full_path, 'wb') as f:
|
with open(full_path, 'wb') as f:
|
||||||
@@ -527,7 +528,7 @@ class SharePointLoader(BaseConnectorLoader):
|
|||||||
params = {'$top': 1000}
|
params = {'$top': 1000}
|
||||||
|
|
||||||
while url:
|
while url:
|
||||||
response = requests.get(url, headers=self._get_headers(), params=params)
|
response = requests.get(url, headers=self._get_headers(), params=params, timeout=100)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
|
|
||||||
results = response.json()
|
results = response.json()
|
||||||
@@ -609,7 +610,7 @@ class SharePointLoader(BaseConnectorLoader):
|
|||||||
try:
|
try:
|
||||||
url = self._get_item_url(folder_id)
|
url = self._get_item_url(folder_id)
|
||||||
params = {'$select': 'id,name'}
|
params = {'$select': 'id,name'}
|
||||||
response = requests.get(url, headers=self._get_headers(), params=params)
|
response = requests.get(url, headers=self._get_headers(), params=params, timeout=100)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
|
|
||||||
folder_metadata = response.json()
|
folder_metadata = response.json()
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ class PDFParser(BaseParser):
|
|||||||
# alternatively you can use local vision capable LLM
|
# alternatively you can use local vision capable LLM
|
||||||
with open(file, "rb") as file_loaded:
|
with open(file, "rb") as file_loaded:
|
||||||
files = {'file': file_loaded}
|
files = {'file': file_loaded}
|
||||||
response = requests.post(doc2md_service, files=files)
|
response = requests.post(doc2md_service, files=files, timeout=100)
|
||||||
data = response.json()["markdown"]
|
data = response.json()["markdown"]
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|||||||
@@ -19,25 +19,10 @@ class EpubParser(BaseParser):
|
|||||||
def parse_file(self, file: Path, errors: str = "ignore") -> str:
|
def parse_file(self, file: Path, errors: str = "ignore") -> str:
|
||||||
"""Parse file."""
|
"""Parse file."""
|
||||||
try:
|
try:
|
||||||
import ebooklib
|
from fast_ebook import epub
|
||||||
from ebooklib import epub
|
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ValueError("`EbookLib` is required to read Epub files.")
|
raise ValueError("`fast-ebook` is required to read Epub files.")
|
||||||
try:
|
|
||||||
import html2text
|
|
||||||
except ImportError:
|
|
||||||
raise ValueError("`html2text` is required to parse Epub files.")
|
|
||||||
|
|
||||||
text_list = []
|
book = epub.read_epub(file)
|
||||||
book = epub.read_epub(file, options={"ignore_ncx": True})
|
text = book.to_markdown()
|
||||||
|
|
||||||
# Iterate through all chapters.
|
|
||||||
for item in book.get_items():
|
|
||||||
# Chapters are typically located in epub documents items.
|
|
||||||
if item.get_type() == ebooklib.ITEM_DOCUMENT:
|
|
||||||
text_list.append(
|
|
||||||
html2text.html2text(item.get_content().decode("utf-8"))
|
|
||||||
)
|
|
||||||
|
|
||||||
text = "\n".join(text_list)
|
|
||||||
return text
|
return text
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ class ImageParser(BaseParser):
|
|||||||
# alternatively you can use local vision capable LLM
|
# alternatively you can use local vision capable LLM
|
||||||
with open(file, "rb") as file_loaded:
|
with open(file, "rb") as file_loaded:
|
||||||
files = {'file': file_loaded}
|
files = {'file': file_loaded}
|
||||||
response = requests.post(doc2md_service, files=files)
|
response = requests.post(doc2md_service, files=files, timeout=100)
|
||||||
data = response.json()["markdown"]
|
data = response.json()["markdown"]
|
||||||
else:
|
else:
|
||||||
data = ""
|
data = ""
|
||||||
|
|||||||
@@ -77,7 +77,7 @@ class GitHubLoader(BaseRemote):
|
|||||||
def _make_request(self, url: str, max_retries: int = 3) -> requests.Response:
|
def _make_request(self, url: str, max_retries: int = 3) -> requests.Response:
|
||||||
"""Make a request with retry logic for rate limiting"""
|
"""Make a request with retry logic for rate limiting"""
|
||||||
for attempt in range(max_retries):
|
for attempt in range(max_retries):
|
||||||
response = requests.get(url, headers=self.headers)
|
response = requests.get(url, headers=self.headers, timeout=100)
|
||||||
|
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
return response
|
return response
|
||||||
|
|||||||
@@ -1,9 +1,10 @@
|
|||||||
anthropic==0.75.0
|
alembic>=1.13,<2
|
||||||
boto3==1.42.17
|
anthropic==0.88.0
|
||||||
|
boto3==1.42.83
|
||||||
beautifulsoup4==4.14.3
|
beautifulsoup4==4.14.3
|
||||||
cel-python==0.5.0
|
cel-python==0.5.0
|
||||||
celery==5.6.0
|
celery==5.6.3
|
||||||
cryptography==46.0.3
|
cryptography==46.0.7
|
||||||
dataclasses-json==0.6.7
|
dataclasses-json==0.6.7
|
||||||
defusedxml==0.7.1
|
defusedxml==0.7.1
|
||||||
docling>=2.16.0
|
docling>=2.16.0
|
||||||
@@ -11,89 +12,84 @@ rapidocr>=1.4.0
|
|||||||
onnxruntime>=1.19.0
|
onnxruntime>=1.19.0
|
||||||
docx2txt==0.9
|
docx2txt==0.9
|
||||||
ddgs>=8.0.0
|
ddgs>=8.0.0
|
||||||
ebooklib==0.20
|
fast-ebook
|
||||||
escodegen==1.0.11
|
elevenlabs==2.41.0
|
||||||
esprima==4.0.1
|
Flask==3.1.3
|
||||||
esutils==1.0.1
|
|
||||||
elevenlabs==2.27.0
|
|
||||||
Flask==3.1.2
|
|
||||||
faiss-cpu==1.13.2
|
faiss-cpu==1.13.2
|
||||||
fastmcp==2.14.1
|
fastmcp==3.2.0
|
||||||
flask-restx==1.3.2
|
flask-restx==1.3.2
|
||||||
google-genai==1.54.0
|
google-genai==1.69.0
|
||||||
google-api-python-client==2.187.0
|
google-api-python-client==2.193.0
|
||||||
google-auth-httplib2==0.3.0
|
google-auth-httplib2==0.3.1
|
||||||
google-auth-oauthlib==1.2.3
|
google-auth-oauthlib==1.3.1
|
||||||
gTTS==2.5.4
|
gTTS==2.5.4
|
||||||
gunicorn==23.0.0
|
gunicorn==25.3.0
|
||||||
html2text==2025.4.15
|
|
||||||
javalang==0.13.0
|
|
||||||
jinja2==3.1.6
|
jinja2==3.1.6
|
||||||
jiter==0.12.0
|
jiter==0.13.0
|
||||||
jmespath==1.0.1
|
jmespath==1.1.0
|
||||||
joblib==1.5.3
|
joblib==1.5.3
|
||||||
jsonpatch==1.33
|
jsonpatch==1.33
|
||||||
jsonpointer==3.0.0
|
jsonpointer==3.1.1
|
||||||
kombu==5.6.1
|
kombu==5.6.2
|
||||||
langchain==1.2.0
|
langchain==1.2.3
|
||||||
langchain-community==0.4.1
|
langchain-community==0.4.1
|
||||||
langchain-core==1.2.5
|
langchain-core==1.2.29
|
||||||
langchain-openai==1.1.6
|
langchain-openai==1.1.12
|
||||||
langchain-text-splitters==1.1.0
|
langchain-text-splitters==1.1.1
|
||||||
langsmith==0.5.1
|
langsmith==0.7.31
|
||||||
lazy-object-proxy==1.12.0
|
lazy-object-proxy==1.12.0
|
||||||
lxml==6.0.2
|
lxml==6.0.2
|
||||||
markupsafe==3.0.3
|
markupsafe==3.0.3
|
||||||
marshmallow>=3.18.0,<5.0.0
|
marshmallow>=3.18.0,<5.0.0
|
||||||
mpmath==1.3.0
|
mpmath==1.3.0
|
||||||
multidict==6.7.0
|
multidict==6.7.1
|
||||||
msal==1.34.0
|
msal==1.35.1
|
||||||
mypy-extensions==1.1.0
|
mypy-extensions==1.1.0
|
||||||
networkx==3.6.1
|
networkx==3.6.1
|
||||||
numpy==2.4.0
|
numpy==2.4.4
|
||||||
openai==2.14.0
|
openai==2.30.0
|
||||||
openapi3-parser==1.1.22
|
openapi3-parser==1.1.22
|
||||||
orjson==3.11.5
|
orjson==3.11.7
|
||||||
packaging==24.2
|
packaging==26.0
|
||||||
pandas==2.3.3
|
pandas==3.0.2
|
||||||
openpyxl==3.1.5
|
openpyxl==3.1.5
|
||||||
pathable==0.4.4
|
pathable==0.5.0
|
||||||
pdf2image>=1.17.0
|
pdf2image>=1.17.0
|
||||||
pillow
|
pillow
|
||||||
portalocker>=2.7.0,<3.0.0
|
portalocker>=2.7.0,<4.0.0
|
||||||
prance==25.4.8.0
|
|
||||||
prompt-toolkit==3.0.52
|
prompt-toolkit==3.0.52
|
||||||
protobuf==6.33.2
|
protobuf==7.34.1
|
||||||
psycopg2-binary==2.9.11
|
psycopg[binary,pool]>=3.1,<4
|
||||||
py==1.11.0
|
py==1.11.0
|
||||||
pydantic
|
pydantic
|
||||||
pydantic-core
|
pydantic-core
|
||||||
pydantic-settings
|
pydantic-settings
|
||||||
pymongo==4.15.5
|
pymongo==4.16.0
|
||||||
pypdf==6.5.0
|
pypdf==6.9.2
|
||||||
python-dateutil==2.9.0.post0
|
python-dateutil==2.9.0.post0
|
||||||
python-dotenv
|
python-dotenv
|
||||||
python-jose==3.5.0
|
python-jose==3.5.0
|
||||||
python-pptx==1.0.2
|
python-pptx==1.0.2
|
||||||
redis==7.1.0
|
redis==7.4.0
|
||||||
referencing>=0.28.0,<0.38.0
|
referencing>=0.28.0,<0.38.0
|
||||||
regex==2025.11.3
|
regex==2026.4.4
|
||||||
requests==2.32.5
|
requests==2.33.1
|
||||||
retry==0.9.2
|
retry==0.9.2
|
||||||
sentence-transformers==5.2.0
|
sentence-transformers==5.3.0
|
||||||
|
sqlalchemy>=2.0,<3
|
||||||
tiktoken==0.12.0
|
tiktoken==0.12.0
|
||||||
tokenizers==0.22.1
|
tokenizers==0.22.2
|
||||||
torch==2.9.1
|
torch==2.11.0
|
||||||
tqdm==4.67.1
|
tqdm==4.67.3
|
||||||
transformers==4.57.3
|
transformers==5.4.0
|
||||||
typing-extensions==4.15.0
|
typing-extensions==4.15.0
|
||||||
typing-inspect==0.9.0
|
typing-inspect==0.9.0
|
||||||
tzdata==2025.3
|
tzdata==2026.1
|
||||||
urllib3==2.6.3
|
urllib3==2.6.3
|
||||||
vine==5.1.0
|
vine==5.1.0
|
||||||
wcwidth==0.2.14
|
wcwidth==0.6.0
|
||||||
werkzeug>=3.1.0
|
werkzeug>=3.1.0
|
||||||
yarl==1.22.0
|
yarl==1.23.0
|
||||||
markdownify==1.2.2
|
markdownify==1.2.2
|
||||||
tldextract==5.3.0
|
tldextract==5.3.1
|
||||||
websockets==15.0.1
|
websockets==16.0
|
||||||
10
application/storage/db/__init__.py
Normal file
10
application/storage/db/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
"""PostgreSQL storage layer for user-level data.
|
||||||
|
|
||||||
|
This package holds the SQLAlchemy Core engine, metadata, repositories, and
|
||||||
|
migration infrastructure for the user-data Postgres database. It is separate
|
||||||
|
from ``application/vectorstore/pgvector.py`` — the two may point at the same
|
||||||
|
cluster or at different clusters depending on operator configuration.
|
||||||
|
|
||||||
|
Repository modules are added in later phases
|
||||||
|
as individual collections are ported.
|
||||||
|
"""
|
||||||
39
application/storage/db/base_repository.py
Normal file
39
application/storage/db/base_repository.py
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
"""Common helpers shared by all repositories.
|
||||||
|
|
||||||
|
Repositories are thin wrappers around SQLAlchemy Core query construction.
|
||||||
|
They take a ``Connection`` on call and return plain ``dict`` rows during the
|
||||||
|
Mongo→Postgres cutover so that call sites don't have to change shape. Once
|
||||||
|
cutover is complete, a follow-up phase may migrate repo return types to
|
||||||
|
Pydantic DTOs (tracked in the migration plan as a post-migration item).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Any, Mapping
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
|
||||||
|
def row_to_dict(row: Any) -> dict:
|
||||||
|
"""Convert a SQLAlchemy ``Row`` to a plain dict with Mongo-compatible ids.
|
||||||
|
|
||||||
|
During the migration window, API responses and downstream code still
|
||||||
|
expect a string ``_id`` field (matching the Mongo shape). This helper
|
||||||
|
normalizes UUID columns to strings and emits both ``id`` and ``_id`` so
|
||||||
|
existing serializers keep working unchanged.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
row: A SQLAlchemy ``Row`` object, or ``None``.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A plain dict, or an empty dict if ``row`` is ``None``.
|
||||||
|
"""
|
||||||
|
if row is None:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
# Row has a ``._mapping`` attribute exposing a MappingProxy view.
|
||||||
|
mapping: Mapping[str, Any] = row._mapping # type: ignore[attr-defined]
|
||||||
|
out = dict(mapping)
|
||||||
|
|
||||||
|
if "id" in out and out["id"] is not None:
|
||||||
|
out["id"] = str(out["id"]) if isinstance(out["id"], UUID) else out["id"]
|
||||||
|
out["_id"] = out["id"]
|
||||||
|
|
||||||
|
return out
|
||||||
67
application/storage/db/dual_write.py
Normal file
67
application/storage/db/dual_write.py
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
"""Best-effort Postgres dual-write helper used during the MongoDB→Postgres
|
||||||
|
migration.
|
||||||
|
|
||||||
|
The helper:
|
||||||
|
|
||||||
|
* Returns immediately if ``settings.USE_POSTGRES`` is off, so default-off
|
||||||
|
call sites add literally zero work.
|
||||||
|
* Opens a transactional connection from the user-data SQLAlchemy engine.
|
||||||
|
* Instantiates the caller's repository class on that connection.
|
||||||
|
* Runs the caller's operation.
|
||||||
|
* Swallows and logs any exception. **Mongo remains the source of truth
|
||||||
|
during the dual-write window** — a Postgres-side failure must never
|
||||||
|
break a user-facing request. Drift that builds up from swallowed
|
||||||
|
failures is caught separately by re-running the backfill script.
|
||||||
|
|
||||||
|
Call sites look like::
|
||||||
|
|
||||||
|
users_collection.update_one(..., {"$addToSet": {...}}) # Mongo write, unchanged
|
||||||
|
dual_write(UsersRepository, lambda r: r.add_pinned(uid, aid)) # Postgres mirror
|
||||||
|
|
||||||
|
A single parameterised helper rather than one function per collection
|
||||||
|
means a new collection just needs its repository class — no new helper
|
||||||
|
function, no new feature flag. The whole helper is deleted at Phase 5
|
||||||
|
when the migration is complete.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Callable, TypeVar
|
||||||
|
|
||||||
|
from application.core.settings import settings
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_Repo = TypeVar("_Repo")
|
||||||
|
|
||||||
|
|
||||||
|
def dual_write(repo_cls: type[_Repo], fn: Callable[[_Repo], None]) -> None:
|
||||||
|
"""Mirror a Mongo write into Postgres via ``repo_cls``, best-effort.
|
||||||
|
|
||||||
|
No-op when ``settings.USE_POSTGRES`` is false. Any exception
|
||||||
|
(connection pool exhaustion, migration drift, SQL error) is logged
|
||||||
|
and swallowed so the caller's primary Mongo write remains the source
|
||||||
|
of truth.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
repo_cls: The repository class to instantiate (e.g. ``UsersRepository``).
|
||||||
|
fn: A callable that takes the instantiated repository and performs
|
||||||
|
the desired write.
|
||||||
|
"""
|
||||||
|
if not settings.USE_POSTGRES:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Lazy import so modules that import dual_write don't pay the
|
||||||
|
# SQLAlchemy import cost when the flag is off.
|
||||||
|
from application.storage.db.engine import get_engine
|
||||||
|
|
||||||
|
with get_engine().begin() as conn:
|
||||||
|
fn(repo_cls(conn))
|
||||||
|
except Exception:
|
||||||
|
logger.warning(
|
||||||
|
"Postgres dual-write failed for %s — Mongo write already committed",
|
||||||
|
repo_cls.__name__,
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
73
application/storage/db/engine.py
Normal file
73
application/storage/db/engine.py
Normal file
@@ -0,0 +1,73 @@
|
|||||||
|
"""SQLAlchemy Core engine factory for the user-data Postgres database.
|
||||||
|
|
||||||
|
The engine is lazily constructed on first use and cached as a module-level
|
||||||
|
singleton. Repositories and the Alembic env module both obtain connections
|
||||||
|
through this factory, so pool tuning lives in one place.
|
||||||
|
|
||||||
|
``POSTGRES_URI`` can be written in any of the common Postgres URI forms::
|
||||||
|
|
||||||
|
postgres://user:pass@host:5432/docsgpt
|
||||||
|
postgresql://user:pass@host:5432/docsgpt
|
||||||
|
|
||||||
|
Both are accepted and normalized internally to the psycopg3 dialect
|
||||||
|
(``postgresql+psycopg://``) by ``application.core.settings``. Operators
|
||||||
|
don't need to know about SQLAlchemy dialect prefixes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from sqlalchemy import Engine, create_engine
|
||||||
|
|
||||||
|
from application.core.settings import settings
|
||||||
|
|
||||||
|
_engine: Optional[Engine] = None
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_uri() -> str:
|
||||||
|
"""Return the Postgres URI for user-data tables.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If ``settings.POSTGRES_URI`` is unset. Callers that
|
||||||
|
reach this path without a configured URI have a setup bug — the
|
||||||
|
error message points them at the right setting.
|
||||||
|
"""
|
||||||
|
if not settings.POSTGRES_URI:
|
||||||
|
raise RuntimeError(
|
||||||
|
"POSTGRES_URI is not configured. Set it in your .env to a "
|
||||||
|
"psycopg3 URI such as "
|
||||||
|
"'postgresql+psycopg://user:pass@host:5432/docsgpt'."
|
||||||
|
)
|
||||||
|
return settings.POSTGRES_URI
|
||||||
|
|
||||||
|
|
||||||
|
def get_engine() -> Engine:
|
||||||
|
"""Return the process-wide SQLAlchemy Engine, creating it if needed.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A SQLAlchemy ``Engine`` configured with a pooled connection to
|
||||||
|
Postgres via psycopg3.
|
||||||
|
"""
|
||||||
|
global _engine
|
||||||
|
if _engine is None:
|
||||||
|
_engine = create_engine(
|
||||||
|
_resolve_uri(),
|
||||||
|
pool_size=10,
|
||||||
|
max_overflow=20,
|
||||||
|
pool_pre_ping=True, # survive PgBouncer / idle-disconnect recycles
|
||||||
|
pool_recycle=1800,
|
||||||
|
future=True,
|
||||||
|
)
|
||||||
|
return _engine
|
||||||
|
|
||||||
|
|
||||||
|
def dispose_engine() -> None:
|
||||||
|
"""Dispose the pooled connections and reset the singleton.
|
||||||
|
|
||||||
|
Called from the Celery ``worker_process_init`` signal so each forked
|
||||||
|
worker gets a fresh pool instead of sharing file descriptors with the
|
||||||
|
parent process (which corrupts the pool on fork).
|
||||||
|
"""
|
||||||
|
global _engine
|
||||||
|
if _engine is not None:
|
||||||
|
_engine.dispose()
|
||||||
|
_engine = None
|
||||||
396
application/storage/db/models.py
Normal file
396
application/storage/db/models.py
Normal file
@@ -0,0 +1,396 @@
|
|||||||
|
"""SQLAlchemy Core metadata for the user-data Postgres database.
|
||||||
|
|
||||||
|
Tables are added here one at a time as repositories are built during the
|
||||||
|
MongoDB→Postgres migration. The baseline schema in the Alembic migration
|
||||||
|
(``application/alembic/versions/0001_initial.py``) is the source of truth
|
||||||
|
for DDL; the ``Table`` definitions below must match it column-for-column.
|
||||||
|
If the two drift, migrations win — update this file to match.
|
||||||
|
|
||||||
|
Cross-table invariant not expressed in the Core ``Table`` definitions
|
||||||
|
below: every ``user_id`` column is FK-enforced against
|
||||||
|
``users(user_id)`` with ``ON DELETE RESTRICT``, and a
|
||||||
|
``BEFORE INSERT OR UPDATE OF user_id`` trigger on each child table
|
||||||
|
auto-creates the ``users`` row if it does not yet exist. See migration
|
||||||
|
``0015_user_id_fk``. The FKs are intentionally omitted from the Core
|
||||||
|
declarations to keep this file readable; the DB is the authority.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from sqlalchemy import (
|
||||||
|
BigInteger,
|
||||||
|
Boolean,
|
||||||
|
Column,
|
||||||
|
DateTime,
|
||||||
|
ForeignKey,
|
||||||
|
ForeignKeyConstraint,
|
||||||
|
Integer,
|
||||||
|
MetaData,
|
||||||
|
UniqueConstraint,
|
||||||
|
Table,
|
||||||
|
Text,
|
||||||
|
func,
|
||||||
|
)
|
||||||
|
from sqlalchemy.dialects.postgresql import ARRAY, CITEXT, JSONB, UUID
|
||||||
|
|
||||||
|
metadata = MetaData()
|
||||||
|
|
||||||
|
|
||||||
|
# --- Phase 1, Tier 1 --------------------------------------------------------
|
||||||
|
|
||||||
|
users_table = Table(
|
||||||
|
"users",
|
||||||
|
metadata,
|
||||||
|
Column("id", UUID(as_uuid=True), primary_key=True, server_default=func.gen_random_uuid()),
|
||||||
|
Column("user_id", Text, nullable=False, unique=True),
|
||||||
|
Column(
|
||||||
|
"agent_preferences",
|
||||||
|
JSONB,
|
||||||
|
nullable=False,
|
||||||
|
server_default='{"pinned": [], "shared_with_me": []}',
|
||||||
|
),
|
||||||
|
Column("created_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||||
|
Column("updated_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||||
|
)
|
||||||
|
|
||||||
|
prompts_table = Table(
|
||||||
|
"prompts",
|
||||||
|
metadata,
|
||||||
|
Column("id", UUID(as_uuid=True), primary_key=True, server_default=func.gen_random_uuid()),
|
||||||
|
Column("user_id", Text, nullable=False),
|
||||||
|
Column("name", Text, nullable=False),
|
||||||
|
Column("content", Text, nullable=False),
|
||||||
|
Column("created_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||||
|
Column("updated_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||||
|
Column("legacy_mongo_id", Text),
|
||||||
|
)
|
||||||
|
|
||||||
|
user_tools_table = Table(
|
||||||
|
"user_tools",
|
||||||
|
metadata,
|
||||||
|
Column("id", UUID(as_uuid=True), primary_key=True, server_default=func.gen_random_uuid()),
|
||||||
|
Column("user_id", Text, nullable=False),
|
||||||
|
Column("name", Text, nullable=False),
|
||||||
|
Column("custom_name", Text),
|
||||||
|
Column("display_name", Text),
|
||||||
|
Column("config", JSONB, nullable=False, server_default="{}"),
|
||||||
|
Column("created_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||||
|
Column("updated_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||||
|
)
|
||||||
|
|
||||||
|
token_usage_table = Table(
|
||||||
|
"token_usage",
|
||||||
|
metadata,
|
||||||
|
Column("id", BigInteger, primary_key=True, autoincrement=True),
|
||||||
|
Column("user_id", Text),
|
||||||
|
Column("api_key", Text),
|
||||||
|
Column("agent_id", UUID(as_uuid=True)),
|
||||||
|
Column("prompt_tokens", Integer, nullable=False, server_default="0"),
|
||||||
|
Column("generated_tokens", Integer, nullable=False, server_default="0"),
|
||||||
|
Column("timestamp", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||||
|
)
|
||||||
|
|
||||||
|
user_logs_table = Table(
|
||||||
|
"user_logs",
|
||||||
|
metadata,
|
||||||
|
Column("id", BigInteger, primary_key=True, autoincrement=True),
|
||||||
|
Column("user_id", Text),
|
||||||
|
Column("endpoint", Text),
|
||||||
|
Column("timestamp", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||||
|
Column("data", JSONB),
|
||||||
|
)
|
||||||
|
|
||||||
|
stack_logs_table = Table(
|
||||||
|
"stack_logs",
|
||||||
|
metadata,
|
||||||
|
Column("id", BigInteger, primary_key=True, autoincrement=True),
|
||||||
|
Column("activity_id", Text, nullable=False),
|
||||||
|
Column("endpoint", Text),
|
||||||
|
Column("level", Text),
|
||||||
|
Column("user_id", Text),
|
||||||
|
Column("api_key", Text),
|
||||||
|
Column("query", Text),
|
||||||
|
Column("stacks", JSONB, nullable=False, server_default="[]"),
|
||||||
|
Column("timestamp", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# --- Phase 2, Tier 2 --------------------------------------------------------
|
||||||
|
|
||||||
|
agent_folders_table = Table(
|
||||||
|
"agent_folders",
|
||||||
|
metadata,
|
||||||
|
Column("id", UUID(as_uuid=True), primary_key=True, server_default=func.gen_random_uuid()),
|
||||||
|
Column("user_id", Text, nullable=False),
|
||||||
|
Column("name", Text, nullable=False),
|
||||||
|
Column("description", Text),
|
||||||
|
Column("created_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||||
|
Column("updated_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||||
|
)
|
||||||
|
|
||||||
|
sources_table = Table(
|
||||||
|
"sources",
|
||||||
|
metadata,
|
||||||
|
Column("id", UUID(as_uuid=True), primary_key=True, server_default=func.gen_random_uuid()),
|
||||||
|
Column("user_id", Text, nullable=False),
|
||||||
|
Column("name", Text, nullable=False),
|
||||||
|
Column("type", Text),
|
||||||
|
Column("metadata", JSONB, nullable=False, server_default="{}"),
|
||||||
|
Column("created_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||||
|
Column("updated_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||||
|
)
|
||||||
|
|
||||||
|
agents_table = Table(
|
||||||
|
"agents",
|
||||||
|
metadata,
|
||||||
|
Column("id", UUID(as_uuid=True), primary_key=True, server_default=func.gen_random_uuid()),
|
||||||
|
Column("user_id", Text, nullable=False),
|
||||||
|
Column("name", Text, nullable=False),
|
||||||
|
Column("description", Text),
|
||||||
|
Column("agent_type", Text),
|
||||||
|
Column("status", Text, nullable=False),
|
||||||
|
Column("key", CITEXT, unique=True),
|
||||||
|
Column("source_id", UUID(as_uuid=True), ForeignKey("sources.id", ondelete="SET NULL")),
|
||||||
|
Column("extra_source_ids", ARRAY(UUID(as_uuid=True)), nullable=False, server_default="{}"),
|
||||||
|
Column("chunks", Integer),
|
||||||
|
Column("retriever", Text),
|
||||||
|
Column("prompt_id", UUID(as_uuid=True), ForeignKey("prompts.id", ondelete="SET NULL")),
|
||||||
|
Column("tools", JSONB, nullable=False, server_default="[]"),
|
||||||
|
Column("json_schema", JSONB),
|
||||||
|
Column("models", JSONB),
|
||||||
|
Column("default_model_id", Text),
|
||||||
|
Column("folder_id", UUID(as_uuid=True), ForeignKey("agent_folders.id", ondelete="SET NULL")),
|
||||||
|
Column("limited_token_mode", Boolean, nullable=False, server_default="false"),
|
||||||
|
Column("token_limit", Integer),
|
||||||
|
Column("limited_request_mode", Boolean, nullable=False, server_default="false"),
|
||||||
|
Column("request_limit", Integer),
|
||||||
|
Column("shared", Boolean, nullable=False, server_default="false"),
|
||||||
|
Column("incoming_webhook_token", CITEXT, unique=True),
|
||||||
|
Column("created_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||||
|
Column("updated_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||||
|
Column("last_used_at", DateTime(timezone=True)),
|
||||||
|
Column("legacy_mongo_id", Text),
|
||||||
|
)
|
||||||
|
|
||||||
|
attachments_table = Table(
|
||||||
|
"attachments",
|
||||||
|
metadata,
|
||||||
|
Column("id", UUID(as_uuid=True), primary_key=True, server_default=func.gen_random_uuid()),
|
||||||
|
Column("user_id", Text, nullable=False),
|
||||||
|
Column("filename", Text, nullable=False),
|
||||||
|
Column("upload_path", Text, nullable=False),
|
||||||
|
Column("mime_type", Text),
|
||||||
|
Column("size", BigInteger),
|
||||||
|
Column("created_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||||
|
Column("legacy_mongo_id", Text),
|
||||||
|
)
|
||||||
|
|
||||||
|
memories_table = Table(
|
||||||
|
"memories",
|
||||||
|
metadata,
|
||||||
|
Column("id", UUID(as_uuid=True), primary_key=True, server_default=func.gen_random_uuid()),
|
||||||
|
Column("user_id", Text, nullable=False),
|
||||||
|
Column("tool_id", UUID(as_uuid=True), ForeignKey("user_tools.id", ondelete="CASCADE")),
|
||||||
|
Column("path", Text, nullable=False),
|
||||||
|
Column("content", Text, nullable=False),
|
||||||
|
Column("updated_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||||
|
UniqueConstraint("user_id", "tool_id", "path", name="memories_user_tool_path_uidx"),
|
||||||
|
)
|
||||||
|
|
||||||
|
todos_table = Table(
|
||||||
|
"todos",
|
||||||
|
metadata,
|
||||||
|
Column("id", UUID(as_uuid=True), primary_key=True, server_default=func.gen_random_uuid()),
|
||||||
|
Column("user_id", Text, nullable=False),
|
||||||
|
Column("tool_id", UUID(as_uuid=True), ForeignKey("user_tools.id", ondelete="CASCADE")),
|
||||||
|
Column("title", Text, nullable=False),
|
||||||
|
Column("completed", Boolean, nullable=False, server_default="false"),
|
||||||
|
Column("created_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||||
|
Column("updated_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||||
|
)
|
||||||
|
|
||||||
|
notes_table = Table(
|
||||||
|
"notes",
|
||||||
|
metadata,
|
||||||
|
Column("id", UUID(as_uuid=True), primary_key=True, server_default=func.gen_random_uuid()),
|
||||||
|
Column("user_id", Text, nullable=False),
|
||||||
|
Column("tool_id", UUID(as_uuid=True), ForeignKey("user_tools.id", ondelete="CASCADE")),
|
||||||
|
Column("title", Text, nullable=False),
|
||||||
|
Column("content", Text, nullable=False),
|
||||||
|
Column("created_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||||
|
Column("updated_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||||
|
UniqueConstraint("user_id", "tool_id", name="notes_user_tool_uidx"),
|
||||||
|
)
|
||||||
|
|
||||||
|
connector_sessions_table = Table(
|
||||||
|
"connector_sessions",
|
||||||
|
metadata,
|
||||||
|
Column("id", UUID(as_uuid=True), primary_key=True, server_default=func.gen_random_uuid()),
|
||||||
|
Column("user_id", Text, nullable=False),
|
||||||
|
Column("provider", Text, nullable=False),
|
||||||
|
Column("session_data", JSONB, nullable=False),
|
||||||
|
Column("expires_at", DateTime(timezone=True)),
|
||||||
|
Column("created_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||||
|
UniqueConstraint("user_id", "provider", name="connector_sessions_user_provider_uidx"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# --- Phase 3, Tier 3 --------------------------------------------------------
|
||||||
|
|
||||||
|
conversations_table = Table(
|
||||||
|
"conversations",
|
||||||
|
metadata,
|
||||||
|
Column("id", UUID(as_uuid=True), primary_key=True, server_default=func.gen_random_uuid()),
|
||||||
|
Column("user_id", Text, nullable=False),
|
||||||
|
Column("agent_id", UUID(as_uuid=True), ForeignKey("agents.id", ondelete="SET NULL")),
|
||||||
|
Column("name", Text),
|
||||||
|
Column("api_key", Text),
|
||||||
|
Column("is_shared_usage", Boolean, nullable=False, server_default="false"),
|
||||||
|
Column("shared_token", Text),
|
||||||
|
Column("shared_with", ARRAY(Text), nullable=False, server_default="{}"),
|
||||||
|
Column("compression_metadata", JSONB),
|
||||||
|
Column("date", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||||
|
Column("created_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||||
|
Column("updated_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||||
|
Column("legacy_mongo_id", Text),
|
||||||
|
)
|
||||||
|
|
||||||
|
conversation_messages_table = Table(
|
||||||
|
"conversation_messages",
|
||||||
|
metadata,
|
||||||
|
Column("id", UUID(as_uuid=True), primary_key=True, server_default=func.gen_random_uuid()),
|
||||||
|
Column("conversation_id", UUID(as_uuid=True), ForeignKey("conversations.id", ondelete="CASCADE"), nullable=False),
|
||||||
|
# Denormalised from conversations.user_id. Auto-filled on insert by a
|
||||||
|
# BEFORE INSERT trigger when the caller omits it. See migration 0020.
|
||||||
|
Column("user_id", Text, nullable=False),
|
||||||
|
Column("position", Integer, nullable=False),
|
||||||
|
Column("prompt", Text),
|
||||||
|
Column("response", Text),
|
||||||
|
Column("thought", Text),
|
||||||
|
Column("sources", JSONB, nullable=False, server_default="[]"),
|
||||||
|
Column("tool_calls", JSONB, nullable=False, server_default="[]"),
|
||||||
|
# Postgres cannot FK-enforce array elements, so the referential
|
||||||
|
# invariant is kept by an AFTER DELETE trigger on ``attachments``
|
||||||
|
# that array_removes the id from every row that references it.
|
||||||
|
# See migration 0017_cleanup_dangling_refs.
|
||||||
|
Column("attachments", ARRAY(UUID(as_uuid=True)), nullable=False, server_default="{}"),
|
||||||
|
Column("model_id", Text),
|
||||||
|
# Renamed from ``metadata`` in migration 0016 to avoid SQLAlchemy's
|
||||||
|
# reserved attribute collision on declarative models. The repository
|
||||||
|
# translates this ↔ API dict key ``metadata`` so external callers
|
||||||
|
# still see ``metadata``.
|
||||||
|
Column("message_metadata", JSONB, nullable=False, server_default="{}"),
|
||||||
|
Column("feedback", JSONB),
|
||||||
|
Column("timestamp", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||||
|
Column("updated_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||||
|
UniqueConstraint("conversation_id", "position", name="conversation_messages_conv_pos_uidx"),
|
||||||
|
)
|
||||||
|
|
||||||
|
shared_conversations_table = Table(
|
||||||
|
"shared_conversations",
|
||||||
|
metadata,
|
||||||
|
Column("id", UUID(as_uuid=True), primary_key=True, server_default=func.gen_random_uuid()),
|
||||||
|
Column("uuid", UUID(as_uuid=True), nullable=False, unique=True),
|
||||||
|
Column("conversation_id", UUID(as_uuid=True), ForeignKey("conversations.id", ondelete="CASCADE"), nullable=False),
|
||||||
|
Column("user_id", Text, nullable=False),
|
||||||
|
Column("prompt_id", UUID(as_uuid=True), ForeignKey("prompts.id", ondelete="SET NULL")),
|
||||||
|
Column("chunks", Integer),
|
||||||
|
Column("is_promptable", Boolean, nullable=False, server_default="false"),
|
||||||
|
Column("first_n_queries", Integer, nullable=False, server_default="0"),
|
||||||
|
Column("api_key", Text),
|
||||||
|
Column("created_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||||
|
)
|
||||||
|
|
||||||
|
pending_tool_state_table = Table(
|
||||||
|
"pending_tool_state",
|
||||||
|
metadata,
|
||||||
|
Column("id", UUID(as_uuid=True), primary_key=True, server_default=func.gen_random_uuid()),
|
||||||
|
Column("conversation_id", UUID(as_uuid=True), ForeignKey("conversations.id", ondelete="CASCADE"), nullable=False),
|
||||||
|
Column("user_id", Text, nullable=False),
|
||||||
|
Column("messages", JSONB, nullable=False),
|
||||||
|
Column("pending_tool_calls", JSONB, nullable=False),
|
||||||
|
Column("tools_dict", JSONB, nullable=False),
|
||||||
|
Column("tool_schemas", JSONB, nullable=False),
|
||||||
|
Column("agent_config", JSONB, nullable=False),
|
||||||
|
Column("client_tools", JSONB),
|
||||||
|
Column("created_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||||
|
Column("expires_at", DateTime(timezone=True), nullable=False),
|
||||||
|
UniqueConstraint("conversation_id", "user_id", name="pending_tool_state_conv_user_uidx"),
|
||||||
|
)
|
||||||
|
|
||||||
|
workflows_table = Table(
|
||||||
|
"workflows",
|
||||||
|
metadata,
|
||||||
|
Column("id", UUID(as_uuid=True), primary_key=True, server_default=func.gen_random_uuid()),
|
||||||
|
Column("user_id", Text, nullable=False),
|
||||||
|
Column("name", Text, nullable=False),
|
||||||
|
Column("description", Text),
|
||||||
|
Column("current_graph_version", Integer, nullable=False, server_default="1"),
|
||||||
|
Column("created_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||||
|
Column("updated_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||||
|
Column("legacy_mongo_id", Text),
|
||||||
|
)
|
||||||
|
|
||||||
|
workflow_nodes_table = Table(
|
||||||
|
"workflow_nodes",
|
||||||
|
metadata,
|
||||||
|
Column("id", UUID(as_uuid=True), primary_key=True, server_default=func.gen_random_uuid()),
|
||||||
|
Column("workflow_id", UUID(as_uuid=True), ForeignKey("workflows.id", ondelete="CASCADE"), nullable=False),
|
||||||
|
Column("graph_version", Integer, nullable=False),
|
||||||
|
Column("node_id", Text, nullable=False),
|
||||||
|
Column("node_type", Text, nullable=False),
|
||||||
|
Column("title", Text),
|
||||||
|
Column("description", Text),
|
||||||
|
Column("position", JSONB, nullable=False, server_default='{"x": 0, "y": 0}'),
|
||||||
|
Column("config", JSONB, nullable=False, server_default="{}"),
|
||||||
|
Column("legacy_mongo_id", Text),
|
||||||
|
# Composite UNIQUE so workflow_edges can use a composite FK that
|
||||||
|
# enforces endpoint nodes belong to the same (workflow, version) as
|
||||||
|
# the edge itself. See migration 0008.
|
||||||
|
UniqueConstraint(
|
||||||
|
"id", "workflow_id", "graph_version",
|
||||||
|
name="workflow_nodes_id_wf_ver_key",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
workflow_edges_table = Table(
|
||||||
|
"workflow_edges",
|
||||||
|
metadata,
|
||||||
|
Column("id", UUID(as_uuid=True), primary_key=True, server_default=func.gen_random_uuid()),
|
||||||
|
Column("workflow_id", UUID(as_uuid=True), ForeignKey("workflows.id", ondelete="CASCADE"), nullable=False),
|
||||||
|
Column("graph_version", Integer, nullable=False),
|
||||||
|
Column("edge_id", Text, nullable=False),
|
||||||
|
Column("from_node_id", UUID(as_uuid=True), nullable=False),
|
||||||
|
Column("to_node_id", UUID(as_uuid=True), nullable=False),
|
||||||
|
Column("source_handle", Text),
|
||||||
|
Column("target_handle", Text),
|
||||||
|
Column("config", JSONB, nullable=False, server_default="{}"),
|
||||||
|
# Composite FKs: endpoints must belong to the same (workflow, version)
|
||||||
|
# as the edge. Prevents cross-workflow / cross-version edges that the
|
||||||
|
# single-column FKs couldn't catch. See migration 0008.
|
||||||
|
ForeignKeyConstraint(
|
||||||
|
["from_node_id", "workflow_id", "graph_version"],
|
||||||
|
["workflow_nodes.id", "workflow_nodes.workflow_id", "workflow_nodes.graph_version"],
|
||||||
|
ondelete="CASCADE",
|
||||||
|
name="workflow_edges_from_node_fk",
|
||||||
|
),
|
||||||
|
ForeignKeyConstraint(
|
||||||
|
["to_node_id", "workflow_id", "graph_version"],
|
||||||
|
["workflow_nodes.id", "workflow_nodes.workflow_id", "workflow_nodes.graph_version"],
|
||||||
|
ondelete="CASCADE",
|
||||||
|
name="workflow_edges_to_node_fk",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
workflow_runs_table = Table(
|
||||||
|
"workflow_runs",
|
||||||
|
metadata,
|
||||||
|
Column("id", UUID(as_uuid=True), primary_key=True, server_default=func.gen_random_uuid()),
|
||||||
|
Column("workflow_id", UUID(as_uuid=True), ForeignKey("workflows.id", ondelete="CASCADE"), nullable=False),
|
||||||
|
Column("user_id", Text, nullable=False),
|
||||||
|
Column("status", Text, nullable=False),
|
||||||
|
Column("inputs", JSONB),
|
||||||
|
Column("result", JSONB),
|
||||||
|
Column("steps", JSONB, nullable=False, server_default="[]"),
|
||||||
|
Column("started_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||||
|
Column("ended_at", DateTime(timezone=True)),
|
||||||
|
Column("legacy_mongo_id", Text),
|
||||||
|
)
|
||||||
11
application/storage/db/repositories/__init__.py
Normal file
11
application/storage/db/repositories/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
"""Repositories for the user-data Postgres database.
|
||||||
|
|
||||||
|
Each module in this package exposes exactly one repository class. Repository
|
||||||
|
methods take a ``Connection`` (either as a constructor argument or as a
|
||||||
|
method argument) and return plain ``dict`` rows via
|
||||||
|
``application.storage.db.base_repository.row_to_dict`` during the
|
||||||
|
MongoDB→Postgres cutover, so call sites don't have to change shape.
|
||||||
|
|
||||||
|
Repositories are added one collection at a time, matching the phased
|
||||||
|
rollout in ``migration-postgres.md``.
|
||||||
|
"""
|
||||||
88
application/storage/db/repositories/agent_folders.py
Normal file
88
application/storage/db/repositories/agent_folders.py
Normal file
@@ -0,0 +1,88 @@
|
|||||||
|
"""Repository for the ``agent_folders`` table."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from sqlalchemy import Connection, text
|
||||||
|
|
||||||
|
from application.storage.db.base_repository import row_to_dict
|
||||||
|
|
||||||
|
|
||||||
|
class AgentFoldersRepository:
|
||||||
|
def __init__(self, conn: Connection) -> None:
|
||||||
|
self._conn = conn
|
||||||
|
|
||||||
|
def create(self, user_id: str, name: str, *, description: Optional[str] = None) -> dict:
|
||||||
|
result = self._conn.execute(
|
||||||
|
text(
|
||||||
|
"""
|
||||||
|
INSERT INTO agent_folders (user_id, name, description)
|
||||||
|
VALUES (:user_id, :name, :description)
|
||||||
|
RETURNING *
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
{"user_id": user_id, "name": name, "description": description},
|
||||||
|
)
|
||||||
|
return row_to_dict(result.fetchone())
|
||||||
|
|
||||||
|
def get(self, folder_id: str, user_id: str) -> Optional[dict]:
|
||||||
|
result = self._conn.execute(
|
||||||
|
text("SELECT * FROM agent_folders WHERE id = CAST(:id AS uuid) AND user_id = :user_id"),
|
||||||
|
{"id": folder_id, "user_id": user_id},
|
||||||
|
)
|
||||||
|
row = result.fetchone()
|
||||||
|
return row_to_dict(row) if row is not None else None
|
||||||
|
|
||||||
|
def list_for_user(self, user_id: str) -> list[dict]:
|
||||||
|
result = self._conn.execute(
|
||||||
|
text("SELECT * FROM agent_folders WHERE user_id = :user_id ORDER BY created_at"),
|
||||||
|
{"user_id": user_id},
|
||||||
|
)
|
||||||
|
return [row_to_dict(r) for r in result.fetchall()]
|
||||||
|
|
||||||
|
def update(self, folder_id: str, user_id: str, fields: dict) -> bool:
|
||||||
|
allowed = {"name", "description"}
|
||||||
|
filtered = {k: v for k, v in fields.items() if k in allowed}
|
||||||
|
if not filtered:
|
||||||
|
return False
|
||||||
|
params: dict = {"id": folder_id, "user_id": user_id}
|
||||||
|
if "name" in filtered and "description" in filtered:
|
||||||
|
params["name"] = filtered["name"]
|
||||||
|
params["description"] = filtered["description"]
|
||||||
|
result = self._conn.execute(
|
||||||
|
text(
|
||||||
|
"UPDATE agent_folders "
|
||||||
|
"SET name = :name, description = :description, updated_at = now() "
|
||||||
|
"WHERE id = CAST(:id AS uuid) AND user_id = :user_id"
|
||||||
|
),
|
||||||
|
params,
|
||||||
|
)
|
||||||
|
elif "name" in filtered:
|
||||||
|
params["name"] = filtered["name"]
|
||||||
|
result = self._conn.execute(
|
||||||
|
text(
|
||||||
|
"UPDATE agent_folders "
|
||||||
|
"SET name = :name, updated_at = now() "
|
||||||
|
"WHERE id = CAST(:id AS uuid) AND user_id = :user_id"
|
||||||
|
),
|
||||||
|
params,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
params["description"] = filtered["description"]
|
||||||
|
result = self._conn.execute(
|
||||||
|
text(
|
||||||
|
"UPDATE agent_folders "
|
||||||
|
"SET description = :description, updated_at = now() "
|
||||||
|
"WHERE id = CAST(:id AS uuid) AND user_id = :user_id"
|
||||||
|
),
|
||||||
|
params,
|
||||||
|
)
|
||||||
|
return result.rowcount > 0
|
||||||
|
|
||||||
|
def delete(self, folder_id: str, user_id: str) -> bool:
|
||||||
|
result = self._conn.execute(
|
||||||
|
text("DELETE FROM agent_folders WHERE id = CAST(:id AS uuid) AND user_id = :user_id"),
|
||||||
|
{"id": folder_id, "user_id": user_id},
|
||||||
|
)
|
||||||
|
return result.rowcount > 0
|
||||||
195
application/storage/db/repositories/agents.py
Normal file
195
application/storage/db/repositories/agents.py
Normal file
@@ -0,0 +1,195 @@
|
|||||||
|
"""Repository for the ``agents`` table.
|
||||||
|
|
||||||
|
This is the most complex Phase 2 repository. Covers every write operation
|
||||||
|
the legacy Mongo code performs on ``agents_collection``:
|
||||||
|
|
||||||
|
- create, update, delete
|
||||||
|
- find by key (API key lookup)
|
||||||
|
- find by webhook token
|
||||||
|
- list for user, list templates
|
||||||
|
- folder assignment
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from sqlalchemy import Connection, func, text
|
||||||
|
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||||
|
|
||||||
|
from application.storage.db.base_repository import row_to_dict
|
||||||
|
from application.storage.db.models import agents_table
|
||||||
|
|
||||||
|
|
||||||
|
class AgentsRepository:
|
||||||
|
def __init__(self, conn: Connection) -> None:
|
||||||
|
self._conn = conn
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _normalize_unique_text(col: str, val):
|
||||||
|
"""Coerce blank strings for nullable unique text columns to NULL."""
|
||||||
|
if col == "key" and val == "":
|
||||||
|
return None
|
||||||
|
return val
|
||||||
|
|
||||||
|
def create(self, user_id: str, name: str, status: str, **kwargs) -> dict:
|
||||||
|
values: dict = {"user_id": user_id, "name": name, "status": status}
|
||||||
|
|
||||||
|
_ALLOWED = {
|
||||||
|
"description", "agent_type", "key", "retriever",
|
||||||
|
"default_model_id", "incoming_webhook_token",
|
||||||
|
"source_id", "prompt_id", "folder_id",
|
||||||
|
"chunks", "token_limit", "request_limit",
|
||||||
|
"limited_token_mode", "limited_request_mode", "shared",
|
||||||
|
"tools", "json_schema", "models", "legacy_mongo_id",
|
||||||
|
}
|
||||||
|
|
||||||
|
for col, val in kwargs.items():
|
||||||
|
if col not in _ALLOWED or val is None:
|
||||||
|
continue
|
||||||
|
if col in ("tools", "json_schema", "models"):
|
||||||
|
# JSONB columns: pass the Python object directly. SQLAlchemy
|
||||||
|
# Core's JSONB type processor json.dumps it once during
|
||||||
|
# bind; pre-serialising would double-encode and the value
|
||||||
|
# would round-trip as a JSON string instead of the dict.
|
||||||
|
values[col] = val
|
||||||
|
elif col in ("chunks", "token_limit", "request_limit"):
|
||||||
|
values[col] = int(val)
|
||||||
|
elif col in ("limited_token_mode", "limited_request_mode", "shared"):
|
||||||
|
values[col] = bool(val)
|
||||||
|
elif col in ("source_id", "prompt_id", "folder_id"):
|
||||||
|
values[col] = str(val)
|
||||||
|
else:
|
||||||
|
values[col] = self._normalize_unique_text(col, val)
|
||||||
|
|
||||||
|
stmt = pg_insert(agents_table).values(**values).returning(agents_table)
|
||||||
|
result = self._conn.execute(stmt)
|
||||||
|
return row_to_dict(result.fetchone())
|
||||||
|
|
||||||
|
def get(self, agent_id: str, user_id: str) -> Optional[dict]:
|
||||||
|
result = self._conn.execute(
|
||||||
|
text("SELECT * FROM agents WHERE id = CAST(:id AS uuid) AND user_id = :user_id"),
|
||||||
|
{"id": agent_id, "user_id": user_id},
|
||||||
|
)
|
||||||
|
row = result.fetchone()
|
||||||
|
return row_to_dict(row) if row is not None else None
|
||||||
|
|
||||||
|
def get_by_legacy_id(self, legacy_mongo_id: str, user_id: str | None = None) -> Optional[dict]:
|
||||||
|
"""Fetch an agent by the original Mongo ObjectId string."""
|
||||||
|
sql = "SELECT * FROM agents WHERE legacy_mongo_id = :legacy_id"
|
||||||
|
params: dict[str, str] = {"legacy_id": legacy_mongo_id}
|
||||||
|
if user_id is not None:
|
||||||
|
sql += " AND user_id = :user_id"
|
||||||
|
params["user_id"] = user_id
|
||||||
|
result = self._conn.execute(text(sql), params)
|
||||||
|
row = result.fetchone()
|
||||||
|
return row_to_dict(row) if row is not None else None
|
||||||
|
|
||||||
|
def find_by_key(self, key: str) -> Optional[dict]:
|
||||||
|
result = self._conn.execute(
|
||||||
|
text("SELECT * FROM agents WHERE key = :key"),
|
||||||
|
{"key": key},
|
||||||
|
)
|
||||||
|
row = result.fetchone()
|
||||||
|
return row_to_dict(row) if row is not None else None
|
||||||
|
|
||||||
|
def find_by_webhook_token(self, token: str) -> Optional[dict]:
|
||||||
|
result = self._conn.execute(
|
||||||
|
text("SELECT * FROM agents WHERE incoming_webhook_token = :token"),
|
||||||
|
{"token": token},
|
||||||
|
)
|
||||||
|
row = result.fetchone()
|
||||||
|
return row_to_dict(row) if row is not None else None
|
||||||
|
|
||||||
|
def list_for_user(self, user_id: str) -> list[dict]:
|
||||||
|
result = self._conn.execute(
|
||||||
|
text("SELECT * FROM agents WHERE user_id = :user_id ORDER BY created_at DESC"),
|
||||||
|
{"user_id": user_id},
|
||||||
|
)
|
||||||
|
return [row_to_dict(r) for r in result.fetchall()]
|
||||||
|
|
||||||
|
def list_templates(self) -> list[dict]:
|
||||||
|
result = self._conn.execute(
|
||||||
|
text("SELECT * FROM agents WHERE user_id = 'system' ORDER BY name"),
|
||||||
|
)
|
||||||
|
return [row_to_dict(r) for r in result.fetchall()]
|
||||||
|
|
||||||
|
def update(self, agent_id: str, user_id: str, fields: dict) -> bool:
|
||||||
|
allowed = {
|
||||||
|
"name", "description", "agent_type", "status", "key", "source_id",
|
||||||
|
"chunks", "retriever", "prompt_id", "tools", "json_schema", "models",
|
||||||
|
"default_model_id", "folder_id", "limited_token_mode", "token_limit",
|
||||||
|
"limited_request_mode", "request_limit", "shared",
|
||||||
|
"incoming_webhook_token", "last_used_at",
|
||||||
|
}
|
||||||
|
filtered = {k: v for k, v in fields.items() if k in allowed}
|
||||||
|
if not filtered:
|
||||||
|
return False
|
||||||
|
|
||||||
|
values: dict = {}
|
||||||
|
for col, val in filtered.items():
|
||||||
|
if col in ("tools", "json_schema", "models"):
|
||||||
|
# See note in create(): JSONB columns receive Python
|
||||||
|
# objects, the type processor handles serialisation.
|
||||||
|
values[col] = val
|
||||||
|
elif col in ("source_id", "prompt_id", "folder_id"):
|
||||||
|
values[col] = str(val) if val else None
|
||||||
|
else:
|
||||||
|
values[col] = self._normalize_unique_text(col, val)
|
||||||
|
values["updated_at"] = func.now()
|
||||||
|
|
||||||
|
t = agents_table
|
||||||
|
stmt = (
|
||||||
|
t.update()
|
||||||
|
.where(t.c.id == agent_id)
|
||||||
|
.where(t.c.user_id == user_id)
|
||||||
|
.values(**values)
|
||||||
|
)
|
||||||
|
result = self._conn.execute(stmt)
|
||||||
|
return result.rowcount > 0
|
||||||
|
|
||||||
|
def update_by_legacy_id(self, legacy_mongo_id: str, user_id: str, fields: dict) -> bool:
|
||||||
|
"""Update an agent addressed by the Mongo ObjectId string."""
|
||||||
|
agent = self.get_by_legacy_id(legacy_mongo_id, user_id)
|
||||||
|
if agent is None:
|
||||||
|
return False
|
||||||
|
return self.update(agent["id"], user_id, fields)
|
||||||
|
|
||||||
|
def delete(self, agent_id: str, user_id: str) -> bool:
|
||||||
|
result = self._conn.execute(
|
||||||
|
text("DELETE FROM agents WHERE id = CAST(:id AS uuid) AND user_id = :user_id"),
|
||||||
|
{"id": agent_id, "user_id": user_id},
|
||||||
|
)
|
||||||
|
return result.rowcount > 0
|
||||||
|
|
||||||
|
def delete_by_legacy_id(self, legacy_mongo_id: str, user_id: str) -> bool:
|
||||||
|
"""Delete an agent addressed by the Mongo ObjectId string."""
|
||||||
|
result = self._conn.execute(
|
||||||
|
text(
|
||||||
|
"DELETE FROM agents "
|
||||||
|
"WHERE legacy_mongo_id = :legacy_id AND user_id = :user_id"
|
||||||
|
),
|
||||||
|
{"legacy_id": legacy_mongo_id, "user_id": user_id},
|
||||||
|
)
|
||||||
|
return result.rowcount > 0
|
||||||
|
|
||||||
|
def set_folder(self, agent_id: str, user_id: str, folder_id: Optional[str]) -> None:
|
||||||
|
self._conn.execute(
|
||||||
|
text(
|
||||||
|
"""
|
||||||
|
UPDATE agents SET folder_id = CAST(:folder_id AS uuid), updated_at = now()
|
||||||
|
WHERE id = CAST(:id AS uuid) AND user_id = :user_id
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
{"id": agent_id, "user_id": user_id, "folder_id": folder_id},
|
||||||
|
)
|
||||||
|
|
||||||
|
def clear_folder_for_all(self, folder_id: str, user_id: str) -> None:
|
||||||
|
"""Remove folder assignment from all agents in a folder (used on folder delete)."""
|
||||||
|
self._conn.execute(
|
||||||
|
text(
|
||||||
|
"UPDATE agents SET folder_id = NULL, updated_at = now() "
|
||||||
|
"WHERE folder_id = CAST(:folder_id AS uuid) AND user_id = :user_id"
|
||||||
|
),
|
||||||
|
{"folder_id": folder_id, "user_id": user_id},
|
||||||
|
)
|
||||||
66
application/storage/db/repositories/attachments.py
Normal file
66
application/storage/db/repositories/attachments.py
Normal file
@@ -0,0 +1,66 @@
|
|||||||
|
"""Repository for the ``attachments`` table."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from sqlalchemy import Connection, text
|
||||||
|
|
||||||
|
from application.storage.db.base_repository import row_to_dict
|
||||||
|
|
||||||
|
|
||||||
|
class AttachmentsRepository:
|
||||||
|
def __init__(self, conn: Connection) -> None:
|
||||||
|
self._conn = conn
|
||||||
|
|
||||||
|
def create(self, user_id: str, filename: str, upload_path: str, *,
|
||||||
|
mime_type: Optional[str] = None, size: Optional[int] = None,
|
||||||
|
legacy_mongo_id: Optional[str] = None) -> dict:
|
||||||
|
result = self._conn.execute(
|
||||||
|
text(
|
||||||
|
"""
|
||||||
|
INSERT INTO attachments
|
||||||
|
(user_id, filename, upload_path, mime_type, size, legacy_mongo_id)
|
||||||
|
VALUES
|
||||||
|
(:user_id, :filename, :upload_path, :mime_type, :size, :legacy_mongo_id)
|
||||||
|
RETURNING *
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
{
|
||||||
|
"user_id": user_id,
|
||||||
|
"filename": filename,
|
||||||
|
"upload_path": upload_path,
|
||||||
|
"mime_type": mime_type,
|
||||||
|
"size": size,
|
||||||
|
"legacy_mongo_id": legacy_mongo_id,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return row_to_dict(result.fetchone())
|
||||||
|
|
||||||
|
def get(self, attachment_id: str, user_id: str) -> Optional[dict]:
|
||||||
|
result = self._conn.execute(
|
||||||
|
text(
|
||||||
|
"SELECT * FROM attachments WHERE id = CAST(:id AS uuid) AND user_id = :user_id"
|
||||||
|
),
|
||||||
|
{"id": attachment_id, "user_id": user_id},
|
||||||
|
)
|
||||||
|
row = result.fetchone()
|
||||||
|
return row_to_dict(row) if row is not None else None
|
||||||
|
|
||||||
|
def get_by_legacy_id(self, legacy_mongo_id: str, user_id: str | None = None) -> Optional[dict]:
|
||||||
|
"""Fetch an attachment by the original Mongo ObjectId string."""
|
||||||
|
sql = "SELECT * FROM attachments WHERE legacy_mongo_id = :legacy_id"
|
||||||
|
params: dict[str, str] = {"legacy_id": legacy_mongo_id}
|
||||||
|
if user_id is not None:
|
||||||
|
sql += " AND user_id = :user_id"
|
||||||
|
params["user_id"] = user_id
|
||||||
|
result = self._conn.execute(text(sql), params)
|
||||||
|
row = result.fetchone()
|
||||||
|
return row_to_dict(row) if row is not None else None
|
||||||
|
|
||||||
|
def list_for_user(self, user_id: str) -> list[dict]:
|
||||||
|
result = self._conn.execute(
|
||||||
|
text("SELECT * FROM attachments WHERE user_id = :user_id ORDER BY created_at DESC"),
|
||||||
|
{"user_id": user_id},
|
||||||
|
)
|
||||||
|
return [row_to_dict(r) for r in result.fetchall()]
|
||||||
65
application/storage/db/repositories/connector_sessions.py
Normal file
65
application/storage/db/repositories/connector_sessions.py
Normal file
@@ -0,0 +1,65 @@
|
|||||||
|
"""Repository for the ``connector_sessions`` table.
|
||||||
|
|
||||||
|
Covers operations across connector routes and tools:
|
||||||
|
- upsert session data
|
||||||
|
- find session by user + provider
|
||||||
|
- find session by token
|
||||||
|
- delete session
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from sqlalchemy import Connection, text
|
||||||
|
|
||||||
|
from application.storage.db.base_repository import row_to_dict
|
||||||
|
|
||||||
|
|
||||||
|
class ConnectorSessionsRepository:
|
||||||
|
def __init__(self, conn: Connection) -> None:
|
||||||
|
self._conn = conn
|
||||||
|
|
||||||
|
def upsert(self, user_id: str, provider: str, session_data: dict) -> dict:
|
||||||
|
result = self._conn.execute(
|
||||||
|
text(
|
||||||
|
"""
|
||||||
|
INSERT INTO connector_sessions (user_id, provider, session_data)
|
||||||
|
VALUES (:user_id, :provider, CAST(:session_data AS jsonb))
|
||||||
|
ON CONFLICT (user_id, provider)
|
||||||
|
DO UPDATE SET session_data = EXCLUDED.session_data
|
||||||
|
RETURNING *
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
{
|
||||||
|
"user_id": user_id,
|
||||||
|
"provider": provider,
|
||||||
|
"session_data": json.dumps(session_data),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return row_to_dict(result.fetchone())
|
||||||
|
|
||||||
|
def get_by_user_provider(self, user_id: str, provider: str) -> Optional[dict]:
|
||||||
|
result = self._conn.execute(
|
||||||
|
text(
|
||||||
|
"SELECT * FROM connector_sessions WHERE user_id = :user_id AND provider = :provider"
|
||||||
|
),
|
||||||
|
{"user_id": user_id, "provider": provider},
|
||||||
|
)
|
||||||
|
row = result.fetchone()
|
||||||
|
return row_to_dict(row) if row is not None else None
|
||||||
|
|
||||||
|
def list_for_user(self, user_id: str) -> list[dict]:
|
||||||
|
result = self._conn.execute(
|
||||||
|
text("SELECT * FROM connector_sessions WHERE user_id = :user_id"),
|
||||||
|
{"user_id": user_id},
|
||||||
|
)
|
||||||
|
return [row_to_dict(r) for r in result.fetchall()]
|
||||||
|
|
||||||
|
def delete(self, user_id: str, provider: str) -> bool:
|
||||||
|
result = self._conn.execute(
|
||||||
|
text("DELETE FROM connector_sessions WHERE user_id = :user_id AND provider = :provider"),
|
||||||
|
{"user_id": user_id, "provider": provider},
|
||||||
|
)
|
||||||
|
return result.rowcount > 0
|
||||||
476
application/storage/db/repositories/conversations.py
Normal file
476
application/storage/db/repositories/conversations.py
Normal file
@@ -0,0 +1,476 @@
|
|||||||
|
"""Repository for the ``conversations`` and ``conversation_messages`` tables.
|
||||||
|
|
||||||
|
Covers every operation the legacy Mongo code performs on
|
||||||
|
``conversations_collection``:
|
||||||
|
|
||||||
|
- create / get / list / delete conversations
|
||||||
|
- append message (transactional position allocation)
|
||||||
|
- update message at index (overwrite + optional truncation)
|
||||||
|
- set / unset feedback on a message
|
||||||
|
- rename conversation
|
||||||
|
- update compression metadata
|
||||||
|
- shared_with access checks
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from sqlalchemy import Connection, text
|
||||||
|
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||||
|
|
||||||
|
from application.storage.db.base_repository import row_to_dict
|
||||||
|
from application.storage.db.models import conversations_table, conversation_messages_table
|
||||||
|
|
||||||
|
|
||||||
|
def _message_row_to_dict(row) -> dict:
|
||||||
|
"""Like ``row_to_dict`` but renames the DB column ``message_metadata``
|
||||||
|
back to the public API key ``metadata`` so callers keep the Mongo-era
|
||||||
|
shape. See migration 0016 for the column rename rationale."""
|
||||||
|
out = row_to_dict(row)
|
||||||
|
if "message_metadata" in out:
|
||||||
|
out["metadata"] = out.pop("message_metadata")
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class ConversationsRepository:
|
||||||
|
def __init__(self, conn: Connection) -> None:
|
||||||
|
self._conn = conn
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Conversation CRUD
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def create(
|
||||||
|
self,
|
||||||
|
user_id: str,
|
||||||
|
name: str | None = None,
|
||||||
|
*,
|
||||||
|
agent_id: str | None = None,
|
||||||
|
api_key: str | None = None,
|
||||||
|
is_shared_usage: bool = False,
|
||||||
|
shared_token: str | None = None,
|
||||||
|
legacy_mongo_id: str | None = None,
|
||||||
|
) -> dict:
|
||||||
|
"""Create a new conversation.
|
||||||
|
|
||||||
|
``legacy_mongo_id`` is used by the dual-write shim so that a
|
||||||
|
Postgres row inserted *after* a successful Mongo insert carries
|
||||||
|
the Mongo ``_id`` as a lookup key. Subsequent appends/updates
|
||||||
|
can then resolve the PG row by that id via
|
||||||
|
:meth:`get_by_legacy_id`.
|
||||||
|
"""
|
||||||
|
values: dict = {
|
||||||
|
"user_id": user_id,
|
||||||
|
"name": name,
|
||||||
|
}
|
||||||
|
if agent_id:
|
||||||
|
values["agent_id"] = agent_id
|
||||||
|
if api_key:
|
||||||
|
values["api_key"] = api_key
|
||||||
|
if is_shared_usage:
|
||||||
|
values["is_shared_usage"] = True
|
||||||
|
if shared_token:
|
||||||
|
values["shared_token"] = shared_token
|
||||||
|
if legacy_mongo_id:
|
||||||
|
values["legacy_mongo_id"] = legacy_mongo_id
|
||||||
|
|
||||||
|
stmt = pg_insert(conversations_table).values(**values).returning(conversations_table)
|
||||||
|
result = self._conn.execute(stmt)
|
||||||
|
return row_to_dict(result.fetchone())
|
||||||
|
|
||||||
|
def get_by_legacy_id(
|
||||||
|
self, legacy_mongo_id: str, user_id: str | None = None,
|
||||||
|
) -> Optional[dict]:
|
||||||
|
"""Look up a conversation by the original Mongo ObjectId string.
|
||||||
|
|
||||||
|
Used by the dual-write helpers to translate a Mongo ``_id`` into
|
||||||
|
the Postgres UUID for follow-up writes. When ``user_id`` is
|
||||||
|
provided, the lookup is scoped to rows owned by that user so
|
||||||
|
callers can't accidentally resolve another user's conversation.
|
||||||
|
"""
|
||||||
|
if user_id is not None:
|
||||||
|
result = self._conn.execute(
|
||||||
|
text(
|
||||||
|
"SELECT * FROM conversations "
|
||||||
|
"WHERE legacy_mongo_id = :legacy_id AND user_id = :user_id"
|
||||||
|
),
|
||||||
|
{"legacy_id": legacy_mongo_id, "user_id": user_id},
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
result = self._conn.execute(
|
||||||
|
text(
|
||||||
|
"SELECT * FROM conversations WHERE legacy_mongo_id = :legacy_id"
|
||||||
|
),
|
||||||
|
{"legacy_id": legacy_mongo_id},
|
||||||
|
)
|
||||||
|
row = result.fetchone()
|
||||||
|
return row_to_dict(row) if row is not None else None
|
||||||
|
|
||||||
|
def get(self, conversation_id: str, user_id: str) -> Optional[dict]:
|
||||||
|
"""Fetch a conversation the user owns or has shared access to."""
|
||||||
|
result = self._conn.execute(
|
||||||
|
text(
|
||||||
|
"SELECT * FROM conversations "
|
||||||
|
"WHERE id = CAST(:id AS uuid) "
|
||||||
|
"AND (user_id = :user_id OR :user_id = ANY(shared_with))"
|
||||||
|
),
|
||||||
|
{"id": conversation_id, "user_id": user_id},
|
||||||
|
)
|
||||||
|
row = result.fetchone()
|
||||||
|
return row_to_dict(row) if row is not None else None
|
||||||
|
|
||||||
|
def get_owned(self, conversation_id: str, user_id: str) -> Optional[dict]:
|
||||||
|
"""Fetch a conversation owned by the user (no shared access)."""
|
||||||
|
result = self._conn.execute(
|
||||||
|
text(
|
||||||
|
"SELECT * FROM conversations "
|
||||||
|
"WHERE id = CAST(:id AS uuid) AND user_id = :user_id"
|
||||||
|
),
|
||||||
|
{"id": conversation_id, "user_id": user_id},
|
||||||
|
)
|
||||||
|
row = result.fetchone()
|
||||||
|
return row_to_dict(row) if row is not None else None
|
||||||
|
|
||||||
|
def list_for_user(self, user_id: str, limit: int = 30) -> list[dict]:
|
||||||
|
"""List conversations for a user, most recent first.
|
||||||
|
|
||||||
|
Mirrors the Mongo query: either no api_key or agent_id exists.
|
||||||
|
"""
|
||||||
|
result = self._conn.execute(
|
||||||
|
text(
|
||||||
|
"SELECT * FROM conversations "
|
||||||
|
"WHERE user_id = :user_id "
|
||||||
|
"AND (api_key IS NULL OR agent_id IS NOT NULL) "
|
||||||
|
"ORDER BY date DESC LIMIT :limit"
|
||||||
|
),
|
||||||
|
{"user_id": user_id, "limit": limit},
|
||||||
|
)
|
||||||
|
return [row_to_dict(r) for r in result.fetchall()]
|
||||||
|
|
||||||
|
def rename(self, conversation_id: str, user_id: str, name: str) -> bool:
|
||||||
|
result = self._conn.execute(
|
||||||
|
text(
|
||||||
|
"UPDATE conversations SET name = :name, updated_at = now() "
|
||||||
|
"WHERE id = CAST(:id AS uuid) AND user_id = :user_id"
|
||||||
|
),
|
||||||
|
{"id": conversation_id, "user_id": user_id, "name": name},
|
||||||
|
)
|
||||||
|
return result.rowcount > 0
|
||||||
|
|
||||||
|
def set_shared_token(self, conversation_id: str, user_id: str, token: str) -> bool:
|
||||||
|
result = self._conn.execute(
|
||||||
|
text(
|
||||||
|
"UPDATE conversations SET shared_token = :token, updated_at = now() "
|
||||||
|
"WHERE id = CAST(:id AS uuid) AND user_id = :user_id"
|
||||||
|
),
|
||||||
|
{"id": conversation_id, "user_id": user_id, "token": token},
|
||||||
|
)
|
||||||
|
return result.rowcount > 0
|
||||||
|
|
||||||
|
def update_compression_metadata(
|
||||||
|
self, conversation_id: str, user_id: str, metadata: dict,
|
||||||
|
) -> bool:
|
||||||
|
"""Replace the entire ``compression_metadata`` JSONB blob.
|
||||||
|
|
||||||
|
Prefer :meth:`append_compression_point` + :meth:`set_compression_flags`
|
||||||
|
to match the Mongo service semantics exactly (those two mirror
|
||||||
|
``$set`` + ``$push $slice``). This method is retained for callers
|
||||||
|
that already compute the full merged blob client-side.
|
||||||
|
"""
|
||||||
|
result = self._conn.execute(
|
||||||
|
text(
|
||||||
|
"UPDATE conversations "
|
||||||
|
"SET compression_metadata = CAST(:meta AS jsonb), updated_at = now() "
|
||||||
|
"WHERE id = CAST(:id AS uuid) AND user_id = :user_id"
|
||||||
|
),
|
||||||
|
{"id": conversation_id, "user_id": user_id, "meta": json.dumps(metadata)},
|
||||||
|
)
|
||||||
|
return result.rowcount > 0
|
||||||
|
|
||||||
|
def set_compression_flags(
|
||||||
|
self,
|
||||||
|
conversation_id: str,
|
||||||
|
*,
|
||||||
|
is_compressed: bool,
|
||||||
|
last_compression_at,
|
||||||
|
) -> bool:
|
||||||
|
"""Update ``compression_metadata.is_compressed`` and
|
||||||
|
``compression_metadata.last_compression_at`` without touching
|
||||||
|
``compression_points``.
|
||||||
|
|
||||||
|
Mirrors the Mongo ``$set`` on those two subfields in
|
||||||
|
``ConversationService.update_compression_metadata``. Initialises
|
||||||
|
the surrounding object when the row has no ``compression_metadata``
|
||||||
|
yet.
|
||||||
|
"""
|
||||||
|
result = self._conn.execute(
|
||||||
|
text(
|
||||||
|
"""
|
||||||
|
UPDATE conversations SET
|
||||||
|
compression_metadata = jsonb_set(
|
||||||
|
jsonb_set(
|
||||||
|
COALESCE(compression_metadata, '{}'::jsonb),
|
||||||
|
'{is_compressed}',
|
||||||
|
to_jsonb(CAST(:is_compressed AS boolean)), true
|
||||||
|
),
|
||||||
|
'{last_compression_at}',
|
||||||
|
to_jsonb(CAST(:last_compression_at AS text)), true
|
||||||
|
),
|
||||||
|
updated_at = now()
|
||||||
|
WHERE id = CAST(:id AS uuid)
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
{
|
||||||
|
"id": conversation_id,
|
||||||
|
"is_compressed": bool(is_compressed),
|
||||||
|
"last_compression_at": (
|
||||||
|
str(last_compression_at) if last_compression_at is not None else None
|
||||||
|
),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return result.rowcount > 0
|
||||||
|
|
||||||
|
def append_compression_point(
|
||||||
|
self,
|
||||||
|
conversation_id: str,
|
||||||
|
point: dict,
|
||||||
|
*,
|
||||||
|
max_points: int,
|
||||||
|
) -> bool:
|
||||||
|
"""Append one compression point, keeping at most ``max_points``.
|
||||||
|
|
||||||
|
Mirrors Mongo's ``$push {"$each": [point], "$slice": -max_points}``
|
||||||
|
on ``compression_metadata.compression_points``. Preserves the
|
||||||
|
other top-level keys in ``compression_metadata``.
|
||||||
|
"""
|
||||||
|
result = self._conn.execute(
|
||||||
|
text(
|
||||||
|
"""
|
||||||
|
UPDATE conversations SET
|
||||||
|
compression_metadata = jsonb_set(
|
||||||
|
COALESCE(compression_metadata, '{}'::jsonb),
|
||||||
|
'{compression_points}',
|
||||||
|
COALESCE(
|
||||||
|
(
|
||||||
|
SELECT jsonb_agg(elem ORDER BY rn)
|
||||||
|
FROM (
|
||||||
|
SELECT
|
||||||
|
elem,
|
||||||
|
row_number() OVER () AS rn,
|
||||||
|
count(*) OVER () AS cnt
|
||||||
|
FROM jsonb_array_elements(
|
||||||
|
COALESCE(
|
||||||
|
compression_metadata -> 'compression_points',
|
||||||
|
'[]'::jsonb
|
||||||
|
) || jsonb_build_array(CAST(:point AS jsonb))
|
||||||
|
) AS elem
|
||||||
|
) ranked
|
||||||
|
WHERE rn > cnt - :max_points
|
||||||
|
),
|
||||||
|
'[]'::jsonb
|
||||||
|
),
|
||||||
|
true
|
||||||
|
),
|
||||||
|
updated_at = now()
|
||||||
|
WHERE id = CAST(:id AS uuid)
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
{
|
||||||
|
"id": conversation_id,
|
||||||
|
"point": json.dumps(point, default=str),
|
||||||
|
"max_points": int(max_points),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return result.rowcount > 0
|
||||||
|
|
||||||
|
def delete(self, conversation_id: str, user_id: str) -> bool:
|
||||||
|
result = self._conn.execute(
|
||||||
|
text(
|
||||||
|
"DELETE FROM conversations "
|
||||||
|
"WHERE id = CAST(:id AS uuid) AND user_id = :user_id"
|
||||||
|
),
|
||||||
|
{"id": conversation_id, "user_id": user_id},
|
||||||
|
)
|
||||||
|
return result.rowcount > 0
|
||||||
|
|
||||||
|
def delete_all_for_user(self, user_id: str) -> int:
|
||||||
|
result = self._conn.execute(
|
||||||
|
text("DELETE FROM conversations WHERE user_id = :user_id"),
|
||||||
|
{"user_id": user_id},
|
||||||
|
)
|
||||||
|
return result.rowcount
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Messages
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def get_messages(self, conversation_id: str) -> list[dict]:
|
||||||
|
result = self._conn.execute(
|
||||||
|
text(
|
||||||
|
"SELECT * FROM conversation_messages "
|
||||||
|
"WHERE conversation_id = CAST(:conv_id AS uuid) "
|
||||||
|
"ORDER BY position ASC"
|
||||||
|
),
|
||||||
|
{"conv_id": conversation_id},
|
||||||
|
)
|
||||||
|
return [_message_row_to_dict(r) for r in result.fetchall()]
|
||||||
|
|
||||||
|
def get_message_at(self, conversation_id: str, position: int) -> Optional[dict]:
|
||||||
|
result = self._conn.execute(
|
||||||
|
text(
|
||||||
|
"SELECT * FROM conversation_messages "
|
||||||
|
"WHERE conversation_id = CAST(:conv_id AS uuid) "
|
||||||
|
"AND position = :pos"
|
||||||
|
),
|
||||||
|
{"conv_id": conversation_id, "pos": position},
|
||||||
|
)
|
||||||
|
row = result.fetchone()
|
||||||
|
return _message_row_to_dict(row) if row is not None else None
|
||||||
|
|
||||||
|
def append_message(self, conversation_id: str, message: dict) -> dict:
|
||||||
|
"""Append a message to a conversation.
|
||||||
|
|
||||||
|
Uses ``SELECT ... FOR UPDATE`` to allocate the next position
|
||||||
|
atomically. The caller must be inside a transaction.
|
||||||
|
|
||||||
|
Mirrors Mongo's ``$push`` on the ``queries`` array.
|
||||||
|
"""
|
||||||
|
# Lock the parent conversation row to serialize concurrent appends.
|
||||||
|
self._conn.execute(
|
||||||
|
text(
|
||||||
|
"SELECT id FROM conversations "
|
||||||
|
"WHERE id = CAST(:conv_id AS uuid) FOR UPDATE"
|
||||||
|
),
|
||||||
|
{"conv_id": conversation_id},
|
||||||
|
)
|
||||||
|
next_pos_result = self._conn.execute(
|
||||||
|
text(
|
||||||
|
"SELECT COALESCE(MAX(position), -1) + 1 AS next_pos "
|
||||||
|
"FROM conversation_messages "
|
||||||
|
"WHERE conversation_id = CAST(:conv_id AS uuid)"
|
||||||
|
),
|
||||||
|
{"conv_id": conversation_id},
|
||||||
|
)
|
||||||
|
next_pos = next_pos_result.scalar()
|
||||||
|
|
||||||
|
values = {
|
||||||
|
"conversation_id": conversation_id,
|
||||||
|
"position": next_pos,
|
||||||
|
"prompt": message.get("prompt"),
|
||||||
|
"response": message.get("response"),
|
||||||
|
"thought": message.get("thought"),
|
||||||
|
"sources": message.get("sources") or [],
|
||||||
|
"tool_calls": message.get("tool_calls") or [],
|
||||||
|
"model_id": message.get("model_id"),
|
||||||
|
"message_metadata": message.get("metadata") or {},
|
||||||
|
}
|
||||||
|
if message.get("timestamp") is not None:
|
||||||
|
values["timestamp"] = message["timestamp"]
|
||||||
|
|
||||||
|
attachments = message.get("attachments")
|
||||||
|
if attachments:
|
||||||
|
values["attachments"] = [str(a) for a in attachments]
|
||||||
|
|
||||||
|
stmt = (
|
||||||
|
pg_insert(conversation_messages_table)
|
||||||
|
.values(**values)
|
||||||
|
.returning(conversation_messages_table)
|
||||||
|
)
|
||||||
|
result = self._conn.execute(stmt)
|
||||||
|
# Touch the parent conversation's updated_at.
|
||||||
|
self._conn.execute(
|
||||||
|
text(
|
||||||
|
"UPDATE conversations SET updated_at = now() "
|
||||||
|
"WHERE id = CAST(:id AS uuid)"
|
||||||
|
),
|
||||||
|
{"id": conversation_id},
|
||||||
|
)
|
||||||
|
return _message_row_to_dict(result.fetchone())
|
||||||
|
|
||||||
|
def update_message_at(
|
||||||
|
self, conversation_id: str, position: int, fields: dict,
|
||||||
|
) -> bool:
|
||||||
|
"""Update specific fields on a message at a given position.
|
||||||
|
|
||||||
|
Mirrors Mongo's ``$set`` on ``queries.{index}.*``.
|
||||||
|
"""
|
||||||
|
allowed = {
|
||||||
|
"prompt", "response", "thought", "sources", "tool_calls",
|
||||||
|
"attachments", "model_id", "metadata", "timestamp",
|
||||||
|
}
|
||||||
|
filtered = {k: v for k, v in fields.items() if k in allowed}
|
||||||
|
if not filtered:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Map public API key ``metadata`` → DB column ``message_metadata``.
|
||||||
|
api_to_col = {"metadata": "message_metadata"}
|
||||||
|
|
||||||
|
set_parts = []
|
||||||
|
params: dict = {"conv_id": conversation_id, "pos": position}
|
||||||
|
for key, val in filtered.items():
|
||||||
|
col = api_to_col.get(key, key)
|
||||||
|
if key in ("sources", "tool_calls", "metadata"):
|
||||||
|
set_parts.append(f"{col} = CAST(:{col} AS jsonb)")
|
||||||
|
params[col] = json.dumps(val) if not isinstance(val, str) else val
|
||||||
|
elif key == "attachments":
|
||||||
|
set_parts.append(f"{col} = CAST(:{col} AS uuid[])")
|
||||||
|
params[col] = [str(a) for a in val] if val else []
|
||||||
|
else:
|
||||||
|
set_parts.append(f"{col} = :{col}")
|
||||||
|
params[col] = val
|
||||||
|
|
||||||
|
if "timestamp" not in filtered:
|
||||||
|
set_parts.append("timestamp = now()")
|
||||||
|
sql = (
|
||||||
|
f"UPDATE conversation_messages SET {', '.join(set_parts)} "
|
||||||
|
"WHERE conversation_id = CAST(:conv_id AS uuid) AND position = :pos"
|
||||||
|
)
|
||||||
|
result = self._conn.execute(text(sql), params)
|
||||||
|
return result.rowcount > 0
|
||||||
|
|
||||||
|
def truncate_after(self, conversation_id: str, keep_up_to: int) -> int:
|
||||||
|
"""Delete messages with position > keep_up_to.
|
||||||
|
|
||||||
|
Mirrors Mongo's ``$push`` + ``$slice`` that trims queries after an
|
||||||
|
index-based update.
|
||||||
|
"""
|
||||||
|
result = self._conn.execute(
|
||||||
|
text(
|
||||||
|
"DELETE FROM conversation_messages "
|
||||||
|
"WHERE conversation_id = CAST(:conv_id AS uuid) "
|
||||||
|
"AND position > :pos"
|
||||||
|
),
|
||||||
|
{"conv_id": conversation_id, "pos": keep_up_to},
|
||||||
|
)
|
||||||
|
return result.rowcount
|
||||||
|
|
||||||
|
def set_feedback(
|
||||||
|
self, conversation_id: str, position: int, feedback: dict | None,
|
||||||
|
) -> bool:
|
||||||
|
"""Set or unset feedback on a message.
|
||||||
|
|
||||||
|
``feedback`` is a JSONB value, e.g. ``{"text": "thumbs_up",
|
||||||
|
"timestamp": "..."}`` or ``None`` to unset.
|
||||||
|
"""
|
||||||
|
fb_json = json.dumps(feedback) if feedback is not None else None
|
||||||
|
result = self._conn.execute(
|
||||||
|
text(
|
||||||
|
"UPDATE conversation_messages "
|
||||||
|
"SET feedback = CAST(:fb AS jsonb) "
|
||||||
|
"WHERE conversation_id = CAST(:conv_id AS uuid) AND position = :pos"
|
||||||
|
),
|
||||||
|
{"conv_id": conversation_id, "pos": position, "fb": fb_json},
|
||||||
|
)
|
||||||
|
return result.rowcount > 0
|
||||||
|
|
||||||
|
def message_count(self, conversation_id: str) -> int:
|
||||||
|
result = self._conn.execute(
|
||||||
|
text(
|
||||||
|
"SELECT COUNT(*) FROM conversation_messages "
|
||||||
|
"WHERE conversation_id = CAST(:conv_id AS uuid)"
|
||||||
|
),
|
||||||
|
{"conv_id": conversation_id},
|
||||||
|
)
|
||||||
|
return result.scalar() or 0
|
||||||
97
application/storage/db/repositories/memories.py
Normal file
97
application/storage/db/repositories/memories.py
Normal file
@@ -0,0 +1,97 @@
|
|||||||
|
"""Repository for the ``memories`` table.
|
||||||
|
|
||||||
|
Covers the operations in ``application/agents/tools/memory.py``:
|
||||||
|
- upsert (create/overwrite file)
|
||||||
|
- find by path (view file)
|
||||||
|
- find by path prefix (view directory, regex scan)
|
||||||
|
- delete by path / path prefix
|
||||||
|
- rename (update path)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from sqlalchemy import Connection, text
|
||||||
|
|
||||||
|
from application.storage.db.base_repository import row_to_dict
|
||||||
|
|
||||||
|
|
||||||
|
class MemoriesRepository:
|
||||||
|
def __init__(self, conn: Connection) -> None:
|
||||||
|
self._conn = conn
|
||||||
|
|
||||||
|
def upsert(self, user_id: str, tool_id: str, path: str, content: str) -> dict:
|
||||||
|
result = self._conn.execute(
|
||||||
|
text(
|
||||||
|
"""
|
||||||
|
INSERT INTO memories (user_id, tool_id, path, content)
|
||||||
|
VALUES (:user_id, CAST(:tool_id AS uuid), :path, :content)
|
||||||
|
ON CONFLICT (user_id, tool_id, path)
|
||||||
|
DO UPDATE SET content = EXCLUDED.content, updated_at = now()
|
||||||
|
RETURNING *
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
{"user_id": user_id, "tool_id": tool_id, "path": path, "content": content},
|
||||||
|
)
|
||||||
|
return row_to_dict(result.fetchone())
|
||||||
|
|
||||||
|
def get_by_path(self, user_id: str, tool_id: str, path: str) -> Optional[dict]:
|
||||||
|
result = self._conn.execute(
|
||||||
|
text(
|
||||||
|
"SELECT * FROM memories WHERE user_id = :user_id "
|
||||||
|
"AND tool_id = CAST(:tool_id AS uuid) AND path = :path"
|
||||||
|
),
|
||||||
|
{"user_id": user_id, "tool_id": tool_id, "path": path},
|
||||||
|
)
|
||||||
|
row = result.fetchone()
|
||||||
|
return row_to_dict(row) if row is not None else None
|
||||||
|
|
||||||
|
def list_by_prefix(self, user_id: str, tool_id: str, prefix: str) -> list[dict]:
|
||||||
|
result = self._conn.execute(
|
||||||
|
text(
|
||||||
|
"SELECT * FROM memories WHERE user_id = :user_id "
|
||||||
|
"AND tool_id = CAST(:tool_id AS uuid) AND path LIKE :prefix"
|
||||||
|
),
|
||||||
|
{"user_id": user_id, "tool_id": tool_id, "prefix": prefix + "%"},
|
||||||
|
)
|
||||||
|
return [row_to_dict(r) for r in result.fetchall()]
|
||||||
|
|
||||||
|
def delete_by_path(self, user_id: str, tool_id: str, path: str) -> int:
|
||||||
|
result = self._conn.execute(
|
||||||
|
text(
|
||||||
|
"DELETE FROM memories WHERE user_id = :user_id "
|
||||||
|
"AND tool_id = CAST(:tool_id AS uuid) AND path = :path"
|
||||||
|
),
|
||||||
|
{"user_id": user_id, "tool_id": tool_id, "path": path},
|
||||||
|
)
|
||||||
|
return result.rowcount
|
||||||
|
|
||||||
|
def delete_by_prefix(self, user_id: str, tool_id: str, prefix: str) -> int:
|
||||||
|
result = self._conn.execute(
|
||||||
|
text(
|
||||||
|
"DELETE FROM memories WHERE user_id = :user_id "
|
||||||
|
"AND tool_id = CAST(:tool_id AS uuid) AND path LIKE :prefix"
|
||||||
|
),
|
||||||
|
{"user_id": user_id, "tool_id": tool_id, "prefix": prefix + "%"},
|
||||||
|
)
|
||||||
|
return result.rowcount
|
||||||
|
|
||||||
|
def delete_all(self, user_id: str, tool_id: str) -> int:
|
||||||
|
result = self._conn.execute(
|
||||||
|
text(
|
||||||
|
"DELETE FROM memories WHERE user_id = :user_id AND tool_id = CAST(:tool_id AS uuid)"
|
||||||
|
),
|
||||||
|
{"user_id": user_id, "tool_id": tool_id},
|
||||||
|
)
|
||||||
|
return result.rowcount
|
||||||
|
|
||||||
|
def update_path(self, user_id: str, tool_id: str, old_path: str, new_path: str) -> bool:
|
||||||
|
result = self._conn.execute(
|
||||||
|
text(
|
||||||
|
"UPDATE memories SET path = :new_path, updated_at = now() "
|
||||||
|
"WHERE user_id = :user_id AND tool_id = CAST(:tool_id AS uuid) AND path = :old_path"
|
||||||
|
),
|
||||||
|
{"user_id": user_id, "tool_id": tool_id, "old_path": old_path, "new_path": new_path},
|
||||||
|
)
|
||||||
|
return result.rowcount > 0
|
||||||
62
application/storage/db/repositories/notes.py
Normal file
62
application/storage/db/repositories/notes.py
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
"""Repository for the ``notes`` table.
|
||||||
|
|
||||||
|
Covers the operations in ``application/agents/tools/notes.py``.
|
||||||
|
Note: the Mongo schema stores a single ``note`` text field per (user_id, tool_id),
|
||||||
|
while the Postgres schema has ``title`` + ``content``. During dual-write,
|
||||||
|
title is set to a default and content holds the note text.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from sqlalchemy import Connection, text
|
||||||
|
|
||||||
|
from application.storage.db.base_repository import row_to_dict
|
||||||
|
|
||||||
|
|
||||||
|
class NotesRepository:
|
||||||
|
def __init__(self, conn: Connection) -> None:
|
||||||
|
self._conn = conn
|
||||||
|
|
||||||
|
def upsert(self, user_id: str, tool_id: str, title: str, content: str) -> dict:
|
||||||
|
result = self._conn.execute(
|
||||||
|
text(
|
||||||
|
"""
|
||||||
|
INSERT INTO notes (user_id, tool_id, title, content)
|
||||||
|
VALUES (:user_id, CAST(:tool_id AS uuid), :title, :content)
|
||||||
|
ON CONFLICT (user_id, tool_id)
|
||||||
|
DO UPDATE SET content = EXCLUDED.content, title = EXCLUDED.title, updated_at = now()
|
||||||
|
RETURNING *
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
{"user_id": user_id, "tool_id": tool_id, "title": title, "content": content},
|
||||||
|
)
|
||||||
|
return row_to_dict(result.fetchone())
|
||||||
|
|
||||||
|
def get_for_user_tool(self, user_id: str, tool_id: str) -> Optional[dict]:
|
||||||
|
result = self._conn.execute(
|
||||||
|
text(
|
||||||
|
"SELECT * FROM notes WHERE user_id = :user_id AND tool_id = CAST(:tool_id AS uuid)"
|
||||||
|
),
|
||||||
|
{"user_id": user_id, "tool_id": tool_id},
|
||||||
|
)
|
||||||
|
row = result.fetchone()
|
||||||
|
return row_to_dict(row) if row is not None else None
|
||||||
|
|
||||||
|
def get(self, note_id: str, user_id: str) -> Optional[dict]:
|
||||||
|
result = self._conn.execute(
|
||||||
|
text("SELECT * FROM notes WHERE id = CAST(:id AS uuid) AND user_id = :user_id"),
|
||||||
|
{"id": note_id, "user_id": user_id},
|
||||||
|
)
|
||||||
|
row = result.fetchone()
|
||||||
|
return row_to_dict(row) if row is not None else None
|
||||||
|
|
||||||
|
def delete(self, user_id: str, tool_id: str) -> bool:
|
||||||
|
result = self._conn.execute(
|
||||||
|
text(
|
||||||
|
"DELETE FROM notes WHERE user_id = :user_id AND tool_id = CAST(:tool_id AS uuid)"
|
||||||
|
),
|
||||||
|
{"user_id": user_id, "tool_id": tool_id},
|
||||||
|
)
|
||||||
|
return result.rowcount > 0
|
||||||
128
application/storage/db/repositories/pending_tool_state.py
Normal file
128
application/storage/db/repositories/pending_tool_state.py
Normal file
@@ -0,0 +1,128 @@
|
|||||||
|
"""Repository for the ``pending_tool_state`` table.
|
||||||
|
|
||||||
|
Mirrors the continuation service's three operations on
|
||||||
|
``pending_tool_state`` in Mongo:
|
||||||
|
|
||||||
|
- save_state → upsert (INSERT ... ON CONFLICT DO UPDATE)
|
||||||
|
- load_state → find_one by (conversation_id, user_id)
|
||||||
|
- delete_state → delete_one by (conversation_id, user_id)
|
||||||
|
|
||||||
|
Plus a cleanup method for the Celery beat task that replaces Mongo's
|
||||||
|
TTL index.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from sqlalchemy import Connection, text
|
||||||
|
|
||||||
|
from application.storage.db.base_repository import row_to_dict
|
||||||
|
|
||||||
|
PENDING_STATE_TTL_SECONDS = 30 * 60 # 1800 seconds
|
||||||
|
|
||||||
|
|
||||||
|
class PendingToolStateRepository:
|
||||||
|
def __init__(self, conn: Connection) -> None:
|
||||||
|
self._conn = conn
|
||||||
|
|
||||||
|
def save_state(
|
||||||
|
self,
|
||||||
|
conversation_id: str,
|
||||||
|
user_id: str,
|
||||||
|
*,
|
||||||
|
messages: list,
|
||||||
|
pending_tool_calls: list,
|
||||||
|
tools_dict: dict,
|
||||||
|
tool_schemas: list,
|
||||||
|
agent_config: dict,
|
||||||
|
client_tools: list | None = None,
|
||||||
|
ttl_seconds: int = PENDING_STATE_TTL_SECONDS,
|
||||||
|
) -> dict:
|
||||||
|
"""Upsert pending tool state.
|
||||||
|
|
||||||
|
Mirrors Mongo's ``replace_one(..., upsert=True)``.
|
||||||
|
"""
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
expires = datetime.fromtimestamp(
|
||||||
|
now.timestamp() + ttl_seconds, tz=timezone.utc,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = self._conn.execute(
|
||||||
|
text(
|
||||||
|
"""
|
||||||
|
INSERT INTO pending_tool_state
|
||||||
|
(conversation_id, user_id, messages, pending_tool_calls,
|
||||||
|
tools_dict, tool_schemas, agent_config, client_tools,
|
||||||
|
created_at, expires_at)
|
||||||
|
VALUES
|
||||||
|
(CAST(:conv_id AS uuid), :user_id,
|
||||||
|
CAST(:messages AS jsonb), CAST(:pending AS jsonb),
|
||||||
|
CAST(:tools_dict AS jsonb), CAST(:schemas AS jsonb),
|
||||||
|
CAST(:agent_config AS jsonb), CAST(:client_tools AS jsonb),
|
||||||
|
:created_at, :expires_at)
|
||||||
|
ON CONFLICT (conversation_id, user_id) DO UPDATE SET
|
||||||
|
messages = EXCLUDED.messages,
|
||||||
|
pending_tool_calls = EXCLUDED.pending_tool_calls,
|
||||||
|
tools_dict = EXCLUDED.tools_dict,
|
||||||
|
tool_schemas = EXCLUDED.tool_schemas,
|
||||||
|
agent_config = EXCLUDED.agent_config,
|
||||||
|
client_tools = EXCLUDED.client_tools,
|
||||||
|
created_at = EXCLUDED.created_at,
|
||||||
|
expires_at = EXCLUDED.expires_at
|
||||||
|
RETURNING *
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
{
|
||||||
|
"conv_id": conversation_id,
|
||||||
|
"user_id": user_id,
|
||||||
|
"messages": json.dumps(messages),
|
||||||
|
"pending": json.dumps(pending_tool_calls),
|
||||||
|
"tools_dict": json.dumps(tools_dict),
|
||||||
|
"schemas": json.dumps(tool_schemas),
|
||||||
|
"agent_config": json.dumps(agent_config),
|
||||||
|
"client_tools": json.dumps(client_tools) if client_tools is not None else None,
|
||||||
|
"created_at": now,
|
||||||
|
"expires_at": expires,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return row_to_dict(result.fetchone())
|
||||||
|
|
||||||
|
def load_state(self, conversation_id: str, user_id: str) -> Optional[dict]:
|
||||||
|
result = self._conn.execute(
|
||||||
|
text(
|
||||||
|
"SELECT * FROM pending_tool_state "
|
||||||
|
"WHERE conversation_id = CAST(:conv_id AS uuid) "
|
||||||
|
"AND user_id = :user_id"
|
||||||
|
),
|
||||||
|
{"conv_id": conversation_id, "user_id": user_id},
|
||||||
|
)
|
||||||
|
row = result.fetchone()
|
||||||
|
return row_to_dict(row) if row is not None else None
|
||||||
|
|
||||||
|
def delete_state(self, conversation_id: str, user_id: str) -> bool:
|
||||||
|
result = self._conn.execute(
|
||||||
|
text(
|
||||||
|
"DELETE FROM pending_tool_state "
|
||||||
|
"WHERE conversation_id = CAST(:conv_id AS uuid) "
|
||||||
|
"AND user_id = :user_id"
|
||||||
|
),
|
||||||
|
{"conv_id": conversation_id, "user_id": user_id},
|
||||||
|
)
|
||||||
|
return result.rowcount > 0
|
||||||
|
|
||||||
|
def cleanup_expired(self) -> int:
|
||||||
|
"""Delete rows where ``expires_at < now()``.
|
||||||
|
|
||||||
|
Replaces Mongo's ``expireAfterSeconds=0`` TTL index. Intended to
|
||||||
|
be called from a Celery beat task every 60 seconds.
|
||||||
|
"""
|
||||||
|
# clock_timestamp() — not now() — since the latter is frozen to the
|
||||||
|
# start of the transaction, which would let state that has just
|
||||||
|
# expired survive one more cleanup tick.
|
||||||
|
result = self._conn.execute(
|
||||||
|
text("DELETE FROM pending_tool_state WHERE expires_at < clock_timestamp()")
|
||||||
|
)
|
||||||
|
return result.rowcount
|
||||||
161
application/storage/db/repositories/prompts.py
Normal file
161
application/storage/db/repositories/prompts.py
Normal file
@@ -0,0 +1,161 @@
|
|||||||
|
"""Repository for the ``prompts`` table.
|
||||||
|
|
||||||
|
Covers every operation the legacy Mongo code performs on
|
||||||
|
``prompts_collection``:
|
||||||
|
|
||||||
|
1. ``insert_one`` in prompts/routes.py (create)
|
||||||
|
2. ``find`` by user in prompts/routes.py (list)
|
||||||
|
3. ``find_one`` by id+user in prompts/routes.py (get single)
|
||||||
|
4. ``find_one`` by id only in stream_processor.py (get content for rendering)
|
||||||
|
5. ``update_one`` in prompts/routes.py (update name+content)
|
||||||
|
6. ``delete_one`` in prompts/routes.py (delete)
|
||||||
|
7. ``find_one`` + ``insert_one`` in seeder.py (upsert by user+name+content)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from sqlalchemy import Connection, text
|
||||||
|
|
||||||
|
from application.storage.db.base_repository import row_to_dict
|
||||||
|
|
||||||
|
|
||||||
|
class PromptsRepository:
|
||||||
|
"""Postgres-backed replacement for Mongo ``prompts_collection``."""
|
||||||
|
|
||||||
|
def __init__(self, conn: Connection) -> None:
|
||||||
|
self._conn = conn
|
||||||
|
|
||||||
|
def create(
|
||||||
|
self,
|
||||||
|
user_id: str,
|
||||||
|
name: str,
|
||||||
|
content: str,
|
||||||
|
*,
|
||||||
|
legacy_mongo_id: str | None = None,
|
||||||
|
) -> dict:
|
||||||
|
sql = """
|
||||||
|
INSERT INTO prompts (user_id, name, content, legacy_mongo_id)
|
||||||
|
VALUES (:user_id, :name, :content, :legacy_mongo_id)
|
||||||
|
RETURNING *
|
||||||
|
"""
|
||||||
|
result = self._conn.execute(
|
||||||
|
text(sql),
|
||||||
|
{
|
||||||
|
"user_id": user_id,
|
||||||
|
"name": name,
|
||||||
|
"content": content,
|
||||||
|
"legacy_mongo_id": legacy_mongo_id,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return row_to_dict(result.fetchone())
|
||||||
|
|
||||||
|
def get(self, prompt_id: str, user_id: str) -> Optional[dict]:
|
||||||
|
result = self._conn.execute(
|
||||||
|
text("SELECT * FROM prompts WHERE id = CAST(:id AS uuid) AND user_id = :user_id"),
|
||||||
|
{"id": prompt_id, "user_id": user_id},
|
||||||
|
)
|
||||||
|
row = result.fetchone()
|
||||||
|
return row_to_dict(row) if row is not None else None
|
||||||
|
|
||||||
|
def get_by_legacy_id(self, legacy_mongo_id: str, user_id: str | None = None) -> Optional[dict]:
|
||||||
|
"""Fetch a prompt by the original Mongo ObjectId string."""
|
||||||
|
sql = "SELECT * FROM prompts WHERE legacy_mongo_id = :legacy_id"
|
||||||
|
params: dict[str, str] = {"legacy_id": legacy_mongo_id}
|
||||||
|
if user_id is not None:
|
||||||
|
sql += " AND user_id = :user_id"
|
||||||
|
params["user_id"] = user_id
|
||||||
|
result = self._conn.execute(text(sql), params)
|
||||||
|
row = result.fetchone()
|
||||||
|
return row_to_dict(row) if row is not None else None
|
||||||
|
|
||||||
|
def get_for_rendering(self, prompt_id: str) -> Optional[dict]:
|
||||||
|
"""Fetch prompt content by ID without user scoping.
|
||||||
|
|
||||||
|
Used only by stream_processor to render a prompt whose owner is
|
||||||
|
not known at call time. Do NOT use in user-facing routes.
|
||||||
|
"""
|
||||||
|
result = self._conn.execute(
|
||||||
|
text("SELECT * FROM prompts WHERE id = CAST(:id AS uuid)"),
|
||||||
|
{"id": prompt_id},
|
||||||
|
)
|
||||||
|
row = result.fetchone()
|
||||||
|
return row_to_dict(row) if row is not None else None
|
||||||
|
|
||||||
|
def list_for_user(self, user_id: str) -> list[dict]:
|
||||||
|
result = self._conn.execute(
|
||||||
|
text("SELECT * FROM prompts WHERE user_id = :user_id ORDER BY created_at"),
|
||||||
|
{"user_id": user_id},
|
||||||
|
)
|
||||||
|
return [row_to_dict(r) for r in result.fetchall()]
|
||||||
|
|
||||||
|
def update(self, prompt_id: str, user_id: str, name: str, content: str) -> None:
|
||||||
|
self._conn.execute(
|
||||||
|
text(
|
||||||
|
"""
|
||||||
|
UPDATE prompts
|
||||||
|
SET name = :name, content = :content, updated_at = now()
|
||||||
|
WHERE id = CAST(:id AS uuid) AND user_id = :user_id
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
{"id": prompt_id, "user_id": user_id, "name": name, "content": content},
|
||||||
|
)
|
||||||
|
|
||||||
|
def update_by_legacy_id(
|
||||||
|
self,
|
||||||
|
legacy_mongo_id: str,
|
||||||
|
user_id: str,
|
||||||
|
name: str,
|
||||||
|
content: str,
|
||||||
|
) -> bool:
|
||||||
|
"""Update a prompt addressed by the Mongo ObjectId string."""
|
||||||
|
result = self._conn.execute(
|
||||||
|
text(
|
||||||
|
"""
|
||||||
|
UPDATE prompts
|
||||||
|
SET name = :name, content = :content, updated_at = now()
|
||||||
|
WHERE legacy_mongo_id = :legacy_id AND user_id = :user_id
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
{
|
||||||
|
"legacy_id": legacy_mongo_id,
|
||||||
|
"user_id": user_id,
|
||||||
|
"name": name,
|
||||||
|
"content": content,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return result.rowcount > 0
|
||||||
|
|
||||||
|
def delete(self, prompt_id: str, user_id: str) -> None:
|
||||||
|
self._conn.execute(
|
||||||
|
text("DELETE FROM prompts WHERE id = CAST(:id AS uuid) AND user_id = :user_id"),
|
||||||
|
{"id": prompt_id, "user_id": user_id},
|
||||||
|
)
|
||||||
|
|
||||||
|
def delete_by_legacy_id(self, legacy_mongo_id: str, user_id: str) -> bool:
|
||||||
|
"""Delete a prompt addressed by the Mongo ObjectId string."""
|
||||||
|
result = self._conn.execute(
|
||||||
|
text(
|
||||||
|
"DELETE FROM prompts "
|
||||||
|
"WHERE legacy_mongo_id = :legacy_id AND user_id = :user_id"
|
||||||
|
),
|
||||||
|
{"legacy_id": legacy_mongo_id, "user_id": user_id},
|
||||||
|
)
|
||||||
|
return result.rowcount > 0
|
||||||
|
|
||||||
|
def find_or_create(self, user_id: str, name: str, content: str) -> dict:
|
||||||
|
"""Return existing prompt matching (user, name, content), or create one.
|
||||||
|
|
||||||
|
Used by the seeder to avoid duplicating template prompts.
|
||||||
|
"""
|
||||||
|
result = self._conn.execute(
|
||||||
|
text(
|
||||||
|
"SELECT * FROM prompts WHERE user_id = :user_id AND name = :name AND content = :content"
|
||||||
|
),
|
||||||
|
{"user_id": user_id, "name": name, "content": content},
|
||||||
|
)
|
||||||
|
row = result.fetchone()
|
||||||
|
if row is not None:
|
||||||
|
return row_to_dict(row)
|
||||||
|
return self.create(user_id, name, content)
|
||||||
205
application/storage/db/repositories/shared_conversations.py
Normal file
205
application/storage/db/repositories/shared_conversations.py
Normal file
@@ -0,0 +1,205 @@
|
|||||||
|
"""Repository for the ``shared_conversations`` table.
|
||||||
|
|
||||||
|
Covers the sharing operations from ``shared_conversations_collections``
|
||||||
|
in Mongo:
|
||||||
|
|
||||||
|
- create a share record (with UUID, conversation_id, user, visibility flags)
|
||||||
|
- look up by uuid (public access)
|
||||||
|
- look up by conversation_id + user + flags (dedup check)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import uuid as uuid_mod
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from sqlalchemy import Connection, text
|
||||||
|
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||||
|
|
||||||
|
from application.storage.db.base_repository import row_to_dict
|
||||||
|
from application.storage.db.models import shared_conversations_table
|
||||||
|
|
||||||
|
|
||||||
|
class SharedConversationsRepository:
|
||||||
|
def __init__(self, conn: Connection) -> None:
|
||||||
|
self._conn = conn
|
||||||
|
|
||||||
|
def create(
|
||||||
|
self,
|
||||||
|
conversation_id: str,
|
||||||
|
user_id: str,
|
||||||
|
*,
|
||||||
|
is_promptable: bool = False,
|
||||||
|
first_n_queries: int = 0,
|
||||||
|
api_key: str | None = None,
|
||||||
|
prompt_id: str | None = None,
|
||||||
|
chunks: int | None = None,
|
||||||
|
share_uuid: str | None = None,
|
||||||
|
) -> dict:
|
||||||
|
"""Create a share record.
|
||||||
|
|
||||||
|
``share_uuid`` allows the dual-write caller to supply the same
|
||||||
|
UUID that Mongo received, so public ``/shared/{uuid}`` links
|
||||||
|
keep resolving from both stores during the dual-write window.
|
||||||
|
|
||||||
|
Callers that need race-free dedup on the logical share key
|
||||||
|
should use :meth:`get_or_create` instead — it relies on the
|
||||||
|
composite partial unique index added in migration 0008 to
|
||||||
|
collapse concurrent requests to a single row.
|
||||||
|
"""
|
||||||
|
final_uuid = share_uuid or str(uuid_mod.uuid4())
|
||||||
|
values: dict = {
|
||||||
|
"uuid": final_uuid,
|
||||||
|
"conversation_id": conversation_id,
|
||||||
|
"user_id": user_id,
|
||||||
|
"is_promptable": is_promptable,
|
||||||
|
"first_n_queries": first_n_queries,
|
||||||
|
}
|
||||||
|
if api_key:
|
||||||
|
values["api_key"] = api_key
|
||||||
|
if prompt_id:
|
||||||
|
values["prompt_id"] = prompt_id
|
||||||
|
if chunks is not None:
|
||||||
|
values["chunks"] = chunks
|
||||||
|
|
||||||
|
stmt = (
|
||||||
|
pg_insert(shared_conversations_table)
|
||||||
|
.values(**values)
|
||||||
|
.returning(shared_conversations_table)
|
||||||
|
)
|
||||||
|
result = self._conn.execute(stmt)
|
||||||
|
return row_to_dict(result.fetchone())
|
||||||
|
|
||||||
|
def get_or_create(
|
||||||
|
self,
|
||||||
|
conversation_id: str,
|
||||||
|
user_id: str,
|
||||||
|
*,
|
||||||
|
is_promptable: bool = False,
|
||||||
|
first_n_queries: int = 0,
|
||||||
|
api_key: str | None = None,
|
||||||
|
prompt_id: str | None = None,
|
||||||
|
chunks: int | None = None,
|
||||||
|
share_uuid: str | None = None,
|
||||||
|
) -> dict:
|
||||||
|
"""Race-free share create/lookup keyed on the logical dedup tuple.
|
||||||
|
|
||||||
|
Leverages the partial unique index on
|
||||||
|
``(conversation_id, user_id, is_promptable, first_n_queries,
|
||||||
|
COALESCE(api_key, ''))`` added in migration 0008. Concurrent
|
||||||
|
requests for the same logical share converge on one row. The
|
||||||
|
returned dict's ``uuid`` is the canonical public identifier.
|
||||||
|
|
||||||
|
Dedup key rationale — ``prompt_id`` and ``chunks`` are
|
||||||
|
deliberately *not* part of the uniqueness key. A share row is
|
||||||
|
identified by "who shared what conversation under which
|
||||||
|
visibility rules"; ``prompt_id`` / ``chunks`` are mutable
|
||||||
|
properties of that share and are last-write-wins on re-share.
|
||||||
|
This preserves existing public ``/shared/{uuid}`` URLs when a
|
||||||
|
user updates the prompt or chunk count, matching the Mongo
|
||||||
|
``find_one`` + ``update`` semantics.
|
||||||
|
"""
|
||||||
|
final_uuid = share_uuid or str(uuid_mod.uuid4())
|
||||||
|
result = self._conn.execute(
|
||||||
|
text(
|
||||||
|
"""
|
||||||
|
INSERT INTO shared_conversations
|
||||||
|
(uuid, conversation_id, user_id, is_promptable,
|
||||||
|
first_n_queries, api_key, prompt_id, chunks)
|
||||||
|
VALUES
|
||||||
|
(CAST(:uuid AS uuid), CAST(:conversation_id AS uuid),
|
||||||
|
:user_id, :is_promptable, :first_n_queries,
|
||||||
|
:api_key, CAST(:prompt_id AS uuid), :chunks)
|
||||||
|
ON CONFLICT (conversation_id, user_id, is_promptable,
|
||||||
|
first_n_queries, COALESCE(api_key, ''))
|
||||||
|
DO UPDATE SET prompt_id = EXCLUDED.prompt_id,
|
||||||
|
chunks = EXCLUDED.chunks
|
||||||
|
RETURNING *
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
{
|
||||||
|
"uuid": final_uuid,
|
||||||
|
"conversation_id": conversation_id,
|
||||||
|
"user_id": user_id,
|
||||||
|
"is_promptable": is_promptable,
|
||||||
|
"first_n_queries": first_n_queries,
|
||||||
|
"api_key": api_key,
|
||||||
|
"prompt_id": prompt_id,
|
||||||
|
"chunks": chunks,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return row_to_dict(result.fetchone())
|
||||||
|
|
||||||
|
def find_by_uuid(self, share_uuid: str) -> Optional[dict]:
|
||||||
|
result = self._conn.execute(
|
||||||
|
text(
|
||||||
|
"SELECT * FROM shared_conversations "
|
||||||
|
"WHERE uuid = CAST(:uuid AS uuid)"
|
||||||
|
),
|
||||||
|
{"uuid": share_uuid},
|
||||||
|
)
|
||||||
|
row = result.fetchone()
|
||||||
|
return row_to_dict(row) if row is not None else None
|
||||||
|
|
||||||
|
def find_existing(
|
||||||
|
self,
|
||||||
|
conversation_id: str,
|
||||||
|
user_id: str,
|
||||||
|
is_promptable: bool,
|
||||||
|
first_n_queries: int,
|
||||||
|
api_key: str | None = None,
|
||||||
|
) -> Optional[dict]:
|
||||||
|
"""Check for an existing share with matching parameters.
|
||||||
|
|
||||||
|
Mirrors the Mongo ``find_one`` dedup check before creating a share.
|
||||||
|
"""
|
||||||
|
if api_key:
|
||||||
|
result = self._conn.execute(
|
||||||
|
text(
|
||||||
|
"SELECT * FROM shared_conversations "
|
||||||
|
"WHERE conversation_id = CAST(:conv_id AS uuid) "
|
||||||
|
"AND user_id = :user_id "
|
||||||
|
"AND is_promptable = :is_promptable "
|
||||||
|
"AND first_n_queries = :fnq "
|
||||||
|
"AND api_key = :api_key "
|
||||||
|
"LIMIT 1"
|
||||||
|
),
|
||||||
|
{
|
||||||
|
"conv_id": conversation_id,
|
||||||
|
"user_id": user_id,
|
||||||
|
"is_promptable": is_promptable,
|
||||||
|
"fnq": first_n_queries,
|
||||||
|
"api_key": api_key,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
result = self._conn.execute(
|
||||||
|
text(
|
||||||
|
"SELECT * FROM shared_conversations "
|
||||||
|
"WHERE conversation_id = CAST(:conv_id AS uuid) "
|
||||||
|
"AND user_id = :user_id "
|
||||||
|
"AND is_promptable = :is_promptable "
|
||||||
|
"AND first_n_queries = :fnq "
|
||||||
|
"AND api_key IS NULL "
|
||||||
|
"LIMIT 1"
|
||||||
|
),
|
||||||
|
{
|
||||||
|
"conv_id": conversation_id,
|
||||||
|
"user_id": user_id,
|
||||||
|
"is_promptable": is_promptable,
|
||||||
|
"fnq": first_n_queries,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
row = result.fetchone()
|
||||||
|
return row_to_dict(row) if row is not None else None
|
||||||
|
|
||||||
|
def list_for_conversation(self, conversation_id: str) -> list[dict]:
|
||||||
|
result = self._conn.execute(
|
||||||
|
text(
|
||||||
|
"SELECT * FROM shared_conversations "
|
||||||
|
"WHERE conversation_id = CAST(:conv_id AS uuid) "
|
||||||
|
"ORDER BY created_at DESC"
|
||||||
|
),
|
||||||
|
{"conv_id": conversation_id},
|
||||||
|
)
|
||||||
|
return [row_to_dict(r) for r in result.fetchall()]
|
||||||
80
application/storage/db/repositories/sources.py
Normal file
80
application/storage/db/repositories/sources.py
Normal file
@@ -0,0 +1,80 @@
|
|||||||
|
"""Repository for the ``sources`` table."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from sqlalchemy import Connection, func, text
|
||||||
|
|
||||||
|
from application.storage.db.base_repository import row_to_dict
|
||||||
|
from application.storage.db.models import sources_table
|
||||||
|
|
||||||
|
|
||||||
|
class SourcesRepository:
|
||||||
|
def __init__(self, conn: Connection) -> None:
|
||||||
|
self._conn = conn
|
||||||
|
|
||||||
|
def create(self, name: str, *, user_id: str,
|
||||||
|
type: Optional[str] = None, metadata: Optional[dict] = None) -> dict:
|
||||||
|
result = self._conn.execute(
|
||||||
|
text(
|
||||||
|
"""
|
||||||
|
INSERT INTO sources (user_id, name, type, metadata)
|
||||||
|
VALUES (:user_id, :name, :type, CAST(:metadata AS jsonb))
|
||||||
|
RETURNING *
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
{
|
||||||
|
"user_id": user_id,
|
||||||
|
"name": name,
|
||||||
|
"type": type,
|
||||||
|
"metadata": json.dumps(metadata or {}),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return row_to_dict(result.fetchone())
|
||||||
|
|
||||||
|
def get(self, source_id: str, user_id: str) -> Optional[dict]:
|
||||||
|
result = self._conn.execute(
|
||||||
|
text("SELECT * FROM sources WHERE id = CAST(:id AS uuid) AND user_id = :user_id"),
|
||||||
|
{"id": source_id, "user_id": user_id},
|
||||||
|
)
|
||||||
|
row = result.fetchone()
|
||||||
|
return row_to_dict(row) if row is not None else None
|
||||||
|
|
||||||
|
def list_for_user(self, user_id: str) -> list[dict]:
|
||||||
|
result = self._conn.execute(
|
||||||
|
text("SELECT * FROM sources WHERE user_id = :user_id ORDER BY created_at DESC"),
|
||||||
|
{"user_id": user_id},
|
||||||
|
)
|
||||||
|
return [row_to_dict(r) for r in result.fetchall()]
|
||||||
|
|
||||||
|
def update(self, source_id: str, user_id: str, fields: dict) -> None:
|
||||||
|
allowed = {"name", "type", "metadata"}
|
||||||
|
filtered = {k: v for k, v in fields.items() if k in allowed}
|
||||||
|
if not filtered:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Pass Python objects directly for JSONB columns when using
|
||||||
|
# SQLAlchemy Core .update() — the JSONB type processor json.dumps
|
||||||
|
# them itself; pre-serialising here would double-encode and the
|
||||||
|
# value would round-trip as a JSON string instead of the original
|
||||||
|
# dict.
|
||||||
|
values: dict = dict(filtered)
|
||||||
|
values["updated_at"] = func.now()
|
||||||
|
|
||||||
|
t = sources_table
|
||||||
|
stmt = (
|
||||||
|
t.update()
|
||||||
|
.where(t.c.id == source_id)
|
||||||
|
.where(t.c.user_id == user_id)
|
||||||
|
.values(**values)
|
||||||
|
)
|
||||||
|
self._conn.execute(stmt)
|
||||||
|
|
||||||
|
def delete(self, source_id: str, user_id: str) -> bool:
|
||||||
|
result = self._conn.execute(
|
||||||
|
text("DELETE FROM sources WHERE id = CAST(:id AS uuid) AND user_id = :user_id"),
|
||||||
|
{"id": source_id, "user_id": user_id},
|
||||||
|
)
|
||||||
|
return result.rowcount > 0
|
||||||
58
application/storage/db/repositories/stack_logs.py
Normal file
58
application/storage/db/repositories/stack_logs.py
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
"""Repository for the ``stack_logs`` table.
|
||||||
|
|
||||||
|
Covers the single operation the legacy Mongo code performs:
|
||||||
|
|
||||||
|
1. ``insert_one`` in logging.py ``_log_to_mongodb`` — append-only debug/error
|
||||||
|
activity log. The Mongo collection is ``stack_logs``; the Mongo variable
|
||||||
|
inside ``_log_to_mongodb`` is misleadingly named ``user_logs_collection``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from sqlalchemy import Connection, text
|
||||||
|
|
||||||
|
|
||||||
|
class StackLogsRepository:
|
||||||
|
"""Postgres-backed replacement for Mongo ``stack_logs`` collection."""
|
||||||
|
|
||||||
|
def __init__(self, conn: Connection) -> None:
|
||||||
|
self._conn = conn
|
||||||
|
|
||||||
|
def insert(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
activity_id: str,
|
||||||
|
endpoint: Optional[str] = None,
|
||||||
|
level: Optional[str] = None,
|
||||||
|
user_id: Optional[str] = None,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
query: Optional[str] = None,
|
||||||
|
stacks: Optional[list] = None,
|
||||||
|
timestamp: Optional[datetime] = None,
|
||||||
|
) -> None:
|
||||||
|
self._conn.execute(
|
||||||
|
text(
|
||||||
|
"""
|
||||||
|
INSERT INTO stack_logs (activity_id, endpoint, level, user_id, api_key, query, stacks, timestamp)
|
||||||
|
VALUES (
|
||||||
|
:activity_id, :endpoint, :level, :user_id, :api_key, :query,
|
||||||
|
CAST(:stacks AS jsonb),
|
||||||
|
COALESCE(:timestamp, now())
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
{
|
||||||
|
"activity_id": activity_id,
|
||||||
|
"endpoint": endpoint,
|
||||||
|
"level": level,
|
||||||
|
"user_id": user_id,
|
||||||
|
"api_key": api_key,
|
||||||
|
"query": query,
|
||||||
|
"stacks": json.dumps(stacks or []),
|
||||||
|
"timestamp": timestamp,
|
||||||
|
},
|
||||||
|
)
|
||||||
78
application/storage/db/repositories/todos.py
Normal file
78
application/storage/db/repositories/todos.py
Normal file
@@ -0,0 +1,78 @@
|
|||||||
|
"""Repository for the ``todos`` table.
|
||||||
|
|
||||||
|
Covers the operations in ``application/agents/tools/todo_list.py``.
|
||||||
|
Note: the Mongo schema uses ``todo_id`` (sequential int) and ``status`` (text),
|
||||||
|
while the Postgres schema uses ``completed`` (boolean) and the UUID ``id`` as PK.
|
||||||
|
The repository bridges both shapes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from sqlalchemy import Connection, text
|
||||||
|
|
||||||
|
from application.storage.db.base_repository import row_to_dict
|
||||||
|
|
||||||
|
|
||||||
|
class TodosRepository:
|
||||||
|
def __init__(self, conn: Connection) -> None:
|
||||||
|
self._conn = conn
|
||||||
|
|
||||||
|
def create(self, user_id: str, tool_id: str, title: str) -> dict:
|
||||||
|
result = self._conn.execute(
|
||||||
|
text(
|
||||||
|
"""
|
||||||
|
INSERT INTO todos (user_id, tool_id, title)
|
||||||
|
VALUES (:user_id, CAST(:tool_id AS uuid), :title)
|
||||||
|
RETURNING *
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
{"user_id": user_id, "tool_id": tool_id, "title": title},
|
||||||
|
)
|
||||||
|
return row_to_dict(result.fetchone())
|
||||||
|
|
||||||
|
def get(self, todo_id: str, user_id: str) -> Optional[dict]:
|
||||||
|
result = self._conn.execute(
|
||||||
|
text("SELECT * FROM todos WHERE id = CAST(:id AS uuid) AND user_id = :user_id"),
|
||||||
|
{"id": todo_id, "user_id": user_id},
|
||||||
|
)
|
||||||
|
row = result.fetchone()
|
||||||
|
return row_to_dict(row) if row is not None else None
|
||||||
|
|
||||||
|
def list_for_user_tool(self, user_id: str, tool_id: str) -> list[dict]:
|
||||||
|
result = self._conn.execute(
|
||||||
|
text(
|
||||||
|
"SELECT * FROM todos WHERE user_id = :user_id "
|
||||||
|
"AND tool_id = CAST(:tool_id AS uuid) ORDER BY created_at"
|
||||||
|
),
|
||||||
|
{"user_id": user_id, "tool_id": tool_id},
|
||||||
|
)
|
||||||
|
return [row_to_dict(r) for r in result.fetchall()]
|
||||||
|
|
||||||
|
def update_title(self, todo_id: str, user_id: str, title: str) -> bool:
|
||||||
|
result = self._conn.execute(
|
||||||
|
text(
|
||||||
|
"UPDATE todos SET title = :title, updated_at = now() "
|
||||||
|
"WHERE id = CAST(:id AS uuid) AND user_id = :user_id"
|
||||||
|
),
|
||||||
|
{"id": todo_id, "user_id": user_id, "title": title},
|
||||||
|
)
|
||||||
|
return result.rowcount > 0
|
||||||
|
|
||||||
|
def set_completed(self, todo_id: str, user_id: str, completed: bool = True) -> bool:
|
||||||
|
result = self._conn.execute(
|
||||||
|
text(
|
||||||
|
"UPDATE todos SET completed = :completed, updated_at = now() "
|
||||||
|
"WHERE id = CAST(:id AS uuid) AND user_id = :user_id"
|
||||||
|
),
|
||||||
|
{"id": todo_id, "user_id": user_id, "completed": completed},
|
||||||
|
)
|
||||||
|
return result.rowcount > 0
|
||||||
|
|
||||||
|
def delete(self, todo_id: str, user_id: str) -> bool:
|
||||||
|
result = self._conn.execute(
|
||||||
|
text("DELETE FROM todos WHERE id = CAST(:id AS uuid) AND user_id = :user_id"),
|
||||||
|
{"id": todo_id, "user_id": user_id},
|
||||||
|
)
|
||||||
|
return result.rowcount > 0
|
||||||
104
application/storage/db/repositories/token_usage.py
Normal file
104
application/storage/db/repositories/token_usage.py
Normal file
@@ -0,0 +1,104 @@
|
|||||||
|
"""Repository for the ``token_usage`` table.
|
||||||
|
|
||||||
|
Covers every operation the legacy Mongo code performs on
|
||||||
|
``token_usage_collection`` / ``usage_collection``:
|
||||||
|
|
||||||
|
1. ``insert_one`` in usage.py (record per-call token counts)
|
||||||
|
2. ``aggregate`` in analytics/routes.py (time-bucketed totals)
|
||||||
|
3. ``aggregate`` in answer/routes/base.py (24h sum for rate limiting)
|
||||||
|
4. ``count_documents`` in answer/routes/base.py (24h request count)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from sqlalchemy import Connection, text
|
||||||
|
|
||||||
|
|
||||||
|
class TokenUsageRepository:
|
||||||
|
"""Postgres-backed replacement for Mongo ``token_usage_collection``."""
|
||||||
|
|
||||||
|
def __init__(self, conn: Connection) -> None:
|
||||||
|
self._conn = conn
|
||||||
|
|
||||||
|
def insert(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
user_id: Optional[str] = None,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
agent_id: Optional[str] = None,
|
||||||
|
prompt_tokens: int = 0,
|
||||||
|
generated_tokens: int = 0,
|
||||||
|
timestamp: Optional[datetime] = None,
|
||||||
|
) -> None:
|
||||||
|
self._conn.execute(
|
||||||
|
text(
|
||||||
|
"""
|
||||||
|
INSERT INTO token_usage (user_id, api_key, agent_id, prompt_tokens, generated_tokens, timestamp)
|
||||||
|
VALUES (
|
||||||
|
:user_id, :api_key,
|
||||||
|
CAST(:agent_id AS uuid),
|
||||||
|
:prompt_tokens, :generated_tokens,
|
||||||
|
COALESCE(:timestamp, now())
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
{
|
||||||
|
"user_id": user_id,
|
||||||
|
"api_key": api_key,
|
||||||
|
"agent_id": agent_id,
|
||||||
|
"prompt_tokens": prompt_tokens,
|
||||||
|
"generated_tokens": generated_tokens,
|
||||||
|
"timestamp": timestamp,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
def sum_tokens_in_range(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
start: datetime,
|
||||||
|
end: datetime,
|
||||||
|
user_id: Optional[str] = None,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
) -> int:
|
||||||
|
"""Total (prompt + generated) tokens in the given time range."""
|
||||||
|
clauses = ["timestamp >= :start", "timestamp <= :end"]
|
||||||
|
params: dict = {"start": start, "end": end}
|
||||||
|
if user_id is not None:
|
||||||
|
clauses.append("user_id = :user_id")
|
||||||
|
params["user_id"] = user_id
|
||||||
|
if api_key is not None:
|
||||||
|
clauses.append("api_key = :api_key")
|
||||||
|
params["api_key"] = api_key
|
||||||
|
where = " AND ".join(clauses)
|
||||||
|
result = self._conn.execute(
|
||||||
|
text(f"SELECT COALESCE(SUM(prompt_tokens + generated_tokens), 0) FROM token_usage WHERE {where}"),
|
||||||
|
params,
|
||||||
|
)
|
||||||
|
return result.scalar()
|
||||||
|
|
||||||
|
def count_in_range(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
start: datetime,
|
||||||
|
end: datetime,
|
||||||
|
user_id: Optional[str] = None,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
) -> int:
|
||||||
|
"""Count of token_usage rows in the given time range (for request limiting)."""
|
||||||
|
clauses = ["timestamp >= :start", "timestamp <= :end"]
|
||||||
|
params: dict = {"start": start, "end": end}
|
||||||
|
if user_id is not None:
|
||||||
|
clauses.append("user_id = :user_id")
|
||||||
|
params["user_id"] = user_id
|
||||||
|
if api_key is not None:
|
||||||
|
clauses.append("api_key = :api_key")
|
||||||
|
params["api_key"] = api_key
|
||||||
|
where = " AND ".join(clauses)
|
||||||
|
result = self._conn.execute(
|
||||||
|
text(f"SELECT COUNT(*) FROM token_usage WHERE {where}"),
|
||||||
|
params,
|
||||||
|
)
|
||||||
|
return result.scalar()
|
||||||
84
application/storage/db/repositories/user_logs.py
Normal file
84
application/storage/db/repositories/user_logs.py
Normal file
@@ -0,0 +1,84 @@
|
|||||||
|
"""Repository for the ``user_logs`` table.
|
||||||
|
|
||||||
|
Covers every operation the legacy Mongo code performs on
|
||||||
|
``user_logs_collection``:
|
||||||
|
|
||||||
|
1. ``insert_one`` in logging.py (per-request activity log via
|
||||||
|
``_log_to_mongodb`` — note: the *Mongo* variable is confusingly named
|
||||||
|
``user_logs_collection`` but points at the ``user_logs`` Mongo
|
||||||
|
collection, not ``stack_logs``)
|
||||||
|
2. ``insert_one`` in answer/routes/base.py (per-stream log entry)
|
||||||
|
3. ``find`` with sort/skip/limit in analytics/routes.py (paginated log list)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from sqlalchemy import Connection, text
|
||||||
|
|
||||||
|
from application.storage.db.base_repository import row_to_dict
|
||||||
|
|
||||||
|
|
||||||
|
class UserLogsRepository:
|
||||||
|
"""Postgres-backed replacement for Mongo ``user_logs_collection``."""
|
||||||
|
|
||||||
|
def __init__(self, conn: Connection) -> None:
|
||||||
|
self._conn = conn
|
||||||
|
|
||||||
|
def insert(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
user_id: Optional[str] = None,
|
||||||
|
endpoint: Optional[str] = None,
|
||||||
|
data: Optional[dict] = None,
|
||||||
|
timestamp: Optional[datetime] = None,
|
||||||
|
) -> None:
|
||||||
|
self._conn.execute(
|
||||||
|
text(
|
||||||
|
"""
|
||||||
|
INSERT INTO user_logs (user_id, endpoint, data, timestamp)
|
||||||
|
VALUES (:user_id, :endpoint, CAST(:data AS jsonb), COALESCE(:timestamp, now()))
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
{
|
||||||
|
"user_id": user_id,
|
||||||
|
"endpoint": endpoint,
|
||||||
|
"data": json.dumps(data, default=str) if data is not None else None,
|
||||||
|
"timestamp": timestamp,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
def list_paginated(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
user_id: Optional[str] = None,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
page: int = 1,
|
||||||
|
page_size: int = 10,
|
||||||
|
) -> tuple[list[dict], bool]:
|
||||||
|
"""Return ``(rows, has_more)`` for the requested page.
|
||||||
|
|
||||||
|
Mirrors the Mongo ``find(query).sort().skip().limit(page_size+1)``
|
||||||
|
pattern used in analytics/routes.py.
|
||||||
|
"""
|
||||||
|
clauses: list[str] = []
|
||||||
|
params: dict = {"limit": page_size + 1, "offset": (page - 1) * page_size}
|
||||||
|
if user_id is not None:
|
||||||
|
clauses.append("user_id = :user_id")
|
||||||
|
params["user_id"] = user_id
|
||||||
|
if api_key is not None:
|
||||||
|
clauses.append("data->>'api_key' = :api_key")
|
||||||
|
params["api_key"] = api_key
|
||||||
|
where = ("WHERE " + " AND ".join(clauses)) if clauses else ""
|
||||||
|
result = self._conn.execute(
|
||||||
|
text(
|
||||||
|
f"SELECT * FROM user_logs {where} ORDER BY timestamp DESC LIMIT :limit OFFSET :offset"
|
||||||
|
),
|
||||||
|
params,
|
||||||
|
)
|
||||||
|
rows = [row_to_dict(r) for r in result.fetchall()]
|
||||||
|
has_more = len(rows) > page_size
|
||||||
|
return rows[:page_size], has_more
|
||||||
114
application/storage/db/repositories/user_tools.py
Normal file
114
application/storage/db/repositories/user_tools.py
Normal file
@@ -0,0 +1,114 @@
|
|||||||
|
"""Repository for the ``user_tools`` table.
|
||||||
|
|
||||||
|
Covers every operation the legacy Mongo code performs on
|
||||||
|
``user_tools_collection``:
|
||||||
|
|
||||||
|
1. ``find`` by user in tools/routes.py and base.py (list all / active)
|
||||||
|
2. ``find_one`` by id in tools/routes.py and sharing.py (get single)
|
||||||
|
3. ``insert_one`` in tools/routes.py and mcp.py (create)
|
||||||
|
4. ``update_one`` in tools/routes.py and mcp.py (update fields)
|
||||||
|
5. ``delete_one`` in tools/routes.py (delete)
|
||||||
|
6. ``find`` by user+status in stream_processor.py and tool_executor.py (active tools)
|
||||||
|
7. ``find_one`` by user+name in mcp.py (upsert check)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from sqlalchemy import Connection, text
|
||||||
|
|
||||||
|
from application.storage.db.base_repository import row_to_dict
|
||||||
|
|
||||||
|
|
||||||
|
class UserToolsRepository:
|
||||||
|
"""Postgres-backed replacement for Mongo ``user_tools_collection``."""
|
||||||
|
|
||||||
|
def __init__(self, conn: Connection) -> None:
|
||||||
|
self._conn = conn
|
||||||
|
|
||||||
|
def create(self, user_id: str, name: str, *, config: Optional[dict] = None,
|
||||||
|
custom_name: Optional[str] = None, display_name: Optional[str] = None,
|
||||||
|
extra: Optional[dict] = None) -> dict:
|
||||||
|
"""Insert a new tool row. ``extra`` is merged into the config JSONB."""
|
||||||
|
cfg = config or {}
|
||||||
|
if extra:
|
||||||
|
cfg.update(extra)
|
||||||
|
result = self._conn.execute(
|
||||||
|
text(
|
||||||
|
"""
|
||||||
|
INSERT INTO user_tools (user_id, name, custom_name, display_name, config)
|
||||||
|
VALUES (:user_id, :name, :custom_name, :display_name, CAST(:config AS jsonb))
|
||||||
|
RETURNING *
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
{
|
||||||
|
"user_id": user_id,
|
||||||
|
"name": name,
|
||||||
|
"custom_name": custom_name,
|
||||||
|
"display_name": display_name,
|
||||||
|
"config": json.dumps(cfg),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return row_to_dict(result.fetchone())
|
||||||
|
|
||||||
|
def get(self, tool_id: str, user_id: str) -> Optional[dict]:
|
||||||
|
result = self._conn.execute(
|
||||||
|
text("SELECT * FROM user_tools WHERE id = CAST(:id AS uuid) AND user_id = :user_id"),
|
||||||
|
{"id": tool_id, "user_id": user_id},
|
||||||
|
)
|
||||||
|
row = result.fetchone()
|
||||||
|
return row_to_dict(row) if row is not None else None
|
||||||
|
|
||||||
|
def list_for_user(self, user_id: str) -> list[dict]:
|
||||||
|
result = self._conn.execute(
|
||||||
|
text("SELECT * FROM user_tools WHERE user_id = :user_id ORDER BY created_at"),
|
||||||
|
{"user_id": user_id},
|
||||||
|
)
|
||||||
|
return [row_to_dict(r) for r in result.fetchall()]
|
||||||
|
|
||||||
|
def update(self, tool_id: str, user_id: str, fields: dict) -> None:
|
||||||
|
"""Update arbitrary fields on a tool row.
|
||||||
|
|
||||||
|
``fields`` maps column names to new values. Only ``name``,
|
||||||
|
``custom_name``, ``display_name``, and ``config`` are allowed.
|
||||||
|
"""
|
||||||
|
allowed = {"name", "custom_name", "display_name", "config"}
|
||||||
|
filtered = {k: v for k, v in fields.items() if k in allowed}
|
||||||
|
if not filtered:
|
||||||
|
return
|
||||||
|
params: dict = {
|
||||||
|
"id": tool_id,
|
||||||
|
"user_id": user_id,
|
||||||
|
"name": filtered.get("name"),
|
||||||
|
"custom_name": filtered.get("custom_name"),
|
||||||
|
"display_name": filtered.get("display_name"),
|
||||||
|
"config": (
|
||||||
|
json.dumps(filtered["config"])
|
||||||
|
if "config" in filtered and isinstance(filtered["config"], dict)
|
||||||
|
else filtered.get("config")
|
||||||
|
),
|
||||||
|
}
|
||||||
|
self._conn.execute(
|
||||||
|
text(
|
||||||
|
"""
|
||||||
|
UPDATE user_tools
|
||||||
|
SET
|
||||||
|
name = COALESCE(:name, name),
|
||||||
|
custom_name = COALESCE(:custom_name, custom_name),
|
||||||
|
display_name = COALESCE(:display_name, display_name),
|
||||||
|
config = COALESCE(CAST(:config AS jsonb), config),
|
||||||
|
updated_at = now()
|
||||||
|
WHERE id = CAST(:id AS uuid) AND user_id = :user_id
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
params,
|
||||||
|
)
|
||||||
|
|
||||||
|
def delete(self, tool_id: str, user_id: str) -> bool:
|
||||||
|
result = self._conn.execute(
|
||||||
|
text("DELETE FROM user_tools WHERE id = CAST(:id AS uuid) AND user_id = :user_id"),
|
||||||
|
{"id": tool_id, "user_id": user_id},
|
||||||
|
)
|
||||||
|
return result.rowcount > 0
|
||||||
245
application/storage/db/repositories/users.py
Normal file
245
application/storage/db/repositories/users.py
Normal file
@@ -0,0 +1,245 @@
|
|||||||
|
"""Repository for the ``users`` table.
|
||||||
|
|
||||||
|
Covers every operation the legacy Mongo code performs on
|
||||||
|
``users_collection``:
|
||||||
|
|
||||||
|
1. ``ensure_user_doc`` in ``application/api/user/base.py`` (upsert + get)
|
||||||
|
2. Pin/unpin agents in ``application/api/user/agents/routes.py`` (add/remove
|
||||||
|
on ``agent_preferences.pinned``)
|
||||||
|
3. Share accept/reject in ``application/api/user/agents/sharing.py`` (add/
|
||||||
|
bulk-remove on ``agent_preferences.shared_with_me``)
|
||||||
|
4. Cascade delete of an agent id from both arrays at once
|
||||||
|
|
||||||
|
All array mutations are implemented as single atomic UPDATE statements
|
||||||
|
using JSONB operators (``jsonb_set``, ``jsonb_array_elements``, ``@>``)
|
||||||
|
so there is no read-modify-write race between concurrent writers on the
|
||||||
|
same user row.
|
||||||
|
|
||||||
|
The repository takes a ``Connection`` and does not manage its own
|
||||||
|
transactions. Callers are responsible for wrapping writes in
|
||||||
|
``with engine.begin() as conn:`` (production) or the test fixture's
|
||||||
|
rollback-per-test connection (tests).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Iterable, Optional
|
||||||
|
|
||||||
|
from sqlalchemy import Connection, text
|
||||||
|
|
||||||
|
from application.storage.db.base_repository import row_to_dict
|
||||||
|
|
||||||
|
|
||||||
|
_DEFAULT_PREFERENCES = '{"pinned": [], "shared_with_me": []}'
|
||||||
|
|
||||||
|
|
||||||
|
class UsersRepository:
|
||||||
|
"""Postgres-backed replacement for Mongo ``users_collection`` writes/reads."""
|
||||||
|
|
||||||
|
def __init__(self, conn: Connection) -> None:
|
||||||
|
self._conn = conn
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Reads
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
def get(self, user_id: str) -> Optional[dict]:
|
||||||
|
"""Return the user row as a dict, or ``None`` if missing.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: Auth-provider ``sub`` (opaque string).
|
||||||
|
"""
|
||||||
|
result = self._conn.execute(
|
||||||
|
text("SELECT * FROM users WHERE user_id = :user_id"),
|
||||||
|
{"user_id": user_id},
|
||||||
|
)
|
||||||
|
row = result.fetchone()
|
||||||
|
return row_to_dict(row) if row is not None else None
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Upsert
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
def upsert(self, user_id: str) -> dict:
|
||||||
|
"""Ensure a row exists for ``user_id`` and return it.
|
||||||
|
|
||||||
|
Matches Mongo's ``find_one_and_update(..., $setOnInsert, upsert=True,
|
||||||
|
return_document=AFTER)`` semantics: if the row exists, preferences
|
||||||
|
are preserved untouched; if it doesn't, a new row is created with
|
||||||
|
default preferences.
|
||||||
|
|
||||||
|
The ``DO UPDATE SET user_id = EXCLUDED.user_id`` branch is a
|
||||||
|
deliberate no-op that lets ``RETURNING *`` fire on both the insert
|
||||||
|
and conflict paths (``DO NOTHING`` would suppress the returning).
|
||||||
|
"""
|
||||||
|
result = self._conn.execute(
|
||||||
|
text(
|
||||||
|
"""
|
||||||
|
INSERT INTO users (user_id, agent_preferences)
|
||||||
|
VALUES (:user_id, CAST(:default_prefs AS jsonb))
|
||||||
|
ON CONFLICT (user_id) DO UPDATE
|
||||||
|
SET user_id = EXCLUDED.user_id
|
||||||
|
RETURNING *
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
{"user_id": user_id, "default_prefs": _DEFAULT_PREFERENCES},
|
||||||
|
)
|
||||||
|
return row_to_dict(result.fetchone())
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Pinned agents
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
def add_pinned(self, user_id: str, agent_id: str) -> None:
|
||||||
|
"""Idempotently append ``agent_id`` to ``agent_preferences.pinned``.
|
||||||
|
|
||||||
|
Uses ``@>`` containment so a duplicate add is a no-op rather than a
|
||||||
|
silent double-insert. The whole update is a single atomic statement
|
||||||
|
so concurrent add_pinned calls on the same user cannot interleave
|
||||||
|
into a read-modify-write race.
|
||||||
|
"""
|
||||||
|
self._append_to_jsonb_array(user_id, "pinned", agent_id)
|
||||||
|
|
||||||
|
def remove_pinned(self, user_id: str, agent_id: str) -> None:
|
||||||
|
"""Remove ``agent_id`` from ``agent_preferences.pinned`` if present."""
|
||||||
|
self._remove_from_jsonb_array(user_id, "pinned", [agent_id])
|
||||||
|
|
||||||
|
def remove_pinned_bulk(self, user_id: str, agent_ids: Iterable[str]) -> None:
|
||||||
|
"""Remove every id in ``agent_ids`` from ``agent_preferences.pinned``.
|
||||||
|
|
||||||
|
No-op if the list is empty. Unknown ids are silently ignored so
|
||||||
|
callers can pass the full "stale" set without pre-filtering.
|
||||||
|
"""
|
||||||
|
ids = list(agent_ids)
|
||||||
|
if not ids:
|
||||||
|
return
|
||||||
|
self._remove_from_jsonb_array(user_id, "pinned", ids)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Shared-with-me agents
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
def add_shared(self, user_id: str, agent_id: str) -> None:
|
||||||
|
"""Idempotently append ``agent_id`` to ``agent_preferences.shared_with_me``."""
|
||||||
|
self._append_to_jsonb_array(user_id, "shared_with_me", agent_id)
|
||||||
|
|
||||||
|
def remove_shared_bulk(self, user_id: str, agent_ids: Iterable[str]) -> None:
|
||||||
|
"""Bulk-remove from ``agent_preferences.shared_with_me``. Empty list is a no-op."""
|
||||||
|
ids = list(agent_ids)
|
||||||
|
if not ids:
|
||||||
|
return
|
||||||
|
self._remove_from_jsonb_array(user_id, "shared_with_me", ids)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Combined removal — called when an agent is hard-deleted
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
def remove_agent_from_all(self, user_id: str, agent_id: str) -> None:
|
||||||
|
"""Remove ``agent_id`` from BOTH pinned and shared_with_me atomically.
|
||||||
|
|
||||||
|
Mirrors the Mongo ``$pull`` that targets both nested array fields
|
||||||
|
in one ``update_one`` — see ``application/api/user/agents/routes.py``
|
||||||
|
around the agent-delete path.
|
||||||
|
"""
|
||||||
|
self._conn.execute(
|
||||||
|
text(
|
||||||
|
"""
|
||||||
|
UPDATE users
|
||||||
|
SET
|
||||||
|
agent_preferences = jsonb_set(
|
||||||
|
jsonb_set(
|
||||||
|
agent_preferences,
|
||||||
|
'{pinned}',
|
||||||
|
COALESCE(
|
||||||
|
(
|
||||||
|
SELECT jsonb_agg(elem)
|
||||||
|
FROM jsonb_array_elements(
|
||||||
|
COALESCE(agent_preferences->'pinned', '[]'::jsonb)
|
||||||
|
) AS elem
|
||||||
|
WHERE (elem #>> '{}') != :agent_id
|
||||||
|
),
|
||||||
|
'[]'::jsonb
|
||||||
|
)
|
||||||
|
),
|
||||||
|
'{shared_with_me}',
|
||||||
|
COALESCE(
|
||||||
|
(
|
||||||
|
SELECT jsonb_agg(elem)
|
||||||
|
FROM jsonb_array_elements(
|
||||||
|
COALESCE(agent_preferences->'shared_with_me', '[]'::jsonb)
|
||||||
|
) AS elem
|
||||||
|
WHERE (elem #>> '{}') != :agent_id
|
||||||
|
),
|
||||||
|
'[]'::jsonb
|
||||||
|
)
|
||||||
|
),
|
||||||
|
updated_at = now()
|
||||||
|
WHERE user_id = :user_id
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
{"user_id": user_id, "agent_id": agent_id},
|
||||||
|
)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Private helpers
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
def _append_to_jsonb_array(self, user_id: str, key: str, agent_id: str) -> None:
|
||||||
|
"""Idempotent append of ``agent_id`` to ``agent_preferences.<key>``.
|
||||||
|
|
||||||
|
The ``key`` argument is NOT user input — it's hard-coded by the
|
||||||
|
calling method (``pinned`` / ``shared_with_me``). It goes into the
|
||||||
|
SQL literal because ``jsonb_set`` requires a path literal, not a
|
||||||
|
bind parameter. This is safe as long as callers never pass
|
||||||
|
untrusted strings for ``key``.
|
||||||
|
"""
|
||||||
|
if key not in ("pinned", "shared_with_me"):
|
||||||
|
raise ValueError(f"unsupported jsonb key: {key!r}")
|
||||||
|
self._conn.execute(
|
||||||
|
text(
|
||||||
|
f"""
|
||||||
|
UPDATE users
|
||||||
|
SET
|
||||||
|
agent_preferences = jsonb_set(
|
||||||
|
agent_preferences,
|
||||||
|
'{{{key}}}',
|
||||||
|
CASE
|
||||||
|
WHEN agent_preferences->'{key}' @> to_jsonb(CAST(:agent_id AS text))
|
||||||
|
THEN agent_preferences->'{key}'
|
||||||
|
ELSE
|
||||||
|
COALESCE(agent_preferences->'{key}', '[]'::jsonb)
|
||||||
|
|| to_jsonb(CAST(:agent_id AS text))
|
||||||
|
END
|
||||||
|
),
|
||||||
|
updated_at = now()
|
||||||
|
WHERE user_id = :user_id
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
{"user_id": user_id, "agent_id": agent_id},
|
||||||
|
)
|
||||||
|
|
||||||
|
def _remove_from_jsonb_array(
|
||||||
|
self, user_id: str, key: str, agent_ids: list[str]
|
||||||
|
) -> None:
|
||||||
|
"""Remove every id in ``agent_ids`` from ``agent_preferences.<key>``."""
|
||||||
|
if key not in ("pinned", "shared_with_me"):
|
||||||
|
raise ValueError(f"unsupported jsonb key: {key!r}")
|
||||||
|
self._conn.execute(
|
||||||
|
text(
|
||||||
|
f"""
|
||||||
|
UPDATE users
|
||||||
|
SET
|
||||||
|
agent_preferences = jsonb_set(
|
||||||
|
agent_preferences,
|
||||||
|
'{{{key}}}',
|
||||||
|
COALESCE(
|
||||||
|
(
|
||||||
|
SELECT jsonb_agg(elem)
|
||||||
|
FROM jsonb_array_elements(
|
||||||
|
COALESCE(agent_preferences->'{key}', '[]'::jsonb)
|
||||||
|
) AS elem
|
||||||
|
WHERE NOT ((elem #>> '{{}}') = ANY(:agent_ids))
|
||||||
|
),
|
||||||
|
'[]'::jsonb
|
||||||
|
)
|
||||||
|
),
|
||||||
|
updated_at = now()
|
||||||
|
WHERE user_id = :user_id
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
{"user_id": user_id, "agent_ids": agent_ids},
|
||||||
|
)
|
||||||
170
application/storage/db/repositories/workflow_edges.py
Normal file
170
application/storage/db/repositories/workflow_edges.py
Normal file
@@ -0,0 +1,170 @@
|
|||||||
|
"""Repository for the ``workflow_edges`` table.
|
||||||
|
|
||||||
|
Covers bulk insert, find by version, and delete operations that the
|
||||||
|
workflow routes perform on ``workflow_edges_collection`` in Mongo.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from sqlalchemy import Connection, text
|
||||||
|
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||||
|
|
||||||
|
from application.storage.db.base_repository import row_to_dict
|
||||||
|
from application.storage.db.models import workflow_edges_table
|
||||||
|
|
||||||
|
|
||||||
|
class WorkflowEdgesRepository:
|
||||||
|
def __init__(self, conn: Connection) -> None:
|
||||||
|
self._conn = conn
|
||||||
|
|
||||||
|
def create(
|
||||||
|
self,
|
||||||
|
workflow_id: str,
|
||||||
|
graph_version: int,
|
||||||
|
edge_id: str,
|
||||||
|
from_node_id: str,
|
||||||
|
to_node_id: str,
|
||||||
|
*,
|
||||||
|
source_handle: str | None = None,
|
||||||
|
target_handle: str | None = None,
|
||||||
|
config: dict | None = None,
|
||||||
|
) -> dict:
|
||||||
|
"""Create a single edge.
|
||||||
|
|
||||||
|
``from_node_id`` and ``to_node_id`` are the Postgres **UUID PKs**
|
||||||
|
of the workflow_nodes rows (not user-provided node_id strings).
|
||||||
|
"""
|
||||||
|
values: dict = {
|
||||||
|
"workflow_id": workflow_id,
|
||||||
|
"graph_version": graph_version,
|
||||||
|
"edge_id": edge_id,
|
||||||
|
"from_node_id": from_node_id,
|
||||||
|
"to_node_id": to_node_id,
|
||||||
|
}
|
||||||
|
if source_handle is not None:
|
||||||
|
values["source_handle"] = source_handle
|
||||||
|
if target_handle is not None:
|
||||||
|
values["target_handle"] = target_handle
|
||||||
|
if config is not None:
|
||||||
|
values["config"] = config
|
||||||
|
|
||||||
|
stmt = pg_insert(workflow_edges_table).values(**values).returning(workflow_edges_table)
|
||||||
|
result = self._conn.execute(stmt)
|
||||||
|
return row_to_dict(result.fetchone())
|
||||||
|
|
||||||
|
def bulk_create(
|
||||||
|
self,
|
||||||
|
workflow_id: str,
|
||||||
|
graph_version: int,
|
||||||
|
edges: list[dict],
|
||||||
|
) -> list[dict]:
|
||||||
|
"""Insert multiple edges in one statement.
|
||||||
|
|
||||||
|
Each element must have ``edge_id``, ``from_node_id`` (UUID PK),
|
||||||
|
``to_node_id`` (UUID PK). Optional: ``source_handle``,
|
||||||
|
``target_handle``, ``config``.
|
||||||
|
"""
|
||||||
|
if not edges:
|
||||||
|
return []
|
||||||
|
|
||||||
|
rows = []
|
||||||
|
for e in edges:
|
||||||
|
rows.append({
|
||||||
|
"workflow_id": workflow_id,
|
||||||
|
"graph_version": graph_version,
|
||||||
|
"edge_id": e["edge_id"],
|
||||||
|
"from_node_id": e["from_node_id"],
|
||||||
|
"to_node_id": e["to_node_id"],
|
||||||
|
"source_handle": e.get("source_handle"),
|
||||||
|
"target_handle": e.get("target_handle"),
|
||||||
|
"config": e.get("config", {}),
|
||||||
|
})
|
||||||
|
|
||||||
|
stmt = pg_insert(workflow_edges_table).values(rows).returning(workflow_edges_table)
|
||||||
|
result = self._conn.execute(stmt)
|
||||||
|
return [row_to_dict(r) for r in result.fetchall()]
|
||||||
|
|
||||||
|
def find_by_version(
|
||||||
|
self, workflow_id: str, graph_version: int,
|
||||||
|
) -> list[dict]:
|
||||||
|
"""List edges for a workflow/version, shaped to match the live API.
|
||||||
|
|
||||||
|
Joins ``workflow_nodes`` twice so callers receive the user-provided
|
||||||
|
node-id strings (``source_id``/``target_id``) that the Mongo code
|
||||||
|
and the frontend use, not the internal node UUIDs. The raw UUID
|
||||||
|
columns (``from_node_id``/``to_node_id``) are still included in
|
||||||
|
case a caller needs them.
|
||||||
|
"""
|
||||||
|
result = self._conn.execute(
|
||||||
|
text(
|
||||||
|
"""
|
||||||
|
SELECT e.*,
|
||||||
|
fn.node_id AS source_id,
|
||||||
|
tn.node_id AS target_id
|
||||||
|
FROM workflow_edges e
|
||||||
|
JOIN workflow_nodes fn ON fn.id = e.from_node_id
|
||||||
|
JOIN workflow_nodes tn ON tn.id = e.to_node_id
|
||||||
|
WHERE e.workflow_id = CAST(:wf_id AS uuid)
|
||||||
|
AND e.graph_version = :ver
|
||||||
|
ORDER BY e.edge_id
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
{"wf_id": workflow_id, "ver": graph_version},
|
||||||
|
)
|
||||||
|
return [row_to_dict(r) for r in result.fetchall()]
|
||||||
|
|
||||||
|
def resolve_node_id(
|
||||||
|
self, workflow_id: str, graph_version: int, node_id: str,
|
||||||
|
) -> Optional[str]:
|
||||||
|
"""Look up the UUID PK of a node by its user-provided ``node_id``.
|
||||||
|
|
||||||
|
Callers that receive edges in the frontend shape (``source_id`` /
|
||||||
|
``target_id`` are user-provided strings) use this helper to
|
||||||
|
translate to the UUID PK before calling :meth:`create` /
|
||||||
|
:meth:`bulk_create`.
|
||||||
|
"""
|
||||||
|
result = self._conn.execute(
|
||||||
|
text(
|
||||||
|
"SELECT id FROM workflow_nodes "
|
||||||
|
"WHERE workflow_id = CAST(:wf_id AS uuid) "
|
||||||
|
"AND graph_version = :ver AND node_id = :node_id"
|
||||||
|
),
|
||||||
|
{"wf_id": workflow_id, "ver": graph_version, "node_id": node_id},
|
||||||
|
)
|
||||||
|
row = result.fetchone()
|
||||||
|
return str(row[0]) if row else None
|
||||||
|
|
||||||
|
def delete_by_workflow(self, workflow_id: str) -> int:
|
||||||
|
result = self._conn.execute(
|
||||||
|
text(
|
||||||
|
"DELETE FROM workflow_edges "
|
||||||
|
"WHERE workflow_id = CAST(:wf_id AS uuid)"
|
||||||
|
),
|
||||||
|
{"wf_id": workflow_id},
|
||||||
|
)
|
||||||
|
return result.rowcount
|
||||||
|
|
||||||
|
def delete_by_version(self, workflow_id: str, graph_version: int) -> int:
|
||||||
|
result = self._conn.execute(
|
||||||
|
text(
|
||||||
|
"DELETE FROM workflow_edges "
|
||||||
|
"WHERE workflow_id = CAST(:wf_id AS uuid) "
|
||||||
|
"AND graph_version = :ver"
|
||||||
|
),
|
||||||
|
{"wf_id": workflow_id, "ver": graph_version},
|
||||||
|
)
|
||||||
|
return result.rowcount
|
||||||
|
|
||||||
|
def delete_other_versions(self, workflow_id: str, keep_version: int) -> int:
|
||||||
|
"""Delete all edges for a workflow except the specified version."""
|
||||||
|
result = self._conn.execute(
|
||||||
|
text(
|
||||||
|
"DELETE FROM workflow_edges "
|
||||||
|
"WHERE workflow_id = CAST(:wf_id AS uuid) "
|
||||||
|
"AND graph_version != :ver"
|
||||||
|
),
|
||||||
|
{"wf_id": workflow_id, "ver": keep_version},
|
||||||
|
)
|
||||||
|
return result.rowcount
|
||||||
158
application/storage/db/repositories/workflow_nodes.py
Normal file
158
application/storage/db/repositories/workflow_nodes.py
Normal file
@@ -0,0 +1,158 @@
|
|||||||
|
"""Repository for the ``workflow_nodes`` table.
|
||||||
|
|
||||||
|
Covers bulk insert, find by version, and delete operations that the
|
||||||
|
workflow routes perform on ``workflow_nodes_collection`` in Mongo.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from sqlalchemy import Connection, text
|
||||||
|
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||||
|
|
||||||
|
from application.storage.db.base_repository import row_to_dict
|
||||||
|
from application.storage.db.models import workflow_nodes_table
|
||||||
|
|
||||||
|
|
||||||
|
class WorkflowNodesRepository:
|
||||||
|
def __init__(self, conn: Connection) -> None:
|
||||||
|
self._conn = conn
|
||||||
|
|
||||||
|
def create(
|
||||||
|
self,
|
||||||
|
workflow_id: str,
|
||||||
|
graph_version: int,
|
||||||
|
node_id: str,
|
||||||
|
node_type: str,
|
||||||
|
*,
|
||||||
|
title: str | None = None,
|
||||||
|
description: str | None = None,
|
||||||
|
position: dict | None = None,
|
||||||
|
config: dict | None = None,
|
||||||
|
legacy_mongo_id: str | None = None,
|
||||||
|
) -> dict:
|
||||||
|
values: dict = {
|
||||||
|
"workflow_id": workflow_id,
|
||||||
|
"graph_version": graph_version,
|
||||||
|
"node_id": node_id,
|
||||||
|
"node_type": node_type,
|
||||||
|
}
|
||||||
|
if title is not None:
|
||||||
|
values["title"] = title
|
||||||
|
if description is not None:
|
||||||
|
values["description"] = description
|
||||||
|
if position is not None:
|
||||||
|
values["position"] = position
|
||||||
|
if config is not None:
|
||||||
|
values["config"] = config
|
||||||
|
if legacy_mongo_id is not None:
|
||||||
|
values["legacy_mongo_id"] = legacy_mongo_id
|
||||||
|
|
||||||
|
stmt = pg_insert(workflow_nodes_table).values(**values).returning(workflow_nodes_table)
|
||||||
|
result = self._conn.execute(stmt)
|
||||||
|
return row_to_dict(result.fetchone())
|
||||||
|
|
||||||
|
def bulk_create(
|
||||||
|
self,
|
||||||
|
workflow_id: str,
|
||||||
|
graph_version: int,
|
||||||
|
nodes: list[dict],
|
||||||
|
) -> list[dict]:
|
||||||
|
"""Insert multiple nodes in one statement.
|
||||||
|
|
||||||
|
Each element of ``nodes`` should have at least ``node_id`` and
|
||||||
|
``node_type``; optional keys: ``title``, ``description``,
|
||||||
|
``position``, ``config``.
|
||||||
|
"""
|
||||||
|
if not nodes:
|
||||||
|
return []
|
||||||
|
|
||||||
|
rows = []
|
||||||
|
for n in nodes:
|
||||||
|
rows.append({
|
||||||
|
"workflow_id": workflow_id,
|
||||||
|
"graph_version": graph_version,
|
||||||
|
"node_id": n["node_id"],
|
||||||
|
"node_type": n["node_type"],
|
||||||
|
"title": n.get("title"),
|
||||||
|
"description": n.get("description"),
|
||||||
|
"position": n.get("position", {"x": 0, "y": 0}),
|
||||||
|
"config": n.get("config", {}),
|
||||||
|
"legacy_mongo_id": n.get("legacy_mongo_id"),
|
||||||
|
})
|
||||||
|
|
||||||
|
stmt = pg_insert(workflow_nodes_table).values(rows).returning(workflow_nodes_table)
|
||||||
|
result = self._conn.execute(stmt)
|
||||||
|
return [row_to_dict(r) for r in result.fetchall()]
|
||||||
|
|
||||||
|
def find_by_version(
|
||||||
|
self, workflow_id: str, graph_version: int,
|
||||||
|
) -> list[dict]:
|
||||||
|
result = self._conn.execute(
|
||||||
|
text(
|
||||||
|
"SELECT * FROM workflow_nodes "
|
||||||
|
"WHERE workflow_id = CAST(:wf_id AS uuid) "
|
||||||
|
"AND graph_version = :ver "
|
||||||
|
"ORDER BY node_id"
|
||||||
|
),
|
||||||
|
{"wf_id": workflow_id, "ver": graph_version},
|
||||||
|
)
|
||||||
|
return [row_to_dict(r) for r in result.fetchall()]
|
||||||
|
|
||||||
|
def find_node(
|
||||||
|
self, workflow_id: str, graph_version: int, node_id: str,
|
||||||
|
) -> Optional[dict]:
|
||||||
|
"""Find a single node by its user-provided ``node_id``."""
|
||||||
|
result = self._conn.execute(
|
||||||
|
text(
|
||||||
|
"SELECT * FROM workflow_nodes "
|
||||||
|
"WHERE workflow_id = CAST(:wf_id AS uuid) "
|
||||||
|
"AND graph_version = :ver AND node_id = :nid"
|
||||||
|
),
|
||||||
|
{"wf_id": workflow_id, "ver": graph_version, "nid": node_id},
|
||||||
|
)
|
||||||
|
row = result.fetchone()
|
||||||
|
return row_to_dict(row) if row is not None else None
|
||||||
|
|
||||||
|
def get_by_legacy_id(self, legacy_mongo_id: str) -> Optional[dict]:
|
||||||
|
"""Find a node by the original Mongo ObjectId string."""
|
||||||
|
result = self._conn.execute(
|
||||||
|
text("SELECT * FROM workflow_nodes WHERE legacy_mongo_id = :legacy_id"),
|
||||||
|
{"legacy_id": legacy_mongo_id},
|
||||||
|
)
|
||||||
|
row = result.fetchone()
|
||||||
|
return row_to_dict(row) if row is not None else None
|
||||||
|
|
||||||
|
def delete_by_workflow(self, workflow_id: str) -> int:
|
||||||
|
result = self._conn.execute(
|
||||||
|
text(
|
||||||
|
"DELETE FROM workflow_nodes "
|
||||||
|
"WHERE workflow_id = CAST(:wf_id AS uuid)"
|
||||||
|
),
|
||||||
|
{"wf_id": workflow_id},
|
||||||
|
)
|
||||||
|
return result.rowcount
|
||||||
|
|
||||||
|
def delete_by_version(self, workflow_id: str, graph_version: int) -> int:
|
||||||
|
result = self._conn.execute(
|
||||||
|
text(
|
||||||
|
"DELETE FROM workflow_nodes "
|
||||||
|
"WHERE workflow_id = CAST(:wf_id AS uuid) "
|
||||||
|
"AND graph_version = :ver"
|
||||||
|
),
|
||||||
|
{"wf_id": workflow_id, "ver": graph_version},
|
||||||
|
)
|
||||||
|
return result.rowcount
|
||||||
|
|
||||||
|
def delete_other_versions(self, workflow_id: str, keep_version: int) -> int:
|
||||||
|
"""Delete all nodes for a workflow except the specified version."""
|
||||||
|
result = self._conn.execute(
|
||||||
|
text(
|
||||||
|
"DELETE FROM workflow_nodes "
|
||||||
|
"WHERE workflow_id = CAST(:wf_id AS uuid) "
|
||||||
|
"AND graph_version != :ver"
|
||||||
|
),
|
||||||
|
{"wf_id": workflow_id, "ver": keep_version},
|
||||||
|
)
|
||||||
|
return result.rowcount
|
||||||
83
application/storage/db/repositories/workflow_runs.py
Normal file
83
application/storage/db/repositories/workflow_runs.py
Normal file
@@ -0,0 +1,83 @@
|
|||||||
|
"""Repository for the ``workflow_runs`` table.
|
||||||
|
|
||||||
|
In Mongo, workflow_runs_collection only has ``insert_one`` — runs are
|
||||||
|
written once after workflow execution completes and never updated.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from sqlalchemy import Connection, text
|
||||||
|
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||||
|
|
||||||
|
from application.storage.db.base_repository import row_to_dict
|
||||||
|
from application.storage.db.models import workflow_runs_table
|
||||||
|
|
||||||
|
|
||||||
|
class WorkflowRunsRepository:
|
||||||
|
def __init__(self, conn: Connection) -> None:
|
||||||
|
self._conn = conn
|
||||||
|
|
||||||
|
def create(
|
||||||
|
self,
|
||||||
|
workflow_id: str,
|
||||||
|
user_id: str,
|
||||||
|
status: str,
|
||||||
|
*,
|
||||||
|
inputs: dict | None = None,
|
||||||
|
result: dict | None = None,
|
||||||
|
steps: list | None = None,
|
||||||
|
started_at=None,
|
||||||
|
ended_at=None,
|
||||||
|
legacy_mongo_id: str | None = None,
|
||||||
|
) -> dict:
|
||||||
|
values: dict = {
|
||||||
|
"workflow_id": workflow_id,
|
||||||
|
"user_id": user_id,
|
||||||
|
"status": status,
|
||||||
|
}
|
||||||
|
if inputs is not None:
|
||||||
|
values["inputs"] = inputs
|
||||||
|
if result is not None:
|
||||||
|
values["result"] = result
|
||||||
|
if steps is not None:
|
||||||
|
values["steps"] = steps
|
||||||
|
if started_at is not None:
|
||||||
|
values["started_at"] = started_at
|
||||||
|
if ended_at is not None:
|
||||||
|
values["ended_at"] = ended_at
|
||||||
|
if legacy_mongo_id is not None:
|
||||||
|
values["legacy_mongo_id"] = legacy_mongo_id
|
||||||
|
|
||||||
|
stmt = pg_insert(workflow_runs_table).values(**values).returning(workflow_runs_table)
|
||||||
|
res = self._conn.execute(stmt)
|
||||||
|
return row_to_dict(res.fetchone())
|
||||||
|
|
||||||
|
def get(self, run_id: str) -> Optional[dict]:
|
||||||
|
res = self._conn.execute(
|
||||||
|
text("SELECT * FROM workflow_runs WHERE id = CAST(:id AS uuid)"),
|
||||||
|
{"id": run_id},
|
||||||
|
)
|
||||||
|
row = res.fetchone()
|
||||||
|
return row_to_dict(row) if row is not None else None
|
||||||
|
|
||||||
|
def get_by_legacy_id(self, legacy_mongo_id: str) -> Optional[dict]:
|
||||||
|
"""Fetch a workflow run by the original Mongo ObjectId string."""
|
||||||
|
res = self._conn.execute(
|
||||||
|
text("SELECT * FROM workflow_runs WHERE legacy_mongo_id = :legacy_id"),
|
||||||
|
{"legacy_id": legacy_mongo_id},
|
||||||
|
)
|
||||||
|
row = res.fetchone()
|
||||||
|
return row_to_dict(row) if row is not None else None
|
||||||
|
|
||||||
|
def list_for_workflow(self, workflow_id: str) -> list[dict]:
|
||||||
|
res = self._conn.execute(
|
||||||
|
text(
|
||||||
|
"SELECT * FROM workflow_runs "
|
||||||
|
"WHERE workflow_id = CAST(:wf_id AS uuid) "
|
||||||
|
"ORDER BY started_at DESC"
|
||||||
|
),
|
||||||
|
{"wf_id": workflow_id},
|
||||||
|
)
|
||||||
|
return [row_to_dict(r) for r in res.fetchall()]
|
||||||
125
application/storage/db/repositories/workflows.py
Normal file
125
application/storage/db/repositories/workflows.py
Normal file
@@ -0,0 +1,125 @@
|
|||||||
|
"""Repository for the ``workflows`` table.
|
||||||
|
|
||||||
|
Covers CRUD on workflow metadata:
|
||||||
|
|
||||||
|
- create / get / list / update / delete
|
||||||
|
- graph version management
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from sqlalchemy import Connection, text
|
||||||
|
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||||
|
|
||||||
|
from application.storage.db.base_repository import row_to_dict
|
||||||
|
from application.storage.db.models import workflows_table
|
||||||
|
|
||||||
|
|
||||||
|
class WorkflowsRepository:
|
||||||
|
def __init__(self, conn: Connection) -> None:
|
||||||
|
self._conn = conn
|
||||||
|
|
||||||
|
def create(
|
||||||
|
self,
|
||||||
|
user_id: str,
|
||||||
|
name: str,
|
||||||
|
description: str | None = None,
|
||||||
|
*,
|
||||||
|
legacy_mongo_id: str | None = None,
|
||||||
|
) -> dict:
|
||||||
|
values: dict = {"user_id": user_id, "name": name}
|
||||||
|
if description is not None:
|
||||||
|
values["description"] = description
|
||||||
|
if legacy_mongo_id is not None:
|
||||||
|
values["legacy_mongo_id"] = legacy_mongo_id
|
||||||
|
|
||||||
|
stmt = pg_insert(workflows_table).values(**values).returning(workflows_table)
|
||||||
|
result = self._conn.execute(stmt)
|
||||||
|
return row_to_dict(result.fetchone())
|
||||||
|
|
||||||
|
def get(self, workflow_id: str, user_id: str) -> Optional[dict]:
|
||||||
|
result = self._conn.execute(
|
||||||
|
text(
|
||||||
|
"SELECT * FROM workflows "
|
||||||
|
"WHERE id = CAST(:id AS uuid) AND user_id = :user_id"
|
||||||
|
),
|
||||||
|
{"id": workflow_id, "user_id": user_id},
|
||||||
|
)
|
||||||
|
row = result.fetchone()
|
||||||
|
return row_to_dict(row) if row is not None else None
|
||||||
|
|
||||||
|
def get_by_id(self, workflow_id: str) -> Optional[dict]:
|
||||||
|
"""Fetch a workflow by ID without user check (for internal use)."""
|
||||||
|
result = self._conn.execute(
|
||||||
|
text("SELECT * FROM workflows WHERE id = CAST(:id AS uuid)"),
|
||||||
|
{"id": workflow_id},
|
||||||
|
)
|
||||||
|
row = result.fetchone()
|
||||||
|
return row_to_dict(row) if row is not None else None
|
||||||
|
|
||||||
|
def get_by_legacy_id(
|
||||||
|
self, legacy_mongo_id: str, user_id: str | None = None,
|
||||||
|
) -> Optional[dict]:
|
||||||
|
"""Fetch a workflow by its original Mongo ObjectId string."""
|
||||||
|
sql = "SELECT * FROM workflows WHERE legacy_mongo_id = :legacy_id"
|
||||||
|
params: dict[str, str] = {"legacy_id": legacy_mongo_id}
|
||||||
|
if user_id is not None:
|
||||||
|
sql += " AND user_id = :user_id"
|
||||||
|
params["user_id"] = user_id
|
||||||
|
result = self._conn.execute(text(sql), params)
|
||||||
|
row = result.fetchone()
|
||||||
|
return row_to_dict(row) if row is not None else None
|
||||||
|
|
||||||
|
def list_for_user(self, user_id: str) -> list[dict]:
|
||||||
|
result = self._conn.execute(
|
||||||
|
text(
|
||||||
|
"SELECT * FROM workflows "
|
||||||
|
"WHERE user_id = :user_id ORDER BY created_at DESC"
|
||||||
|
),
|
||||||
|
{"user_id": user_id},
|
||||||
|
)
|
||||||
|
return [row_to_dict(r) for r in result.fetchall()]
|
||||||
|
|
||||||
|
def update(self, workflow_id: str, user_id: str, fields: dict) -> bool:
|
||||||
|
allowed = {"name", "description", "current_graph_version"}
|
||||||
|
filtered = {k: v for k, v in fields.items() if k in allowed}
|
||||||
|
if not filtered:
|
||||||
|
return False
|
||||||
|
|
||||||
|
set_parts = [f"{col} = :{col}" for col in filtered]
|
||||||
|
set_parts.append("updated_at = now()")
|
||||||
|
params = {**filtered, "id": workflow_id, "user_id": user_id}
|
||||||
|
|
||||||
|
sql = (
|
||||||
|
f"UPDATE workflows SET {', '.join(set_parts)} "
|
||||||
|
"WHERE id = CAST(:id AS uuid) AND user_id = :user_id"
|
||||||
|
)
|
||||||
|
result = self._conn.execute(text(sql), params)
|
||||||
|
return result.rowcount > 0
|
||||||
|
|
||||||
|
def increment_graph_version(self, workflow_id: str, user_id: str) -> Optional[int]:
|
||||||
|
"""Atomically increment ``current_graph_version`` and return the new value."""
|
||||||
|
result = self._conn.execute(
|
||||||
|
text(
|
||||||
|
"UPDATE workflows "
|
||||||
|
"SET current_graph_version = current_graph_version + 1, "
|
||||||
|
" updated_at = now() "
|
||||||
|
"WHERE id = CAST(:id AS uuid) AND user_id = :user_id "
|
||||||
|
"RETURNING current_graph_version"
|
||||||
|
),
|
||||||
|
{"id": workflow_id, "user_id": user_id},
|
||||||
|
)
|
||||||
|
row = result.fetchone()
|
||||||
|
return row[0] if row else None
|
||||||
|
|
||||||
|
def delete(self, workflow_id: str, user_id: str) -> bool:
|
||||||
|
result = self._conn.execute(
|
||||||
|
text(
|
||||||
|
"DELETE FROM workflows "
|
||||||
|
"WHERE id = CAST(:id AS uuid) AND user_id = :user_id"
|
||||||
|
),
|
||||||
|
{"id": workflow_id, "user_id": user_id},
|
||||||
|
)
|
||||||
|
return result.rowcount > 0
|
||||||
@@ -21,10 +21,19 @@ class LocalStorage(BaseStorage):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _get_full_path(self, path: str) -> str:
|
def _get_full_path(self, path: str) -> str:
|
||||||
"""Get absolute path by combining base_dir and path."""
|
"""Get absolute path by combining base_dir and path.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the resolved path escapes base_dir (path traversal).
|
||||||
|
"""
|
||||||
if os.path.isabs(path):
|
if os.path.isabs(path):
|
||||||
return path
|
resolved = os.path.realpath(path)
|
||||||
return os.path.join(self.base_dir, path)
|
else:
|
||||||
|
resolved = os.path.realpath(os.path.join(self.base_dir, path))
|
||||||
|
base = os.path.realpath(self.base_dir)
|
||||||
|
if not resolved.startswith(base + os.sep) and resolved != base:
|
||||||
|
raise ValueError(f"Path traversal detected: {path}")
|
||||||
|
return resolved
|
||||||
|
|
||||||
def save_file(self, file_data: BinaryIO, path: str, **kwargs) -> dict:
|
def save_file(self, file_data: BinaryIO, path: str, **kwargs) -> dict:
|
||||||
"""Save a file to local storage."""
|
"""Save a file to local storage."""
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user