mirror of
https://github.com/arc53/DocsGPT.git
synced 2026-05-08 06:53:40 +00:00
Compare commits
108 Commits
0.16.0
...
dependabot
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
acfa972c40 | ||
|
|
3ceabed8ad | ||
|
|
422a4b139e | ||
|
|
e85935eed0 | ||
|
|
6a69b8aca0 | ||
|
|
33c2cc9660 | ||
|
|
175d4d5a68 | ||
|
|
6c3ead1071 | ||
|
|
d23f88f825 | ||
|
|
da1df515f7 | ||
|
|
671a9d75ad | ||
|
|
1c829667ff | ||
|
|
3ab0ebb16d | ||
|
|
988c4a5a15 | ||
|
|
01db8b2c41 | ||
|
|
ef19da9516 | ||
|
|
cc1275c3f9 | ||
|
|
14c2f4890f | ||
|
|
b3aec36aa2 | ||
|
|
50f62beaeb | ||
|
|
423b4c6494 | ||
|
|
54f615c59d | ||
|
|
223b3de66e | ||
|
|
4db9622ef5 | ||
|
|
e8d1bbfb68 | ||
|
|
aff1345ae4 | ||
|
|
ee430aff1e | ||
|
|
81b6ee5daa | ||
|
|
ebb7938d1b | ||
|
|
7f6360b4ff | ||
|
|
c68f18a0ae | ||
|
|
684b29e73c | ||
|
|
a1efea81d0 | ||
|
|
9eb34262e0 | ||
|
|
951bdb8365 | ||
|
|
c18f85a050 | ||
|
|
5ecb174567 | ||
|
|
ed7212d016 | ||
|
|
f82acdab5d | ||
|
|
361aebc34c | ||
|
|
bf194c1a0f | ||
|
|
54c396750b | ||
|
|
9adebfec69 | ||
|
|
92c321f163 | ||
|
|
e3d36b9e52 | ||
|
|
8950e11208 | ||
|
|
5de0132a65 | ||
|
|
b92ca91512 | ||
|
|
8e0b2844a2 | ||
|
|
0969db5e30 | ||
|
|
af335a27e8 | ||
|
|
0ae3139284 | ||
|
|
7529ca3dd6 | ||
|
|
1b813320f1 | ||
|
|
02012e9a0b | ||
|
|
c2f027265a | ||
|
|
0ae615c10e | ||
|
|
881d0da344 | ||
|
|
1376de6bae | ||
|
|
362ebfcc0a | ||
|
|
bc77eed3d8 | ||
|
|
1f346588e7 | ||
|
|
2fed5c882b | ||
|
|
aa938d76d7 | ||
|
|
2940628aa6 | ||
|
|
7f23928134 | ||
|
|
20e17c84c7 | ||
|
|
389ddf6068 | ||
|
|
1e2443fb90 | ||
|
|
6387bd1892 | ||
|
|
7d22724d1c | ||
|
|
f6f12f6895 | ||
|
|
934127f323 | ||
|
|
1780e3cc91 | ||
|
|
5e7fab2f34 | ||
|
|
92ae76f95e | ||
|
|
18755bdd9b | ||
|
|
0f20adcbf4 | ||
|
|
18e2a829c9 | ||
|
|
cd44501a71 | ||
|
|
f8ebdf3fd4 | ||
|
|
7c6fca18ad | ||
|
|
5fab798707 | ||
|
|
cb30a24e05 | ||
|
|
530761d08c | ||
|
|
73fbc28744 | ||
|
|
b5b6538762 | ||
|
|
a9761061fc | ||
|
|
9388996a15 | ||
|
|
875868b7e5 | ||
|
|
502819ae52 | ||
|
|
cada1a44fc | ||
|
|
6192767451 | ||
|
|
5c3e6eca54 | ||
|
|
59d9d4ac50 | ||
|
|
3931ccccee | ||
|
|
55717043f6 | ||
|
|
ececcb8b17 | ||
|
|
420e9d3dd5 | ||
|
|
749eed3d0b | ||
|
|
bd03a513e3 | ||
|
|
fcdb4fb5e8 | ||
|
|
e787c896eb | ||
|
|
23aeaff5db | ||
|
|
689dd79597 | ||
|
|
0c15af90b1 | ||
|
|
cdd6ff6557 | ||
|
|
cdb71a54f0 |
@@ -34,3 +34,9 @@ MICROSOFT_TENANT_ID=your-azure-ad-tenant-id
|
|||||||
#or "https://login.microsoftonline.com/contoso.onmicrosoft.com".
|
#or "https://login.microsoftonline.com/contoso.onmicrosoft.com".
|
||||||
#Alternatively, use "https://login.microsoftonline.com/common" for multi-tenant app.
|
#Alternatively, use "https://login.microsoftonline.com/common" for multi-tenant app.
|
||||||
MICROSOFT_AUTHORITY=https://{tenantId}.ciamlogin.com/{tenantId}
|
MICROSOFT_AUTHORITY=https://{tenantId}.ciamlogin.com/{tenantId}
|
||||||
|
|
||||||
|
# User-data Postgres DB (Phase 0 of the MongoDB→Postgres migration).
|
||||||
|
# Standard Postgres URI — `postgres://` and `postgresql://` both work.
|
||||||
|
# Leave unset while the migration is still being rolled out; the app will
|
||||||
|
# fall back to MongoDB for user data until POSTGRES_URI is configured.
|
||||||
|
# POSTGRES_URI=postgresql://docsgpt:docsgpt@localhost:5432/docsgpt
|
||||||
|
|||||||
99
.github/INCIDENT_RESPONSE.md
vendored
Normal file
99
.github/INCIDENT_RESPONSE.md
vendored
Normal file
@@ -0,0 +1,99 @@
|
|||||||
|
# DocsGPT Incident Response Plan (IRP)
|
||||||
|
|
||||||
|
This playbook describes how maintainers respond to confirmed or suspected security incidents.
|
||||||
|
|
||||||
|
- Vulnerability reporting: [`SECURITY.md`](../SECURITY.md)
|
||||||
|
- Non-security bugs/features: [`CONTRIBUTING.md`](../CONTRIBUTING.md)
|
||||||
|
|
||||||
|
## Severity
|
||||||
|
|
||||||
|
| Severity | Definition | Typical examples |
|
||||||
|
|---|---|---|
|
||||||
|
| **Critical** | Active exploitation, supply-chain compromise, or confirmed data breach requiring immediate user action. | Compromised release artifact/image; remote execution. |
|
||||||
|
| **High** | Serious undisclosed vulnerability with no practical workaround, or CVSS >= 7.0. | key leakage; prompt injection enabling cross-tenant access. |
|
||||||
|
| **Medium** | Material impact but constrained by preconditions/scope, or a practical workaround exists. | Auth-required exploit; dependency CVE with limited reachability. |
|
||||||
|
| **Low** | Defense-in-depth or narrow availability impact with no confirmed data exposure. | Missing rate limiting; hardening gap without exploit evidence. |
|
||||||
|
|
||||||
|
|
||||||
|
## Response workflow
|
||||||
|
|
||||||
|
### 1) Triage (target: initial response within 48 hours)
|
||||||
|
|
||||||
|
1. Acknowledge report.
|
||||||
|
2. Validate on latest release and `main`.
|
||||||
|
3. Confirm in-scope security issue vs. hardening item (per `SECURITY.md`).
|
||||||
|
4. Assign severity and open a **draft GitHub Security Advisory (GHSA)** (no public issue).
|
||||||
|
5. Determine whether root cause is DocsGPT code or upstream dependency/provider.
|
||||||
|
|
||||||
|
### 2) Investigation
|
||||||
|
|
||||||
|
1. Identify affected components, versions, and deployment scope (self-hosted, cloud, or both).
|
||||||
|
2. For AI issues, explicitly evaluate prompt injection, document isolation, and output leakage.
|
||||||
|
3. Request a CVE through GHSA for **Medium+** issues.
|
||||||
|
|
||||||
|
### 3) Containment, fix, and disclosure
|
||||||
|
|
||||||
|
1. Implement and test fix in private security workflow (GHSA private fork/branch).
|
||||||
|
2. Merge fix to `main`, cut patched release, and verify published artifacts/images.
|
||||||
|
3. Patch managed cloud deployment (`app.docsgpt.cloud`) and other deployments as soon as validated.
|
||||||
|
4. Publish GHSA with CVE (if assigned), affected/fixed versions, CVSS, mitigations, and upgrade guidance.
|
||||||
|
5. **Critical/High:** coordinate disclosure timing with reporter (goal: <= 90 days) and publish a notice.
|
||||||
|
6. **Medium/Low:** include in next scheduled release unless risk requires immediate out-of-band patching.
|
||||||
|
|
||||||
|
### 4) Post-incident
|
||||||
|
|
||||||
|
1. Monitor support channels (GitHub/Discord) for regressions or exploitation reports.
|
||||||
|
2. Run a short retrospective (root cause, detection, response gaps, prevention work).
|
||||||
|
3. Track follow-up hardening actions with owners/dates.
|
||||||
|
4. Update this IRP and related runbooks as needed.
|
||||||
|
|
||||||
|
## Scenario playbooks
|
||||||
|
|
||||||
|
### Supply-chain compromise
|
||||||
|
|
||||||
|
1. Freeze releases and investigate blast radius.
|
||||||
|
2. Rotate credentials in order: Docker Hub -> GitHub tokens -> LLM provider keys -> DB credentials -> `JWT_SECRET_KEY` -> `ENCRYPTION_SECRET_KEY` -> `INTERNAL_KEY`.
|
||||||
|
3. Replace compromised artifacts/tags with clean releases and revoke/remove bad tags where possible.
|
||||||
|
4. Publish advisory with exact affected versions and required user actions.
|
||||||
|
|
||||||
|
### Data exposure
|
||||||
|
|
||||||
|
1. Determine scope (users, documents, keys, logs, time window).
|
||||||
|
2. Disable affected path or hotfix immediately for managed cloud.
|
||||||
|
3. Notify affected users with concrete remediation steps (for example, rotate keys).
|
||||||
|
4. Continue through standard fix/disclosure workflow.
|
||||||
|
|
||||||
|
### Critical regression with security impact
|
||||||
|
|
||||||
|
1. Identify introducing change (`git bisect` if needed).
|
||||||
|
2. Publish workaround within 24 hours (for example, pin to known-good version).
|
||||||
|
3. Ship patch release with regression test and close incident with public summary.
|
||||||
|
|
||||||
|
## AI-specific guidance
|
||||||
|
|
||||||
|
Treat confirmed AI-specific abuse as security incidents:
|
||||||
|
|
||||||
|
- Prompt injection causing sensitive data exfiltration (from tools that don't belong to the agent) -> **High**
|
||||||
|
- Cross-tenant retrieval/isolation failure -> **High**
|
||||||
|
- API key disclosure in output -> **High**
|
||||||
|
|
||||||
|
## Secret rotation quick reference
|
||||||
|
|
||||||
|
| Secret | Standard rotation action |
|
||||||
|
|---|---|
|
||||||
|
| Docker Hub credentials | Revoke/replace in Docker Hub; update CI/CD secrets |
|
||||||
|
| GitHub tokens/PATs | Revoke/replace in GitHub; update automation secrets |
|
||||||
|
| LLM provider API keys | Rotate in provider console; update runtime/deploy secrets |
|
||||||
|
| Database credentials | Rotate in DB platform; redeploy with new secrets |
|
||||||
|
| `JWT_SECRET_KEY` | Rotate and redeploy (invalidates all active user sessions/tokens) |
|
||||||
|
| `ENCRYPTION_SECRET_KEY` | Rotate and redeploy (re-encrypt stored data if possible; existing encrypted data may become inaccessible) |
|
||||||
|
| `INTERNAL_KEY` | Rotate and redeploy (invalidates worker-to-backend authentication) |
|
||||||
|
|
||||||
|
## Maintenance
|
||||||
|
|
||||||
|
Review this document:
|
||||||
|
|
||||||
|
- after every **Critical/High** incident, and
|
||||||
|
- at least annually.
|
||||||
|
|
||||||
|
Changes should be proposed via pull request to `main`.
|
||||||
144
.github/THREAT_MODEL.md
vendored
Normal file
144
.github/THREAT_MODEL.md
vendored
Normal file
@@ -0,0 +1,144 @@
|
|||||||
|
# DocsGPT Public Threat Model
|
||||||
|
|
||||||
|
**Classification:** Public
|
||||||
|
**Last updated:** 2026-04-15
|
||||||
|
**Applies to:** Open-source and self-hosted DocsGPT deployments
|
||||||
|
|
||||||
|
## 1) Overview
|
||||||
|
|
||||||
|
DocsGPT ingests content (files/URLs/connectors), indexes it, and answers queries via LLM-backed APIs and optional tools.
|
||||||
|
|
||||||
|
Core components:
|
||||||
|
- Backend API (`application/`)
|
||||||
|
- Workers/ingestion (`application/worker.py` and related modules)
|
||||||
|
- Datastores (MongoDB/Redis/vector stores)
|
||||||
|
- Frontend (`frontend/`)
|
||||||
|
- Optional extensions/integrations (`extensions/`)
|
||||||
|
|
||||||
|
## 2) Scope and assumptions
|
||||||
|
|
||||||
|
In scope:
|
||||||
|
- Application-level threats in this repository.
|
||||||
|
- Local and internet-exposed self-hosted deployments.
|
||||||
|
|
||||||
|
Assumptions:
|
||||||
|
- Internet-facing instances enable auth and use strong secrets.
|
||||||
|
- Datastores/internal services are not publicly exposed.
|
||||||
|
|
||||||
|
Out of scope:
|
||||||
|
- Cloud hardware/provider compromise.
|
||||||
|
- Security guarantees of external LLM vendors.
|
||||||
|
- Full security audits of third-party systems targeted by tools (external DBs/MCP servers/code-exec APIs).
|
||||||
|
|
||||||
|
## 3) Security objectives
|
||||||
|
|
||||||
|
- Protect document/conversation confidentiality.
|
||||||
|
- Preserve integrity of prompts, agents, tools, and indexed data.
|
||||||
|
- Maintain API/worker availability.
|
||||||
|
- Enforce tenant isolation in authenticated deployments.
|
||||||
|
|
||||||
|
## 4) Assets
|
||||||
|
|
||||||
|
- Documents, attachments, chunks/embeddings, summaries.
|
||||||
|
- Conversations, agents, workflows, prompt templates.
|
||||||
|
- Secrets (JWT secret, `INTERNAL_KEY`, provider/API/OAuth credentials).
|
||||||
|
- Operational capacity (worker throughput, queue depth, model quota/cost).
|
||||||
|
|
||||||
|
## 5) Trust boundaries and untrusted input
|
||||||
|
|
||||||
|
Trust boundaries:
|
||||||
|
- Internet ↔ Frontend
|
||||||
|
- Frontend ↔ Backend API
|
||||||
|
- Backend ↔ Workers/internal APIs
|
||||||
|
- Backend/workers ↔ Datastores
|
||||||
|
- Backend ↔ External LLM/connectors/remote URLs
|
||||||
|
|
||||||
|
Untrusted input includes API payloads, file uploads, remote URLs, OAuth/webhook data, retrieved content, and LLM/tool arguments.
|
||||||
|
|
||||||
|
## 6) Main attack surfaces
|
||||||
|
|
||||||
|
1. Auth/authz paths and sharing tokens.
|
||||||
|
2. File upload + parsing pipeline.
|
||||||
|
3. Remote URL fetching and connectors (SSRF risk).
|
||||||
|
4. Agent/tool execution from LLM output.
|
||||||
|
5. Template/workflow rendering.
|
||||||
|
6. Frontend rendering + token storage.
|
||||||
|
7. Internal service endpoints (`INTERNAL_KEY`).
|
||||||
|
8. High-impact integrations (SQL tool, generic API tool, remote MCP tools).
|
||||||
|
|
||||||
|
## 7) Key threats and expected mitigations
|
||||||
|
|
||||||
|
### A. Auth/authz misconfiguration
|
||||||
|
- Threat: weak/no auth or leaked tokens leads to broad data access.
|
||||||
|
- Mitigations: require auth for public deployments, short-lived tokens, rotation/revocation, least-privilege sharing.
|
||||||
|
|
||||||
|
### B. Untrusted file ingestion
|
||||||
|
- Threat: malicious files/archives trigger traversal, parser exploits, or resource exhaustion.
|
||||||
|
- Mitigations: strict path checks, archive safeguards, file limits, patched parser dependencies.
|
||||||
|
|
||||||
|
### C. SSRF/outbound abuse
|
||||||
|
- Threat: URL loaders/tools access private/internal/metadata endpoints.
|
||||||
|
- Mitigations: validate URLs + redirects, block private/link-local ranges, apply egress controls/allowlists.
|
||||||
|
|
||||||
|
### D. Prompt injection + tool abuse
|
||||||
|
- Threat: retrieved text manipulates model behavior and causes unsafe tool calls.
|
||||||
|
- Threat: never rely on the model to "choose correctly" under adversarial input.
|
||||||
|
- Mitigations: treat retrieved/model output as untrusted, enforce tool policies, only expose tools explicitly assigned by the user/admin to that agent, separate system instructions from retrieved content, audit tool calls.
|
||||||
|
|
||||||
|
### E. Dangerous tool capability chaining (SQL/API/MCP)
|
||||||
|
- Threat: write-capable SQL credentials allow destructive queries.
|
||||||
|
- Threat: API tool can trigger side effects (infra/payment/webhook/code-exec endpoints).
|
||||||
|
- Threat: remote MCP tools may expose privileged operations.
|
||||||
|
- Mitigations: read-only-by-default credentials, destination allowlists, explicit approval for write/exec actions, per-tool policy enforcement + logging.
|
||||||
|
|
||||||
|
### F. Frontend/XSS + token theft
|
||||||
|
- Threat: XSS can steal local tokens and call APIs.
|
||||||
|
- Mitigations: reduce unsafe rendering paths, strong CSP, scoped short-lived credentials.
|
||||||
|
|
||||||
|
### G. Internal endpoint exposure
|
||||||
|
- Threat: weak/unset `INTERNAL_KEY` enables internal API abuse.
|
||||||
|
- Mitigations: fail closed, require strong random keys, keep internal APIs private.
|
||||||
|
|
||||||
|
### H. DoS and cost abuse
|
||||||
|
- Threat: request floods, large ingestion jobs, expensive prompts/crawls.
|
||||||
|
- Mitigations: rate limits, quotas, timeouts, queue backpressure, usage budgets.
|
||||||
|
|
||||||
|
## 8) Example attacker stories
|
||||||
|
|
||||||
|
- Internet-exposed deployment runs with weak/no auth and receives unauthorized data access/abuse.
|
||||||
|
- Intranet deployment intentionally using weak/no auth is vulnerable to insider misuse and lateral-movement abuse.
|
||||||
|
- Crafted archive attempts path traversal during extraction.
|
||||||
|
- Malicious URL/redirect chain targets internal services.
|
||||||
|
- Poisoned document causes data exfiltration through tool calls.
|
||||||
|
- Over-privileged SQL/API/MCP tool performs destructive side effects.
|
||||||
|
|
||||||
|
## 9) Severity calibration
|
||||||
|
|
||||||
|
- **Critical:** unauthenticated public data access; prompt-injection-driven exfiltration; SSRF to sensitive internal endpoints.
|
||||||
|
- **High:** cross-tenant leakage, persistent token compromise, over-privileged destructive tools.
|
||||||
|
- **Medium:** DoS/cost amplification and non-critical information disclosure.
|
||||||
|
- **Low:** minor hardening gaps with limited impact.
|
||||||
|
|
||||||
|
## 10) Baseline controls for public deployments
|
||||||
|
|
||||||
|
1. Enforce authentication and secure defaults.
|
||||||
|
2. Set/rotate strong secrets (`JWT`, `INTERNAL_KEY`, encryption keys).
|
||||||
|
3. Restrict CORS and front API with a hardened proxy.
|
||||||
|
4. Add rate limiting/quotas for answer/upload/crawl/token endpoints.
|
||||||
|
5. Enforce URL+redirect SSRF protections and egress restrictions.
|
||||||
|
6. Apply upload/archive/parsing hardening.
|
||||||
|
7. Require least-privilege tool credentials and auditable tool execution.
|
||||||
|
8. Monitor auth failures, tool anomalies, ingestion spikes, and cost anomalies.
|
||||||
|
9. Keep dependencies/images patched and scanned.
|
||||||
|
10. Validate multi-tenant isolation with explicit tests.
|
||||||
|
|
||||||
|
## 11) Maintenance
|
||||||
|
|
||||||
|
Review this model after major auth, ingestion, connector, tool, or workflow changes.
|
||||||
|
|
||||||
|
## References
|
||||||
|
|
||||||
|
- [OWASP Top 10 for LLM Applications](https://owasp.org/www-project-top-10-for-large-language-model-applications/)
|
||||||
|
- [OWASP ASVS](https://owasp.org/www-project-application-security-verification-standard/)
|
||||||
|
- [STRIDE overview](https://learn.microsoft.com/azure/security/develop/threat-modeling-tool-threats)
|
||||||
|
- [DocsGPT SECURITY.md](../SECURITY.md)
|
||||||
@@ -1,46 +1,80 @@
|
|||||||
Ollama
|
Agentic
|
||||||
Qdrant
|
Anthropic's
|
||||||
Milvus
|
api
|
||||||
Chatwoot
|
|
||||||
Nextra
|
|
||||||
VSCode
|
|
||||||
npm
|
|
||||||
LLMs
|
|
||||||
APIs
|
APIs
|
||||||
Groq
|
Atlassian
|
||||||
SGLang
|
automations
|
||||||
LMDeploy
|
autoescaping
|
||||||
OAuth
|
Autoescaping
|
||||||
Vite
|
backfill
|
||||||
LLM
|
backfills
|
||||||
JSONPath
|
bool
|
||||||
UIs
|
boolean
|
||||||
|
brave_web_search
|
||||||
|
chatbot
|
||||||
|
Chatwoot
|
||||||
|
config
|
||||||
configs
|
configs
|
||||||
uncomment
|
CSVs
|
||||||
qdrant
|
dev
|
||||||
vectorstore
|
diarization
|
||||||
|
Docling
|
||||||
docsgpt
|
docsgpt
|
||||||
llm
|
docstrings
|
||||||
|
Entra
|
||||||
|
env
|
||||||
|
enqueues
|
||||||
|
EOL
|
||||||
|
ESLint
|
||||||
|
feedbacks
|
||||||
|
Figma
|
||||||
GPUs
|
GPUs
|
||||||
|
Groq
|
||||||
|
hardcode
|
||||||
|
hardcoding
|
||||||
|
Idempotency
|
||||||
|
JSONPath
|
||||||
kubectl
|
kubectl
|
||||||
Lightsail
|
Lightsail
|
||||||
enqueues
|
llama_cpp
|
||||||
chatbot
|
llm
|
||||||
VSCode's
|
LLM
|
||||||
Shareability
|
LLMs
|
||||||
feedbacks
|
LMDeploy
|
||||||
automations
|
Milvus
|
||||||
|
Mixtral
|
||||||
|
namespace
|
||||||
|
namespaces
|
||||||
|
needs_auth
|
||||||
|
Nextra
|
||||||
|
Novita
|
||||||
|
npm
|
||||||
|
OAuth
|
||||||
|
Ollama
|
||||||
|
opencode
|
||||||
|
parsable
|
||||||
|
passthrough
|
||||||
|
PDFs
|
||||||
|
pgvector
|
||||||
|
Postgres
|
||||||
Premade
|
Premade
|
||||||
Signup
|
Pydantic
|
||||||
|
pytest
|
||||||
|
Qdrant
|
||||||
|
qdrant
|
||||||
Repo
|
Repo
|
||||||
repo
|
repo
|
||||||
env
|
Sanitization
|
||||||
URl
|
|
||||||
agentic
|
|
||||||
llama_cpp
|
|
||||||
parsable
|
|
||||||
SDKs
|
SDKs
|
||||||
boolean
|
SGLang
|
||||||
bool
|
Shareability
|
||||||
hardcode
|
Signup
|
||||||
EOL
|
Supabase
|
||||||
|
UIs
|
||||||
|
uncomment
|
||||||
|
URl
|
||||||
|
vectorstore
|
||||||
|
Vite
|
||||||
|
VSCode
|
||||||
|
VSCode's
|
||||||
|
widget's
|
||||||
|
|||||||
2
.github/workflows/bandit.yaml
vendored
2
.github/workflows/bandit.yaml
vendored
@@ -21,7 +21,7 @@ jobs:
|
|||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: Set up Python
|
- name: Set up Python
|
||||||
uses: actions/setup-python@v5
|
uses: actions/setup-python@v6
|
||||||
with:
|
with:
|
||||||
python-version: '3.12'
|
python-version: '3.12'
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
|
|||||||
2
.github/workflows/pytest.yml
vendored
2
.github/workflows/pytest.yml
vendored
@@ -14,7 +14,7 @@ jobs:
|
|||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
- name: Set up Python ${{ matrix.python-version }}
|
- name: Set up Python ${{ matrix.python-version }}
|
||||||
uses: actions/setup-python@v5
|
uses: actions/setup-python@v6
|
||||||
with:
|
with:
|
||||||
python-version: ${{ matrix.python-version }}
|
python-version: ${{ matrix.python-version }}
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
|
|||||||
22
.github/workflows/vale.yml
vendored
22
.github/workflows/vale.yml
vendored
@@ -11,7 +11,6 @@ on:
|
|||||||
|
|
||||||
permissions:
|
permissions:
|
||||||
contents: read
|
contents: read
|
||||||
pull-requests: write
|
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
vale:
|
vale:
|
||||||
@@ -20,11 +19,16 @@ jobs:
|
|||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: Vale linter
|
- name: Install Vale
|
||||||
uses: errata-ai/vale-action@v2
|
run: |
|
||||||
with:
|
curl -fsSL -o vale.tar.gz \
|
||||||
files: docs
|
https://github.com/errata-ai/vale/releases/download/v3.0.5/vale_3.0.5_Linux_64-bit.tar.gz
|
||||||
fail_on_error: false
|
tar -xzf vale.tar.gz
|
||||||
version: 3.0.5
|
sudo mv vale /usr/local/bin/vale
|
||||||
env:
|
vale --version
|
||||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
|
||||||
|
- name: Sync Vale packages
|
||||||
|
run: vale sync
|
||||||
|
|
||||||
|
- name: Run Vale
|
||||||
|
run: vale --minAlertLevel=error docs
|
||||||
|
|||||||
25
.github/workflows/zizmor.yml
vendored
Normal file
25
.github/workflows/zizmor.yml
vendored
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
name: GitHub Actions Security Analysis
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches: ["master"]
|
||||||
|
pull_request:
|
||||||
|
branches: ["**"]
|
||||||
|
|
||||||
|
permissions: {}
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
zizmor:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
|
permissions:
|
||||||
|
security-events: write # Required for upload-sarif (used by zizmor-action) to upload SARIF files.
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Checkout repository
|
||||||
|
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||||
|
with:
|
||||||
|
persist-credentials: false
|
||||||
|
|
||||||
|
- name: Run zizmor 🌈
|
||||||
|
uses: zizmorcore/zizmor-action@71321a20a9ded102f6e9ce5718a2fcec2c4f70d8 # v0.5.2
|
||||||
11
.gitignore
vendored
11
.gitignore
vendored
@@ -108,6 +108,8 @@ celerybeat.pid
|
|||||||
# Environments
|
# Environments
|
||||||
.env
|
.env
|
||||||
.venv
|
.venv
|
||||||
|
# Machine-specific Claude Code guidance (see CLAUDE.md preamble)
|
||||||
|
CLAUDE.md
|
||||||
env/
|
env/
|
||||||
venv/
|
venv/
|
||||||
ENV/
|
ENV/
|
||||||
@@ -181,5 +183,14 @@ application/vectors/
|
|||||||
|
|
||||||
node_modules/
|
node_modules/
|
||||||
.vscode/settings.json
|
.vscode/settings.json
|
||||||
|
.vscode/sftp.json
|
||||||
/models/
|
/models/
|
||||||
model/
|
model/
|
||||||
|
|
||||||
|
# E2E test artifacts
|
||||||
|
.e2e-tmp/
|
||||||
|
/tmp/docsgpt-e2e/
|
||||||
|
tests/e2e/node_modules/
|
||||||
|
tests/e2e/playwright-report/
|
||||||
|
tests/e2e/test-results/
|
||||||
|
tests/e2e/.e2e-last-run.json
|
||||||
|
|||||||
@@ -1,5 +1,7 @@
|
|||||||
MinAlertLevel = warning
|
MinAlertLevel = warning
|
||||||
StylesPath = .github/styles
|
StylesPath = .github/styles
|
||||||
|
Vocab = DocsGPT
|
||||||
|
|
||||||
[*.{md,mdx}]
|
[*.{md,mdx}]
|
||||||
BasedOnStyles = DocsGPT
|
BasedOnStyles = DocsGPT
|
||||||
|
|
||||||
|
|||||||
10
AGENTS.md
10
AGENTS.md
@@ -10,9 +10,15 @@
|
|||||||
For feature work, do **not** assume the environment needs to be recreated.
|
For feature work, do **not** assume the environment needs to be recreated.
|
||||||
|
|
||||||
- Check whether the user already has a Python virtual environment such as `venv/` or `.venv/`.
|
- Check whether the user already has a Python virtual environment such as `venv/` or `.venv/`.
|
||||||
- Check whether MongoDB is already running.
|
- Check whether Postgres is already running and reachable via `POSTGRES_URI` (the canonical user-data store).
|
||||||
- Check whether Redis is already running.
|
- Check whether Redis is already running.
|
||||||
- Reuse what is already working. Do not stop or recreate MongoDB, Redis, or the Python environment unless the task is environment setup or troubleshooting.
|
- Reuse what is already working. Do not stop or recreate Postgres, Redis, or the Python environment unless the task is environment setup or troubleshooting.
|
||||||
|
|
||||||
|
> MongoDB is **not** required for the default install. It is only needed if
|
||||||
|
> the user opts into the Mongo vector-store backend (`VECTOR_STORE=mongodb`)
|
||||||
|
> or is running the one-shot `scripts/db/backfill.py` to migrate existing
|
||||||
|
> user data from the legacy Mongo-based install. In those cases, `pymongo`
|
||||||
|
> is available as an optional extra, not a core dependency.
|
||||||
|
|
||||||
## Normal local development commands
|
## Normal local development commands
|
||||||
|
|
||||||
|
|||||||
12
SECURITY.md
12
SECURITY.md
@@ -2,9 +2,7 @@
|
|||||||
|
|
||||||
## Supported Versions
|
## Supported Versions
|
||||||
|
|
||||||
Supported Versions:
|
Security patches target the latest release and the `main` branch. We recommend always running the most recent version.
|
||||||
|
|
||||||
Currently, we support security patches by committing changes and bumping the version published on Github.
|
|
||||||
|
|
||||||
## Reporting a Vulnerability
|
## Reporting a Vulnerability
|
||||||
|
|
||||||
@@ -14,7 +12,11 @@ https://github.com/arc53/DocsGPT/security
|
|||||||
Then click **Report a vulnerability**.
|
Then click **Report a vulnerability**.
|
||||||
|
|
||||||
|
|
||||||
Alternatively:
|
Alternatively, email us at: security@arc53.com
|
||||||
|
|
||||||
security@arc53.com
|
We aim to acknowledge reports within 48 hours.
|
||||||
|
|
||||||
|
## Incident Handling
|
||||||
|
|
||||||
|
For the public incident response process, see [`INCIDENT_RESPONSE.md`](./.github/INCIDENT_RESPONSE.md). If you believe an active exploit is occurring, include **URGENT** in your report subject line.
|
||||||
|
|
||||||
|
|||||||
@@ -3,13 +3,12 @@ import uuid
|
|||||||
from collections import Counter
|
from collections import Counter
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
from bson.objectid import ObjectId
|
|
||||||
|
|
||||||
from application.agents.tools.tool_action_parser import ToolActionParser
|
from application.agents.tools.tool_action_parser import ToolActionParser
|
||||||
from application.agents.tools.tool_manager import ToolManager
|
from application.agents.tools.tool_manager import ToolManager
|
||||||
from application.core.mongo_db import MongoDB
|
|
||||||
from application.core.settings import settings
|
|
||||||
from application.security.encryption import decrypt_credentials
|
from application.security.encryption import decrypt_credentials
|
||||||
|
from application.storage.db.repositories.agents import AgentsRepository
|
||||||
|
from application.storage.db.repositories.user_tools import UserToolsRepository
|
||||||
|
from application.storage.db.session import db_readonly
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -51,30 +50,28 @@ class ToolExecutor:
|
|||||||
return tools
|
return tools
|
||||||
|
|
||||||
def _get_tools_by_api_key(self, api_key: str) -> Dict[str, Dict]:
|
def _get_tools_by_api_key(self, api_key: str) -> Dict[str, Dict]:
|
||||||
mongo = MongoDB.get_client()
|
# Per-operation session: the answer pipeline spans a long-lived
|
||||||
db = mongo[settings.MONGO_DB_NAME]
|
# generator; wrapping it in a single connection would pin a PG
|
||||||
agents_collection = db["agents"]
|
# conn for the whole stream. Open, fetch, close.
|
||||||
tools_collection = db["user_tools"]
|
with db_readonly() as conn:
|
||||||
|
agent_data = AgentsRepository(conn).find_by_key(api_key)
|
||||||
agent_data = agents_collection.find_one({"key": api_key})
|
tool_ids = agent_data.get("tools", []) if agent_data else []
|
||||||
tool_ids = agent_data.get("tools", []) if agent_data else []
|
if not tool_ids:
|
||||||
|
return {}
|
||||||
tools = (
|
tools_repo = UserToolsRepository(conn)
|
||||||
tools_collection.find(
|
tools: List[Dict] = []
|
||||||
{"_id": {"$in": [ObjectId(tool_id) for tool_id in tool_ids]}}
|
owner = (agent_data.get("user_id") or agent_data.get("user")) if agent_data else None
|
||||||
)
|
for tid in tool_ids:
|
||||||
if tool_ids
|
row = None
|
||||||
else []
|
if owner:
|
||||||
)
|
row = tools_repo.get_any(str(tid), owner)
|
||||||
tools = list(tools)
|
if row is not None:
|
||||||
return {str(tool["_id"]): tool for tool in tools} if tools else {}
|
tools.append(row)
|
||||||
|
return {str(tool["id"]): tool for tool in tools} if tools else {}
|
||||||
|
|
||||||
def _get_user_tools(self, user: str = "local") -> Dict[str, Dict]:
|
def _get_user_tools(self, user: str = "local") -> Dict[str, Dict]:
|
||||||
mongo = MongoDB.get_client()
|
with db_readonly() as conn:
|
||||||
db = mongo[settings.MONGO_DB_NAME]
|
user_tools = UserToolsRepository(conn).list_active_for_user(user)
|
||||||
user_tools_collection = db["user_tools"]
|
|
||||||
user_tools = user_tools_collection.find({"user": user, "status": True})
|
|
||||||
user_tools = list(user_tools)
|
|
||||||
return {str(i): tool for i, tool in enumerate(user_tools)}
|
return {str(i): tool for i, tool in enumerate(user_tools)}
|
||||||
|
|
||||||
def merge_client_tools(
|
def merge_client_tools(
|
||||||
@@ -354,6 +351,17 @@ class ToolExecutor:
|
|||||||
headers=headers, query_params=query_params,
|
headers=headers, query_params=query_params,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if tool is None:
|
||||||
|
error_message = (
|
||||||
|
f"Failed to load tool '{tool_data.get('name')}' (tool_id key={tool_id}): "
|
||||||
|
"missing 'id' on tool row."
|
||||||
|
)
|
||||||
|
logger.error(error_message)
|
||||||
|
tool_call_data["result"] = error_message
|
||||||
|
yield {"type": "tool_call", "data": {**tool_call_data, "status": "error"}}
|
||||||
|
self.tool_calls.append(tool_call_data)
|
||||||
|
return error_message, call_id
|
||||||
|
|
||||||
resolved_arguments = (
|
resolved_arguments = (
|
||||||
{"query_params": query_params, "headers": headers, "body": body}
|
{"query_params": query_params, "headers": headers, "body": body}
|
||||||
if tool_data["name"] == "api_tool"
|
if tool_data["name"] == "api_tool"
|
||||||
@@ -440,7 +448,16 @@ class ToolExecutor:
|
|||||||
tool_config.update(decrypted)
|
tool_config.update(decrypted)
|
||||||
tool_config["auth_credentials"] = decrypted
|
tool_config["auth_credentials"] = decrypted
|
||||||
tool_config.pop("encrypted_credentials", None)
|
tool_config.pop("encrypted_credentials", None)
|
||||||
tool_config["tool_id"] = str(tool_data.get("_id", tool_id))
|
row_id = tool_data.get("id")
|
||||||
|
if not row_id:
|
||||||
|
logger.error(
|
||||||
|
"Tool data missing 'id' for tool name=%s (enumerate-key tool_id=%s); "
|
||||||
|
"skipping load to avoid binding a non-UUID downstream.",
|
||||||
|
tool_data.get("name"),
|
||||||
|
tool_id,
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
tool_config["tool_id"] = str(row_id)
|
||||||
if self.conversation_id:
|
if self.conversation_id:
|
||||||
tool_config["conversation_id"] = self.conversation_id
|
tool_config["conversation_id"] = self.conversation_id
|
||||||
if tool_data["name"] == "mcp_tool":
|
if tool_data["name"] == "mcp_tool":
|
||||||
|
|||||||
@@ -73,7 +73,7 @@ class BraveSearchTool(Tool):
|
|||||||
"X-Subscription-Token": self.token,
|
"X-Subscription-Token": self.token,
|
||||||
}
|
}
|
||||||
|
|
||||||
response = requests.get(url, params=params, headers=headers)
|
response = requests.get(url, params=params, headers=headers, timeout=100)
|
||||||
|
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
return {
|
return {
|
||||||
@@ -118,7 +118,7 @@ class BraveSearchTool(Tool):
|
|||||||
"X-Subscription-Token": self.token,
|
"X-Subscription-Token": self.token,
|
||||||
}
|
}
|
||||||
|
|
||||||
response = requests.get(url, params=params, headers=headers)
|
response = requests.get(url, params=params, headers=headers, timeout=100)
|
||||||
|
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
return {
|
return {
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ class CryptoPriceTool(Tool):
|
|||||||
returns price in USD.
|
returns price in USD.
|
||||||
"""
|
"""
|
||||||
url = f"https://min-api.cryptocompare.com/data/price?fsym={symbol.upper()}&tsyms={currency.upper()}"
|
url = f"https://min-api.cryptocompare.com/data/price?fsym={symbol.upper()}&tsyms={currency.upper()}"
|
||||||
response = requests.get(url)
|
response = requests.get(url, timeout=100)
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
data = response.json()
|
data = response.json()
|
||||||
if currency.upper() in data:
|
if currency.upper() in data:
|
||||||
|
|||||||
@@ -48,7 +48,7 @@ class InternalSearchTool(Tool):
|
|||||||
return self._retriever
|
return self._retriever
|
||||||
|
|
||||||
def _get_directory_structure(self) -> Optional[Dict]:
|
def _get_directory_structure(self) -> Optional[Dict]:
|
||||||
"""Load directory structure from MongoDB for the configured sources."""
|
"""Load directory structure from Postgres for the configured sources."""
|
||||||
if self._dir_structure_loaded:
|
if self._dir_structure_loaded:
|
||||||
return self._directory_structure
|
return self._directory_structure
|
||||||
|
|
||||||
@@ -59,35 +59,39 @@ class InternalSearchTool(Tool):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from bson.objectid import ObjectId
|
# Per-operation session: this tool runs inside the answer
|
||||||
from application.core.mongo_db import MongoDB
|
# generator hot path, so we open a short-lived read
|
||||||
|
# connection for the batch lookup and release immediately.
|
||||||
mongo = MongoDB.get_client()
|
from application.storage.db.repositories.sources import (
|
||||||
db = mongo[settings.MONGO_DB_NAME]
|
SourcesRepository,
|
||||||
sources_collection = db["sources"]
|
)
|
||||||
|
from application.storage.db.session import db_readonly
|
||||||
|
|
||||||
if isinstance(active_docs, str):
|
if isinstance(active_docs, str):
|
||||||
active_docs = [active_docs]
|
active_docs = [active_docs]
|
||||||
|
|
||||||
|
decoded_token = self.config.get("decoded_token") or {}
|
||||||
|
user_id = decoded_token.get("sub") if decoded_token else None
|
||||||
|
|
||||||
merged_structure = {}
|
merged_structure = {}
|
||||||
for doc_id in active_docs:
|
with db_readonly() as conn:
|
||||||
try:
|
repo = SourcesRepository(conn)
|
||||||
source_doc = sources_collection.find_one(
|
for doc_id in active_docs:
|
||||||
{"_id": ObjectId(doc_id)}
|
try:
|
||||||
)
|
source_doc = repo.get_any(str(doc_id), user_id) if user_id else None
|
||||||
if not source_doc:
|
if not source_doc:
|
||||||
continue
|
continue
|
||||||
dir_str = source_doc.get("directory_structure")
|
dir_str = source_doc.get("directory_structure")
|
||||||
if dir_str:
|
if dir_str:
|
||||||
if isinstance(dir_str, str):
|
if isinstance(dir_str, str):
|
||||||
dir_str = json.loads(dir_str)
|
dir_str = json.loads(dir_str)
|
||||||
source_name = source_doc.get("name", doc_id)
|
source_name = source_doc.get("name", doc_id)
|
||||||
if len(active_docs) > 1:
|
if len(active_docs) > 1:
|
||||||
merged_structure[source_name] = dir_str
|
merged_structure[source_name] = dir_str
|
||||||
else:
|
else:
|
||||||
merged_structure = dir_str
|
merged_structure = dir_str
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug(f"Could not load dir structure for {doc_id}: {e}")
|
logger.debug(f"Could not load dir structure for {doc_id}: {e}")
|
||||||
|
|
||||||
self._directory_structure = merged_structure if merged_structure else None
|
self._directory_structure = merged_structure if merged_structure else None
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -357,32 +361,48 @@ INTERNAL_TOOL_ENTRY = build_internal_tool_entry(has_directory_structure=False)
|
|||||||
|
|
||||||
|
|
||||||
def sources_have_directory_structure(source: Dict) -> bool:
|
def sources_have_directory_structure(source: Dict) -> bool:
|
||||||
"""Check if any of the active sources have directory_structure in MongoDB."""
|
"""Check if any of the active sources have a ``directory_structure`` row."""
|
||||||
active_docs = source.get("active_docs", [])
|
active_docs = source.get("active_docs", [])
|
||||||
if not active_docs:
|
if not active_docs:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from bson.objectid import ObjectId
|
# TODO(pg-cutover): SourcesRepository.get_any requires ``user_id``
|
||||||
from application.core.mongo_db import MongoDB
|
# scoping, but callers in the agent build path don't always
|
||||||
|
# thread the decoded token through here. Use a direct
|
||||||
|
# short-lived SQL lookup instead of the repo until the call
|
||||||
|
# sites are updated to propagate user context.
|
||||||
|
from sqlalchemy import text as _text
|
||||||
|
|
||||||
mongo = MongoDB.get_client()
|
from application.storage.db.session import db_readonly
|
||||||
db = mongo[settings.MONGO_DB_NAME]
|
|
||||||
sources_collection = db["sources"]
|
|
||||||
|
|
||||||
if isinstance(active_docs, str):
|
if isinstance(active_docs, str):
|
||||||
active_docs = [active_docs]
|
active_docs = [active_docs]
|
||||||
|
|
||||||
for doc_id in active_docs:
|
with db_readonly() as conn:
|
||||||
try:
|
for doc_id in active_docs:
|
||||||
source_doc = sources_collection.find_one(
|
try:
|
||||||
{"_id": ObjectId(doc_id)},
|
value = str(doc_id)
|
||||||
{"directory_structure": 1},
|
if len(value) == 36 and "-" in value:
|
||||||
)
|
row = conn.execute(
|
||||||
if source_doc and source_doc.get("directory_structure"):
|
_text(
|
||||||
return True
|
"SELECT directory_structure FROM sources "
|
||||||
except Exception:
|
"WHERE id = CAST(:id AS uuid)"
|
||||||
continue
|
),
|
||||||
|
{"id": value},
|
||||||
|
).fetchone()
|
||||||
|
else:
|
||||||
|
row = conn.execute(
|
||||||
|
_text(
|
||||||
|
"SELECT directory_structure FROM sources "
|
||||||
|
"WHERE legacy_mongo_id = :lid"
|
||||||
|
),
|
||||||
|
{"lid": value},
|
||||||
|
).fetchone()
|
||||||
|
if row is not None and row[0]:
|
||||||
|
return True
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug(f"Could not check directory structure: {e}")
|
logger.debug(f"Could not check directory structure: {e}")
|
||||||
|
|
||||||
|
|||||||
@@ -22,16 +22,12 @@ from redis import Redis
|
|||||||
from application.agents.tools.base import Tool
|
from application.agents.tools.base import Tool
|
||||||
from application.api.user.tasks import mcp_oauth_status_task, mcp_oauth_task
|
from application.api.user.tasks import mcp_oauth_status_task, mcp_oauth_task
|
||||||
from application.cache import get_redis_instance
|
from application.cache import get_redis_instance
|
||||||
from application.core.mongo_db import MongoDB
|
|
||||||
from application.core.settings import settings
|
from application.core.settings import settings
|
||||||
from application.core.url_validation import SSRFError, validate_url
|
from application.core.url_validation import SSRFError, validate_url
|
||||||
from application.security.encryption import decrypt_credentials
|
from application.security.encryption import decrypt_credentials
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
mongo = MongoDB.get_client()
|
|
||||||
db = mongo[settings.MONGO_DB_NAME]
|
|
||||||
|
|
||||||
_mcp_clients_cache = {}
|
_mcp_clients_cache = {}
|
||||||
|
|
||||||
|
|
||||||
@@ -161,7 +157,6 @@ class MCPTool(Tool):
|
|||||||
scopes=self.oauth_scopes,
|
scopes=self.oauth_scopes,
|
||||||
redis_client=redis_client,
|
redis_client=redis_client,
|
||||||
redirect_uri=self.redirect_uri,
|
redirect_uri=self.redirect_uri,
|
||||||
db=db,
|
|
||||||
user_id=self.user_id,
|
user_id=self.user_id,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@@ -171,7 +166,6 @@ class MCPTool(Tool):
|
|||||||
redis_client=redis_client,
|
redis_client=redis_client,
|
||||||
redirect_uri=self.redirect_uri,
|
redirect_uri=self.redirect_uri,
|
||||||
task_id=self.oauth_task_id,
|
task_id=self.oauth_task_id,
|
||||||
db=db,
|
|
||||||
user_id=self.user_id,
|
user_id=self.user_id,
|
||||||
)
|
)
|
||||||
elif self.auth_type == "bearer":
|
elif self.auth_type == "bearer":
|
||||||
@@ -491,7 +485,7 @@ class MCPTool(Tool):
|
|||||||
|
|
||||||
def _test_oauth_connection(self) -> Dict:
|
def _test_oauth_connection(self) -> Dict:
|
||||||
storage = DBTokenStorage(
|
storage = DBTokenStorage(
|
||||||
server_url=self.server_url, user_id=self.user_id, db_client=db
|
server_url=self.server_url, user_id=self.user_id,
|
||||||
)
|
)
|
||||||
loop = asyncio.new_event_loop()
|
loop = asyncio.new_event_loop()
|
||||||
try:
|
try:
|
||||||
@@ -683,7 +677,6 @@ class DocsGPTOAuth(OAuthClientProvider):
|
|||||||
scopes: str | list[str] | None = None,
|
scopes: str | list[str] | None = None,
|
||||||
client_name: str = "DocsGPT-MCP",
|
client_name: str = "DocsGPT-MCP",
|
||||||
user_id=None,
|
user_id=None,
|
||||||
db=None,
|
|
||||||
additional_client_metadata: dict[str, Any] | None = None,
|
additional_client_metadata: dict[str, Any] | None = None,
|
||||||
skip_redirect_validation: bool = False,
|
skip_redirect_validation: bool = False,
|
||||||
):
|
):
|
||||||
@@ -692,7 +685,6 @@ class DocsGPTOAuth(OAuthClientProvider):
|
|||||||
self.redis_prefix = redis_prefix
|
self.redis_prefix = redis_prefix
|
||||||
self.task_id = task_id
|
self.task_id = task_id
|
||||||
self.user_id = user_id
|
self.user_id = user_id
|
||||||
self.db = db
|
|
||||||
|
|
||||||
parsed_url = urlparse(mcp_url)
|
parsed_url = urlparse(mcp_url)
|
||||||
self.server_base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
|
self.server_base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
|
||||||
@@ -711,7 +703,6 @@ class DocsGPTOAuth(OAuthClientProvider):
|
|||||||
storage = DBTokenStorage(
|
storage = DBTokenStorage(
|
||||||
server_url=self.server_base_url,
|
server_url=self.server_base_url,
|
||||||
user_id=self.user_id,
|
user_id=self.user_id,
|
||||||
db_client=self.db,
|
|
||||||
expected_redirect_uri=None if skip_redirect_validation else redirect_uri,
|
expected_redirect_uri=None if skip_redirect_validation else redirect_uri,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -853,54 +844,95 @@ class DBTokenStorage(TokenStorage):
|
|||||||
self,
|
self,
|
||||||
server_url: str,
|
server_url: str,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
db_client,
|
|
||||||
expected_redirect_uri: Optional[str] = None,
|
expected_redirect_uri: Optional[str] = None,
|
||||||
):
|
):
|
||||||
self.server_url = server_url
|
self.server_url = server_url
|
||||||
self.user_id = user_id
|
self.user_id = user_id
|
||||||
self.db_client = db_client
|
|
||||||
self.expected_redirect_uri = expected_redirect_uri
|
self.expected_redirect_uri = expected_redirect_uri
|
||||||
self.collection = db_client["connector_sessions"]
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_base_url(url: str) -> str:
|
def get_base_url(url: str) -> str:
|
||||||
parsed = urlparse(url)
|
parsed = urlparse(url)
|
||||||
return f"{parsed.scheme}://{parsed.netloc}"
|
return f"{parsed.scheme}://{parsed.netloc}"
|
||||||
|
|
||||||
def get_db_key(self) -> dict:
|
def _pg_provider(self) -> str:
|
||||||
return {
|
return f"mcp:{self.get_base_url(self.server_url)}"
|
||||||
"server_url": self.get_base_url(self.server_url),
|
|
||||||
"user_id": self.user_id,
|
def _fetch_session_data(self) -> dict:
|
||||||
}
|
"""Read the JSONB ``session_data`` blob for this MCP server row."""
|
||||||
|
from application.storage.db.repositories.connector_sessions import (
|
||||||
|
ConnectorSessionsRepository,
|
||||||
|
)
|
||||||
|
from application.storage.db.session import db_readonly
|
||||||
|
|
||||||
|
base_url = self.get_base_url(self.server_url)
|
||||||
|
with db_readonly() as conn:
|
||||||
|
row = ConnectorSessionsRepository(conn).get_by_user_and_server_url(
|
||||||
|
self.user_id, base_url,
|
||||||
|
)
|
||||||
|
if not row:
|
||||||
|
return {}
|
||||||
|
data = row.get("session_data") or {}
|
||||||
|
if isinstance(data, str):
|
||||||
|
try:
|
||||||
|
data = json.loads(data)
|
||||||
|
except ValueError:
|
||||||
|
return {}
|
||||||
|
return data if isinstance(data, dict) else {}
|
||||||
|
|
||||||
async def get_tokens(self) -> OAuthToken | None:
|
async def get_tokens(self) -> OAuthToken | None:
|
||||||
doc = await asyncio.to_thread(self.collection.find_one, self.get_db_key())
|
data = await asyncio.to_thread(self._fetch_session_data)
|
||||||
if not doc or "tokens" not in doc:
|
if not data or "tokens" not in data:
|
||||||
return None
|
return None
|
||||||
try:
|
try:
|
||||||
return OAuthToken.model_validate(doc["tokens"])
|
return OAuthToken.model_validate(data["tokens"])
|
||||||
except ValidationError as e:
|
except ValidationError as e:
|
||||||
logger.error("Could not load tokens: %s", e)
|
logger.error("Could not load tokens: %s", e)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def set_tokens(self, tokens: OAuthToken) -> None:
|
def _merge(self, patch: dict) -> None:
|
||||||
await asyncio.to_thread(
|
"""Shallow-merge ``patch`` into this row's ``session_data``.
|
||||||
self.collection.update_one,
|
|
||||||
self.get_db_key(),
|
Threads ``server_url`` through to the repository so it lands in
|
||||||
{"$set": {"tokens": tokens.model_dump()}},
|
the scalar column — ``get_by_user_and_server_url`` needs that to
|
||||||
True,
|
resolve the row (``NULL = 'https://...'`` is UNKNOWN in SQL).
|
||||||
|
"""
|
||||||
|
from application.storage.db.repositories.connector_sessions import (
|
||||||
|
ConnectorSessionsRepository,
|
||||||
)
|
)
|
||||||
logger.info("Saved tokens for %s", self.get_base_url(self.server_url))
|
from application.storage.db.session import db_session
|
||||||
|
|
||||||
|
base_url = self.get_base_url(self.server_url)
|
||||||
|
with db_session() as conn:
|
||||||
|
ConnectorSessionsRepository(conn).merge_session_data(
|
||||||
|
self.user_id, self._pg_provider(), base_url, patch,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _delete(self) -> None:
|
||||||
|
from application.storage.db.repositories.connector_sessions import (
|
||||||
|
ConnectorSessionsRepository,
|
||||||
|
)
|
||||||
|
from application.storage.db.session import db_session
|
||||||
|
|
||||||
|
with db_session() as conn:
|
||||||
|
ConnectorSessionsRepository(conn).delete(
|
||||||
|
self.user_id, self._pg_provider(),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def set_tokens(self, tokens: OAuthToken) -> None:
|
||||||
|
base_url = self.get_base_url(self.server_url)
|
||||||
|
token_dump = tokens.model_dump()
|
||||||
|
await asyncio.to_thread(self._merge, {"tokens": token_dump})
|
||||||
|
logger.info("Saved tokens for %s", base_url)
|
||||||
|
|
||||||
async def get_client_info(self) -> OAuthClientInformationFull | None:
|
async def get_client_info(self) -> OAuthClientInformationFull | None:
|
||||||
doc = await asyncio.to_thread(self.collection.find_one, self.get_db_key())
|
data = await asyncio.to_thread(self._fetch_session_data)
|
||||||
if not doc or "client_info" not in doc:
|
base_url = self.get_base_url(self.server_url)
|
||||||
logger.debug(
|
if not data or "client_info" not in data:
|
||||||
"No client_info in DB for %s", self.get_base_url(self.server_url)
|
logger.debug("No client_info in DB for %s", base_url)
|
||||||
)
|
|
||||||
return None
|
return None
|
||||||
try:
|
try:
|
||||||
client_info = OAuthClientInformationFull.model_validate(doc["client_info"])
|
client_info = OAuthClientInformationFull.model_validate(data["client_info"])
|
||||||
if self.expected_redirect_uri:
|
if self.expected_redirect_uri:
|
||||||
stored_uris = [
|
stored_uris = [
|
||||||
str(uri).rstrip("/") for uri in client_info.redirect_uris
|
str(uri).rstrip("/") for uri in client_info.redirect_uris
|
||||||
@@ -909,14 +941,16 @@ class DBTokenStorage(TokenStorage):
|
|||||||
if expected_uri not in stored_uris:
|
if expected_uri not in stored_uris:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Redirect URI mismatch for %s: expected=%s stored=%s — clearing.",
|
"Redirect URI mismatch for %s: expected=%s stored=%s — clearing.",
|
||||||
self.get_base_url(self.server_url),
|
base_url,
|
||||||
expected_uri,
|
expected_uri,
|
||||||
stored_uris,
|
stored_uris,
|
||||||
)
|
)
|
||||||
|
# Drop ``tokens`` and ``client_info`` from the JSONB
|
||||||
|
# blob via merge_session_data's ``None``-drops-key
|
||||||
|
# semantics — preserves the row + any other keys.
|
||||||
await asyncio.to_thread(
|
await asyncio.to_thread(
|
||||||
self.collection.update_one,
|
self._merge,
|
||||||
self.get_db_key(),
|
{"tokens": None, "client_info": None},
|
||||||
{"$unset": {"client_info": "", "tokens": ""}},
|
|
||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
return client_info
|
return client_info
|
||||||
@@ -931,22 +965,37 @@ class DBTokenStorage(TokenStorage):
|
|||||||
|
|
||||||
async def set_client_info(self, client_info: OAuthClientInformationFull) -> None:
|
async def set_client_info(self, client_info: OAuthClientInformationFull) -> None:
|
||||||
serialized_info = self._serialize_client_info(client_info.model_dump())
|
serialized_info = self._serialize_client_info(client_info.model_dump())
|
||||||
|
base_url = self.get_base_url(self.server_url)
|
||||||
await asyncio.to_thread(
|
await asyncio.to_thread(
|
||||||
self.collection.update_one,
|
self._merge, {"client_info": serialized_info},
|
||||||
self.get_db_key(),
|
|
||||||
{"$set": {"client_info": serialized_info}},
|
|
||||||
True,
|
|
||||||
)
|
)
|
||||||
logger.info("Saved client info for %s", self.get_base_url(self.server_url))
|
logger.info("Saved client info for %s", base_url)
|
||||||
|
|
||||||
async def clear(self) -> None:
|
async def clear(self) -> None:
|
||||||
await asyncio.to_thread(self.collection.delete_one, self.get_db_key())
|
await asyncio.to_thread(self._delete)
|
||||||
logger.info("Cleared OAuth cache for %s", self.get_base_url(self.server_url))
|
logger.info("Cleared OAuth cache for %s", self.get_base_url(self.server_url))
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def clear_all(cls, db_client) -> None:
|
async def clear_all(cls, db_client=None) -> None:
|
||||||
collection = db_client["connector_sessions"]
|
"""Delete every MCP-tagged connector session row.
|
||||||
await asyncio.to_thread(collection.delete_many, {})
|
|
||||||
|
``db_client`` retained for call-site compatibility but unused —
|
||||||
|
storage is Postgres-only now.
|
||||||
|
"""
|
||||||
|
from sqlalchemy import text
|
||||||
|
|
||||||
|
from application.storage.db.session import db_session
|
||||||
|
|
||||||
|
def _delete_all() -> None:
|
||||||
|
with db_session() as conn:
|
||||||
|
conn.execute(
|
||||||
|
text(
|
||||||
|
"DELETE FROM connector_sessions "
|
||||||
|
"WHERE provider LIKE 'mcp:%'"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
await asyncio.to_thread(_delete_all)
|
||||||
logger.info("Cleared all OAuth client cache data.")
|
logger.info("Cleared all OAuth client cache data.")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,12 +1,14 @@
|
|||||||
from datetime import datetime
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
import re
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
from .base import Tool
|
from .base import Tool
|
||||||
from application.core.mongo_db import MongoDB
|
from application.storage.db.repositories.memories import MemoriesRepository
|
||||||
from application.core.settings import settings
|
from application.storage.db.session import db_readonly, db_session
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class MemoryTool(Tool):
|
class MemoryTool(Tool):
|
||||||
@@ -27,7 +29,7 @@ class MemoryTool(Tool):
|
|||||||
self.user_id: Optional[str] = user_id
|
self.user_id: Optional[str] = user_id
|
||||||
|
|
||||||
# Get tool_id from configuration (passed from user_tools._id in production)
|
# Get tool_id from configuration (passed from user_tools._id in production)
|
||||||
# In production, tool_id is the MongoDB ObjectId string from user_tools collection
|
# In production, tool_id is the UUID string from user_tools.id.
|
||||||
if tool_config and "tool_id" in tool_config:
|
if tool_config and "tool_id" in tool_config:
|
||||||
self.tool_id = tool_config["tool_id"]
|
self.tool_id = tool_config["tool_id"]
|
||||||
elif user_id:
|
elif user_id:
|
||||||
@@ -37,8 +39,35 @@ class MemoryTool(Tool):
|
|||||||
# Last resort fallback (shouldn't happen in normal use)
|
# Last resort fallback (shouldn't happen in normal use)
|
||||||
self.tool_id = str(uuid.uuid4())
|
self.tool_id = str(uuid.uuid4())
|
||||||
|
|
||||||
db = MongoDB.get_client()[settings.MONGO_DB_NAME]
|
def _pg_enabled(self) -> bool:
|
||||||
self.collection = db["memories"]
|
"""Return True if this MemoryTool's tool_id is a real ``user_tools.id``.
|
||||||
|
|
||||||
|
The ``memories`` PG table has a UUID foreign key to ``user_tools``.
|
||||||
|
The sentinel ``default_{uid}`` fallback tool_id is not a UUID and
|
||||||
|
has no row in ``user_tools``, so any storage operation would fail
|
||||||
|
the foreign-key check. After the Postgres cutover Postgres is the
|
||||||
|
only store, so for the sentinel case there is nowhere to read or
|
||||||
|
write — operations become no-ops and the tool returns an
|
||||||
|
explanatory error to the caller.
|
||||||
|
"""
|
||||||
|
tool_id = getattr(self, "tool_id", None)
|
||||||
|
if not tool_id or not isinstance(tool_id, str):
|
||||||
|
return False
|
||||||
|
if tool_id.startswith("default_"):
|
||||||
|
logger.debug(
|
||||||
|
"Skipping Postgres operation for MemoryTool with sentinel tool_id=%s",
|
||||||
|
tool_id,
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
from application.storage.db.base_repository import looks_like_uuid
|
||||||
|
|
||||||
|
if not looks_like_uuid(tool_id):
|
||||||
|
logger.debug(
|
||||||
|
"Skipping Postgres operation for MemoryTool with non-UUID tool_id=%s",
|
||||||
|
tool_id,
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
# -----------------------------
|
# -----------------------------
|
||||||
# Action implementations
|
# Action implementations
|
||||||
@@ -56,6 +85,12 @@ class MemoryTool(Tool):
|
|||||||
if not self.user_id:
|
if not self.user_id:
|
||||||
return "Error: MemoryTool requires a valid user_id."
|
return "Error: MemoryTool requires a valid user_id."
|
||||||
|
|
||||||
|
if not self._pg_enabled():
|
||||||
|
return (
|
||||||
|
"Error: MemoryTool is not configured with a persistent tool_id; "
|
||||||
|
"memory storage is unavailable for this session."
|
||||||
|
)
|
||||||
|
|
||||||
if action_name == "view":
|
if action_name == "view":
|
||||||
return self._view(
|
return self._view(
|
||||||
kwargs.get("path", "/"),
|
kwargs.get("path", "/"),
|
||||||
@@ -282,14 +317,10 @@ class MemoryTool(Tool):
|
|||||||
# Ensure path ends with / for proper prefix matching
|
# Ensure path ends with / for proper prefix matching
|
||||||
search_path = path if path.endswith("/") else path + "/"
|
search_path = path if path.endswith("/") else path + "/"
|
||||||
|
|
||||||
# Find all files that start with this directory path
|
with db_readonly() as conn:
|
||||||
query = {
|
docs = MemoriesRepository(conn).list_by_prefix(
|
||||||
"user_id": self.user_id,
|
self.user_id, self.tool_id, search_path
|
||||||
"tool_id": self.tool_id,
|
)
|
||||||
"path": {"$regex": f"^{re.escape(search_path)}"}
|
|
||||||
}
|
|
||||||
|
|
||||||
docs = list(self.collection.find(query, {"path": 1}))
|
|
||||||
|
|
||||||
if not docs:
|
if not docs:
|
||||||
return f"Directory: {path}\n(empty)"
|
return f"Directory: {path}\n(empty)"
|
||||||
@@ -310,7 +341,10 @@ class MemoryTool(Tool):
|
|||||||
|
|
||||||
def _view_file(self, path: str, view_range: Optional[List[int]] = None) -> str:
|
def _view_file(self, path: str, view_range: Optional[List[int]] = None) -> str:
|
||||||
"""View file contents with optional line range."""
|
"""View file contents with optional line range."""
|
||||||
doc = self.collection.find_one({"user_id": self.user_id, "tool_id": self.tool_id, "path": path})
|
with db_readonly() as conn:
|
||||||
|
doc = MemoriesRepository(conn).get_by_path(
|
||||||
|
self.user_id, self.tool_id, path
|
||||||
|
)
|
||||||
|
|
||||||
if not doc or not doc.get("content"):
|
if not doc or not doc.get("content"):
|
||||||
return f"Error: File not found: {path}"
|
return f"Error: File not found: {path}"
|
||||||
@@ -344,16 +378,10 @@ class MemoryTool(Tool):
|
|||||||
if validated_path == "/" or validated_path.endswith("/"):
|
if validated_path == "/" or validated_path.endswith("/"):
|
||||||
return "Error: Cannot create a file at directory path."
|
return "Error: Cannot create a file at directory path."
|
||||||
|
|
||||||
self.collection.update_one(
|
with db_session() as conn:
|
||||||
{"user_id": self.user_id, "tool_id": self.tool_id, "path": validated_path},
|
MemoriesRepository(conn).upsert(
|
||||||
{
|
self.user_id, self.tool_id, validated_path, file_text
|
||||||
"$set": {
|
)
|
||||||
"content": file_text,
|
|
||||||
"updated_at": datetime.now()
|
|
||||||
}
|
|
||||||
},
|
|
||||||
upsert=True
|
|
||||||
)
|
|
||||||
|
|
||||||
return f"File created: {validated_path}"
|
return f"File created: {validated_path}"
|
||||||
|
|
||||||
@@ -366,30 +394,29 @@ class MemoryTool(Tool):
|
|||||||
if not old_str:
|
if not old_str:
|
||||||
return "Error: old_str is required."
|
return "Error: old_str is required."
|
||||||
|
|
||||||
doc = self.collection.find_one({"user_id": self.user_id, "tool_id": self.tool_id, "path": validated_path})
|
with db_session() as conn:
|
||||||
|
repo = MemoriesRepository(conn)
|
||||||
|
doc = repo.get_by_path(self.user_id, self.tool_id, validated_path)
|
||||||
|
|
||||||
if not doc or not doc.get("content"):
|
if not doc or not doc.get("content"):
|
||||||
return f"Error: File not found: {validated_path}"
|
return f"Error: File not found: {validated_path}"
|
||||||
|
|
||||||
current_content = str(doc["content"])
|
current_content = str(doc["content"])
|
||||||
|
|
||||||
# Check if old_str exists (case-insensitive)
|
# Check if old_str exists (case-insensitive)
|
||||||
if old_str.lower() not in current_content.lower():
|
if old_str.lower() not in current_content.lower():
|
||||||
return f"Error: String '{old_str}' not found in file."
|
return f"Error: String '{old_str}' not found in file."
|
||||||
|
|
||||||
# Replace the string (case-insensitive)
|
# Case-insensitive replace
|
||||||
import re as regex_module
|
import re as regex_module
|
||||||
updated_content = regex_module.sub(regex_module.escape(old_str), new_str, current_content, flags=regex_module.IGNORECASE)
|
updated_content = regex_module.sub(
|
||||||
|
regex_module.escape(old_str),
|
||||||
|
new_str,
|
||||||
|
current_content,
|
||||||
|
flags=regex_module.IGNORECASE,
|
||||||
|
)
|
||||||
|
|
||||||
self.collection.update_one(
|
repo.upsert(self.user_id, self.tool_id, validated_path, updated_content)
|
||||||
{"user_id": self.user_id, "tool_id": self.tool_id, "path": validated_path},
|
|
||||||
{
|
|
||||||
"$set": {
|
|
||||||
"content": updated_content,
|
|
||||||
"updated_at": datetime.now()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
return f"File updated: {validated_path}"
|
return f"File updated: {validated_path}"
|
||||||
|
|
||||||
@@ -402,31 +429,25 @@ class MemoryTool(Tool):
|
|||||||
if not insert_text:
|
if not insert_text:
|
||||||
return "Error: insert_text is required."
|
return "Error: insert_text is required."
|
||||||
|
|
||||||
doc = self.collection.find_one({"user_id": self.user_id, "tool_id": self.tool_id, "path": validated_path})
|
with db_session() as conn:
|
||||||
|
repo = MemoriesRepository(conn)
|
||||||
|
doc = repo.get_by_path(self.user_id, self.tool_id, validated_path)
|
||||||
|
|
||||||
if not doc or not doc.get("content"):
|
if not doc or not doc.get("content"):
|
||||||
return f"Error: File not found: {validated_path}"
|
return f"Error: File not found: {validated_path}"
|
||||||
|
|
||||||
current_content = str(doc["content"])
|
current_content = str(doc["content"])
|
||||||
lines = current_content.split("\n")
|
lines = current_content.split("\n")
|
||||||
|
|
||||||
# Convert to 0-indexed
|
# Convert to 0-indexed
|
||||||
index = insert_line - 1
|
index = insert_line - 1
|
||||||
if index < 0 or index > len(lines):
|
if index < 0 or index > len(lines):
|
||||||
return f"Error: Invalid line number. File has {len(lines)} lines."
|
return f"Error: Invalid line number. File has {len(lines)} lines."
|
||||||
|
|
||||||
lines.insert(index, insert_text)
|
lines.insert(index, insert_text)
|
||||||
updated_content = "\n".join(lines)
|
updated_content = "\n".join(lines)
|
||||||
|
|
||||||
self.collection.update_one(
|
repo.upsert(self.user_id, self.tool_id, validated_path, updated_content)
|
||||||
{"user_id": self.user_id, "tool_id": self.tool_id, "path": validated_path},
|
|
||||||
{
|
|
||||||
"$set": {
|
|
||||||
"content": updated_content,
|
|
||||||
"updated_at": datetime.now()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
return f"Text inserted at line {insert_line} in {validated_path}"
|
return f"Text inserted at line {insert_line} in {validated_path}"
|
||||||
|
|
||||||
@@ -438,39 +459,36 @@ class MemoryTool(Tool):
|
|||||||
|
|
||||||
if validated_path == "/":
|
if validated_path == "/":
|
||||||
# Delete all files for this user and tool
|
# Delete all files for this user and tool
|
||||||
result = self.collection.delete_many({"user_id": self.user_id, "tool_id": self.tool_id})
|
with db_session() as conn:
|
||||||
return f"Deleted {result.deleted_count} file(s) from memory."
|
deleted = MemoriesRepository(conn).delete_all(
|
||||||
|
self.user_id, self.tool_id
|
||||||
|
)
|
||||||
|
return f"Deleted {deleted} file(s) from memory."
|
||||||
|
|
||||||
# Check if it's a directory (ends with /)
|
# Check if it's a directory (ends with /)
|
||||||
if validated_path.endswith("/"):
|
if validated_path.endswith("/"):
|
||||||
# Delete all files in directory
|
with db_session() as conn:
|
||||||
result = self.collection.delete_many({
|
deleted = MemoriesRepository(conn).delete_by_prefix(
|
||||||
"user_id": self.user_id,
|
self.user_id, self.tool_id, validated_path
|
||||||
"tool_id": self.tool_id,
|
)
|
||||||
"path": {"$regex": f"^{re.escape(validated_path)}"}
|
return f"Deleted directory and {deleted} file(s)."
|
||||||
})
|
|
||||||
return f"Deleted directory and {result.deleted_count} file(s)."
|
|
||||||
|
|
||||||
# Try to delete as directory first (without trailing slash)
|
# Try as directory first (without trailing slash)
|
||||||
# Check if any files start with this path + /
|
|
||||||
search_path = validated_path + "/"
|
search_path = validated_path + "/"
|
||||||
directory_result = self.collection.delete_many({
|
with db_session() as conn:
|
||||||
"user_id": self.user_id,
|
repo = MemoriesRepository(conn)
|
||||||
"tool_id": self.tool_id,
|
directory_deleted = repo.delete_by_prefix(
|
||||||
"path": {"$regex": f"^{re.escape(search_path)}"}
|
self.user_id, self.tool_id, search_path
|
||||||
})
|
)
|
||||||
|
if directory_deleted > 0:
|
||||||
|
return f"Deleted directory and {directory_deleted} file(s)."
|
||||||
|
|
||||||
if directory_result.deleted_count > 0:
|
# Otherwise delete a single file
|
||||||
return f"Deleted directory and {directory_result.deleted_count} file(s)."
|
file_deleted = repo.delete_by_path(
|
||||||
|
self.user_id, self.tool_id, validated_path
|
||||||
|
)
|
||||||
|
|
||||||
# Delete single file
|
if file_deleted:
|
||||||
result = self.collection.delete_one({
|
|
||||||
"user_id": self.user_id,
|
|
||||||
"tool_id": self.tool_id,
|
|
||||||
"path": validated_path
|
|
||||||
})
|
|
||||||
|
|
||||||
if result.deleted_count:
|
|
||||||
return f"Deleted: {validated_path}"
|
return f"Deleted: {validated_path}"
|
||||||
return f"Error: File not found: {validated_path}"
|
return f"Error: File not found: {validated_path}"
|
||||||
|
|
||||||
@@ -485,62 +503,46 @@ class MemoryTool(Tool):
|
|||||||
if validated_old == "/" or validated_new == "/":
|
if validated_old == "/" or validated_new == "/":
|
||||||
return "Error: Cannot rename root directory."
|
return "Error: Cannot rename root directory."
|
||||||
|
|
||||||
# Check if renaming a directory
|
# Directory rename: do all path updates inside one transaction so
|
||||||
|
# the rename is atomic from the caller's perspective.
|
||||||
if validated_old.endswith("/"):
|
if validated_old.endswith("/"):
|
||||||
# Ensure validated_new also ends with / for proper path replacement
|
# Ensure validated_new also ends with / for proper path replacement
|
||||||
if not validated_new.endswith("/"):
|
if not validated_new.endswith("/"):
|
||||||
validated_new = validated_new + "/"
|
validated_new = validated_new + "/"
|
||||||
|
|
||||||
# Find all files in the old directory
|
with db_session() as conn:
|
||||||
docs = list(self.collection.find({
|
repo = MemoriesRepository(conn)
|
||||||
"user_id": self.user_id,
|
docs = repo.list_by_prefix(
|
||||||
"tool_id": self.tool_id,
|
self.user_id, self.tool_id, validated_old
|
||||||
"path": {"$regex": f"^{re.escape(validated_old)}"}
|
|
||||||
}))
|
|
||||||
|
|
||||||
if not docs:
|
|
||||||
return f"Error: Directory not found: {validated_old}"
|
|
||||||
|
|
||||||
# Update paths for all files
|
|
||||||
for doc in docs:
|
|
||||||
old_file_path = doc["path"]
|
|
||||||
new_file_path = old_file_path.replace(validated_old, validated_new, 1)
|
|
||||||
|
|
||||||
self.collection.update_one(
|
|
||||||
{"_id": doc["_id"]},
|
|
||||||
{"$set": {"path": new_file_path, "updated_at": datetime.now()}}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if not docs:
|
||||||
|
return f"Error: Directory not found: {validated_old}"
|
||||||
|
|
||||||
|
for doc in docs:
|
||||||
|
old_file_path = doc["path"]
|
||||||
|
new_file_path = old_file_path.replace(
|
||||||
|
validated_old, validated_new, 1
|
||||||
|
)
|
||||||
|
repo.update_path(
|
||||||
|
self.user_id, self.tool_id, old_file_path, new_file_path
|
||||||
|
)
|
||||||
|
|
||||||
return f"Renamed directory: {validated_old} -> {validated_new} ({len(docs)} files)"
|
return f"Renamed directory: {validated_old} -> {validated_new} ({len(docs)} files)"
|
||||||
|
|
||||||
# Rename single file
|
# Single-file rename: lookup, collision check, and update in one txn.
|
||||||
doc = self.collection.find_one({
|
with db_session() as conn:
|
||||||
"user_id": self.user_id,
|
repo = MemoriesRepository(conn)
|
||||||
"tool_id": self.tool_id,
|
doc = repo.get_by_path(self.user_id, self.tool_id, validated_old)
|
||||||
"path": validated_old
|
if not doc:
|
||||||
})
|
return f"Error: File not found: {validated_old}"
|
||||||
|
|
||||||
if not doc:
|
existing = repo.get_by_path(self.user_id, self.tool_id, validated_new)
|
||||||
return f"Error: File not found: {validated_old}"
|
if existing:
|
||||||
|
return f"Error: File already exists at {validated_new}"
|
||||||
|
|
||||||
# Check if new path already exists
|
repo.update_path(
|
||||||
existing = self.collection.find_one({
|
self.user_id, self.tool_id, validated_old, validated_new
|
||||||
"user_id": self.user_id,
|
)
|
||||||
"tool_id": self.tool_id,
|
|
||||||
"path": validated_new
|
|
||||||
})
|
|
||||||
|
|
||||||
if existing:
|
|
||||||
return f"Error: File already exists at {validated_new}"
|
|
||||||
|
|
||||||
# Delete the old document and create a new one with the new path
|
|
||||||
self.collection.delete_one({"user_id": self.user_id, "tool_id": self.tool_id, "path": validated_old})
|
|
||||||
self.collection.insert_one({
|
|
||||||
"user_id": self.user_id,
|
|
||||||
"tool_id": self.tool_id,
|
|
||||||
"path": validated_new,
|
|
||||||
"content": doc.get("content", ""),
|
|
||||||
"updated_at": datetime.now()
|
|
||||||
})
|
|
||||||
|
|
||||||
return f"Renamed: {validated_old} -> {validated_new}"
|
return f"Renamed: {validated_old} -> {validated_new}"
|
||||||
|
|||||||
@@ -1,10 +1,16 @@
|
|||||||
from datetime import datetime
|
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
from .base import Tool
|
from .base import Tool
|
||||||
from application.core.mongo_db import MongoDB
|
from application.storage.db.repositories.notes import NotesRepository
|
||||||
from application.core.settings import settings
|
from application.storage.db.session import db_readonly, db_session
|
||||||
|
|
||||||
|
|
||||||
|
# Stable synthetic title used in the Postgres ``notes.title`` column.
|
||||||
|
# The notes tool stores one note per (user_id, tool_id); there is no
|
||||||
|
# user-facing title. PG requires ``title`` NOT NULL, so we write a stable
|
||||||
|
# constant alongside the actual note body in ``content``.
|
||||||
|
_NOTE_TITLE = "note"
|
||||||
|
|
||||||
|
|
||||||
class NotesTool(Tool):
|
class NotesTool(Tool):
|
||||||
@@ -25,7 +31,6 @@ class NotesTool(Tool):
|
|||||||
self.user_id: Optional[str] = user_id
|
self.user_id: Optional[str] = user_id
|
||||||
|
|
||||||
# Get tool_id from configuration (passed from user_tools._id in production)
|
# Get tool_id from configuration (passed from user_tools._id in production)
|
||||||
# In production, tool_id is the MongoDB ObjectId string from user_tools collection
|
|
||||||
if tool_config and "tool_id" in tool_config:
|
if tool_config and "tool_id" in tool_config:
|
||||||
self.tool_id = tool_config["tool_id"]
|
self.tool_id = tool_config["tool_id"]
|
||||||
elif user_id:
|
elif user_id:
|
||||||
@@ -35,11 +40,25 @@ class NotesTool(Tool):
|
|||||||
# Last resort fallback (shouldn't happen in normal use)
|
# Last resort fallback (shouldn't happen in normal use)
|
||||||
self.tool_id = str(uuid.uuid4())
|
self.tool_id = str(uuid.uuid4())
|
||||||
|
|
||||||
db = MongoDB.get_client()[settings.MONGO_DB_NAME]
|
|
||||||
self.collection = db["notes"]
|
|
||||||
|
|
||||||
self._last_artifact_id: Optional[str] = None
|
self._last_artifact_id: Optional[str] = None
|
||||||
|
|
||||||
|
def _pg_enabled(self) -> bool:
|
||||||
|
"""Return True only when ``tool_id`` is a real ``user_tools.id`` UUID.
|
||||||
|
|
||||||
|
``notes.tool_id`` is a UUID FK to ``user_tools``; repo queries
|
||||||
|
``CAST(:tool_id AS uuid)``. The sentinel ``default_{uid}``
|
||||||
|
fallback is neither a UUID nor a ``user_tools`` row, so any DB
|
||||||
|
operation would crash. Mirror MemoryTool's guard and no-op.
|
||||||
|
"""
|
||||||
|
tool_id = getattr(self, "tool_id", None)
|
||||||
|
if not tool_id or not isinstance(tool_id, str):
|
||||||
|
return False
|
||||||
|
if tool_id.startswith("default_"):
|
||||||
|
return False
|
||||||
|
from application.storage.db.base_repository import looks_like_uuid
|
||||||
|
|
||||||
|
return looks_like_uuid(tool_id)
|
||||||
|
|
||||||
# -----------------------------
|
# -----------------------------
|
||||||
# Action implementations
|
# Action implementations
|
||||||
# -----------------------------
|
# -----------------------------
|
||||||
@@ -54,7 +73,13 @@ class NotesTool(Tool):
|
|||||||
A human-readable string result.
|
A human-readable string result.
|
||||||
"""
|
"""
|
||||||
if not self.user_id:
|
if not self.user_id:
|
||||||
return "Error: NotesTool requires a valid user_id."
|
return "Error: NotesTool requires a valid user_id."
|
||||||
|
|
||||||
|
if not self._pg_enabled():
|
||||||
|
return (
|
||||||
|
"Error: NotesTool is not configured with a persistent "
|
||||||
|
"tool_id; note storage is unavailable for this session."
|
||||||
|
)
|
||||||
|
|
||||||
self._last_artifact_id = None
|
self._last_artifact_id = None
|
||||||
|
|
||||||
@@ -135,37 +160,45 @@ class NotesTool(Tool):
|
|||||||
# -----------------------------
|
# -----------------------------
|
||||||
# Internal helpers (single-note)
|
# Internal helpers (single-note)
|
||||||
# -----------------------------
|
# -----------------------------
|
||||||
|
def _fetch_note(self) -> Optional[dict]:
|
||||||
|
"""Read the note row for this (user, tool) from Postgres."""
|
||||||
|
with db_readonly() as conn:
|
||||||
|
return NotesRepository(conn).get_for_user_tool(self.user_id, self.tool_id)
|
||||||
|
|
||||||
def _get_note(self) -> str:
|
def _get_note(self) -> str:
|
||||||
doc = self.collection.find_one({"user_id": self.user_id, "tool_id": self.tool_id})
|
doc = self._fetch_note()
|
||||||
if not doc or not doc.get("note"):
|
# ``content`` is the PG column; expose as ``note`` to callers via the
|
||||||
|
# textual return value. Frontends that read the artifact via the
|
||||||
|
# repo dict get ``content`` (PG-native) plus the artifact id below.
|
||||||
|
body = (doc or {}).get("content")
|
||||||
|
if not doc or not body:
|
||||||
return "No note found."
|
return "No note found."
|
||||||
if doc.get("_id") is not None:
|
if doc.get("id") is not None:
|
||||||
self._last_artifact_id = str(doc.get("_id"))
|
self._last_artifact_id = str(doc.get("id"))
|
||||||
return str(doc["note"])
|
return str(body)
|
||||||
|
|
||||||
def _overwrite_note(self, content: str) -> str:
|
def _overwrite_note(self, content: str) -> str:
|
||||||
content = (content or "").strip()
|
content = (content or "").strip()
|
||||||
if not content:
|
if not content:
|
||||||
return "Note content required."
|
return "Note content required."
|
||||||
result = self.collection.find_one_and_update(
|
with db_session() as conn:
|
||||||
{"user_id": self.user_id, "tool_id": self.tool_id},
|
row = NotesRepository(conn).upsert(
|
||||||
{"$set": {"note": content, "updated_at": datetime.utcnow()}},
|
self.user_id, self.tool_id, _NOTE_TITLE, content
|
||||||
upsert=True,
|
)
|
||||||
return_document=True,
|
if row and row.get("id") is not None:
|
||||||
)
|
self._last_artifact_id = str(row.get("id"))
|
||||||
if result and result.get("_id") is not None:
|
|
||||||
self._last_artifact_id = str(result.get("_id"))
|
|
||||||
return "Note saved."
|
return "Note saved."
|
||||||
|
|
||||||
def _str_replace(self, old_str: str, new_str: str) -> str:
|
def _str_replace(self, old_str: str, new_str: str) -> str:
|
||||||
if not old_str:
|
if not old_str:
|
||||||
return "old_str is required."
|
return "old_str is required."
|
||||||
|
|
||||||
doc = self.collection.find_one({"user_id": self.user_id, "tool_id": self.tool_id})
|
doc = self._fetch_note()
|
||||||
if not doc or not doc.get("note"):
|
existing = (doc or {}).get("content")
|
||||||
|
if not doc or not existing:
|
||||||
return "No note found."
|
return "No note found."
|
||||||
|
|
||||||
current_note = str(doc["note"])
|
current_note = str(existing)
|
||||||
|
|
||||||
# Case-insensitive search
|
# Case-insensitive search
|
||||||
if old_str.lower() not in current_note.lower():
|
if old_str.lower() not in current_note.lower():
|
||||||
@@ -175,24 +208,24 @@ class NotesTool(Tool):
|
|||||||
import re
|
import re
|
||||||
updated_note = re.sub(re.escape(old_str), new_str, current_note, flags=re.IGNORECASE)
|
updated_note = re.sub(re.escape(old_str), new_str, current_note, flags=re.IGNORECASE)
|
||||||
|
|
||||||
result = self.collection.find_one_and_update(
|
with db_session() as conn:
|
||||||
{"user_id": self.user_id, "tool_id": self.tool_id},
|
row = NotesRepository(conn).upsert(
|
||||||
{"$set": {"note": updated_note, "updated_at": datetime.utcnow()}},
|
self.user_id, self.tool_id, _NOTE_TITLE, updated_note
|
||||||
return_document=True,
|
)
|
||||||
)
|
if row and row.get("id") is not None:
|
||||||
if result and result.get("_id") is not None:
|
self._last_artifact_id = str(row.get("id"))
|
||||||
self._last_artifact_id = str(result.get("_id"))
|
|
||||||
return "Note updated."
|
return "Note updated."
|
||||||
|
|
||||||
def _insert(self, line_number: int, text: str) -> str:
|
def _insert(self, line_number: int, text: str) -> str:
|
||||||
if not text:
|
if not text:
|
||||||
return "Text is required."
|
return "Text is required."
|
||||||
|
|
||||||
doc = self.collection.find_one({"user_id": self.user_id, "tool_id": self.tool_id})
|
doc = self._fetch_note()
|
||||||
if not doc or not doc.get("note"):
|
existing = (doc or {}).get("content")
|
||||||
|
if not doc or not existing:
|
||||||
return "No note found."
|
return "No note found."
|
||||||
|
|
||||||
current_note = str(doc["note"])
|
current_note = str(existing)
|
||||||
lines = current_note.split("\n")
|
lines = current_note.split("\n")
|
||||||
|
|
||||||
# Convert to 0-indexed and validate
|
# Convert to 0-indexed and validate
|
||||||
@@ -203,21 +236,23 @@ class NotesTool(Tool):
|
|||||||
lines.insert(index, text)
|
lines.insert(index, text)
|
||||||
updated_note = "\n".join(lines)
|
updated_note = "\n".join(lines)
|
||||||
|
|
||||||
result = self.collection.find_one_and_update(
|
with db_session() as conn:
|
||||||
{"user_id": self.user_id, "tool_id": self.tool_id},
|
row = NotesRepository(conn).upsert(
|
||||||
{"$set": {"note": updated_note, "updated_at": datetime.utcnow()}},
|
self.user_id, self.tool_id, _NOTE_TITLE, updated_note
|
||||||
return_document=True,
|
)
|
||||||
)
|
if row and row.get("id") is not None:
|
||||||
if result and result.get("_id") is not None:
|
self._last_artifact_id = str(row.get("id"))
|
||||||
self._last_artifact_id = str(result.get("_id"))
|
|
||||||
return "Text inserted."
|
return "Text inserted."
|
||||||
|
|
||||||
def _delete_note(self) -> str:
|
def _delete_note(self) -> str:
|
||||||
doc = self.collection.find_one_and_delete(
|
# Capture the id (for artifact tracking) before deleting.
|
||||||
{"user_id": self.user_id, "tool_id": self.tool_id}
|
existing = self._fetch_note()
|
||||||
)
|
if not existing:
|
||||||
if not doc:
|
|
||||||
return "No note found to delete."
|
return "No note found to delete."
|
||||||
if doc.get("_id") is not None:
|
with db_session() as conn:
|
||||||
self._last_artifact_id = str(doc.get("_id"))
|
deleted = NotesRepository(conn).delete(self.user_id, self.tool_id)
|
||||||
|
if not deleted:
|
||||||
|
return "No note found to delete."
|
||||||
|
if existing.get("id") is not None:
|
||||||
|
self._last_artifact_id = str(existing.get("id"))
|
||||||
return "Note deleted."
|
return "Note deleted."
|
||||||
|
|||||||
@@ -71,7 +71,7 @@ class NtfyTool(Tool):
|
|||||||
if self.token:
|
if self.token:
|
||||||
headers["Authorization"] = f"Basic {self.token}"
|
headers["Authorization"] = f"Basic {self.token}"
|
||||||
data = message.encode("utf-8")
|
data = message.encode("utf-8")
|
||||||
response = requests.post(url, headers=headers, data=data)
|
response = requests.post(url, headers=headers, data=data, timeout=100)
|
||||||
return {"status_code": response.status_code, "message": "Message sent"}
|
return {"status_code": response.status_code, "message": "Message sent"}
|
||||||
|
|
||||||
def get_actions_metadata(self):
|
def get_actions_metadata(self):
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
import logging
|
import logging
|
||||||
|
|
||||||
import psycopg2
|
import psycopg
|
||||||
|
|
||||||
from application.agents.tools.base import Tool
|
from application.agents.tools.base import Tool
|
||||||
|
|
||||||
@@ -33,7 +33,7 @@ class PostgresTool(Tool):
|
|||||||
"""
|
"""
|
||||||
conn = None
|
conn = None
|
||||||
try:
|
try:
|
||||||
conn = psycopg2.connect(self.connection_string)
|
conn = psycopg.connect(self.connection_string)
|
||||||
cur = conn.cursor()
|
cur = conn.cursor()
|
||||||
cur.execute(sql_query)
|
cur.execute(sql_query)
|
||||||
conn.commit()
|
conn.commit()
|
||||||
@@ -60,7 +60,7 @@ class PostgresTool(Tool):
|
|||||||
"response_data": response_data,
|
"response_data": response_data,
|
||||||
}
|
}
|
||||||
|
|
||||||
except psycopg2.Error as e:
|
except psycopg.Error as e:
|
||||||
error_message = f"Database error: {e}"
|
error_message = f"Database error: {e}"
|
||||||
logger.error("PostgreSQL execute_sql error: %s", e)
|
logger.error("PostgreSQL execute_sql error: %s", e)
|
||||||
return {
|
return {
|
||||||
@@ -78,7 +78,7 @@ class PostgresTool(Tool):
|
|||||||
"""
|
"""
|
||||||
conn = None
|
conn = None
|
||||||
try:
|
try:
|
||||||
conn = psycopg2.connect(self.connection_string)
|
conn = psycopg.connect(self.connection_string)
|
||||||
cur = conn.cursor()
|
cur = conn.cursor()
|
||||||
|
|
||||||
cur.execute(
|
cur.execute(
|
||||||
@@ -120,7 +120,7 @@ class PostgresTool(Tool):
|
|||||||
"schema": schema_data,
|
"schema": schema_data,
|
||||||
}
|
}
|
||||||
|
|
||||||
except psycopg2.Error as e:
|
except psycopg.Error as e:
|
||||||
error_message = f"Database error: {e}"
|
error_message = f"Database error: {e}"
|
||||||
logger.error("PostgreSQL get_schema error: %s", e)
|
logger.error("PostgreSQL get_schema error: %s", e)
|
||||||
return {
|
return {
|
||||||
|
|||||||
@@ -31,14 +31,14 @@ class TelegramTool(Tool):
|
|||||||
logger.debug("Sending Telegram message to chat_id=%s", chat_id)
|
logger.debug("Sending Telegram message to chat_id=%s", chat_id)
|
||||||
url = f"https://api.telegram.org/bot{self.token}/sendMessage"
|
url = f"https://api.telegram.org/bot{self.token}/sendMessage"
|
||||||
payload = {"chat_id": chat_id, "text": text}
|
payload = {"chat_id": chat_id, "text": text}
|
||||||
response = requests.post(url, data=payload)
|
response = requests.post(url, data=payload, timeout=100)
|
||||||
return {"status_code": response.status_code, "message": "Message sent"}
|
return {"status_code": response.status_code, "message": "Message sent"}
|
||||||
|
|
||||||
def _send_image(self, image_url, chat_id):
|
def _send_image(self, image_url, chat_id):
|
||||||
logger.debug("Sending Telegram image to chat_id=%s", chat_id)
|
logger.debug("Sending Telegram image to chat_id=%s", chat_id)
|
||||||
url = f"https://api.telegram.org/bot{self.token}/sendPhoto"
|
url = f"https://api.telegram.org/bot{self.token}/sendPhoto"
|
||||||
payload = {"chat_id": chat_id, "photo": image_url}
|
payload = {"chat_id": chat_id, "photo": image_url}
|
||||||
response = requests.post(url, data=payload)
|
response = requests.post(url, data=payload, timeout=100)
|
||||||
return {"status_code": response.status_code, "message": "Image sent"}
|
return {"status_code": response.status_code, "message": "Image sent"}
|
||||||
|
|
||||||
def get_actions_metadata(self):
|
def get_actions_metadata(self):
|
||||||
|
|||||||
@@ -1,10 +1,19 @@
|
|||||||
from datetime import datetime
|
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
from .base import Tool
|
from .base import Tool
|
||||||
from application.core.mongo_db import MongoDB
|
from application.storage.db.repositories.todos import TodosRepository
|
||||||
from application.core.settings import settings
|
from application.storage.db.session import db_readonly, db_session
|
||||||
|
|
||||||
|
|
||||||
|
def _status_from_completed(completed: Any) -> str:
|
||||||
|
"""Translate the PG ``completed`` boolean to the legacy status string.
|
||||||
|
|
||||||
|
The frontend (and prior LLM-facing tool output) expects
|
||||||
|
``"open"`` / ``"completed"``. Keeping that contract at the tool
|
||||||
|
boundary insulates callers from the schema change.
|
||||||
|
"""
|
||||||
|
return "completed" if bool(completed) else "open"
|
||||||
|
|
||||||
|
|
||||||
class TodoListTool(Tool):
|
class TodoListTool(Tool):
|
||||||
@@ -25,7 +34,6 @@ class TodoListTool(Tool):
|
|||||||
self.user_id: Optional[str] = user_id
|
self.user_id: Optional[str] = user_id
|
||||||
|
|
||||||
# Get tool_id from configuration (passed from user_tools._id in production)
|
# Get tool_id from configuration (passed from user_tools._id in production)
|
||||||
# In production, tool_id is the MongoDB ObjectId string from user_tools collection
|
|
||||||
if tool_config and "tool_id" in tool_config:
|
if tool_config and "tool_id" in tool_config:
|
||||||
self.tool_id = tool_config["tool_id"]
|
self.tool_id = tool_config["tool_id"]
|
||||||
elif user_id:
|
elif user_id:
|
||||||
@@ -35,11 +43,27 @@ class TodoListTool(Tool):
|
|||||||
# Last resort fallback (shouldn't happen in normal use)
|
# Last resort fallback (shouldn't happen in normal use)
|
||||||
self.tool_id = str(uuid.uuid4())
|
self.tool_id = str(uuid.uuid4())
|
||||||
|
|
||||||
db = MongoDB.get_client()[settings.MONGO_DB_NAME]
|
|
||||||
self.collection = db["todos"]
|
|
||||||
|
|
||||||
self._last_artifact_id: Optional[str] = None
|
self._last_artifact_id: Optional[str] = None
|
||||||
|
|
||||||
|
def _pg_enabled(self) -> bool:
|
||||||
|
"""Return True only when ``tool_id`` is a real ``user_tools.id`` UUID.
|
||||||
|
|
||||||
|
The ``todos`` PG table has a UUID foreign key to ``user_tools`` and
|
||||||
|
the repo queries ``CAST(:tool_id AS uuid)``. The sentinel
|
||||||
|
``default_{uid}`` fallback is neither a UUID nor a row in
|
||||||
|
``user_tools`` — binding it would crash ``invalid input syntax for
|
||||||
|
type uuid`` and even if it didn't the FK would reject it. Mirror
|
||||||
|
the MemoryTool guard and no-op in that case.
|
||||||
|
"""
|
||||||
|
tool_id = getattr(self, "tool_id", None)
|
||||||
|
if not tool_id or not isinstance(tool_id, str):
|
||||||
|
return False
|
||||||
|
if tool_id.startswith("default_"):
|
||||||
|
return False
|
||||||
|
from application.storage.db.base_repository import looks_like_uuid
|
||||||
|
|
||||||
|
return looks_like_uuid(tool_id)
|
||||||
|
|
||||||
# -----------------------------
|
# -----------------------------
|
||||||
# Action implementations
|
# Action implementations
|
||||||
# -----------------------------
|
# -----------------------------
|
||||||
@@ -56,6 +80,12 @@ class TodoListTool(Tool):
|
|||||||
if not self.user_id:
|
if not self.user_id:
|
||||||
return "Error: TodoListTool requires a valid user_id."
|
return "Error: TodoListTool requires a valid user_id."
|
||||||
|
|
||||||
|
if not self._pg_enabled():
|
||||||
|
return (
|
||||||
|
"Error: TodoListTool is not configured with a persistent "
|
||||||
|
"tool_id; todo storage is unavailable for this session."
|
||||||
|
)
|
||||||
|
|
||||||
self._last_artifact_id = None
|
self._last_artifact_id = None
|
||||||
|
|
||||||
if action_name == "list":
|
if action_name == "list":
|
||||||
@@ -191,28 +221,10 @@ class TodoListTool(Tool):
|
|||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _get_next_todo_id(self) -> int:
|
|
||||||
"""Get the next sequential todo_id for this user and tool.
|
|
||||||
|
|
||||||
Returns a simple integer (1, 2, 3, ...) scoped to this user/tool.
|
|
||||||
With 5-10 todos max, scanning is negligible.
|
|
||||||
"""
|
|
||||||
query = {"user_id": self.user_id, "tool_id": self.tool_id}
|
|
||||||
todos = list(self.collection.find(query, {"todo_id": 1}))
|
|
||||||
|
|
||||||
# Find the maximum todo_id
|
|
||||||
max_id = 0
|
|
||||||
for todo in todos:
|
|
||||||
todo_id = self._coerce_todo_id(todo.get("todo_id"))
|
|
||||||
if todo_id is not None:
|
|
||||||
max_id = max(max_id, todo_id)
|
|
||||||
|
|
||||||
return max_id + 1
|
|
||||||
|
|
||||||
def _list(self) -> str:
|
def _list(self) -> str:
|
||||||
"""List all todos for the user."""
|
"""List all todos for the user."""
|
||||||
query = {"user_id": self.user_id, "tool_id": self.tool_id}
|
with db_readonly() as conn:
|
||||||
todos = list(self.collection.find(query))
|
todos = TodosRepository(conn).list_for_tool(self.user_id, self.tool_id)
|
||||||
|
|
||||||
if not todos:
|
if not todos:
|
||||||
return "No todos found."
|
return "No todos found."
|
||||||
@@ -221,7 +233,7 @@ class TodoListTool(Tool):
|
|||||||
for doc in todos:
|
for doc in todos:
|
||||||
todo_id = doc.get("todo_id")
|
todo_id = doc.get("todo_id")
|
||||||
title = doc.get("title", "Untitled")
|
title = doc.get("title", "Untitled")
|
||||||
status = doc.get("status", "open")
|
status = _status_from_completed(doc.get("completed"))
|
||||||
|
|
||||||
line = f"[{todo_id}] {title} ({status})"
|
line = f"[{todo_id}] {title} ({status})"
|
||||||
result_lines.append(line)
|
result_lines.append(line)
|
||||||
@@ -229,27 +241,23 @@ class TodoListTool(Tool):
|
|||||||
return "\n".join(result_lines)
|
return "\n".join(result_lines)
|
||||||
|
|
||||||
def _create(self, title: str) -> str:
|
def _create(self, title: str) -> str:
|
||||||
"""Create a new todo item."""
|
"""Create a new todo item.
|
||||||
|
|
||||||
|
``TodosRepository.create`` allocates the per-tool monotonic
|
||||||
|
``todo_id`` inside the same transaction (``COALESCE(MAX(todo_id),0)+1``
|
||||||
|
scoped to ``tool_id``), so we no longer need a separate read-then-
|
||||||
|
write step here.
|
||||||
|
"""
|
||||||
title = (title or "").strip()
|
title = (title or "").strip()
|
||||||
if not title:
|
if not title:
|
||||||
return "Error: Title is required."
|
return "Error: Title is required."
|
||||||
|
|
||||||
now = datetime.now()
|
with db_session() as conn:
|
||||||
todo_id = self._get_next_todo_id()
|
row = TodosRepository(conn).create(self.user_id, self.tool_id, title)
|
||||||
|
|
||||||
doc = {
|
todo_id = row.get("todo_id")
|
||||||
"todo_id": todo_id,
|
if row.get("id") is not None:
|
||||||
"user_id": self.user_id,
|
self._last_artifact_id = str(row.get("id"))
|
||||||
"tool_id": self.tool_id,
|
|
||||||
"title": title,
|
|
||||||
"status": "open",
|
|
||||||
"created_at": now,
|
|
||||||
"updated_at": now,
|
|
||||||
}
|
|
||||||
insert_result = self.collection.insert_one(doc)
|
|
||||||
inserted_id = getattr(insert_result, "inserted_id", None) or doc.get("_id")
|
|
||||||
if inserted_id is not None:
|
|
||||||
self._last_artifact_id = str(inserted_id)
|
|
||||||
return f"Todo created with ID {todo_id}: {title}"
|
return f"Todo created with ID {todo_id}: {title}"
|
||||||
|
|
||||||
def _get(self, todo_id: Optional[Any]) -> str:
|
def _get(self, todo_id: Optional[Any]) -> str:
|
||||||
@@ -258,21 +266,21 @@ class TodoListTool(Tool):
|
|||||||
if parsed_todo_id is None:
|
if parsed_todo_id is None:
|
||||||
return "Error: todo_id must be a positive integer."
|
return "Error: todo_id must be a positive integer."
|
||||||
|
|
||||||
query = {"user_id": self.user_id, "tool_id": self.tool_id, "todo_id": parsed_todo_id}
|
with db_readonly() as conn:
|
||||||
doc = self.collection.find_one(query)
|
doc = TodosRepository(conn).get_by_tool_and_todo_id(
|
||||||
|
self.user_id, self.tool_id, parsed_todo_id
|
||||||
|
)
|
||||||
|
|
||||||
if not doc:
|
if not doc:
|
||||||
return f"Error: Todo with ID {parsed_todo_id} not found."
|
return f"Error: Todo with ID {parsed_todo_id} not found."
|
||||||
|
|
||||||
if doc.get("_id") is not None:
|
if doc.get("id") is not None:
|
||||||
self._last_artifact_id = str(doc.get("_id"))
|
self._last_artifact_id = str(doc.get("id"))
|
||||||
|
|
||||||
title = doc.get("title", "Untitled")
|
title = doc.get("title", "Untitled")
|
||||||
status = doc.get("status", "open")
|
status = _status_from_completed(doc.get("completed"))
|
||||||
|
|
||||||
result = f"Todo [{parsed_todo_id}]:\nTitle: {title}\nStatus: {status}"
|
return f"Todo [{parsed_todo_id}]:\nTitle: {title}\nStatus: {status}"
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
def _update(self, todo_id: Optional[Any], title: str) -> str:
|
def _update(self, todo_id: Optional[Any], title: str) -> str:
|
||||||
"""Update a todo's title by ID."""
|
"""Update a todo's title by ID."""
|
||||||
@@ -284,16 +292,19 @@ class TodoListTool(Tool):
|
|||||||
if not title:
|
if not title:
|
||||||
return "Error: Title is required."
|
return "Error: Title is required."
|
||||||
|
|
||||||
query = {"user_id": self.user_id, "tool_id": self.tool_id, "todo_id": parsed_todo_id}
|
with db_session() as conn:
|
||||||
doc = self.collection.find_one_and_update(
|
repo = TodosRepository(conn)
|
||||||
query,
|
existing = repo.get_by_tool_and_todo_id(
|
||||||
{"$set": {"title": title, "updated_at": datetime.now()}},
|
self.user_id, self.tool_id, parsed_todo_id
|
||||||
)
|
)
|
||||||
if not doc:
|
if not existing:
|
||||||
return f"Error: Todo with ID {parsed_todo_id} not found."
|
return f"Error: Todo with ID {parsed_todo_id} not found."
|
||||||
|
repo.update_title_by_tool_and_todo_id(
|
||||||
|
self.user_id, self.tool_id, parsed_todo_id, title
|
||||||
|
)
|
||||||
|
|
||||||
if doc.get("_id") is not None:
|
if existing.get("id") is not None:
|
||||||
self._last_artifact_id = str(doc.get("_id"))
|
self._last_artifact_id = str(existing.get("id"))
|
||||||
|
|
||||||
return f"Todo {parsed_todo_id} updated to: {title}"
|
return f"Todo {parsed_todo_id} updated to: {title}"
|
||||||
|
|
||||||
@@ -303,16 +314,17 @@ class TodoListTool(Tool):
|
|||||||
if parsed_todo_id is None:
|
if parsed_todo_id is None:
|
||||||
return "Error: todo_id must be a positive integer."
|
return "Error: todo_id must be a positive integer."
|
||||||
|
|
||||||
query = {"user_id": self.user_id, "tool_id": self.tool_id, "todo_id": parsed_todo_id}
|
with db_session() as conn:
|
||||||
doc = self.collection.find_one_and_update(
|
repo = TodosRepository(conn)
|
||||||
query,
|
existing = repo.get_by_tool_and_todo_id(
|
||||||
{"$set": {"status": "completed", "updated_at": datetime.now()}},
|
self.user_id, self.tool_id, parsed_todo_id
|
||||||
)
|
)
|
||||||
if not doc:
|
if not existing:
|
||||||
return f"Error: Todo with ID {parsed_todo_id} not found."
|
return f"Error: Todo with ID {parsed_todo_id} not found."
|
||||||
|
repo.set_completed(self.user_id, self.tool_id, parsed_todo_id, True)
|
||||||
|
|
||||||
if doc.get("_id") is not None:
|
if existing.get("id") is not None:
|
||||||
self._last_artifact_id = str(doc.get("_id"))
|
self._last_artifact_id = str(existing.get("id"))
|
||||||
|
|
||||||
return f"Todo {parsed_todo_id} marked as completed."
|
return f"Todo {parsed_todo_id} marked as completed."
|
||||||
|
|
||||||
@@ -322,12 +334,18 @@ class TodoListTool(Tool):
|
|||||||
if parsed_todo_id is None:
|
if parsed_todo_id is None:
|
||||||
return "Error: todo_id must be a positive integer."
|
return "Error: todo_id must be a positive integer."
|
||||||
|
|
||||||
query = {"user_id": self.user_id, "tool_id": self.tool_id, "todo_id": parsed_todo_id}
|
with db_session() as conn:
|
||||||
doc = self.collection.find_one_and_delete(query)
|
repo = TodosRepository(conn)
|
||||||
if not doc:
|
existing = repo.get_by_tool_and_todo_id(
|
||||||
return f"Error: Todo with ID {parsed_todo_id} not found."
|
self.user_id, self.tool_id, parsed_todo_id
|
||||||
|
)
|
||||||
|
if not existing:
|
||||||
|
return f"Error: Todo with ID {parsed_todo_id} not found."
|
||||||
|
repo.delete_by_tool_and_todo_id(
|
||||||
|
self.user_id, self.tool_id, parsed_todo_id
|
||||||
|
)
|
||||||
|
|
||||||
if doc.get("_id") is not None:
|
if existing.get("id") is not None:
|
||||||
self._last_artifact_id = str(doc.get("_id"))
|
self._last_artifact_id = str(existing.get("id"))
|
||||||
|
|
||||||
return f"Todo {parsed_todo_id} deleted."
|
return f"Todo {parsed_todo_id} deleted."
|
||||||
|
|||||||
@@ -12,9 +12,13 @@ from application.agents.workflows.schemas import (
|
|||||||
WorkflowRun,
|
WorkflowRun,
|
||||||
)
|
)
|
||||||
from application.agents.workflows.workflow_engine import WorkflowEngine
|
from application.agents.workflows.workflow_engine import WorkflowEngine
|
||||||
from application.core.mongo_db import MongoDB
|
|
||||||
from application.core.settings import settings
|
|
||||||
from application.logging import log_activity, LogContext
|
from application.logging import log_activity, LogContext
|
||||||
|
from application.storage.db.base_repository import looks_like_uuid
|
||||||
|
from application.storage.db.repositories.workflow_edges import WorkflowEdgesRepository
|
||||||
|
from application.storage.db.repositories.workflow_nodes import WorkflowNodesRepository
|
||||||
|
from application.storage.db.repositories.workflow_runs import WorkflowRunsRepository
|
||||||
|
from application.storage.db.repositories.workflows import WorkflowsRepository
|
||||||
|
from application.storage.db.session import db_readonly, db_session
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -103,10 +107,8 @@ class WorkflowAgent(BaseAgent):
|
|||||||
|
|
||||||
def _load_from_database(self) -> Optional[WorkflowGraph]:
|
def _load_from_database(self) -> Optional[WorkflowGraph]:
|
||||||
try:
|
try:
|
||||||
from bson.objectid import ObjectId
|
if not self.workflow_id:
|
||||||
|
logger.error("Missing workflow ID for load")
|
||||||
if not self.workflow_id or not ObjectId.is_valid(self.workflow_id):
|
|
||||||
logger.error(f"Invalid workflow ID: {self.workflow_id}")
|
|
||||||
return None
|
return None
|
||||||
owner_id = self.workflow_owner
|
owner_id = self.workflow_owner
|
||||||
if not owner_id and isinstance(self.decoded_token, dict):
|
if not owner_id and isinstance(self.decoded_token, dict):
|
||||||
@@ -117,61 +119,61 @@ class WorkflowAgent(BaseAgent):
|
|||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
mongo = MongoDB.get_client()
|
with db_readonly() as conn:
|
||||||
db = mongo[settings.MONGO_DB_NAME]
|
wf_repo = WorkflowsRepository(conn)
|
||||||
|
if looks_like_uuid(self.workflow_id):
|
||||||
workflows_coll = db["workflows"]
|
workflow_row = wf_repo.get(self.workflow_id, owner_id)
|
||||||
workflow_nodes_coll = db["workflow_nodes"]
|
else:
|
||||||
workflow_edges_coll = db["workflow_edges"]
|
workflow_row = wf_repo.get_by_legacy_id(self.workflow_id, owner_id)
|
||||||
|
if workflow_row is None:
|
||||||
workflow_doc = workflows_coll.find_one(
|
logger.error(
|
||||||
{"_id": ObjectId(self.workflow_id), "user": owner_id}
|
f"Workflow {self.workflow_id} not found or inaccessible "
|
||||||
)
|
f"for user {owner_id}"
|
||||||
if not workflow_doc:
|
)
|
||||||
logger.error(
|
return None
|
||||||
f"Workflow {self.workflow_id} not found or inaccessible for user {owner_id}"
|
pg_workflow_id = str(workflow_row["id"])
|
||||||
)
|
graph_version = workflow_row.get("current_graph_version", 1)
|
||||||
return None
|
try:
|
||||||
workflow = Workflow(**workflow_doc)
|
graph_version = int(graph_version)
|
||||||
graph_version = workflow_doc.get("current_graph_version", 1)
|
if graph_version <= 0:
|
||||||
try:
|
graph_version = 1
|
||||||
graph_version = int(graph_version)
|
except (ValueError, TypeError):
|
||||||
if graph_version <= 0:
|
|
||||||
graph_version = 1
|
graph_version = 1
|
||||||
except (ValueError, TypeError):
|
|
||||||
graph_version = 1
|
|
||||||
|
|
||||||
nodes_docs = list(
|
node_rows = WorkflowNodesRepository(conn).find_by_version(
|
||||||
workflow_nodes_coll.find(
|
pg_workflow_id, graph_version,
|
||||||
{"workflow_id": self.workflow_id, "graph_version": graph_version}
|
|
||||||
)
|
)
|
||||||
)
|
edge_rows = WorkflowEdgesRepository(conn).find_by_version(
|
||||||
if not nodes_docs and graph_version == 1:
|
pg_workflow_id, graph_version,
|
||||||
nodes_docs = list(
|
|
||||||
workflow_nodes_coll.find(
|
|
||||||
{
|
|
||||||
"workflow_id": self.workflow_id,
|
|
||||||
"graph_version": {"$exists": False},
|
|
||||||
}
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
nodes = [WorkflowNode(**doc) for doc in nodes_docs]
|
|
||||||
|
|
||||||
edges_docs = list(
|
workflow = Workflow(
|
||||||
workflow_edges_coll.find(
|
name=workflow_row.get("name"),
|
||||||
{"workflow_id": self.workflow_id, "graph_version": graph_version}
|
description=workflow_row.get("description"),
|
||||||
)
|
|
||||||
)
|
)
|
||||||
if not edges_docs and graph_version == 1:
|
nodes = [
|
||||||
edges_docs = list(
|
WorkflowNode(
|
||||||
workflow_edges_coll.find(
|
id=n["node_id"],
|
||||||
{
|
workflow_id=pg_workflow_id,
|
||||||
"workflow_id": self.workflow_id,
|
type=n["node_type"],
|
||||||
"graph_version": {"$exists": False},
|
title=n.get("title") or "Node",
|
||||||
}
|
description=n.get("description"),
|
||||||
)
|
position=n.get("position") or {"x": 0, "y": 0},
|
||||||
|
config=n.get("config") or {},
|
||||||
)
|
)
|
||||||
edges = [WorkflowEdge(**doc) for doc in edges_docs]
|
for n in node_rows
|
||||||
|
]
|
||||||
|
edges = [
|
||||||
|
WorkflowEdge(
|
||||||
|
id=e["edge_id"],
|
||||||
|
workflow_id=pg_workflow_id,
|
||||||
|
source=e.get("source_id"),
|
||||||
|
target=e.get("target_id"),
|
||||||
|
sourceHandle=e.get("source_handle"),
|
||||||
|
targetHandle=e.get("target_handle"),
|
||||||
|
)
|
||||||
|
for e in edge_rows
|
||||||
|
]
|
||||||
|
|
||||||
return WorkflowGraph(workflow=workflow, nodes=nodes, edges=edges)
|
return WorkflowGraph(workflow=workflow, nodes=nodes, edges=edges)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -181,13 +183,13 @@ class WorkflowAgent(BaseAgent):
|
|||||||
def _save_workflow_run(self, query: str) -> None:
|
def _save_workflow_run(self, query: str) -> None:
|
||||||
if not self._engine:
|
if not self._engine:
|
||||||
return
|
return
|
||||||
|
owner_id = self.workflow_owner
|
||||||
|
if not owner_id and isinstance(self.decoded_token, dict):
|
||||||
|
owner_id = self.decoded_token.get("sub")
|
||||||
try:
|
try:
|
||||||
mongo = MongoDB.get_client()
|
|
||||||
db = mongo[settings.MONGO_DB_NAME]
|
|
||||||
workflow_runs_coll = db["workflow_runs"]
|
|
||||||
|
|
||||||
run = WorkflowRun(
|
run = WorkflowRun(
|
||||||
workflow_id=self.workflow_id or "unknown",
|
workflow_id=self.workflow_id or "unknown",
|
||||||
|
user=owner_id,
|
||||||
status=self._determine_run_status(),
|
status=self._determine_run_status(),
|
||||||
inputs={"query": query},
|
inputs={"query": query},
|
||||||
outputs=self._serialize_state(self._engine.state),
|
outputs=self._serialize_state(self._engine.state),
|
||||||
@@ -196,7 +198,28 @@ class WorkflowAgent(BaseAgent):
|
|||||||
completed_at=datetime.now(timezone.utc),
|
completed_at=datetime.now(timezone.utc),
|
||||||
)
|
)
|
||||||
|
|
||||||
workflow_runs_coll.insert_one(run.to_mongo_doc())
|
if not self.workflow_id or not owner_id:
|
||||||
|
return
|
||||||
|
with db_session() as conn:
|
||||||
|
wf_repo = WorkflowsRepository(conn)
|
||||||
|
if looks_like_uuid(self.workflow_id):
|
||||||
|
workflow_row = wf_repo.get(self.workflow_id, owner_id)
|
||||||
|
else:
|
||||||
|
workflow_row = wf_repo.get_by_legacy_id(
|
||||||
|
self.workflow_id, owner_id,
|
||||||
|
)
|
||||||
|
if workflow_row is None:
|
||||||
|
return
|
||||||
|
WorkflowRunsRepository(conn).create(
|
||||||
|
str(workflow_row["id"]),
|
||||||
|
owner_id,
|
||||||
|
run.status.value,
|
||||||
|
inputs=run.inputs,
|
||||||
|
result=run.outputs,
|
||||||
|
steps=[step.model_dump(mode="json") for step in run.steps],
|
||||||
|
started_at=run.created_at,
|
||||||
|
ended_at=run.completed_at,
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to save workflow run: {e}")
|
logger.error(f"Failed to save workflow run: {e}")
|
||||||
|
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ from datetime import datetime, timezone
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Dict, List, Literal, Optional, Union
|
from typing import Any, Dict, List, Literal, Optional, Union
|
||||||
|
|
||||||
from bson import ObjectId
|
|
||||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||||
|
|
||||||
|
|
||||||
@@ -81,24 +80,7 @@ class WorkflowEdgeCreate(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class WorkflowEdge(WorkflowEdgeCreate):
|
class WorkflowEdge(WorkflowEdgeCreate):
|
||||||
mongo_id: Optional[str] = Field(None, alias="_id")
|
pass
|
||||||
|
|
||||||
@field_validator("mongo_id", mode="before")
|
|
||||||
@classmethod
|
|
||||||
def convert_objectid(cls, v: Any) -> Optional[str]:
|
|
||||||
if isinstance(v, ObjectId):
|
|
||||||
return str(v)
|
|
||||||
return v
|
|
||||||
|
|
||||||
def to_mongo_doc(self) -> Dict[str, Any]:
|
|
||||||
return {
|
|
||||||
"id": self.id,
|
|
||||||
"workflow_id": self.workflow_id,
|
|
||||||
"source_id": self.source_id,
|
|
||||||
"target_id": self.target_id,
|
|
||||||
"source_handle": self.source_handle,
|
|
||||||
"target_handle": self.target_handle,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class WorkflowNodeCreate(BaseModel):
|
class WorkflowNodeCreate(BaseModel):
|
||||||
@@ -120,25 +102,7 @@ class WorkflowNodeCreate(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class WorkflowNode(WorkflowNodeCreate):
|
class WorkflowNode(WorkflowNodeCreate):
|
||||||
mongo_id: Optional[str] = Field(None, alias="_id")
|
pass
|
||||||
|
|
||||||
@field_validator("mongo_id", mode="before")
|
|
||||||
@classmethod
|
|
||||||
def convert_objectid(cls, v: Any) -> Optional[str]:
|
|
||||||
if isinstance(v, ObjectId):
|
|
||||||
return str(v)
|
|
||||||
return v
|
|
||||||
|
|
||||||
def to_mongo_doc(self) -> Dict[str, Any]:
|
|
||||||
return {
|
|
||||||
"id": self.id,
|
|
||||||
"workflow_id": self.workflow_id,
|
|
||||||
"type": self.type.value,
|
|
||||||
"title": self.title,
|
|
||||||
"description": self.description,
|
|
||||||
"position": self.position.model_dump(),
|
|
||||||
"config": self.config,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class WorkflowCreate(BaseModel):
|
class WorkflowCreate(BaseModel):
|
||||||
@@ -149,26 +113,10 @@ class WorkflowCreate(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class Workflow(WorkflowCreate):
|
class Workflow(WorkflowCreate):
|
||||||
id: Optional[str] = Field(None, alias="_id")
|
id: Optional[str] = None
|
||||||
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||||
updated_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
updated_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||||
|
|
||||||
@field_validator("id", mode="before")
|
|
||||||
@classmethod
|
|
||||||
def convert_objectid(cls, v: Any) -> Optional[str]:
|
|
||||||
if isinstance(v, ObjectId):
|
|
||||||
return str(v)
|
|
||||||
return v
|
|
||||||
|
|
||||||
def to_mongo_doc(self) -> Dict[str, Any]:
|
|
||||||
return {
|
|
||||||
"name": self.name,
|
|
||||||
"description": self.description,
|
|
||||||
"user": self.user,
|
|
||||||
"created_at": self.created_at,
|
|
||||||
"updated_at": self.updated_at,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class WorkflowGraph(BaseModel):
|
class WorkflowGraph(BaseModel):
|
||||||
workflow: Workflow
|
workflow: Workflow
|
||||||
@@ -209,29 +157,12 @@ class WorkflowRunCreate(BaseModel):
|
|||||||
|
|
||||||
class WorkflowRun(BaseModel):
|
class WorkflowRun(BaseModel):
|
||||||
model_config = ConfigDict(extra="allow")
|
model_config = ConfigDict(extra="allow")
|
||||||
id: Optional[str] = Field(None, alias="_id")
|
id: Optional[str] = None
|
||||||
workflow_id: str
|
workflow_id: str
|
||||||
|
user: Optional[str] = None
|
||||||
status: ExecutionStatus = ExecutionStatus.PENDING
|
status: ExecutionStatus = ExecutionStatus.PENDING
|
||||||
inputs: Dict[str, str] = Field(default_factory=dict)
|
inputs: Dict[str, str] = Field(default_factory=dict)
|
||||||
outputs: Dict[str, Any] = Field(default_factory=dict)
|
outputs: Dict[str, Any] = Field(default_factory=dict)
|
||||||
steps: List[NodeExecutionLog] = Field(default_factory=list)
|
steps: List[NodeExecutionLog] = Field(default_factory=list)
|
||||||
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||||
completed_at: Optional[datetime] = None
|
completed_at: Optional[datetime] = None
|
||||||
|
|
||||||
@field_validator("id", mode="before")
|
|
||||||
@classmethod
|
|
||||||
def convert_objectid(cls, v: Any) -> Optional[str]:
|
|
||||||
if isinstance(v, ObjectId):
|
|
||||||
return str(v)
|
|
||||||
return v
|
|
||||||
|
|
||||||
def to_mongo_doc(self) -> Dict[str, Any]:
|
|
||||||
return {
|
|
||||||
"workflow_id": self.workflow_id,
|
|
||||||
"status": self.status.value,
|
|
||||||
"inputs": self.inputs,
|
|
||||||
"outputs": self.outputs,
|
|
||||||
"steps": [step.model_dump() for step in self.steps],
|
|
||||||
"created_at": self.created_at,
|
|
||||||
"completed_at": self.completed_at,
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -200,6 +200,9 @@ class WorkflowEngine:
|
|||||||
|
|
||||||
node_config = AgentNodeConfig(**node.config.get("config", node.config))
|
node_config = AgentNodeConfig(**node.config.get("config", node.config))
|
||||||
|
|
||||||
|
if node_config.sources:
|
||||||
|
self._retrieve_node_sources(node_config)
|
||||||
|
|
||||||
if node_config.prompt_template:
|
if node_config.prompt_template:
|
||||||
formatted_prompt = self._format_template(node_config.prompt_template)
|
formatted_prompt = self._format_template(node_config.prompt_template)
|
||||||
else:
|
else:
|
||||||
@@ -455,6 +458,29 @@ class WorkflowEngine:
|
|||||||
docs_together = "\n\n".join(docs_together_parts) if docs_together_parts else None
|
docs_together = "\n\n".join(docs_together_parts) if docs_together_parts else None
|
||||||
return docs, docs_together
|
return docs, docs_together
|
||||||
|
|
||||||
|
def _retrieve_node_sources(self, node_config: AgentNodeConfig) -> None:
|
||||||
|
"""Retrieve documents from the node's sources for template resolution."""
|
||||||
|
from application.retriever.retriever_creator import RetrieverCreator
|
||||||
|
|
||||||
|
query = self.state.get("query", "")
|
||||||
|
if not query:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
retriever = RetrieverCreator.create_retriever(
|
||||||
|
node_config.retriever or "classic",
|
||||||
|
source={"active_docs": node_config.sources},
|
||||||
|
chat_history=[],
|
||||||
|
prompt="",
|
||||||
|
chunks=int(node_config.chunks) if node_config.chunks else 2,
|
||||||
|
decoded_token=self.agent.decoded_token,
|
||||||
|
)
|
||||||
|
docs = retriever.search(query)
|
||||||
|
if docs:
|
||||||
|
self.agent.retrieved_docs = docs
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to retrieve docs for workflow node")
|
||||||
|
|
||||||
def get_execution_summary(self) -> List[NodeExecutionLog]:
|
def get_execution_summary(self) -> List[NodeExecutionLog]:
|
||||||
return [
|
return [
|
||||||
NodeExecutionLog(
|
NodeExecutionLog(
|
||||||
|
|||||||
52
application/alembic.ini
Normal file
52
application/alembic.ini
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
# Alembic configuration for the DocsGPT user-data Postgres database.
|
||||||
|
#
|
||||||
|
# The SQLAlchemy URL is deliberately NOT set here — env.py reads it from
|
||||||
|
# ``application.core.settings.settings.POSTGRES_URI`` so the same config
|
||||||
|
# source serves the running app and migrations. To run from the project
|
||||||
|
# root::
|
||||||
|
#
|
||||||
|
# alembic -c application/alembic.ini upgrade head
|
||||||
|
|
||||||
|
[alembic]
|
||||||
|
script_location = %(here)s/alembic
|
||||||
|
prepend_sys_path = ..
|
||||||
|
version_path_separator = os
|
||||||
|
|
||||||
|
# sqlalchemy.url is intentionally left blank — env.py supplies it.
|
||||||
|
sqlalchemy.url =
|
||||||
|
|
||||||
|
[post_write_hooks]
|
||||||
|
|
||||||
|
[loggers]
|
||||||
|
keys = root,sqlalchemy,alembic
|
||||||
|
|
||||||
|
[handlers]
|
||||||
|
keys = console
|
||||||
|
|
||||||
|
[formatters]
|
||||||
|
keys = generic
|
||||||
|
|
||||||
|
[logger_root]
|
||||||
|
level = WARNING
|
||||||
|
handlers = console
|
||||||
|
qualname =
|
||||||
|
|
||||||
|
[logger_sqlalchemy]
|
||||||
|
level = WARNING
|
||||||
|
handlers =
|
||||||
|
qualname = sqlalchemy.engine
|
||||||
|
|
||||||
|
[logger_alembic]
|
||||||
|
level = INFO
|
||||||
|
handlers =
|
||||||
|
qualname = alembic
|
||||||
|
|
||||||
|
[handler_console]
|
||||||
|
class = StreamHandler
|
||||||
|
args = (sys.stderr,)
|
||||||
|
level = NOTSET
|
||||||
|
formatter = generic
|
||||||
|
|
||||||
|
[formatter_generic]
|
||||||
|
format = %(levelname)-5.5s [%(name)s] %(message)s
|
||||||
|
datefmt = %H:%M:%S
|
||||||
82
application/alembic/env.py
Normal file
82
application/alembic/env.py
Normal file
@@ -0,0 +1,82 @@
|
|||||||
|
"""Alembic environment for the DocsGPT user-data Postgres database.
|
||||||
|
|
||||||
|
The URL is pulled from ``application.core.settings`` rather than
|
||||||
|
``alembic.ini`` so that a single ``POSTGRES_URI`` env var drives both the
|
||||||
|
running app and ``alembic`` CLI invocations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
from logging.config import fileConfig
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# Make the project root importable regardless of cwd. env.py lives at
|
||||||
|
# <repo>/application/alembic/env.py, so parents[2] is the repo root.
|
||||||
|
_PROJECT_ROOT = Path(__file__).resolve().parents[2]
|
||||||
|
if str(_PROJECT_ROOT) not in sys.path:
|
||||||
|
sys.path.insert(0, str(_PROJECT_ROOT))
|
||||||
|
|
||||||
|
from alembic import context # noqa: E402
|
||||||
|
from sqlalchemy import engine_from_config, pool # noqa: E402
|
||||||
|
|
||||||
|
from application.core.settings import settings # noqa: E402
|
||||||
|
from application.storage.db.models import metadata as target_metadata # noqa: E402
|
||||||
|
|
||||||
|
config = context.config
|
||||||
|
|
||||||
|
# Populate the runtime URL from settings.
|
||||||
|
if settings.POSTGRES_URI:
|
||||||
|
config.set_main_option("sqlalchemy.url", settings.POSTGRES_URI)
|
||||||
|
|
||||||
|
if config.config_file_name is not None:
|
||||||
|
fileConfig(config.config_file_name)
|
||||||
|
|
||||||
|
|
||||||
|
def run_migrations_offline() -> None:
|
||||||
|
"""Run migrations in 'offline' mode (emits SQL without a live DB)."""
|
||||||
|
url = config.get_main_option("sqlalchemy.url")
|
||||||
|
if not url:
|
||||||
|
raise RuntimeError(
|
||||||
|
"POSTGRES_URI is not configured. Set it in your .env to a "
|
||||||
|
"psycopg3 URI such as "
|
||||||
|
"'postgresql+psycopg://user:pass@host:5432/docsgpt'."
|
||||||
|
)
|
||||||
|
context.configure(
|
||||||
|
url=url,
|
||||||
|
target_metadata=target_metadata,
|
||||||
|
literal_binds=True,
|
||||||
|
dialect_opts={"paramstyle": "named"},
|
||||||
|
compare_type=True,
|
||||||
|
)
|
||||||
|
with context.begin_transaction():
|
||||||
|
context.run_migrations()
|
||||||
|
|
||||||
|
|
||||||
|
def run_migrations_online() -> None:
|
||||||
|
"""Run migrations in 'online' mode against a live connection."""
|
||||||
|
if not config.get_main_option("sqlalchemy.url"):
|
||||||
|
raise RuntimeError(
|
||||||
|
"POSTGRES_URI is not configured. Set it in your .env to a "
|
||||||
|
"psycopg3 URI such as "
|
||||||
|
"'postgresql+psycopg://user:pass@host:5432/docsgpt'."
|
||||||
|
)
|
||||||
|
connectable = engine_from_config(
|
||||||
|
config.get_section(config.config_ini_section, {}),
|
||||||
|
prefix="sqlalchemy.",
|
||||||
|
poolclass=pool.NullPool,
|
||||||
|
future=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
with connectable.connect() as connection:
|
||||||
|
context.configure(
|
||||||
|
connection=connection,
|
||||||
|
target_metadata=target_metadata,
|
||||||
|
compare_type=True,
|
||||||
|
)
|
||||||
|
with context.begin_transaction():
|
||||||
|
context.run_migrations()
|
||||||
|
|
||||||
|
|
||||||
|
if context.is_offline_mode():
|
||||||
|
run_migrations_offline()
|
||||||
|
else:
|
||||||
|
run_migrations_online()
|
||||||
26
application/alembic/script.py.mako
Normal file
26
application/alembic/script.py.mako
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
"""${message}
|
||||||
|
|
||||||
|
Revision ID: ${up_revision}
|
||||||
|
Revises: ${down_revision | comma,n}
|
||||||
|
Create Date: ${create_date}
|
||||||
|
|
||||||
|
"""
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
${imports if imports else ""}
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = ${repr(up_revision)}
|
||||||
|
down_revision: Union[str, None] = ${repr(down_revision)}
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
|
||||||
|
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
${upgrades if upgrades else "pass"}
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
${downgrades if downgrades else "pass"}
|
||||||
927
application/alembic/versions/0001_initial.py
Normal file
927
application/alembic/versions/0001_initial.py
Normal file
@@ -0,0 +1,927 @@
|
|||||||
|
"""0001 initial schema — consolidated Phase-1..3 baseline.
|
||||||
|
|
||||||
|
Revision ID: 0001_initial
|
||||||
|
Revises:
|
||||||
|
Create Date: 2026-04-13
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
|
||||||
|
revision: str = "0001_initial"
|
||||||
|
down_revision: Union[str, None] = None
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Extensions
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
op.execute('CREATE EXTENSION IF NOT EXISTS "pgcrypto";')
|
||||||
|
op.execute('CREATE EXTENSION IF NOT EXISTS "citext";')
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Trigger functions
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
op.execute(
|
||||||
|
"""
|
||||||
|
CREATE FUNCTION set_updated_at() RETURNS trigger
|
||||||
|
LANGUAGE plpgsql AS $$
|
||||||
|
BEGIN
|
||||||
|
NEW.updated_at = now();
|
||||||
|
RETURN NEW;
|
||||||
|
END;
|
||||||
|
$$;
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
op.execute(
|
||||||
|
"""
|
||||||
|
CREATE FUNCTION ensure_user_exists() RETURNS trigger
|
||||||
|
LANGUAGE plpgsql AS $$
|
||||||
|
BEGIN
|
||||||
|
IF NEW.user_id IS NOT NULL THEN
|
||||||
|
INSERT INTO users (user_id) VALUES (NEW.user_id)
|
||||||
|
ON CONFLICT (user_id) DO NOTHING;
|
||||||
|
END IF;
|
||||||
|
RETURN NEW;
|
||||||
|
END;
|
||||||
|
$$;
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
op.execute(
|
||||||
|
"""
|
||||||
|
CREATE FUNCTION cleanup_message_attachment_refs() RETURNS trigger
|
||||||
|
LANGUAGE plpgsql AS $$
|
||||||
|
BEGIN
|
||||||
|
UPDATE conversation_messages
|
||||||
|
SET attachments = array_remove(attachments, OLD.id)
|
||||||
|
WHERE OLD.id = ANY(attachments);
|
||||||
|
RETURN OLD;
|
||||||
|
END;
|
||||||
|
$$;
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
op.execute(
|
||||||
|
"""
|
||||||
|
CREATE FUNCTION cleanup_agent_extra_source_refs() RETURNS trigger
|
||||||
|
LANGUAGE plpgsql AS $$
|
||||||
|
BEGIN
|
||||||
|
UPDATE agents
|
||||||
|
SET extra_source_ids = array_remove(extra_source_ids, OLD.id)
|
||||||
|
WHERE OLD.id = ANY(extra_source_ids);
|
||||||
|
RETURN OLD;
|
||||||
|
END;
|
||||||
|
$$;
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
op.execute(
|
||||||
|
"""
|
||||||
|
CREATE FUNCTION cleanup_user_agent_prefs() RETURNS trigger
|
||||||
|
LANGUAGE plpgsql AS $$
|
||||||
|
DECLARE
|
||||||
|
agent_id_text text := OLD.id::text;
|
||||||
|
BEGIN
|
||||||
|
UPDATE users
|
||||||
|
SET agent_preferences = jsonb_set(
|
||||||
|
jsonb_set(
|
||||||
|
agent_preferences,
|
||||||
|
'{pinned}',
|
||||||
|
COALESCE((
|
||||||
|
SELECT jsonb_agg(e)
|
||||||
|
FROM jsonb_array_elements(
|
||||||
|
COALESCE(agent_preferences->'pinned', '[]'::jsonb)
|
||||||
|
) e
|
||||||
|
WHERE (e #>> '{}') <> agent_id_text
|
||||||
|
), '[]'::jsonb)
|
||||||
|
),
|
||||||
|
'{shared_with_me}',
|
||||||
|
COALESCE((
|
||||||
|
SELECT jsonb_agg(e)
|
||||||
|
FROM jsonb_array_elements(
|
||||||
|
COALESCE(agent_preferences->'shared_with_me', '[]'::jsonb)
|
||||||
|
) e
|
||||||
|
WHERE (e #>> '{}') <> agent_id_text
|
||||||
|
), '[]'::jsonb)
|
||||||
|
)
|
||||||
|
WHERE agent_preferences->'pinned' @> to_jsonb(agent_id_text)
|
||||||
|
OR agent_preferences->'shared_with_me' @> to_jsonb(agent_id_text);
|
||||||
|
RETURN OLD;
|
||||||
|
END;
|
||||||
|
$$;
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
op.execute(
|
||||||
|
"""
|
||||||
|
CREATE FUNCTION conversation_messages_fill_user_id() RETURNS trigger
|
||||||
|
LANGUAGE plpgsql AS $$
|
||||||
|
BEGIN
|
||||||
|
IF NEW.user_id IS NULL THEN
|
||||||
|
SELECT user_id INTO NEW.user_id
|
||||||
|
FROM conversations
|
||||||
|
WHERE id = NEW.conversation_id;
|
||||||
|
END IF;
|
||||||
|
RETURN NEW;
|
||||||
|
END;
|
||||||
|
$$;
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Tables
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
op.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE users (
|
||||||
|
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||||
|
user_id TEXT NOT NULL UNIQUE,
|
||||||
|
agent_preferences JSONB NOT NULL
|
||||||
|
DEFAULT '{"pinned": [], "shared_with_me": []}'::jsonb,
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||||
|
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
op.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE prompts (
|
||||||
|
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||||
|
user_id TEXT NOT NULL,
|
||||||
|
name TEXT NOT NULL,
|
||||||
|
content TEXT NOT NULL,
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||||
|
updated_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||||
|
legacy_mongo_id TEXT
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
op.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE user_tools (
|
||||||
|
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||||
|
user_id TEXT NOT NULL,
|
||||||
|
name TEXT NOT NULL,
|
||||||
|
custom_name TEXT,
|
||||||
|
display_name TEXT,
|
||||||
|
description TEXT,
|
||||||
|
config JSONB NOT NULL DEFAULT '{}'::jsonb,
|
||||||
|
config_requirements JSONB NOT NULL DEFAULT '{}'::jsonb,
|
||||||
|
actions JSONB NOT NULL DEFAULT '[]'::jsonb,
|
||||||
|
status BOOLEAN NOT NULL DEFAULT true,
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||||
|
updated_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||||
|
legacy_mongo_id TEXT
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
op.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE token_usage (
|
||||||
|
id BIGSERIAL PRIMARY KEY,
|
||||||
|
user_id TEXT,
|
||||||
|
api_key TEXT,
|
||||||
|
agent_id UUID,
|
||||||
|
prompt_tokens INTEGER NOT NULL DEFAULT 0,
|
||||||
|
generated_tokens INTEGER NOT NULL DEFAULT 0,
|
||||||
|
timestamp TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||||
|
mongo_id TEXT
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
op.execute(
|
||||||
|
"ALTER TABLE token_usage ADD CONSTRAINT token_usage_attribution_chk "
|
||||||
|
"CHECK (user_id IS NOT NULL OR api_key IS NOT NULL) NOT VALID;"
|
||||||
|
)
|
||||||
|
|
||||||
|
op.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE user_logs (
|
||||||
|
id BIGSERIAL PRIMARY KEY,
|
||||||
|
user_id TEXT,
|
||||||
|
endpoint TEXT,
|
||||||
|
timestamp TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||||
|
data JSONB,
|
||||||
|
mongo_id TEXT
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
op.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE stack_logs (
|
||||||
|
id BIGSERIAL PRIMARY KEY,
|
||||||
|
activity_id TEXT NOT NULL,
|
||||||
|
endpoint TEXT,
|
||||||
|
level TEXT,
|
||||||
|
user_id TEXT,
|
||||||
|
api_key TEXT,
|
||||||
|
query TEXT,
|
||||||
|
stacks JSONB NOT NULL DEFAULT '[]'::jsonb,
|
||||||
|
timestamp TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||||
|
mongo_id TEXT
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
op.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE agent_folders (
|
||||||
|
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||||
|
user_id TEXT NOT NULL,
|
||||||
|
name TEXT NOT NULL,
|
||||||
|
description TEXT,
|
||||||
|
parent_id UUID REFERENCES agent_folders(id) ON DELETE SET NULL,
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||||
|
updated_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||||
|
legacy_mongo_id TEXT
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
op.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE sources (
|
||||||
|
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||||
|
user_id TEXT NOT NULL,
|
||||||
|
name TEXT NOT NULL,
|
||||||
|
language TEXT,
|
||||||
|
date TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||||
|
model TEXT,
|
||||||
|
type TEXT,
|
||||||
|
metadata JSONB NOT NULL DEFAULT '{}'::jsonb,
|
||||||
|
retriever TEXT,
|
||||||
|
sync_frequency TEXT,
|
||||||
|
tokens TEXT,
|
||||||
|
file_path TEXT,
|
||||||
|
remote_data JSONB,
|
||||||
|
directory_structure JSONB,
|
||||||
|
file_name_map JSONB,
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||||
|
updated_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||||
|
legacy_mongo_id TEXT
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
op.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE agents (
|
||||||
|
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||||
|
user_id TEXT NOT NULL,
|
||||||
|
name TEXT NOT NULL,
|
||||||
|
description TEXT,
|
||||||
|
agent_type TEXT,
|
||||||
|
status TEXT NOT NULL,
|
||||||
|
key CITEXT UNIQUE,
|
||||||
|
image TEXT,
|
||||||
|
source_id UUID REFERENCES sources(id) ON DELETE SET NULL,
|
||||||
|
extra_source_ids UUID[] NOT NULL DEFAULT '{}',
|
||||||
|
chunks INTEGER,
|
||||||
|
retriever TEXT,
|
||||||
|
prompt_id UUID REFERENCES prompts(id) ON DELETE SET NULL,
|
||||||
|
tools JSONB NOT NULL DEFAULT '[]'::jsonb,
|
||||||
|
json_schema JSONB,
|
||||||
|
models JSONB,
|
||||||
|
default_model_id TEXT,
|
||||||
|
folder_id UUID REFERENCES agent_folders(id) ON DELETE SET NULL,
|
||||||
|
workflow_id UUID,
|
||||||
|
limited_token_mode BOOLEAN NOT NULL DEFAULT false,
|
||||||
|
token_limit INTEGER,
|
||||||
|
limited_request_mode BOOLEAN NOT NULL DEFAULT false,
|
||||||
|
request_limit INTEGER,
|
||||||
|
allow_system_prompt_override BOOLEAN NOT NULL DEFAULT false,
|
||||||
|
shared BOOLEAN NOT NULL DEFAULT false,
|
||||||
|
shared_token CITEXT UNIQUE,
|
||||||
|
shared_metadata JSONB,
|
||||||
|
incoming_webhook_token CITEXT UNIQUE,
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||||
|
updated_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||||
|
last_used_at TIMESTAMPTZ,
|
||||||
|
legacy_mongo_id TEXT
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
op.execute(
|
||||||
|
"ALTER TABLE token_usage ADD CONSTRAINT token_usage_agent_fk "
|
||||||
|
"FOREIGN KEY (agent_id) REFERENCES agents(id) ON DELETE SET NULL;"
|
||||||
|
)
|
||||||
|
|
||||||
|
op.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE attachments (
|
||||||
|
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||||
|
user_id TEXT NOT NULL,
|
||||||
|
filename TEXT NOT NULL,
|
||||||
|
upload_path TEXT NOT NULL,
|
||||||
|
mime_type TEXT,
|
||||||
|
size BIGINT,
|
||||||
|
content TEXT,
|
||||||
|
token_count INTEGER,
|
||||||
|
openai_file_id TEXT,
|
||||||
|
google_file_uri TEXT,
|
||||||
|
metadata JSONB,
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||||
|
legacy_mongo_id TEXT
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
op.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE memories (
|
||||||
|
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||||
|
user_id TEXT NOT NULL,
|
||||||
|
tool_id UUID REFERENCES user_tools(id) ON DELETE CASCADE,
|
||||||
|
path TEXT NOT NULL,
|
||||||
|
content TEXT NOT NULL,
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||||
|
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
op.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE todos (
|
||||||
|
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||||
|
user_id TEXT NOT NULL,
|
||||||
|
tool_id UUID REFERENCES user_tools(id) ON DELETE CASCADE,
|
||||||
|
todo_id INTEGER,
|
||||||
|
title TEXT NOT NULL,
|
||||||
|
completed BOOLEAN NOT NULL DEFAULT false,
|
||||||
|
metadata JSONB NOT NULL DEFAULT '{}'::jsonb,
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||||
|
updated_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||||
|
legacy_mongo_id TEXT
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
op.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE notes (
|
||||||
|
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||||
|
user_id TEXT NOT NULL,
|
||||||
|
tool_id UUID REFERENCES user_tools(id) ON DELETE CASCADE,
|
||||||
|
title TEXT NOT NULL,
|
||||||
|
content TEXT NOT NULL,
|
||||||
|
metadata JSONB NOT NULL DEFAULT '{}'::jsonb,
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||||
|
updated_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||||
|
legacy_mongo_id TEXT
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
op.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE connector_sessions (
|
||||||
|
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||||
|
user_id TEXT NOT NULL,
|
||||||
|
provider TEXT NOT NULL,
|
||||||
|
server_url TEXT,
|
||||||
|
session_token TEXT UNIQUE,
|
||||||
|
user_email TEXT,
|
||||||
|
status TEXT,
|
||||||
|
token_info JSONB,
|
||||||
|
session_data JSONB NOT NULL DEFAULT '{}'::jsonb,
|
||||||
|
expires_at TIMESTAMPTZ,
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||||
|
legacy_mongo_id TEXT
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
op.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE conversations (
|
||||||
|
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||||
|
user_id TEXT NOT NULL,
|
||||||
|
agent_id UUID REFERENCES agents(id) ON DELETE SET NULL,
|
||||||
|
name TEXT,
|
||||||
|
api_key TEXT,
|
||||||
|
is_shared_usage BOOLEAN NOT NULL DEFAULT false,
|
||||||
|
shared_token TEXT,
|
||||||
|
date TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||||
|
updated_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||||
|
shared_with TEXT[] NOT NULL DEFAULT '{}'::text[],
|
||||||
|
compression_metadata JSONB,
|
||||||
|
legacy_mongo_id TEXT,
|
||||||
|
CONSTRAINT conversations_api_key_nonempty_chk
|
||||||
|
CHECK (api_key IS NULL OR api_key <> '')
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
op.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE conversation_messages (
|
||||||
|
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||||
|
conversation_id UUID NOT NULL REFERENCES conversations(id) ON DELETE CASCADE,
|
||||||
|
position INTEGER NOT NULL,
|
||||||
|
prompt TEXT,
|
||||||
|
response TEXT,
|
||||||
|
thought TEXT,
|
||||||
|
sources JSONB NOT NULL DEFAULT '[]'::jsonb,
|
||||||
|
tool_calls JSONB NOT NULL DEFAULT '[]'::jsonb,
|
||||||
|
attachments UUID[] NOT NULL DEFAULT '{}'::uuid[],
|
||||||
|
model_id TEXT,
|
||||||
|
message_metadata JSONB NOT NULL DEFAULT '{}'::jsonb,
|
||||||
|
feedback JSONB,
|
||||||
|
timestamp TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||||
|
user_id TEXT NOT NULL,
|
||||||
|
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
op.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE shared_conversations (
|
||||||
|
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||||
|
conversation_id UUID NOT NULL REFERENCES conversations(id) ON DELETE CASCADE,
|
||||||
|
user_id TEXT NOT NULL,
|
||||||
|
is_promptable BOOLEAN NOT NULL DEFAULT false,
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||||
|
uuid UUID NOT NULL,
|
||||||
|
first_n_queries INTEGER NOT NULL DEFAULT 0,
|
||||||
|
api_key TEXT,
|
||||||
|
prompt_id UUID REFERENCES prompts(id) ON DELETE SET NULL,
|
||||||
|
chunks INTEGER,
|
||||||
|
CONSTRAINT shared_conversations_api_key_nonempty_chk
|
||||||
|
CHECK (api_key IS NULL OR api_key <> '')
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
op.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE pending_tool_state (
|
||||||
|
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||||
|
conversation_id UUID NOT NULL REFERENCES conversations(id) ON DELETE CASCADE,
|
||||||
|
user_id TEXT NOT NULL,
|
||||||
|
messages JSONB NOT NULL,
|
||||||
|
pending_tool_calls JSONB NOT NULL,
|
||||||
|
tools_dict JSONB NOT NULL,
|
||||||
|
tool_schemas JSONB NOT NULL,
|
||||||
|
agent_config JSONB NOT NULL,
|
||||||
|
client_tools JSONB,
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||||
|
expires_at TIMESTAMPTZ NOT NULL
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
op.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE workflows (
|
||||||
|
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||||
|
user_id TEXT NOT NULL,
|
||||||
|
name TEXT NOT NULL,
|
||||||
|
description TEXT,
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||||
|
updated_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||||
|
current_graph_version INTEGER NOT NULL DEFAULT 1,
|
||||||
|
legacy_mongo_id TEXT
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
# Backfill the agents.workflow_id FK now that workflows exists.
|
||||||
|
# The column was created without a FK (forward reference to a table
|
||||||
|
# that hadn't been declared yet); add the constraint here so workflow
|
||||||
|
# deletion still cascades through to agent unset.
|
||||||
|
op.execute(
|
||||||
|
"ALTER TABLE agents ADD CONSTRAINT agents_workflow_fk "
|
||||||
|
"FOREIGN KEY (workflow_id) REFERENCES workflows(id) ON DELETE SET NULL;"
|
||||||
|
)
|
||||||
|
|
||||||
|
op.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE workflow_nodes (
|
||||||
|
id UUID DEFAULT gen_random_uuid() NOT NULL,
|
||||||
|
workflow_id UUID NOT NULL REFERENCES workflows(id) ON DELETE CASCADE,
|
||||||
|
graph_version INTEGER NOT NULL,
|
||||||
|
node_type TEXT NOT NULL,
|
||||||
|
config JSONB NOT NULL DEFAULT '{}'::jsonb,
|
||||||
|
node_id TEXT NOT NULL,
|
||||||
|
title TEXT,
|
||||||
|
description TEXT,
|
||||||
|
position JSONB NOT NULL DEFAULT '{"x": 0, "y": 0}'::jsonb,
|
||||||
|
legacy_mongo_id TEXT,
|
||||||
|
PRIMARY KEY (id),
|
||||||
|
CONSTRAINT workflow_nodes_id_wf_ver_key
|
||||||
|
UNIQUE (id, workflow_id, graph_version)
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
op.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE workflow_edges (
|
||||||
|
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||||
|
workflow_id UUID NOT NULL REFERENCES workflows(id) ON DELETE CASCADE,
|
||||||
|
graph_version INTEGER NOT NULL,
|
||||||
|
from_node_id UUID NOT NULL,
|
||||||
|
to_node_id UUID NOT NULL,
|
||||||
|
config JSONB NOT NULL DEFAULT '{}'::jsonb,
|
||||||
|
edge_id TEXT NOT NULL,
|
||||||
|
source_handle TEXT,
|
||||||
|
target_handle TEXT,
|
||||||
|
CONSTRAINT workflow_edges_from_node_fk
|
||||||
|
FOREIGN KEY (from_node_id, workflow_id, graph_version)
|
||||||
|
REFERENCES workflow_nodes(id, workflow_id, graph_version) ON DELETE CASCADE,
|
||||||
|
CONSTRAINT workflow_edges_to_node_fk
|
||||||
|
FOREIGN KEY (to_node_id, workflow_id, graph_version)
|
||||||
|
REFERENCES workflow_nodes(id, workflow_id, graph_version) ON DELETE CASCADE
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
op.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE workflow_runs (
|
||||||
|
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||||
|
workflow_id UUID NOT NULL REFERENCES workflows(id) ON DELETE CASCADE,
|
||||||
|
user_id TEXT NOT NULL,
|
||||||
|
status TEXT NOT NULL,
|
||||||
|
started_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||||
|
ended_at TIMESTAMPTZ,
|
||||||
|
result JSONB,
|
||||||
|
inputs JSONB,
|
||||||
|
steps JSONB NOT NULL DEFAULT '[]'::jsonb,
|
||||||
|
legacy_mongo_id TEXT,
|
||||||
|
CONSTRAINT workflow_runs_status_chk
|
||||||
|
CHECK (status IN ('pending', 'running', 'completed', 'failed'))
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Indexes
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
op.execute("CREATE INDEX agent_folders_user_idx ON agent_folders (user_id);")
|
||||||
|
|
||||||
|
op.execute("CREATE INDEX agents_user_idx ON agents (user_id);")
|
||||||
|
op.execute("CREATE INDEX agents_shared_idx ON agents (shared) WHERE shared = true;")
|
||||||
|
op.execute("CREATE INDEX agents_status_idx ON agents (status);")
|
||||||
|
op.execute("CREATE INDEX agents_source_id_idx ON agents (source_id);")
|
||||||
|
op.execute("CREATE INDEX agents_prompt_id_idx ON agents (prompt_id);")
|
||||||
|
op.execute("CREATE INDEX agents_folder_id_idx ON agents (folder_id);")
|
||||||
|
op.execute(
|
||||||
|
"CREATE UNIQUE INDEX agents_legacy_mongo_id_uidx "
|
||||||
|
"ON agents (legacy_mongo_id) WHERE legacy_mongo_id IS NOT NULL;"
|
||||||
|
)
|
||||||
|
|
||||||
|
op.execute("CREATE INDEX attachments_user_idx ON attachments (user_id);")
|
||||||
|
op.execute(
|
||||||
|
"CREATE UNIQUE INDEX attachments_legacy_mongo_id_uidx "
|
||||||
|
"ON attachments (legacy_mongo_id) WHERE legacy_mongo_id IS NOT NULL;"
|
||||||
|
)
|
||||||
|
|
||||||
|
op.execute(
|
||||||
|
# MCP and OAuth connectors share the ``provider`` slot, so the
|
||||||
|
# dedup key is ``(user_id, server_url, provider)``: MCP rows
|
||||||
|
# differentiate by server_url (one per MCP server), OAuth rows
|
||||||
|
# have server_url = NULL and differentiate by provider alone.
|
||||||
|
# COALESCE lets NULL server_url participate in the constraint.
|
||||||
|
"CREATE UNIQUE INDEX connector_sessions_user_endpoint_uidx "
|
||||||
|
"ON connector_sessions (user_id, COALESCE(server_url, ''), provider);"
|
||||||
|
)
|
||||||
|
op.execute(
|
||||||
|
"CREATE INDEX connector_sessions_expiry_idx "
|
||||||
|
"ON connector_sessions (expires_at) WHERE expires_at IS NOT NULL;"
|
||||||
|
)
|
||||||
|
op.execute(
|
||||||
|
"CREATE INDEX connector_sessions_server_url_idx "
|
||||||
|
"ON connector_sessions (server_url) WHERE server_url IS NOT NULL;"
|
||||||
|
)
|
||||||
|
op.execute(
|
||||||
|
"CREATE UNIQUE INDEX connector_sessions_legacy_mongo_id_uidx "
|
||||||
|
"ON connector_sessions (legacy_mongo_id) WHERE legacy_mongo_id IS NOT NULL;"
|
||||||
|
)
|
||||||
|
|
||||||
|
op.execute(
|
||||||
|
"CREATE UNIQUE INDEX conversation_messages_conv_pos_uidx "
|
||||||
|
"ON conversation_messages (conversation_id, position);"
|
||||||
|
)
|
||||||
|
op.execute(
|
||||||
|
"CREATE INDEX conversation_messages_user_ts_idx "
|
||||||
|
"ON conversation_messages (user_id, timestamp DESC);"
|
||||||
|
)
|
||||||
|
|
||||||
|
op.execute("CREATE INDEX conversations_user_date_idx ON conversations (user_id, date DESC);")
|
||||||
|
op.execute("CREATE INDEX conversations_agent_idx ON conversations (agent_id);")
|
||||||
|
op.execute(
|
||||||
|
"CREATE UNIQUE INDEX conversations_shared_token_uidx "
|
||||||
|
"ON conversations (shared_token) WHERE shared_token IS NOT NULL;"
|
||||||
|
)
|
||||||
|
op.execute(
|
||||||
|
"CREATE INDEX conversations_api_key_date_idx "
|
||||||
|
"ON conversations (api_key, date DESC) WHERE api_key IS NOT NULL;"
|
||||||
|
)
|
||||||
|
op.execute(
|
||||||
|
"CREATE UNIQUE INDEX conversations_legacy_mongo_id_uidx "
|
||||||
|
"ON conversations (legacy_mongo_id) WHERE legacy_mongo_id IS NOT NULL;"
|
||||||
|
)
|
||||||
|
|
||||||
|
op.execute(
|
||||||
|
"CREATE UNIQUE INDEX memories_user_tool_path_uidx "
|
||||||
|
"ON memories (user_id, tool_id, path);"
|
||||||
|
)
|
||||||
|
op.execute(
|
||||||
|
"CREATE UNIQUE INDEX memories_user_path_null_tool_uidx "
|
||||||
|
"ON memories (user_id, path) WHERE tool_id IS NULL;"
|
||||||
|
)
|
||||||
|
op.execute(
|
||||||
|
"CREATE INDEX memories_path_prefix_idx "
|
||||||
|
"ON memories (user_id, tool_id, path text_pattern_ops);"
|
||||||
|
)
|
||||||
|
op.execute("CREATE INDEX memories_tool_id_idx ON memories (tool_id);")
|
||||||
|
|
||||||
|
op.execute("CREATE UNIQUE INDEX notes_user_tool_uidx ON notes (user_id, tool_id);")
|
||||||
|
op.execute("CREATE INDEX notes_tool_id_idx ON notes (tool_id);")
|
||||||
|
op.execute(
|
||||||
|
"CREATE UNIQUE INDEX notes_legacy_mongo_id_uidx "
|
||||||
|
"ON notes (legacy_mongo_id) WHERE legacy_mongo_id IS NOT NULL;"
|
||||||
|
)
|
||||||
|
|
||||||
|
op.execute(
|
||||||
|
"CREATE UNIQUE INDEX pending_tool_state_conv_user_uidx "
|
||||||
|
"ON pending_tool_state (conversation_id, user_id);"
|
||||||
|
)
|
||||||
|
op.execute(
|
||||||
|
"CREATE INDEX pending_tool_state_expires_idx ON pending_tool_state (expires_at);"
|
||||||
|
)
|
||||||
|
|
||||||
|
op.execute("CREATE INDEX prompts_user_id_idx ON prompts (user_id);")
|
||||||
|
op.execute(
|
||||||
|
"CREATE UNIQUE INDEX prompts_legacy_mongo_id_uidx "
|
||||||
|
"ON prompts (legacy_mongo_id) WHERE legacy_mongo_id IS NOT NULL;"
|
||||||
|
)
|
||||||
|
|
||||||
|
op.execute("CREATE INDEX shared_conversations_user_idx ON shared_conversations (user_id);")
|
||||||
|
op.execute("CREATE INDEX shared_conversations_conv_idx ON shared_conversations (conversation_id);")
|
||||||
|
op.execute(
|
||||||
|
"CREATE INDEX shared_conversations_prompt_id_idx ON shared_conversations (prompt_id);"
|
||||||
|
)
|
||||||
|
op.execute(
|
||||||
|
"CREATE UNIQUE INDEX shared_conversations_uuid_uidx ON shared_conversations (uuid);"
|
||||||
|
)
|
||||||
|
op.execute(
|
||||||
|
"CREATE UNIQUE INDEX shared_conversations_dedup_uidx "
|
||||||
|
"ON shared_conversations (conversation_id, user_id, is_promptable, first_n_queries, COALESCE(api_key, ''));"
|
||||||
|
)
|
||||||
|
|
||||||
|
op.execute("CREATE INDEX sources_user_idx ON sources (user_id);")
|
||||||
|
op.execute(
|
||||||
|
"CREATE UNIQUE INDEX sources_legacy_mongo_id_uidx "
|
||||||
|
"ON sources (legacy_mongo_id) WHERE legacy_mongo_id IS NOT NULL;"
|
||||||
|
)
|
||||||
|
op.execute(
|
||||||
|
"CREATE UNIQUE INDEX user_tools_legacy_mongo_id_uidx "
|
||||||
|
"ON user_tools (legacy_mongo_id) WHERE legacy_mongo_id IS NOT NULL;"
|
||||||
|
)
|
||||||
|
op.execute(
|
||||||
|
"CREATE UNIQUE INDEX agent_folders_legacy_mongo_id_uidx "
|
||||||
|
"ON agent_folders (legacy_mongo_id) WHERE legacy_mongo_id IS NOT NULL;"
|
||||||
|
)
|
||||||
|
op.execute("CREATE INDEX agent_folders_parent_idx ON agent_folders (parent_id);")
|
||||||
|
op.execute("CREATE INDEX agents_workflow_idx ON agents (workflow_id);")
|
||||||
|
|
||||||
|
op.execute('CREATE INDEX stack_logs_timestamp_idx ON stack_logs ("timestamp" DESC);')
|
||||||
|
op.execute('CREATE INDEX stack_logs_user_ts_idx ON stack_logs (user_id, "timestamp" DESC);')
|
||||||
|
op.execute('CREATE INDEX stack_logs_level_ts_idx ON stack_logs (level, "timestamp" DESC);')
|
||||||
|
op.execute("CREATE INDEX stack_logs_activity_idx ON stack_logs (activity_id);")
|
||||||
|
op.execute(
|
||||||
|
"CREATE UNIQUE INDEX stack_logs_mongo_id_uidx "
|
||||||
|
"ON stack_logs (mongo_id) WHERE mongo_id IS NOT NULL;"
|
||||||
|
)
|
||||||
|
|
||||||
|
op.execute("CREATE INDEX todos_user_tool_idx ON todos (user_id, tool_id);")
|
||||||
|
op.execute("CREATE INDEX todos_tool_id_idx ON todos (tool_id);")
|
||||||
|
op.execute(
|
||||||
|
"CREATE UNIQUE INDEX todos_legacy_mongo_id_uidx "
|
||||||
|
"ON todos (legacy_mongo_id) WHERE legacy_mongo_id IS NOT NULL;"
|
||||||
|
)
|
||||||
|
op.execute(
|
||||||
|
"CREATE UNIQUE INDEX todos_tool_todo_id_uidx "
|
||||||
|
"ON todos (tool_id, todo_id) WHERE todo_id IS NOT NULL;"
|
||||||
|
)
|
||||||
|
|
||||||
|
op.execute('CREATE INDEX token_usage_user_ts_idx ON token_usage (user_id, "timestamp" DESC);')
|
||||||
|
op.execute('CREATE INDEX token_usage_key_ts_idx ON token_usage (api_key, "timestamp" DESC);')
|
||||||
|
op.execute('CREATE INDEX token_usage_agent_ts_idx ON token_usage (agent_id, "timestamp" DESC);')
|
||||||
|
op.execute(
|
||||||
|
"CREATE UNIQUE INDEX token_usage_mongo_id_uidx "
|
||||||
|
"ON token_usage (mongo_id) WHERE mongo_id IS NOT NULL;"
|
||||||
|
)
|
||||||
|
|
||||||
|
op.execute('CREATE INDEX user_logs_user_ts_idx ON user_logs (user_id, "timestamp" DESC);')
|
||||||
|
op.execute(
|
||||||
|
"CREATE UNIQUE INDEX user_logs_mongo_id_uidx "
|
||||||
|
"ON user_logs (mongo_id) WHERE mongo_id IS NOT NULL;"
|
||||||
|
)
|
||||||
|
|
||||||
|
op.execute("CREATE INDEX user_tools_user_id_idx ON user_tools (user_id);")
|
||||||
|
|
||||||
|
op.execute("CREATE INDEX workflow_edges_from_node_idx ON workflow_edges (from_node_id);")
|
||||||
|
op.execute("CREATE INDEX workflow_edges_to_node_idx ON workflow_edges (to_node_id);")
|
||||||
|
op.execute(
|
||||||
|
"CREATE UNIQUE INDEX workflow_edges_wf_ver_eid_uidx "
|
||||||
|
"ON workflow_edges (workflow_id, graph_version, edge_id);"
|
||||||
|
)
|
||||||
|
|
||||||
|
op.execute(
|
||||||
|
"CREATE UNIQUE INDEX workflow_nodes_wf_ver_nid_uidx "
|
||||||
|
"ON workflow_nodes (workflow_id, graph_version, node_id);"
|
||||||
|
)
|
||||||
|
op.execute(
|
||||||
|
"CREATE UNIQUE INDEX workflow_nodes_legacy_mongo_id_uidx "
|
||||||
|
"ON workflow_nodes (legacy_mongo_id) WHERE legacy_mongo_id IS NOT NULL;"
|
||||||
|
)
|
||||||
|
|
||||||
|
op.execute("CREATE INDEX workflow_runs_workflow_idx ON workflow_runs (workflow_id);")
|
||||||
|
op.execute("CREATE INDEX workflow_runs_user_idx ON workflow_runs (user_id);")
|
||||||
|
op.execute(
|
||||||
|
"CREATE INDEX workflow_runs_status_started_idx "
|
||||||
|
"ON workflow_runs (status, started_at DESC);"
|
||||||
|
)
|
||||||
|
op.execute(
|
||||||
|
"CREATE UNIQUE INDEX workflow_runs_legacy_mongo_id_uidx "
|
||||||
|
"ON workflow_runs (legacy_mongo_id) WHERE legacy_mongo_id IS NOT NULL;"
|
||||||
|
)
|
||||||
|
|
||||||
|
op.execute("CREATE INDEX workflows_user_idx ON workflows (user_id);")
|
||||||
|
op.execute(
|
||||||
|
"CREATE UNIQUE INDEX workflows_legacy_mongo_id_uidx "
|
||||||
|
"ON workflows (legacy_mongo_id) WHERE legacy_mongo_id IS NOT NULL;"
|
||||||
|
)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# user_id foreign keys (deferrable so backfills can stage rows)
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
user_fk_tables = (
|
||||||
|
"agent_folders",
|
||||||
|
"agents",
|
||||||
|
"attachments",
|
||||||
|
"connector_sessions",
|
||||||
|
"conversation_messages",
|
||||||
|
"conversations",
|
||||||
|
"memories",
|
||||||
|
"notes",
|
||||||
|
"pending_tool_state",
|
||||||
|
"prompts",
|
||||||
|
"shared_conversations",
|
||||||
|
"sources",
|
||||||
|
"stack_logs",
|
||||||
|
"todos",
|
||||||
|
"token_usage",
|
||||||
|
"user_logs",
|
||||||
|
"user_tools",
|
||||||
|
"workflow_runs",
|
||||||
|
"workflows",
|
||||||
|
)
|
||||||
|
for table in user_fk_tables:
|
||||||
|
op.execute(
|
||||||
|
f"ALTER TABLE {table} "
|
||||||
|
f"ADD CONSTRAINT {table}_user_id_fk "
|
||||||
|
f"FOREIGN KEY (user_id) REFERENCES users(user_id) "
|
||||||
|
f"ON DELETE RESTRICT DEFERRABLE INITIALLY IMMEDIATE;"
|
||||||
|
)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Triggers
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
updated_at_tables = (
|
||||||
|
"agent_folders",
|
||||||
|
"agents",
|
||||||
|
"conversation_messages",
|
||||||
|
"conversations",
|
||||||
|
"memories",
|
||||||
|
"notes",
|
||||||
|
"prompts",
|
||||||
|
"sources",
|
||||||
|
"todos",
|
||||||
|
"user_tools",
|
||||||
|
"users",
|
||||||
|
"workflows",
|
||||||
|
)
|
||||||
|
for table in updated_at_tables:
|
||||||
|
op.execute(
|
||||||
|
f"CREATE TRIGGER {table}_set_updated_at "
|
||||||
|
f"BEFORE UPDATE ON {table} "
|
||||||
|
f"FOR EACH ROW WHEN (OLD.* IS DISTINCT FROM NEW.*) "
|
||||||
|
f"EXECUTE FUNCTION set_updated_at();"
|
||||||
|
)
|
||||||
|
|
||||||
|
ensure_user_tables = (
|
||||||
|
"agent_folders",
|
||||||
|
"agents",
|
||||||
|
"attachments",
|
||||||
|
"connector_sessions",
|
||||||
|
"conversation_messages",
|
||||||
|
"conversations",
|
||||||
|
"memories",
|
||||||
|
"notes",
|
||||||
|
"pending_tool_state",
|
||||||
|
"prompts",
|
||||||
|
"shared_conversations",
|
||||||
|
"sources",
|
||||||
|
"stack_logs",
|
||||||
|
"todos",
|
||||||
|
"token_usage",
|
||||||
|
"user_logs",
|
||||||
|
"user_tools",
|
||||||
|
"workflow_runs",
|
||||||
|
"workflows",
|
||||||
|
)
|
||||||
|
for table in ensure_user_tables:
|
||||||
|
op.execute(
|
||||||
|
f"CREATE TRIGGER {table}_ensure_user "
|
||||||
|
f"BEFORE INSERT OR UPDATE OF user_id ON {table} "
|
||||||
|
f"FOR EACH ROW EXECUTE FUNCTION ensure_user_exists();"
|
||||||
|
)
|
||||||
|
|
||||||
|
op.execute(
|
||||||
|
"CREATE TRIGGER conversation_messages_fill_user "
|
||||||
|
"BEFORE INSERT ON conversation_messages "
|
||||||
|
"FOR EACH ROW EXECUTE FUNCTION conversation_messages_fill_user_id();"
|
||||||
|
)
|
||||||
|
|
||||||
|
op.execute(
|
||||||
|
"CREATE TRIGGER attachments_cleanup_message_refs "
|
||||||
|
"AFTER DELETE ON attachments "
|
||||||
|
"FOR EACH ROW EXECUTE FUNCTION cleanup_message_attachment_refs();"
|
||||||
|
)
|
||||||
|
op.execute(
|
||||||
|
"CREATE TRIGGER agents_cleanup_user_prefs "
|
||||||
|
"AFTER DELETE ON agents "
|
||||||
|
"FOR EACH ROW EXECUTE FUNCTION cleanup_user_agent_prefs();"
|
||||||
|
)
|
||||||
|
op.execute(
|
||||||
|
"CREATE TRIGGER sources_cleanup_agent_extra_refs "
|
||||||
|
"AFTER DELETE ON sources "
|
||||||
|
"FOR EACH ROW EXECUTE FUNCTION cleanup_agent_extra_source_refs();"
|
||||||
|
)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Seed sentinel __system__ user (system/template sources attribute here)
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
op.execute(
|
||||||
|
"INSERT INTO users (user_id) VALUES ('__system__') "
|
||||||
|
"ON CONFLICT (user_id) DO NOTHING;"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
# Nuclear downgrade: drop everything this migration created. The
|
||||||
|
# ordering drops FK-bearing children before parents; CASCADE would
|
||||||
|
# also work but explicit ordering is easier to reason about in code
|
||||||
|
# review.
|
||||||
|
tables_in_drop_order = (
|
||||||
|
"workflow_edges",
|
||||||
|
"workflow_runs",
|
||||||
|
"workflow_nodes",
|
||||||
|
"workflows",
|
||||||
|
"pending_tool_state",
|
||||||
|
"shared_conversations",
|
||||||
|
"conversation_messages",
|
||||||
|
"conversations",
|
||||||
|
"connector_sessions",
|
||||||
|
"notes",
|
||||||
|
"todos",
|
||||||
|
"memories",
|
||||||
|
"attachments",
|
||||||
|
"agents",
|
||||||
|
"sources",
|
||||||
|
"agent_folders",
|
||||||
|
"stack_logs",
|
||||||
|
"user_logs",
|
||||||
|
"token_usage",
|
||||||
|
"user_tools",
|
||||||
|
"prompts",
|
||||||
|
"users",
|
||||||
|
)
|
||||||
|
for table in tables_in_drop_order:
|
||||||
|
op.execute(f"DROP TABLE IF EXISTS {table} CASCADE;")
|
||||||
|
|
||||||
|
for fn in (
|
||||||
|
"conversation_messages_fill_user_id",
|
||||||
|
"cleanup_user_agent_prefs",
|
||||||
|
"cleanup_agent_extra_source_refs",
|
||||||
|
"cleanup_message_attachment_refs",
|
||||||
|
"ensure_user_exists",
|
||||||
|
"set_updated_at",
|
||||||
|
):
|
||||||
|
op.execute(f"DROP FUNCTION IF EXISTS {fn}();")
|
||||||
@@ -14,10 +14,13 @@ from application.core.model_utils import (
|
|||||||
get_provider_from_model_id,
|
get_provider_from_model_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
from application.core.mongo_db import MongoDB
|
|
||||||
from application.core.settings import settings
|
from application.core.settings import settings
|
||||||
from application.error import sanitize_api_error
|
from application.error import sanitize_api_error
|
||||||
from application.llm.llm_creator import LLMCreator
|
from application.llm.llm_creator import LLMCreator
|
||||||
|
from application.storage.db.repositories.agents import AgentsRepository
|
||||||
|
from application.storage.db.repositories.token_usage import TokenUsageRepository
|
||||||
|
from application.storage.db.repositories.user_logs import UserLogsRepository
|
||||||
|
from application.storage.db.session import db_readonly, db_session
|
||||||
from application.utils import check_required_fields
|
from application.utils import check_required_fields
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -30,10 +33,6 @@ class BaseAnswerResource:
|
|||||||
"""Shared base class for answer endpoints"""
|
"""Shared base class for answer endpoints"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
mongo = MongoDB.get_client()
|
|
||||||
db = mongo[settings.MONGO_DB_NAME]
|
|
||||||
self.db = db
|
|
||||||
self.user_logs_collection = db["user_logs"]
|
|
||||||
self.default_model_id = get_default_model_id()
|
self.default_model_id = get_default_model_id()
|
||||||
self.conversation_service = ConversationService()
|
self.conversation_service = ConversationService()
|
||||||
|
|
||||||
@@ -91,8 +90,8 @@ class BaseAnswerResource:
|
|||||||
api_key = agent_config.get("user_api_key")
|
api_key = agent_config.get("user_api_key")
|
||||||
if not api_key:
|
if not api_key:
|
||||||
return None
|
return None
|
||||||
agents_collection = self.db["agents"]
|
with db_readonly() as conn:
|
||||||
agent = agents_collection.find_one({"key": api_key})
|
agent = AgentsRepository(conn).find_by_key(api_key)
|
||||||
|
|
||||||
if not agent:
|
if not agent:
|
||||||
return make_response(
|
return make_response(
|
||||||
@@ -113,41 +112,32 @@ class BaseAnswerResource:
|
|||||||
)
|
)
|
||||||
|
|
||||||
token_limit = int(
|
token_limit = int(
|
||||||
agent.get("token_limit", settings.DEFAULT_AGENT_LIMITS["token_limit"])
|
agent.get("token_limit") or settings.DEFAULT_AGENT_LIMITS["token_limit"]
|
||||||
)
|
)
|
||||||
request_limit = int(
|
request_limit = int(
|
||||||
agent.get("request_limit", settings.DEFAULT_AGENT_LIMITS["request_limit"])
|
agent.get("request_limit") or settings.DEFAULT_AGENT_LIMITS["request_limit"]
|
||||||
)
|
)
|
||||||
|
|
||||||
token_usage_collection = self.db["token_usage"]
|
end_date = datetime.datetime.now(datetime.timezone.utc)
|
||||||
|
|
||||||
end_date = datetime.datetime.now()
|
|
||||||
start_date = end_date - datetime.timedelta(hours=24)
|
start_date = end_date - datetime.timedelta(hours=24)
|
||||||
|
|
||||||
match_query = {
|
if limited_token_mode or limited_request_mode:
|
||||||
"timestamp": {"$gte": start_date, "$lte": end_date},
|
with db_readonly() as conn:
|
||||||
"api_key": api_key,
|
token_repo = TokenUsageRepository(conn)
|
||||||
}
|
if limited_token_mode:
|
||||||
|
daily_token_usage = token_repo.sum_tokens_in_range(
|
||||||
if limited_token_mode:
|
start=start_date, end=end_date, api_key=api_key,
|
||||||
token_pipeline = [
|
)
|
||||||
{"$match": match_query},
|
else:
|
||||||
{
|
daily_token_usage = 0
|
||||||
"$group": {
|
if limited_request_mode:
|
||||||
"_id": None,
|
daily_request_usage = token_repo.count_in_range(
|
||||||
"total_tokens": {
|
start=start_date, end=end_date, api_key=api_key,
|
||||||
"$sum": {"$add": ["$prompt_tokens", "$generated_tokens"]}
|
)
|
||||||
},
|
else:
|
||||||
}
|
daily_request_usage = 0
|
||||||
},
|
|
||||||
]
|
|
||||||
token_result = list(token_usage_collection.aggregate(token_pipeline))
|
|
||||||
daily_token_usage = token_result[0]["total_tokens"] if token_result else 0
|
|
||||||
else:
|
else:
|
||||||
daily_token_usage = 0
|
daily_token_usage = 0
|
||||||
if limited_request_mode:
|
|
||||||
daily_request_usage = token_usage_collection.count_documents(match_query)
|
|
||||||
else:
|
|
||||||
daily_request_usage = 0
|
daily_request_usage = 0
|
||||||
if not limited_token_mode and not limited_request_mode:
|
if not limited_token_mode and not limited_request_mode:
|
||||||
return None
|
return None
|
||||||
@@ -467,7 +457,18 @@ class BaseAnswerResource:
|
|||||||
for key, value in log_data.items():
|
for key, value in log_data.items():
|
||||||
if isinstance(value, str) and len(value) > 10000:
|
if isinstance(value, str) and len(value) > 10000:
|
||||||
log_data[key] = value[:10000]
|
log_data[key] = value[:10000]
|
||||||
self.user_logs_collection.insert_one(log_data)
|
try:
|
||||||
|
with db_session() as conn:
|
||||||
|
UserLogsRepository(conn).insert(
|
||||||
|
user_id=log_data.get("user"),
|
||||||
|
endpoint="stream_answer",
|
||||||
|
data=log_data,
|
||||||
|
)
|
||||||
|
except Exception as log_err:
|
||||||
|
logger.error(
|
||||||
|
f"Failed to persist stream_answer user log: {log_err}",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
|
||||||
data = json.dumps({"type": "end"})
|
data = json.dumps({"type": "end"})
|
||||||
yield f"data: {data}\n\n"
|
yield f"data: {data}\n\n"
|
||||||
|
|||||||
@@ -4,11 +4,10 @@ from typing import Any, Dict, List
|
|||||||
from flask import make_response, request
|
from flask import make_response, request
|
||||||
from flask_restx import fields, Resource
|
from flask_restx import fields, Resource
|
||||||
|
|
||||||
from bson.dbref import DBRef
|
|
||||||
|
|
||||||
from application.api.answer.routes.base import answer_ns
|
from application.api.answer.routes.base import answer_ns
|
||||||
from application.core.mongo_db import MongoDB
|
|
||||||
from application.core.settings import settings
|
from application.core.settings import settings
|
||||||
|
from application.storage.db.repositories.agents import AgentsRepository
|
||||||
|
from application.storage.db.session import db_readonly
|
||||||
from application.vectorstore.vector_creator import VectorCreator
|
from application.vectorstore.vector_creator import VectorCreator
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -18,12 +17,6 @@ logger = logging.getLogger(__name__)
|
|||||||
class SearchResource(Resource):
|
class SearchResource(Resource):
|
||||||
"""Fast search endpoint for retrieving relevant documents"""
|
"""Fast search endpoint for retrieving relevant documents"""
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
mongo = MongoDB.get_client()
|
|
||||||
self.db = mongo[settings.MONGO_DB_NAME]
|
|
||||||
self.agents_collection = self.db["agents"]
|
|
||||||
|
|
||||||
search_model = answer_ns.model(
|
search_model = answer_ns.model(
|
||||||
"SearchModel",
|
"SearchModel",
|
||||||
{
|
{
|
||||||
@@ -40,37 +33,23 @@ class SearchResource(Resource):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _get_sources_from_api_key(self, api_key: str) -> List[str]:
|
def _get_sources_from_api_key(self, api_key: str) -> List[str]:
|
||||||
"""Get source IDs connected to the API key/agent.
|
"""Get source IDs connected to the API key/agent."""
|
||||||
|
with db_readonly() as conn:
|
||||||
"""
|
agent_data = AgentsRepository(conn).find_by_key(api_key)
|
||||||
agent_data = self.agents_collection.find_one({"key": api_key})
|
|
||||||
if not agent_data:
|
if not agent_data:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
source_ids = []
|
source_ids: List[str] = []
|
||||||
|
# extra_source_ids is a PG ARRAY(UUID) of source UUIDs.
|
||||||
|
extra = agent_data.get("extra_source_ids") or []
|
||||||
|
for src in extra:
|
||||||
|
if src:
|
||||||
|
source_ids.append(str(src))
|
||||||
|
|
||||||
# Handle multiple sources (only if non-empty)
|
|
||||||
sources = agent_data.get("sources", [])
|
|
||||||
if sources and isinstance(sources, list) and len(sources) > 0:
|
|
||||||
for source_ref in sources:
|
|
||||||
# Skip "default" - it's a placeholder, not an actual vectorstore
|
|
||||||
if source_ref == "default":
|
|
||||||
continue
|
|
||||||
elif isinstance(source_ref, DBRef):
|
|
||||||
source_doc = self.db.dereference(source_ref)
|
|
||||||
if source_doc:
|
|
||||||
source_ids.append(str(source_doc["_id"]))
|
|
||||||
|
|
||||||
# Handle single source (legacy) - check if sources was empty or didn't yield results
|
|
||||||
if not source_ids:
|
if not source_ids:
|
||||||
source = agent_data.get("source")
|
single = agent_data.get("source_id")
|
||||||
if isinstance(source, DBRef):
|
if single:
|
||||||
source_doc = self.db.dereference(source)
|
source_ids.append(str(single))
|
||||||
if source_doc:
|
|
||||||
source_ids.append(str(source_doc["_id"]))
|
|
||||||
# Skip "default" - it's a placeholder, not an actual vectorstore
|
|
||||||
elif source and source != "default":
|
|
||||||
source_ids.append(source)
|
|
||||||
|
|
||||||
return source_ids
|
return source_ids
|
||||||
|
|
||||||
@@ -161,7 +140,8 @@ class SearchResource(Resource):
|
|||||||
return make_response({"error": "api_key is required"}, 400)
|
return make_response({"error": "api_key is required"}, 400)
|
||||||
|
|
||||||
# Validate API key
|
# Validate API key
|
||||||
agent = self.agents_collection.find_one({"key": api_key})
|
with db_readonly() as conn:
|
||||||
|
agent = AgentsRepository(conn).find_by_key(api_key)
|
||||||
if not agent:
|
if not agent:
|
||||||
return make_response({"error": "Invalid API key"}, 401)
|
return make_response({"error": "Invalid API key"}, 401)
|
||||||
|
|
||||||
|
|||||||
@@ -1,18 +1,20 @@
|
|||||||
"""Service for saving and restoring tool-call continuation state.
|
"""Service for saving and restoring tool-call continuation state.
|
||||||
|
|
||||||
When a stream pauses (tool needs approval or client-side execution),
|
When a stream pauses (tool needs approval or client-side execution),
|
||||||
the full execution state is persisted to MongoDB so the client can
|
the full execution state is persisted to Postgres so the client can
|
||||||
resume later by sending tool_actions.
|
resume later by sending tool_actions.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import datetime
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
from bson import ObjectId
|
from application.storage.db.base_repository import looks_like_uuid
|
||||||
|
from application.storage.db.repositories.conversations import ConversationsRepository
|
||||||
from application.core.mongo_db import MongoDB
|
from application.storage.db.repositories.pending_tool_state import (
|
||||||
from application.core.settings import settings
|
PendingToolStateRepository,
|
||||||
|
)
|
||||||
|
from application.storage.db.session import db_readonly, db_session
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -21,8 +23,13 @@ PENDING_STATE_TTL_SECONDS = 30 * 60 # 30 minutes
|
|||||||
|
|
||||||
|
|
||||||
def _make_serializable(obj: Any) -> Any:
|
def _make_serializable(obj: Any) -> Any:
|
||||||
"""Recursively convert MongoDB ObjectIds and other non-JSON types."""
|
"""Recursively coerce non-JSON values into JSON-safe forms.
|
||||||
if isinstance(obj, ObjectId):
|
|
||||||
|
Handles ``uuid.UUID`` (from PG columns), ``bytes``, and recurses into
|
||||||
|
dicts/lists. Post-Mongo-cutover the ObjectId branch is gone — none of
|
||||||
|
our writers produce them anymore.
|
||||||
|
"""
|
||||||
|
if isinstance(obj, UUID):
|
||||||
return str(obj)
|
return str(obj)
|
||||||
if isinstance(obj, dict):
|
if isinstance(obj, dict):
|
||||||
return {str(k): _make_serializable(v) for k, v in obj.items()}
|
return {str(k): _make_serializable(v) for k, v in obj.items()}
|
||||||
@@ -34,25 +41,13 @@ def _make_serializable(obj: Any) -> Any:
|
|||||||
|
|
||||||
|
|
||||||
class ContinuationService:
|
class ContinuationService:
|
||||||
"""Manages pending tool-call state in MongoDB."""
|
"""Manages pending tool-call state in Postgres."""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
mongo = MongoDB.get_client()
|
# No-op constructor retained for call-site compatibility. State
|
||||||
db = mongo[settings.MONGO_DB_NAME]
|
# lives in Postgres now; each operation opens its own short-lived
|
||||||
self.collection = db["pending_tool_state"]
|
# session rather than holding a connection on the service.
|
||||||
self._ensure_indexes()
|
pass
|
||||||
|
|
||||||
def _ensure_indexes(self):
|
|
||||||
try:
|
|
||||||
self.collection.create_index(
|
|
||||||
"expires_at", expireAfterSeconds=0
|
|
||||||
)
|
|
||||||
self.collection.create_index(
|
|
||||||
[("conversation_id", 1), ("user", 1)], unique=True
|
|
||||||
)
|
|
||||||
except Exception:
|
|
||||||
# Indexes may already exist or mongomock doesn't support TTL
|
|
||||||
pass
|
|
||||||
|
|
||||||
def save_state(
|
def save_state(
|
||||||
self,
|
self,
|
||||||
@@ -67,6 +62,10 @@ class ContinuationService:
|
|||||||
) -> str:
|
) -> str:
|
||||||
"""Save execution state for later continuation.
|
"""Save execution state for later continuation.
|
||||||
|
|
||||||
|
``conversation_id`` may be a Postgres UUID or the legacy Mongo
|
||||||
|
``ObjectId`` string — the latter is resolved via
|
||||||
|
``conversations.legacy_mongo_id`` to find the matching row.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
conversation_id: The conversation this state belongs to.
|
conversation_id: The conversation this state belongs to.
|
||||||
user: Owner user ID.
|
user: Owner user ID.
|
||||||
@@ -78,36 +77,40 @@ class ContinuationService:
|
|||||||
client_tools: Client-provided tool schemas for client-side execution.
|
client_tools: Client-provided tool schemas for client-side execution.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The string ID of the saved state document.
|
The string ID (conversation_id as provided) of the saved state.
|
||||||
"""
|
"""
|
||||||
now = datetime.datetime.now(datetime.timezone.utc)
|
with db_session() as conn:
|
||||||
expires_at = now + datetime.timedelta(seconds=PENDING_STATE_TTL_SECONDS)
|
conv = ConversationsRepository(conn).get_by_legacy_id(conversation_id)
|
||||||
|
if conv is not None:
|
||||||
|
pg_conv_id = conv["id"]
|
||||||
|
elif looks_like_uuid(conversation_id):
|
||||||
|
pg_conv_id = conversation_id
|
||||||
|
else:
|
||||||
|
# Unresolvable legacy ObjectId — downstream ``CAST AS uuid``
|
||||||
|
# would raise and poison the save. Surface the mismatch so
|
||||||
|
# the caller can decide (the stream loop in routes/base.py
|
||||||
|
# already wraps this in try/except).
|
||||||
|
raise ValueError(
|
||||||
|
f"Cannot save continuation state: conversation_id "
|
||||||
|
f"{conversation_id!r} is neither a PG UUID nor a "
|
||||||
|
f"backfilled legacy Mongo id."
|
||||||
|
)
|
||||||
|
PendingToolStateRepository(conn).save_state(
|
||||||
|
pg_conv_id,
|
||||||
|
user,
|
||||||
|
messages=_make_serializable(messages),
|
||||||
|
pending_tool_calls=_make_serializable(pending_tool_calls),
|
||||||
|
tools_dict=_make_serializable(tools_dict),
|
||||||
|
tool_schemas=_make_serializable(tool_schemas),
|
||||||
|
agent_config=_make_serializable(agent_config),
|
||||||
|
client_tools=_make_serializable(client_tools) if client_tools else None,
|
||||||
|
)
|
||||||
|
|
||||||
doc = {
|
|
||||||
"conversation_id": conversation_id,
|
|
||||||
"user": user,
|
|
||||||
"messages": _make_serializable(messages),
|
|
||||||
"pending_tool_calls": _make_serializable(pending_tool_calls),
|
|
||||||
"tools_dict": _make_serializable(tools_dict),
|
|
||||||
"tool_schemas": _make_serializable(tool_schemas),
|
|
||||||
"agent_config": _make_serializable(agent_config),
|
|
||||||
"client_tools": _make_serializable(client_tools) if client_tools else None,
|
|
||||||
"created_at": now,
|
|
||||||
"expires_at": expires_at,
|
|
||||||
}
|
|
||||||
|
|
||||||
# Upsert — only one pending state per conversation per user
|
|
||||||
result = self.collection.replace_one(
|
|
||||||
{"conversation_id": conversation_id, "user": user},
|
|
||||||
doc,
|
|
||||||
upsert=True,
|
|
||||||
)
|
|
||||||
state_id = str(result.upserted_id) if result.upserted_id else conversation_id
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Saved continuation state for conversation {conversation_id} "
|
f"Saved continuation state for conversation {conversation_id} "
|
||||||
f"with {len(pending_tool_calls)} pending tool call(s)"
|
f"with {len(pending_tool_calls)} pending tool call(s)"
|
||||||
)
|
)
|
||||||
return state_id
|
return conversation_id
|
||||||
|
|
||||||
def load_state(
|
def load_state(
|
||||||
self, conversation_id: str, user: str
|
self, conversation_id: str, user: str
|
||||||
@@ -117,25 +120,38 @@ class ContinuationService:
|
|||||||
Returns:
|
Returns:
|
||||||
The state dict, or None if no pending state exists.
|
The state dict, or None if no pending state exists.
|
||||||
"""
|
"""
|
||||||
doc = self.collection.find_one(
|
with db_readonly() as conn:
|
||||||
{"conversation_id": conversation_id, "user": user}
|
conv = ConversationsRepository(conn).get_by_legacy_id(conversation_id)
|
||||||
)
|
if conv is not None:
|
||||||
|
pg_conv_id = conv["id"]
|
||||||
|
elif looks_like_uuid(conversation_id):
|
||||||
|
pg_conv_id = conversation_id
|
||||||
|
else:
|
||||||
|
# Unresolvable legacy ObjectId → no state can exist for it.
|
||||||
|
return None
|
||||||
|
doc = PendingToolStateRepository(conn).load_state(pg_conv_id, user)
|
||||||
if not doc:
|
if not doc:
|
||||||
return None
|
return None
|
||||||
doc["_id"] = str(doc["_id"])
|
|
||||||
return doc
|
return doc
|
||||||
|
|
||||||
def delete_state(self, conversation_id: str, user: str) -> bool:
|
def delete_state(self, conversation_id: str, user: str) -> bool:
|
||||||
"""Delete pending state after successful resumption.
|
"""Delete pending state after successful resumption.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
True if a document was deleted.
|
True if a row was deleted.
|
||||||
"""
|
"""
|
||||||
result = self.collection.delete_one(
|
with db_session() as conn:
|
||||||
{"conversation_id": conversation_id, "user": user}
|
conv = ConversationsRepository(conn).get_by_legacy_id(conversation_id)
|
||||||
)
|
if conv is not None:
|
||||||
if result.deleted_count:
|
pg_conv_id = conv["id"]
|
||||||
|
elif looks_like_uuid(conversation_id):
|
||||||
|
pg_conv_id = conversation_id
|
||||||
|
else:
|
||||||
|
# Unresolvable legacy ObjectId → nothing to delete.
|
||||||
|
return False
|
||||||
|
deleted = PendingToolStateRepository(conn).delete_state(pg_conv_id, user)
|
||||||
|
if deleted:
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Deleted continuation state for conversation {conversation_id}"
|
f"Deleted continuation state for conversation {conversation_id}"
|
||||||
)
|
)
|
||||||
return result.deleted_count > 0
|
return deleted
|
||||||
|
|||||||
@@ -1,44 +1,51 @@
|
|||||||
|
"""Conversation persistence service backed by Postgres.
|
||||||
|
|
||||||
|
Handles create / append / update / compression for conversations during
|
||||||
|
the answer-streaming path. Connections are opened per-operation rather
|
||||||
|
than held for the duration of a stream.
|
||||||
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
from application.core.mongo_db import MongoDB
|
from sqlalchemy import text as sql_text
|
||||||
|
|
||||||
from application.core.settings import settings
|
from application.core.settings import settings
|
||||||
from bson import ObjectId
|
from application.storage.db.base_repository import looks_like_uuid
|
||||||
|
from application.storage.db.repositories.agents import AgentsRepository
|
||||||
|
from application.storage.db.repositories.conversations import ConversationsRepository
|
||||||
|
from application.storage.db.session import db_readonly, db_session
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class ConversationService:
|
class ConversationService:
|
||||||
def __init__(self):
|
|
||||||
mongo = MongoDB.get_client()
|
|
||||||
db = mongo[settings.MONGO_DB_NAME]
|
|
||||||
self.conversations_collection = db["conversations"]
|
|
||||||
self.agents_collection = db["agents"]
|
|
||||||
|
|
||||||
def get_conversation(
|
def get_conversation(
|
||||||
self, conversation_id: str, user_id: str
|
self, conversation_id: str, user_id: str
|
||||||
) -> Optional[Dict[str, Any]]:
|
) -> Optional[Dict[str, Any]]:
|
||||||
"""Retrieve a conversation with proper access control"""
|
"""Retrieve a conversation with owner-or-shared access control.
|
||||||
|
|
||||||
|
Returns a dict in the legacy Mongo shape — ``queries`` is a list
|
||||||
|
of message dicts (prompt/response/...) — for compatibility with
|
||||||
|
the streaming pipeline that consumes this shape.
|
||||||
|
"""
|
||||||
if not conversation_id or not user_id:
|
if not conversation_id or not user_id:
|
||||||
return None
|
return None
|
||||||
try:
|
try:
|
||||||
conversation = self.conversations_collection.find_one(
|
with db_readonly() as conn:
|
||||||
{
|
repo = ConversationsRepository(conn)
|
||||||
"_id": ObjectId(conversation_id),
|
conv = repo.get_any(conversation_id, user_id)
|
||||||
"$or": [{"user": user_id}, {"shared_with": user_id}],
|
if conv is None:
|
||||||
}
|
logger.warning(
|
||||||
)
|
f"Conversation not found or unauthorized - ID: {conversation_id}, User: {user_id}"
|
||||||
|
)
|
||||||
if not conversation:
|
return None
|
||||||
logger.warning(
|
messages = repo.get_messages(str(conv["id"]))
|
||||||
f"Conversation not found or unauthorized - ID: {conversation_id}, User: {user_id}"
|
conv["queries"] = messages
|
||||||
)
|
conv["_id"] = str(conv["id"])
|
||||||
return None
|
return conv
|
||||||
conversation["_id"] = str(conversation["_id"])
|
|
||||||
return conversation
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error fetching conversation: {str(e)}", exc_info=True)
|
logger.error(f"Error fetching conversation: {str(e)}", exc_info=True)
|
||||||
return None
|
return None
|
||||||
@@ -62,7 +69,11 @@ class ConversationService:
|
|||||||
attachment_ids: Optional[List[str]] = None,
|
attachment_ids: Optional[List[str]] = None,
|
||||||
metadata: Optional[Dict[str, Any]] = None,
|
metadata: Optional[Dict[str, Any]] = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Save or update a conversation in the database"""
|
"""Save or update a conversation in Postgres.
|
||||||
|
|
||||||
|
Returns the string conversation id (PG UUID as string, or the
|
||||||
|
caller-provided id if it was already a UUID).
|
||||||
|
"""
|
||||||
if decoded_token is None:
|
if decoded_token is None:
|
||||||
raise ValueError("Invalid or missing authentication token")
|
raise ValueError("Invalid or missing authentication token")
|
||||||
user_id = decoded_token.get("sub")
|
user_id = decoded_token.get("sub")
|
||||||
@@ -70,78 +81,47 @@ class ConversationService:
|
|||||||
raise ValueError("User ID not found in token")
|
raise ValueError("User ID not found in token")
|
||||||
current_time = datetime.now(timezone.utc)
|
current_time = datetime.now(timezone.utc)
|
||||||
|
|
||||||
# clean up in sources array such that we save max 1k characters for text part
|
# Trim huge inline source text to a reasonable max before persist.
|
||||||
for source in sources:
|
for source in sources:
|
||||||
if "text" in source and isinstance(source["text"], str):
|
if "text" in source and isinstance(source["text"], str):
|
||||||
source["text"] = source["text"][:1000]
|
source["text"] = source["text"][:1000]
|
||||||
|
|
||||||
|
message_payload = {
|
||||||
|
"prompt": question,
|
||||||
|
"response": response,
|
||||||
|
"thought": thought,
|
||||||
|
"sources": sources,
|
||||||
|
"tool_calls": tool_calls,
|
||||||
|
"attachments": attachment_ids,
|
||||||
|
"model_id": model_id,
|
||||||
|
"timestamp": current_time,
|
||||||
|
}
|
||||||
|
if metadata:
|
||||||
|
message_payload["metadata"] = metadata
|
||||||
|
|
||||||
if conversation_id is not None and index is not None:
|
if conversation_id is not None and index is not None:
|
||||||
# Update existing conversation with new query
|
with db_session() as conn:
|
||||||
|
repo = ConversationsRepository(conn)
|
||||||
result = self.conversations_collection.update_one(
|
conv = repo.get_any(conversation_id, user_id)
|
||||||
{
|
if conv is None:
|
||||||
"_id": ObjectId(conversation_id),
|
raise ValueError("Conversation not found or unauthorized")
|
||||||
"user": user_id,
|
conv_pg_id = str(conv["id"])
|
||||||
f"queries.{index}": {"$exists": True},
|
repo.update_message_at(conv_pg_id, index, message_payload)
|
||||||
},
|
repo.truncate_after(conv_pg_id, index)
|
||||||
{
|
|
||||||
"$set": {
|
|
||||||
f"queries.{index}.prompt": question,
|
|
||||||
f"queries.{index}.response": response,
|
|
||||||
f"queries.{index}.thought": thought,
|
|
||||||
f"queries.{index}.sources": sources,
|
|
||||||
f"queries.{index}.tool_calls": tool_calls,
|
|
||||||
f"queries.{index}.timestamp": current_time,
|
|
||||||
f"queries.{index}.attachments": attachment_ids,
|
|
||||||
f"queries.{index}.model_id": model_id,
|
|
||||||
**(
|
|
||||||
{f"queries.{index}.metadata": metadata}
|
|
||||||
if metadata
|
|
||||||
else {}
|
|
||||||
),
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
if result.matched_count == 0:
|
|
||||||
raise ValueError("Conversation not found or unauthorized")
|
|
||||||
self.conversations_collection.update_one(
|
|
||||||
{
|
|
||||||
"_id": ObjectId(conversation_id),
|
|
||||||
"user": user_id,
|
|
||||||
f"queries.{index}": {"$exists": True},
|
|
||||||
},
|
|
||||||
{"$push": {"queries": {"$each": [], "$slice": index + 1}}},
|
|
||||||
)
|
|
||||||
return conversation_id
|
return conversation_id
|
||||||
elif conversation_id:
|
elif conversation_id:
|
||||||
# Append new message to existing conversation
|
with db_session() as conn:
|
||||||
|
repo = ConversationsRepository(conn)
|
||||||
result = self.conversations_collection.update_one(
|
conv = repo.get_any(conversation_id, user_id)
|
||||||
{"_id": ObjectId(conversation_id), "user": user_id},
|
if conv is None:
|
||||||
{
|
raise ValueError("Conversation not found or unauthorized")
|
||||||
"$push": {
|
conv_pg_id = str(conv["id"])
|
||||||
"queries": {
|
# append_message expects 'metadata' key either way; normalise.
|
||||||
"prompt": question,
|
append_payload = dict(message_payload)
|
||||||
"response": response,
|
append_payload.setdefault("metadata", metadata or {})
|
||||||
"thought": thought,
|
repo.append_message(conv_pg_id, append_payload)
|
||||||
"sources": sources,
|
|
||||||
"tool_calls": tool_calls,
|
|
||||||
"timestamp": current_time,
|
|
||||||
"attachments": attachment_ids,
|
|
||||||
"model_id": model_id,
|
|
||||||
**({"metadata": metadata} if metadata else {}),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
if result.matched_count == 0:
|
|
||||||
raise ValueError("Conversation not found or unauthorized")
|
|
||||||
return conversation_id
|
return conversation_id
|
||||||
else:
|
else:
|
||||||
# Create new conversation
|
|
||||||
|
|
||||||
messages_summary = [
|
messages_summary = [
|
||||||
{
|
{
|
||||||
"role": "system",
|
"role": "system",
|
||||||
@@ -163,70 +143,64 @@ class ConversationService:
|
|||||||
if not completion or not completion.strip():
|
if not completion or not completion.strip():
|
||||||
completion = question[:50] if question else "New Conversation"
|
completion = question[:50] if question else "New Conversation"
|
||||||
|
|
||||||
query_doc = {
|
resolved_api_key: Optional[str] = None
|
||||||
"prompt": question,
|
resolved_agent_id: Optional[str] = None
|
||||||
"response": response,
|
|
||||||
"thought": thought,
|
|
||||||
"sources": sources,
|
|
||||||
"tool_calls": tool_calls,
|
|
||||||
"timestamp": current_time,
|
|
||||||
"attachments": attachment_ids,
|
|
||||||
"model_id": model_id,
|
|
||||||
}
|
|
||||||
if metadata:
|
|
||||||
query_doc["metadata"] = metadata
|
|
||||||
|
|
||||||
conversation_data = {
|
|
||||||
"user": user_id,
|
|
||||||
"date": current_time,
|
|
||||||
"name": completion,
|
|
||||||
"queries": [query_doc],
|
|
||||||
}
|
|
||||||
|
|
||||||
if api_key:
|
if api_key:
|
||||||
if agent_id:
|
with db_readonly() as conn:
|
||||||
conversation_data["agent_id"] = agent_id
|
agent = AgentsRepository(conn).find_by_key(api_key)
|
||||||
if is_shared_usage:
|
|
||||||
conversation_data["is_shared_usage"] = is_shared_usage
|
|
||||||
conversation_data["shared_token"] = shared_token
|
|
||||||
agent = self.agents_collection.find_one({"key": api_key})
|
|
||||||
if agent:
|
if agent:
|
||||||
conversation_data["api_key"] = agent["key"]
|
resolved_api_key = agent.get("key")
|
||||||
result = self.conversations_collection.insert_one(conversation_data)
|
if agent_id:
|
||||||
return str(result.inserted_id)
|
resolved_agent_id = agent_id
|
||||||
|
|
||||||
|
with db_session() as conn:
|
||||||
|
repo = ConversationsRepository(conn)
|
||||||
|
conv = repo.create(
|
||||||
|
user_id,
|
||||||
|
completion,
|
||||||
|
agent_id=resolved_agent_id,
|
||||||
|
api_key=resolved_api_key,
|
||||||
|
is_shared_usage=bool(resolved_agent_id and is_shared_usage),
|
||||||
|
shared_token=(
|
||||||
|
shared_token
|
||||||
|
if (resolved_agent_id and is_shared_usage)
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
)
|
||||||
|
conv_pg_id = str(conv["id"])
|
||||||
|
append_payload = dict(message_payload)
|
||||||
|
append_payload.setdefault("metadata", metadata or {})
|
||||||
|
repo.append_message(conv_pg_id, append_payload)
|
||||||
|
return conv_pg_id
|
||||||
|
|
||||||
def update_compression_metadata(
|
def update_compression_metadata(
|
||||||
self, conversation_id: str, compression_metadata: Dict[str, Any]
|
self, conversation_id: str, compression_metadata: Dict[str, Any]
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""Persist compression flags and append a compression point.
|
||||||
Update conversation with compression metadata.
|
|
||||||
|
|
||||||
Uses $push with $slice to keep only the most recent compression points,
|
Mirrors the Mongo-era ``$set`` + ``$push $slice`` on
|
||||||
preventing unbounded array growth. Since each compression incorporates
|
``compression_metadata`` but goes through the PG repo API.
|
||||||
previous compressions, older points become redundant.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
conversation_id: Conversation ID
|
|
||||||
compression_metadata: Compression point data
|
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
self.conversations_collection.update_one(
|
with db_session() as conn:
|
||||||
{"_id": ObjectId(conversation_id)},
|
repo = ConversationsRepository(conn)
|
||||||
{
|
# conversation_id here comes from the streaming pipeline
|
||||||
"$set": {
|
# which has already resolved it; accept either UUID or
|
||||||
"compression_metadata.is_compressed": True,
|
# legacy id for safety.
|
||||||
"compression_metadata.last_compression_at": compression_metadata.get(
|
conv = repo.get_by_legacy_id(conversation_id)
|
||||||
"timestamp"
|
conv_pg_id = (
|
||||||
),
|
str(conv["id"]) if conv is not None else conversation_id
|
||||||
},
|
)
|
||||||
"$push": {
|
repo.set_compression_flags(
|
||||||
"compression_metadata.compression_points": {
|
conv_pg_id,
|
||||||
"$each": [compression_metadata],
|
is_compressed=True,
|
||||||
"$slice": -settings.COMPRESSION_MAX_HISTORY_POINTS,
|
last_compression_at=compression_metadata.get("timestamp"),
|
||||||
}
|
)
|
||||||
},
|
repo.append_compression_point(
|
||||||
},
|
conv_pg_id,
|
||||||
)
|
compression_metadata,
|
||||||
|
max_points=settings.COMPRESSION_MAX_HISTORY_POINTS,
|
||||||
|
)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Updated compression metadata for conversation {conversation_id}"
|
f"Updated compression metadata for conversation {conversation_id}"
|
||||||
)
|
)
|
||||||
@@ -239,34 +213,34 @@ class ConversationService:
|
|||||||
def append_compression_message(
|
def append_compression_message(
|
||||||
self, conversation_id: str, compression_metadata: Dict[str, Any]
|
self, conversation_id: str, compression_metadata: Dict[str, Any]
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""Append a synthetic compression summary message to the conversation."""
|
||||||
Append a synthetic compression summary entry into the conversation history.
|
|
||||||
This makes the summary visible in the DB alongside normal queries.
|
|
||||||
"""
|
|
||||||
try:
|
try:
|
||||||
summary = compression_metadata.get("compressed_summary", "")
|
summary = compression_metadata.get("compressed_summary", "")
|
||||||
if not summary:
|
if not summary:
|
||||||
return
|
return
|
||||||
timestamp = compression_metadata.get("timestamp", datetime.now(timezone.utc))
|
timestamp = compression_metadata.get(
|
||||||
|
"timestamp", datetime.now(timezone.utc)
|
||||||
self.conversations_collection.update_one(
|
)
|
||||||
{"_id": ObjectId(conversation_id)},
|
|
||||||
{
|
with db_session() as conn:
|
||||||
"$push": {
|
repo = ConversationsRepository(conn)
|
||||||
"queries": {
|
conv = repo.get_by_legacy_id(conversation_id)
|
||||||
"prompt": "[Context Compression Summary]",
|
conv_pg_id = (
|
||||||
"response": summary,
|
str(conv["id"]) if conv is not None else conversation_id
|
||||||
"thought": "",
|
)
|
||||||
"sources": [],
|
repo.append_message(conv_pg_id, {
|
||||||
"tool_calls": [],
|
"prompt": "[Context Compression Summary]",
|
||||||
"timestamp": timestamp,
|
"response": summary,
|
||||||
"attachments": [],
|
"thought": "",
|
||||||
"model_id": compression_metadata.get("model_used"),
|
"sources": [],
|
||||||
}
|
"tool_calls": [],
|
||||||
}
|
"attachments": [],
|
||||||
},
|
"model_id": compression_metadata.get("model_used"),
|
||||||
|
"timestamp": timestamp,
|
||||||
|
})
|
||||||
|
logger.info(
|
||||||
|
f"Appended compression summary to conversation {conversation_id}"
|
||||||
)
|
)
|
||||||
logger.info(f"Appended compression summary to conversation {conversation_id}")
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
logger.error(
|
||||||
f"Error appending compression summary: {str(e)}", exc_info=True
|
f"Error appending compression summary: {str(e)}", exc_info=True
|
||||||
@@ -275,20 +249,30 @@ class ConversationService:
|
|||||||
def get_compression_metadata(
|
def get_compression_metadata(
|
||||||
self, conversation_id: str
|
self, conversation_id: str
|
||||||
) -> Optional[Dict[str, Any]]:
|
) -> Optional[Dict[str, Any]]:
|
||||||
"""
|
"""Fetch the stored compression metadata JSONB blob for a conversation."""
|
||||||
Get compression metadata for a conversation.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
conversation_id: Conversation ID
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Compression metadata dict or None
|
|
||||||
"""
|
|
||||||
try:
|
try:
|
||||||
conversation = self.conversations_collection.find_one(
|
with db_readonly() as conn:
|
||||||
{"_id": ObjectId(conversation_id)}, {"compression_metadata": 1}
|
repo = ConversationsRepository(conn)
|
||||||
)
|
conv = repo.get_by_legacy_id(conversation_id)
|
||||||
return conversation.get("compression_metadata") if conversation else None
|
if conv is None:
|
||||||
|
# Fallback to UUID lookup without user scoping — the
|
||||||
|
# caller already holds an authenticated conversation
|
||||||
|
# id from the streaming path. Gate on id shape so a
|
||||||
|
# non-UUID (legacy ObjectId that wasn't backfilled)
|
||||||
|
# doesn't reach CAST — the cast raises and spams the
|
||||||
|
# logs with a stack trace on every call.
|
||||||
|
if not looks_like_uuid(conversation_id):
|
||||||
|
return None
|
||||||
|
result = conn.execute(
|
||||||
|
sql_text(
|
||||||
|
"SELECT compression_metadata FROM conversations "
|
||||||
|
"WHERE id = CAST(:id AS uuid)"
|
||||||
|
),
|
||||||
|
{"id": conversation_id},
|
||||||
|
)
|
||||||
|
row = result.fetchone()
|
||||||
|
return row[0] if row is not None else None
|
||||||
|
return conv.get("compression_metadata") if conv else None
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
logger.error(
|
||||||
f"Error getting compression metadata: {str(e)}", exc_info=True
|
f"Error getting compression metadata: {str(e)}", exc_info=True
|
||||||
|
|||||||
@@ -5,10 +5,6 @@ import os
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, Optional, Set
|
from typing import Any, Dict, Optional, Set
|
||||||
|
|
||||||
from bson.dbref import DBRef
|
|
||||||
|
|
||||||
from bson.objectid import ObjectId
|
|
||||||
|
|
||||||
from application.agents.agent_creator import AgentCreator
|
from application.agents.agent_creator import AgentCreator
|
||||||
from application.api.answer.services.compression import CompressionOrchestrator
|
from application.api.answer.services.compression import CompressionOrchestrator
|
||||||
from application.api.answer.services.compression.token_counter import TokenCounter
|
from application.api.answer.services.compression.token_counter import TokenCounter
|
||||||
@@ -20,8 +16,16 @@ from application.core.model_utils import (
|
|||||||
get_provider_from_model_id,
|
get_provider_from_model_id,
|
||||||
validate_model_id,
|
validate_model_id,
|
||||||
)
|
)
|
||||||
from application.core.mongo_db import MongoDB
|
|
||||||
from application.core.settings import settings
|
from application.core.settings import settings
|
||||||
|
from sqlalchemy import text as sql_text
|
||||||
|
|
||||||
|
from application.storage.db.base_repository import looks_like_uuid, row_to_dict
|
||||||
|
from application.storage.db.repositories.agents import AgentsRepository
|
||||||
|
from application.storage.db.repositories.attachments import AttachmentsRepository
|
||||||
|
from application.storage.db.repositories.prompts import PromptsRepository
|
||||||
|
from application.storage.db.repositories.sources import SourcesRepository
|
||||||
|
from application.storage.db.repositories.user_tools import UserToolsRepository
|
||||||
|
from application.storage.db.session import db_readonly, db_session
|
||||||
from application.retriever.retriever_creator import RetrieverCreator
|
from application.retriever.retriever_creator import RetrieverCreator
|
||||||
from application.utils import (
|
from application.utils import (
|
||||||
calculate_doc_token_budget,
|
calculate_doc_token_budget,
|
||||||
@@ -32,28 +36,41 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
def get_prompt(prompt_id: str, prompts_collection=None) -> str:
|
def get_prompt(prompt_id: str, prompts_collection=None) -> str:
|
||||||
|
"""Get a prompt by preset name or Postgres ID (UUID or legacy ObjectId).
|
||||||
|
|
||||||
|
The ``prompts_collection`` parameter is retained for backwards
|
||||||
|
compatibility with call sites that still pass it positionally; it is
|
||||||
|
ignored post-cutover.
|
||||||
"""
|
"""
|
||||||
Get a prompt by preset name or MongoDB ID
|
del prompts_collection # unused — retained for call-site compatibility
|
||||||
"""
|
# Callers may pass a ``uuid.UUID`` (from a PG ``prompt_id`` column) or a
|
||||||
|
# plain string ("default"/"creative"/legacy ObjectId). Normalise to str
|
||||||
|
# so both the preset lookup and the UUID-vs-legacy branching work.
|
||||||
|
# ``None`` / empty means "use the default prompt" — agents that never
|
||||||
|
# set a custom prompt land here (PG ``agents.prompt_id`` is NULL).
|
||||||
|
if prompt_id is None or prompt_id == "":
|
||||||
|
prompt_id = "default"
|
||||||
|
elif not isinstance(prompt_id, str):
|
||||||
|
prompt_id = str(prompt_id)
|
||||||
current_dir = Path(__file__).resolve().parents[3]
|
current_dir = Path(__file__).resolve().parents[3]
|
||||||
prompts_dir = current_dir / "prompts"
|
prompts_dir = current_dir / "prompts"
|
||||||
|
|
||||||
# Maps for classic agent types
|
|
||||||
CLASSIC_PRESETS = {
|
CLASSIC_PRESETS = {
|
||||||
"default": "chat_combine_default.txt",
|
"default": "chat_combine_default.txt",
|
||||||
"creative": "chat_combine_creative.txt",
|
"creative": "chat_combine_creative.txt",
|
||||||
"strict": "chat_combine_strict.txt",
|
"strict": "chat_combine_strict.txt",
|
||||||
"reduce": "chat_reduce_prompt.txt",
|
"reduce": "chat_reduce_prompt.txt",
|
||||||
}
|
}
|
||||||
|
|
||||||
# Agentic counterparts — same styles, but with search tool instructions
|
|
||||||
AGENTIC_PRESETS = {
|
AGENTIC_PRESETS = {
|
||||||
"default": "agentic/default.txt",
|
"default": "agentic/default.txt",
|
||||||
"creative": "agentic/creative.txt",
|
"creative": "agentic/creative.txt",
|
||||||
"strict": "agentic/strict.txt",
|
"strict": "agentic/strict.txt",
|
||||||
}
|
}
|
||||||
|
|
||||||
preset_mapping = {**CLASSIC_PRESETS, **{f"agentic_{k}": v for k, v in AGENTIC_PRESETS.items()}}
|
preset_mapping = {
|
||||||
|
**CLASSIC_PRESETS,
|
||||||
|
**{f"agentic_{k}": v for k, v in AGENTIC_PRESETS.items()},
|
||||||
|
}
|
||||||
|
|
||||||
if prompt_id in preset_mapping:
|
if prompt_id in preset_mapping:
|
||||||
file_path = os.path.join(prompts_dir, preset_mapping[prompt_id])
|
file_path = os.path.join(prompts_dir, preset_mapping[prompt_id])
|
||||||
@@ -63,14 +80,18 @@ def get_prompt(prompt_id: str, prompts_collection=None) -> str:
|
|||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
raise FileNotFoundError(f"Prompt file not found: {file_path}")
|
raise FileNotFoundError(f"Prompt file not found: {file_path}")
|
||||||
try:
|
try:
|
||||||
if prompts_collection is None:
|
with db_readonly() as conn:
|
||||||
mongo = MongoDB.get_client()
|
repo = PromptsRepository(conn)
|
||||||
db = mongo[settings.MONGO_DB_NAME]
|
prompt_doc = None
|
||||||
prompts_collection = db["prompts"]
|
if looks_like_uuid(prompt_id):
|
||||||
prompt_doc = prompts_collection.find_one({"_id": ObjectId(prompt_id)})
|
prompt_doc = repo.get_for_rendering(prompt_id)
|
||||||
|
if prompt_doc is None:
|
||||||
|
prompt_doc = repo.get_by_legacy_id(prompt_id)
|
||||||
if not prompt_doc:
|
if not prompt_doc:
|
||||||
raise ValueError(f"Prompt with ID {prompt_id} not found")
|
raise ValueError(f"Prompt with ID {prompt_id} not found")
|
||||||
return prompt_doc["content"]
|
return prompt_doc["content"]
|
||||||
|
except ValueError:
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ValueError(f"Invalid prompt ID: {prompt_id}") from e
|
raise ValueError(f"Invalid prompt ID: {prompt_id}") from e
|
||||||
|
|
||||||
@@ -79,12 +100,9 @@ class StreamProcessor:
|
|||||||
def __init__(
|
def __init__(
|
||||||
self, request_data: Dict[str, Any], decoded_token: Optional[Dict[str, Any]]
|
self, request_data: Dict[str, Any], decoded_token: Optional[Dict[str, Any]]
|
||||||
):
|
):
|
||||||
mongo = MongoDB.get_client()
|
# Legacy attribute retained as None for any external callers that
|
||||||
self.db = mongo[settings.MONGO_DB_NAME]
|
# introspect the processor; all DB access uses per-op connections.
|
||||||
self.agents_collection = self.db["agents"]
|
self.prompts_collection = None
|
||||||
self.attachments_collection = self.db["attachments"]
|
|
||||||
self.prompts_collection = self.db["prompts"]
|
|
||||||
|
|
||||||
self.data = request_data
|
self.data = request_data
|
||||||
self.decoded_token = decoded_token
|
self.decoded_token = decoded_token
|
||||||
self.initial_user_id = (
|
self.initial_user_id = (
|
||||||
@@ -112,6 +130,7 @@ class StreamProcessor:
|
|||||||
self._required_tool_actions: Optional[Dict[str, Set[Optional[str]]]] = None
|
self._required_tool_actions: Optional[Dict[str, Set[Optional[str]]]] = None
|
||||||
self.compressed_summary: Optional[str] = None
|
self.compressed_summary: Optional[str] = None
|
||||||
self.compressed_summary_tokens: int = 0
|
self.compressed_summary_tokens: int = 0
|
||||||
|
self._agent_data: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
def initialize(self):
|
def initialize(self):
|
||||||
"""Initialize all required components for processing"""
|
"""Initialize all required components for processing"""
|
||||||
@@ -243,17 +262,21 @@ class StreamProcessor:
|
|||||||
if not attachment_ids:
|
if not attachment_ids:
|
||||||
return []
|
return []
|
||||||
attachments = []
|
attachments = []
|
||||||
for attachment_id in attachment_ids:
|
try:
|
||||||
try:
|
with db_readonly() as conn:
|
||||||
attachment_doc = self.attachments_collection.find_one(
|
repo = AttachmentsRepository(conn)
|
||||||
{"_id": ObjectId(attachment_id), "user": user_id}
|
for attachment_id in attachment_ids:
|
||||||
)
|
try:
|
||||||
if attachment_doc:
|
attachment_doc = repo.get_any(str(attachment_id), user_id)
|
||||||
attachments.append(attachment_doc)
|
if attachment_doc:
|
||||||
except Exception as e:
|
attachments.append(attachment_doc)
|
||||||
logger.error(
|
except Exception as e:
|
||||||
f"Error retrieving attachment {attachment_id}: {e}", exc_info=True
|
logger.error(
|
||||||
)
|
f"Error retrieving attachment {attachment_id}: {e}",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error opening attachments connection: {e}", exc_info=True)
|
||||||
return attachments
|
return attachments
|
||||||
|
|
||||||
def _validate_and_set_model(self):
|
def _validate_and_set_model(self):
|
||||||
@@ -284,97 +307,127 @@ class StreamProcessor:
|
|||||||
self.model_id = get_default_model_id()
|
self.model_id = get_default_model_id()
|
||||||
|
|
||||||
def _get_agent_key(self, agent_id: Optional[str], user_id: Optional[str]) -> tuple:
|
def _get_agent_key(self, agent_id: Optional[str], user_id: Optional[str]) -> tuple:
|
||||||
"""Get API key for agent with access control"""
|
"""Get API key for agent with access control."""
|
||||||
if not agent_id:
|
if not agent_id:
|
||||||
return None, False, None
|
return None, False, None
|
||||||
try:
|
try:
|
||||||
agent = self.agents_collection.find_one({"_id": ObjectId(agent_id)})
|
with db_readonly() as conn:
|
||||||
|
# Lookup without user scoping — access control is done
|
||||||
|
# against ``user_id`` / ``shared_with`` / ``shared`` flags
|
||||||
|
# right below, matching the legacy Mongo semantics.
|
||||||
|
repo = AgentsRepository(conn)
|
||||||
|
agent = None
|
||||||
|
if looks_like_uuid(str(agent_id)):
|
||||||
|
result = conn.execute(
|
||||||
|
sql_text(
|
||||||
|
"SELECT * FROM agents WHERE id = CAST(:id AS uuid)"
|
||||||
|
),
|
||||||
|
{"id": str(agent_id)},
|
||||||
|
)
|
||||||
|
row = result.fetchone()
|
||||||
|
if row is not None:
|
||||||
|
agent = row_to_dict(row)
|
||||||
|
if agent is None:
|
||||||
|
agent = repo.get_by_legacy_id(str(agent_id))
|
||||||
if agent is None:
|
if agent is None:
|
||||||
raise Exception("Agent not found")
|
raise Exception("Agent not found")
|
||||||
is_owner = agent.get("user") == user_id
|
agent_owner = agent.get("user_id")
|
||||||
is_shared_with_user = agent.get(
|
is_owner = agent_owner == user_id
|
||||||
"shared_publicly", False
|
is_shared_with_user = bool(agent.get("shared", False))
|
||||||
) or user_id in agent.get("shared_with", [])
|
|
||||||
|
|
||||||
if not (is_owner or is_shared_with_user):
|
if not (is_owner or is_shared_with_user):
|
||||||
raise Exception("Unauthorized access to the agent")
|
raise Exception("Unauthorized access to the agent")
|
||||||
if is_owner:
|
if is_owner:
|
||||||
self.agents_collection.update_one(
|
now = datetime.datetime.now(datetime.timezone.utc)
|
||||||
{"_id": ObjectId(agent_id)},
|
try:
|
||||||
{
|
with db_session() as conn:
|
||||||
"$set": {
|
AgentsRepository(conn).update(
|
||||||
"lastUsedAt": datetime.datetime.now(datetime.timezone.utc)
|
str(agent["id"]), agent_owner,
|
||||||
}
|
{"last_used_at": now},
|
||||||
},
|
)
|
||||||
)
|
except Exception:
|
||||||
return str(agent["key"]), not is_owner, agent.get("shared_token")
|
logger.warning(
|
||||||
|
"Failed to update last_used_at for agent",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
return (
|
||||||
|
str(agent["key"]) if agent.get("key") else None,
|
||||||
|
not is_owner,
|
||||||
|
agent.get("shared_token"),
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error in get_agent_key: {str(e)}", exc_info=True)
|
logger.error(f"Error in get_agent_key: {str(e)}", exc_info=True)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def _get_data_from_api_key(self, api_key: str) -> Dict[str, Any]:
|
def _get_data_from_api_key(self, api_key: str) -> Dict[str, Any]:
|
||||||
data = self.agents_collection.find_one({"key": api_key})
|
with db_readonly() as conn:
|
||||||
if not data:
|
agent = AgentsRepository(conn).find_by_key(api_key)
|
||||||
raise Exception("Invalid API Key, please generate a new key", 401)
|
if not agent:
|
||||||
source = data.get("source")
|
raise Exception("Invalid API Key, please generate a new key", 401)
|
||||||
if isinstance(source, DBRef):
|
sources_repo = SourcesRepository(conn)
|
||||||
source_doc = self.db.dereference(source)
|
# The repo dict uses "user_id" — the streaming path expects
|
||||||
if source_doc:
|
# a "user" key (legacy Mongo shape) for identity propagation.
|
||||||
data["source"] = str(source_doc["_id"])
|
data: Dict[str, Any] = dict(agent)
|
||||||
data["retriever"] = source_doc.get("retriever", data.get("retriever"))
|
data["user"] = agent.get("user_id")
|
||||||
data["chunks"] = source_doc.get("chunks", data.get("chunks"))
|
|
||||||
|
# Resolve the primary source row (if any) for retriever/chunks.
|
||||||
|
source_id = agent.get("source_id")
|
||||||
|
if source_id:
|
||||||
|
source_doc = sources_repo.get(str(source_id), agent.get("user_id"))
|
||||||
|
if source_doc:
|
||||||
|
data["source"] = str(source_doc["id"])
|
||||||
|
data["retriever"] = source_doc.get(
|
||||||
|
"retriever", data.get("retriever")
|
||||||
|
)
|
||||||
|
data["chunks"] = source_doc.get("chunks", data.get("chunks"))
|
||||||
|
else:
|
||||||
|
data["source"] = None
|
||||||
else:
|
else:
|
||||||
data["source"] = None
|
data["source"] = None
|
||||||
elif source == "default":
|
|
||||||
data["source"] = "default"
|
|
||||||
else:
|
|
||||||
data["source"] = None
|
|
||||||
|
|
||||||
sources = data.get("sources", [])
|
|
||||||
if sources and isinstance(sources, list):
|
|
||||||
sources_list = []
|
sources_list = []
|
||||||
for i, source_ref in enumerate(sources):
|
extra = agent.get("extra_source_ids") or []
|
||||||
if source_ref == "default":
|
if extra:
|
||||||
processed_source = {
|
for sid in extra:
|
||||||
"id": "default",
|
source_doc = sources_repo.get(str(sid), agent.get("user_id"))
|
||||||
"retriever": "classic",
|
|
||||||
"chunks": data.get("chunks", "2"),
|
|
||||||
}
|
|
||||||
sources_list.append(processed_source)
|
|
||||||
elif isinstance(source_ref, DBRef):
|
|
||||||
source_doc = self.db.dereference(source_ref)
|
|
||||||
if source_doc:
|
if source_doc:
|
||||||
processed_source = {
|
sources_list.append(
|
||||||
"id": str(source_doc["_id"]),
|
{
|
||||||
"retriever": source_doc.get("retriever", "classic"),
|
"id": str(source_doc["id"]),
|
||||||
"chunks": source_doc.get("chunks", data.get("chunks", "2")),
|
"retriever": source_doc.get("retriever", "classic"),
|
||||||
}
|
"chunks": source_doc.get(
|
||||||
sources_list.append(processed_source)
|
"chunks", data.get("chunks", "2")
|
||||||
data["sources"] = sources_list
|
),
|
||||||
else:
|
}
|
||||||
data["sources"] = []
|
)
|
||||||
|
data["sources"] = sources_list
|
||||||
data["default_model_id"] = data.get("default_model_id", "")
|
data["default_model_id"] = data.get("default_model_id", "")
|
||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
def _configure_source(self):
|
def _configure_source(self):
|
||||||
"""Configure the source based on agent data"""
|
"""Configure the source based on agent data.
|
||||||
api_key = self.data.get("api_key") or self.agent_key
|
|
||||||
|
|
||||||
if api_key:
|
The literal string ``"default"`` is a placeholder meaning "no
|
||||||
agent_data = self._get_data_from_api_key(api_key)
|
ingested source" and is normalized to an empty source so that no
|
||||||
|
retrieval is attempted.
|
||||||
|
"""
|
||||||
|
if self._agent_data:
|
||||||
|
agent_data = self._agent_data
|
||||||
|
|
||||||
if agent_data.get("sources") and len(agent_data["sources"]) > 0:
|
if agent_data.get("sources") and len(agent_data["sources"]) > 0:
|
||||||
source_ids = [
|
source_ids = [
|
||||||
source["id"] for source in agent_data["sources"] if source.get("id")
|
source["id"]
|
||||||
|
for source in agent_data["sources"]
|
||||||
|
if source.get("id") and source["id"] != "default"
|
||||||
]
|
]
|
||||||
if source_ids:
|
if source_ids:
|
||||||
self.source = {"active_docs": source_ids}
|
self.source = {"active_docs": source_ids}
|
||||||
else:
|
else:
|
||||||
self.source = {}
|
self.source = {}
|
||||||
self.all_sources = agent_data["sources"]
|
self.all_sources = [
|
||||||
elif agent_data.get("source"):
|
s for s in agent_data["sources"] if s.get("id") != "default"
|
||||||
|
]
|
||||||
|
elif agent_data.get("source") and agent_data["source"] != "default":
|
||||||
self.source = {"active_docs": agent_data["source"]}
|
self.source = {"active_docs": agent_data["source"]}
|
||||||
self.all_sources = [
|
self.all_sources = [
|
||||||
{
|
{
|
||||||
@@ -387,11 +440,24 @@ class StreamProcessor:
|
|||||||
self.all_sources = []
|
self.all_sources = []
|
||||||
return
|
return
|
||||||
if "active_docs" in self.data:
|
if "active_docs" in self.data:
|
||||||
self.source = {"active_docs": self.data["active_docs"]}
|
active_docs = self.data["active_docs"]
|
||||||
|
if active_docs and active_docs != "default":
|
||||||
|
self.source = {"active_docs": active_docs}
|
||||||
|
else:
|
||||||
|
self.source = {}
|
||||||
return
|
return
|
||||||
self.source = {}
|
self.source = {}
|
||||||
self.all_sources = []
|
self.all_sources = []
|
||||||
|
|
||||||
|
def _has_active_docs(self) -> bool:
|
||||||
|
"""Return True if a real document source is configured for retrieval."""
|
||||||
|
active_docs = self.source.get("active_docs") if self.source else None
|
||||||
|
if not active_docs:
|
||||||
|
return False
|
||||||
|
if active_docs == "default":
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
def _resolve_agent_id(self) -> Optional[str]:
|
def _resolve_agent_id(self) -> Optional[str]:
|
||||||
"""Resolve agent_id from request, then fall back to conversation context."""
|
"""Resolve agent_id from request, then fall back to conversation context."""
|
||||||
request_agent_id = self.data.get("agent_id")
|
request_agent_id = self.data.get("agent_id")
|
||||||
@@ -433,48 +499,45 @@ class StreamProcessor:
|
|||||||
effective_key = self.data.get("api_key") or self.agent_key
|
effective_key = self.data.get("api_key") or self.agent_key
|
||||||
|
|
||||||
if effective_key:
|
if effective_key:
|
||||||
data_key = self._get_data_from_api_key(effective_key)
|
self._agent_data = self._get_data_from_api_key(effective_key)
|
||||||
if data_key.get("_id"):
|
if self._agent_data.get("_id"):
|
||||||
self.agent_id = str(data_key.get("_id"))
|
self.agent_id = str(self._agent_data.get("_id"))
|
||||||
|
|
||||||
self.agent_config.update(
|
self.agent_config.update(
|
||||||
{
|
{
|
||||||
"prompt_id": data_key.get("prompt_id", "default"),
|
"prompt_id": self._agent_data.get("prompt_id", "default"),
|
||||||
"agent_type": data_key.get("agent_type", settings.AGENT_NAME),
|
"agent_type": self._agent_data.get("agent_type", settings.AGENT_NAME),
|
||||||
"user_api_key": effective_key,
|
"user_api_key": effective_key,
|
||||||
"json_schema": data_key.get("json_schema"),
|
"json_schema": self._agent_data.get("json_schema"),
|
||||||
"default_model_id": data_key.get("default_model_id", ""),
|
"default_model_id": self._agent_data.get("default_model_id", ""),
|
||||||
"models": data_key.get("models", []),
|
"models": self._agent_data.get("models", []),
|
||||||
|
"allow_system_prompt_override": self._agent_data.get(
|
||||||
|
"allow_system_prompt_override", False
|
||||||
|
),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
# Set identity context
|
# Set identity context
|
||||||
if self.data.get("api_key"):
|
if self.data.get("api_key"):
|
||||||
# External API key: use the key owner's identity
|
# External API key: use the key owner's identity
|
||||||
self.initial_user_id = data_key.get("user")
|
self.initial_user_id = self._agent_data.get("user")
|
||||||
self.decoded_token = {"sub": data_key.get("user")}
|
self.decoded_token = {"sub": self._agent_data.get("user")}
|
||||||
elif self.is_shared_usage:
|
elif self.is_shared_usage:
|
||||||
# Shared agent: keep the caller's identity
|
# Shared agent: keep the caller's identity
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
# Owner using their own agent
|
# Owner using their own agent
|
||||||
self.decoded_token = {"sub": data_key.get("user")}
|
self.decoded_token = {"sub": self._agent_data.get("user")}
|
||||||
|
|
||||||
if data_key.get("source"):
|
# PG row exposes the workflow as ``workflow_id`` (UUID column);
|
||||||
self.source = {"active_docs": data_key["source"]}
|
# legacy Mongo shape used the key ``workflow``. Accept either so
|
||||||
if data_key.get("workflow"):
|
# API-key-invoked workflow agents bind correctly downstream.
|
||||||
self.agent_config["workflow"] = data_key["workflow"]
|
wf_ref = self._agent_data.get("workflow") or self._agent_data.get(
|
||||||
self.agent_config["workflow_owner"] = data_key.get("user")
|
"workflow_id"
|
||||||
if data_key.get("retriever"):
|
)
|
||||||
self.retriever_config["retriever_name"] = data_key["retriever"]
|
if wf_ref:
|
||||||
if data_key.get("chunks") is not None:
|
self.agent_config["workflow"] = str(wf_ref)
|
||||||
try:
|
self.agent_config["workflow_owner"] = self._agent_data.get("user")
|
||||||
self.retriever_config["chunks"] = int(data_key["chunks"])
|
|
||||||
except (ValueError, TypeError):
|
|
||||||
logger.warning(
|
|
||||||
f"Invalid chunks value: {data_key['chunks']}, using default value 2"
|
|
||||||
)
|
|
||||||
self.retriever_config["chunks"] = 2
|
|
||||||
else:
|
else:
|
||||||
# No API key — default/workflow configuration
|
# No API key — default/workflow configuration
|
||||||
agent_type = settings.AGENT_NAME
|
agent_type = settings.AGENT_NAME
|
||||||
@@ -497,14 +560,45 @@ class StreamProcessor:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _configure_retriever(self):
|
def _configure_retriever(self):
|
||||||
|
"""Assemble retriever config with precedence: request > agent > default."""
|
||||||
doc_token_limit = calculate_doc_token_budget(model_id=self.model_id)
|
doc_token_limit = calculate_doc_token_budget(model_id=self.model_id)
|
||||||
|
|
||||||
|
# Start with defaults
|
||||||
|
retriever_name = "classic"
|
||||||
|
chunks = 2
|
||||||
|
|
||||||
|
# Layer agent-level config (if present)
|
||||||
|
if self._agent_data:
|
||||||
|
if self._agent_data.get("retriever"):
|
||||||
|
retriever_name = self._agent_data["retriever"]
|
||||||
|
if self._agent_data.get("chunks") is not None:
|
||||||
|
try:
|
||||||
|
chunks = int(self._agent_data["chunks"])
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
logger.warning(
|
||||||
|
f"Invalid agent chunks value: {self._agent_data['chunks']}, "
|
||||||
|
"using default value 2"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Explicit request values win over agent config
|
||||||
|
if "retriever" in self.data:
|
||||||
|
retriever_name = self.data["retriever"]
|
||||||
|
if "chunks" in self.data:
|
||||||
|
try:
|
||||||
|
chunks = int(self.data["chunks"])
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
logger.warning(
|
||||||
|
f"Invalid request chunks value: {self.data['chunks']}, "
|
||||||
|
"using default value 2"
|
||||||
|
)
|
||||||
|
|
||||||
self.retriever_config = {
|
self.retriever_config = {
|
||||||
"retriever_name": self.data.get("retriever", "classic"),
|
"retriever_name": retriever_name,
|
||||||
"chunks": int(self.data.get("chunks", 2)),
|
"chunks": chunks,
|
||||||
"doc_token_limit": doc_token_limit,
|
"doc_token_limit": doc_token_limit,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# isNoneDoc without an API key forces no retrieval
|
||||||
api_key = self.data.get("api_key") or self.agent_key
|
api_key = self.data.get("api_key") or self.agent_key
|
||||||
if not api_key and "isNoneDoc" in self.data and self.data["isNoneDoc"]:
|
if not api_key and "isNoneDoc" in self.data and self.data["isNoneDoc"]:
|
||||||
self.retriever_config["chunks"] = 0
|
self.retriever_config["chunks"] = 0
|
||||||
@@ -528,6 +622,9 @@ class StreamProcessor:
|
|||||||
if self.data.get("isNoneDoc", False) and not self.agent_id:
|
if self.data.get("isNoneDoc", False) and not self.agent_id:
|
||||||
logger.info("Pre-fetch skipped: isNoneDoc=True")
|
logger.info("Pre-fetch skipped: isNoneDoc=True")
|
||||||
return None, None
|
return None, None
|
||||||
|
if not self._has_active_docs():
|
||||||
|
logger.info("Pre-fetch skipped: no active docs configured")
|
||||||
|
return None, None
|
||||||
try:
|
try:
|
||||||
retriever = self.create_retriever()
|
retriever = self.create_retriever()
|
||||||
logger.info(
|
logger.info(
|
||||||
@@ -574,12 +671,9 @@ class StreamProcessor:
|
|||||||
filtering_enabled = required_tool_actions is not None
|
filtering_enabled = required_tool_actions is not None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
user_tools_collection = self.db["user_tools"]
|
|
||||||
user_id = self.initial_user_id or "local"
|
user_id = self.initial_user_id or "local"
|
||||||
|
with db_readonly() as conn:
|
||||||
user_tools = list(
|
user_tools = UserToolsRepository(conn).list_active_for_user(user_id)
|
||||||
user_tools_collection.find({"user": user_id, "status": True})
|
|
||||||
)
|
|
||||||
|
|
||||||
if not user_tools:
|
if not user_tools:
|
||||||
return None
|
return None
|
||||||
@@ -910,15 +1004,23 @@ class StreamProcessor:
|
|||||||
raw_prompt = get_prompt(prompt_id, self.prompts_collection)
|
raw_prompt = get_prompt(prompt_id, self.prompts_collection)
|
||||||
self._prompt_content = raw_prompt
|
self._prompt_content = raw_prompt
|
||||||
|
|
||||||
rendered_prompt = self.prompt_renderer.render_prompt(
|
# Allow API callers to override the system prompt when the agent
|
||||||
prompt_content=raw_prompt,
|
# has opted in via allow_system_prompt_override.
|
||||||
user_id=self.initial_user_id,
|
if (
|
||||||
request_id=self.data.get("request_id"),
|
self.agent_config.get("allow_system_prompt_override", False)
|
||||||
passthrough_data=self.data.get("passthrough"),
|
and self.data.get("system_prompt_override")
|
||||||
docs=docs,
|
):
|
||||||
docs_together=docs_together,
|
rendered_prompt = self.data["system_prompt_override"]
|
||||||
tools_data=tools_data,
|
else:
|
||||||
)
|
rendered_prompt = self.prompt_renderer.render_prompt(
|
||||||
|
prompt_content=raw_prompt,
|
||||||
|
user_id=self.initial_user_id,
|
||||||
|
request_id=self.data.get("request_id"),
|
||||||
|
passthrough_data=self.data.get("passthrough"),
|
||||||
|
docs=docs,
|
||||||
|
docs_together=docs_together,
|
||||||
|
tools_data=tools_data,
|
||||||
|
)
|
||||||
|
|
||||||
provider = (
|
provider = (
|
||||||
get_provider_from_model_id(self.model_id)
|
get_provider_from_model_id(self.model_id)
|
||||||
@@ -932,8 +1034,10 @@ class StreamProcessor:
|
|||||||
from application.llm.handlers.handler_creator import LLMHandlerCreator
|
from application.llm.handlers.handler_creator import LLMHandlerCreator
|
||||||
from application.agents.tool_executor import ToolExecutor
|
from application.agents.tool_executor import ToolExecutor
|
||||||
|
|
||||||
# Compute backup models: agent's configured models minus the active one
|
# Compute backup models: agent's configured models minus the active one.
|
||||||
agent_models = self.agent_config.get("models", [])
|
# PG agents may carry an explicit ``models: NULL`` (not absent), so
|
||||||
|
# ``.get("models", [])`` isn't enough — coerce None → [].
|
||||||
|
agent_models = self.agent_config.get("models") or []
|
||||||
backup_models = [m for m in agent_models if m != self.model_id]
|
backup_models = [m for m in agent_models if m != self.model_id]
|
||||||
|
|
||||||
llm = LLMCreator.create_llm(
|
llm = LLMCreator.create_llm(
|
||||||
|
|||||||
@@ -1,12 +1,10 @@
|
|||||||
import base64
|
import base64
|
||||||
import datetime
|
|
||||||
import html
|
import html
|
||||||
import json
|
import json
|
||||||
import uuid
|
import uuid
|
||||||
from urllib.parse import urlencode
|
from urllib.parse import urlencode
|
||||||
|
|
||||||
|
|
||||||
from bson.objectid import ObjectId
|
|
||||||
from flask import (
|
from flask import (
|
||||||
Blueprint,
|
Blueprint,
|
||||||
current_app,
|
current_app,
|
||||||
@@ -17,22 +15,18 @@ from flask import (
|
|||||||
from flask_restx import fields, Namespace, Resource
|
from flask_restx import fields, Namespace, Resource
|
||||||
|
|
||||||
|
|
||||||
|
from application.api import api
|
||||||
from application.api.user.tasks import (
|
from application.api.user.tasks import (
|
||||||
ingest_connector_task,
|
ingest_connector_task,
|
||||||
)
|
)
|
||||||
from application.core.mongo_db import MongoDB
|
|
||||||
from application.core.settings import settings
|
|
||||||
from application.api import api
|
|
||||||
|
|
||||||
|
|
||||||
from application.parser.connectors.connector_creator import ConnectorCreator
|
from application.parser.connectors.connector_creator import ConnectorCreator
|
||||||
|
from application.storage.db.repositories.connector_sessions import (
|
||||||
|
ConnectorSessionsRepository,
|
||||||
|
)
|
||||||
|
from application.storage.db.repositories.sources import SourcesRepository
|
||||||
|
from application.storage.db.session import db_readonly, db_session
|
||||||
|
|
||||||
|
|
||||||
mongo = MongoDB.get_client()
|
|
||||||
db = mongo[settings.MONGO_DB_NAME]
|
|
||||||
sources_collection = db["sources"]
|
|
||||||
sessions_collection = db["connector_sessions"]
|
|
||||||
|
|
||||||
connector = Blueprint("connector", __name__)
|
connector = Blueprint("connector", __name__)
|
||||||
connectors_ns = Namespace("connectors", description="Connector operations", path="/")
|
connectors_ns = Namespace("connectors", description="Connector operations", path="/")
|
||||||
api.add_namespace(connectors_ns)
|
api.add_namespace(connectors_ns)
|
||||||
@@ -68,16 +62,14 @@ class ConnectorAuth(Resource):
|
|||||||
return make_response(jsonify({"success": False, "error": "Unauthorized"}), 401)
|
return make_response(jsonify({"success": False, "error": "Unauthorized"}), 401)
|
||||||
user_id = decoded_token.get('sub')
|
user_id = decoded_token.get('sub')
|
||||||
|
|
||||||
now = datetime.datetime.now(datetime.timezone.utc)
|
with db_session() as conn:
|
||||||
result = sessions_collection.insert_one({
|
session_row = ConnectorSessionsRepository(conn).upsert(
|
||||||
"provider": provider,
|
user_id, provider, status="pending",
|
||||||
"user": user_id,
|
)
|
||||||
"status": "pending",
|
session_pg_id = str(session_row["id"])
|
||||||
"created_at": now
|
|
||||||
})
|
|
||||||
state_dict = {
|
state_dict = {
|
||||||
"provider": provider,
|
"provider": provider,
|
||||||
"object_id": str(result.inserted_id)
|
"object_id": session_pg_id,
|
||||||
}
|
}
|
||||||
state = base64.urlsafe_b64encode(json.dumps(state_dict).encode()).decode()
|
state = base64.urlsafe_b64encode(json.dumps(state_dict).encode()).decode()
|
||||||
|
|
||||||
@@ -160,17 +152,25 @@ class ConnectorsCallback(Resource):
|
|||||||
|
|
||||||
sanitized_token_info = auth.sanitize_token_info(token_info)
|
sanitized_token_info = auth.sanitize_token_info(token_info)
|
||||||
|
|
||||||
sessions_collection.find_one_and_update(
|
# ``object_id`` in the OAuth state is the PG session row
|
||||||
{"_id": ObjectId(state_object_id), "provider": provider},
|
# UUID (new flow) or a legacy Mongo ObjectId (pre-cutover
|
||||||
{
|
# issued state). Try UUID update first; fall back to
|
||||||
"$set": {
|
# legacy id path.
|
||||||
"session_token": session_token,
|
patch = {
|
||||||
"token_info": sanitized_token_info,
|
"session_token": session_token,
|
||||||
"user_email": user_email,
|
"token_info": sanitized_token_info,
|
||||||
"status": "authorized"
|
"user_email": user_email,
|
||||||
}
|
"status": "authorized",
|
||||||
}
|
}
|
||||||
)
|
with db_session() as conn:
|
||||||
|
repo = ConnectorSessionsRepository(conn)
|
||||||
|
if state_object_id:
|
||||||
|
value = str(state_object_id)
|
||||||
|
updated = False
|
||||||
|
if len(value) == 36 and "-" in value:
|
||||||
|
updated = repo.update(value, patch)
|
||||||
|
if not updated:
|
||||||
|
repo.update_by_legacy_id(value, patch)
|
||||||
|
|
||||||
# Redirect to success page with session token and user email
|
# Redirect to success page with session token and user email
|
||||||
return redirect(build_callback_redirect({
|
return redirect(build_callback_redirect({
|
||||||
@@ -222,8 +222,11 @@ class ConnectorFiles(Resource):
|
|||||||
if not decoded_token:
|
if not decoded_token:
|
||||||
return make_response(jsonify({"success": False, "error": "Unauthorized"}), 401)
|
return make_response(jsonify({"success": False, "error": "Unauthorized"}), 401)
|
||||||
user = decoded_token.get('sub')
|
user = decoded_token.get('sub')
|
||||||
session = sessions_collection.find_one({"session_token": session_token, "user": user})
|
with db_readonly() as conn:
|
||||||
if not session:
|
session = ConnectorSessionsRepository(conn).get_by_session_token(
|
||||||
|
session_token,
|
||||||
|
)
|
||||||
|
if not session or session.get("user_id") != user:
|
||||||
return make_response(jsonify({"success": False, "error": "Invalid or unauthorized session"}), 401)
|
return make_response(jsonify({"success": False, "error": "Invalid or unauthorized session"}), 401)
|
||||||
|
|
||||||
loader = ConnectorCreator.create_connector(provider, session_token)
|
loader = ConnectorCreator.create_connector(provider, session_token)
|
||||||
@@ -288,8 +291,11 @@ class ConnectorValidateSession(Resource):
|
|||||||
return make_response(jsonify({"success": False, "error": "Unauthorized"}), 401)
|
return make_response(jsonify({"success": False, "error": "Unauthorized"}), 401)
|
||||||
user = decoded_token.get('sub')
|
user = decoded_token.get('sub')
|
||||||
|
|
||||||
session = sessions_collection.find_one({"session_token": session_token, "user": user})
|
with db_readonly() as conn:
|
||||||
if not session or "token_info" not in session:
|
session = ConnectorSessionsRepository(conn).get_by_session_token(
|
||||||
|
session_token,
|
||||||
|
)
|
||||||
|
if not session or session.get("user_id") != user or not session.get("token_info"):
|
||||||
return make_response(jsonify({"success": False, "error": "Invalid or expired session"}), 401)
|
return make_response(jsonify({"success": False, "error": "Invalid or expired session"}), 401)
|
||||||
|
|
||||||
token_info = session["token_info"]
|
token_info = session["token_info"]
|
||||||
@@ -300,10 +306,11 @@ class ConnectorValidateSession(Resource):
|
|||||||
try:
|
try:
|
||||||
refreshed_token_info = auth.refresh_access_token(token_info.get('refresh_token'))
|
refreshed_token_info = auth.refresh_access_token(token_info.get('refresh_token'))
|
||||||
sanitized_token_info = auth.sanitize_token_info(refreshed_token_info)
|
sanitized_token_info = auth.sanitize_token_info(refreshed_token_info)
|
||||||
sessions_collection.update_one(
|
with db_session() as conn:
|
||||||
{"session_token": session_token},
|
repo = ConnectorSessionsRepository(conn)
|
||||||
{"$set": {"token_info": sanitized_token_info}}
|
row = repo.get_by_session_token(session_token)
|
||||||
)
|
if row:
|
||||||
|
repo.update(str(row["id"]), {"token_info": sanitized_token_info})
|
||||||
token_info = sanitized_token_info
|
token_info = sanitized_token_info
|
||||||
is_expired = False
|
is_expired = False
|
||||||
except Exception as refresh_error:
|
except Exception as refresh_error:
|
||||||
@@ -347,7 +354,10 @@ class ConnectorDisconnect(Resource):
|
|||||||
|
|
||||||
|
|
||||||
if session_token:
|
if session_token:
|
||||||
sessions_collection.delete_one({"session_token": session_token})
|
with db_session() as conn:
|
||||||
|
ConnectorSessionsRepository(conn).delete_by_session_token(
|
||||||
|
session_token,
|
||||||
|
)
|
||||||
|
|
||||||
return make_response(jsonify({"success": True}), 200)
|
return make_response(jsonify({"success": True}), 200)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -385,7 +395,9 @@ class ConnectorSync(Resource):
|
|||||||
}),
|
}),
|
||||||
400
|
400
|
||||||
)
|
)
|
||||||
source = sources_collection.find_one({"_id": ObjectId(source_id)})
|
user_id = decoded_token.get('sub')
|
||||||
|
with db_readonly() as conn:
|
||||||
|
source = SourcesRepository(conn).get_any(source_id, user_id)
|
||||||
if not source:
|
if not source:
|
||||||
return make_response(
|
return make_response(
|
||||||
jsonify({
|
jsonify({
|
||||||
@@ -395,22 +407,16 @@ class ConnectorSync(Resource):
|
|||||||
404
|
404
|
||||||
)
|
)
|
||||||
|
|
||||||
if source.get('user') != decoded_token.get('sub'):
|
# ``get_any`` already scopes by ``user_id``; an extra guard
|
||||||
return make_response(
|
# here would be dead code.
|
||||||
jsonify({
|
|
||||||
"success": False,
|
|
||||||
"error": "Unauthorized access to source"
|
|
||||||
}),
|
|
||||||
403
|
|
||||||
)
|
|
||||||
|
|
||||||
remote_data = {}
|
remote_data = source.get('remote_data') or {}
|
||||||
try:
|
if isinstance(remote_data, str):
|
||||||
if source.get('remote_data'):
|
try:
|
||||||
remote_data = json.loads(source.get('remote_data'))
|
remote_data = json.loads(remote_data)
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
current_app.logger.error(f"Invalid remote_data format for source {source_id}")
|
current_app.logger.error(f"Invalid remote_data format for source {source_id}")
|
||||||
remote_data = {}
|
remote_data = {}
|
||||||
|
|
||||||
source_type = remote_data.get('provider')
|
source_type = remote_data.get('provider')
|
||||||
if not source_type:
|
if not source_type:
|
||||||
@@ -438,7 +444,7 @@ class ConnectorSync(Resource):
|
|||||||
recursive=recursive,
|
recursive=recursive,
|
||||||
retriever=source.get('retriever', 'classic'),
|
retriever=source.get('retriever', 'classic'),
|
||||||
operation_mode="sync",
|
operation_mode="sync",
|
||||||
doc_id=source_id,
|
doc_id=str(source.get('id') or source_id),
|
||||||
sync_frequency=source.get('sync_frequency', 'never')
|
sync_frequency=source.get('sync_frequency', 'never')
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -3,18 +3,16 @@ import datetime
|
|||||||
import json
|
import json
|
||||||
from flask import Blueprint, request, send_from_directory, jsonify
|
from flask import Blueprint, request, send_from_directory, jsonify
|
||||||
from werkzeug.utils import secure_filename
|
from werkzeug.utils import secure_filename
|
||||||
from bson.objectid import ObjectId
|
|
||||||
import logging
|
import logging
|
||||||
from application.core.mongo_db import MongoDB
|
|
||||||
from application.core.settings import settings
|
from application.core.settings import settings
|
||||||
|
from application.storage.db.base_repository import looks_like_uuid
|
||||||
|
from application.storage.db.repositories.sources import SourcesRepository
|
||||||
|
from application.storage.db.session import db_session
|
||||||
from application.storage.storage_creator import StorageCreator
|
from application.storage.storage_creator import StorageCreator
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
mongo = MongoDB.get_client()
|
|
||||||
db = mongo[settings.MONGO_DB_NAME]
|
|
||||||
conversations_collection = db["conversations"]
|
|
||||||
sources_collection = db["sources"]
|
|
||||||
|
|
||||||
current_dir = os.path.dirname(
|
current_dir = os.path.dirname(
|
||||||
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||||
@@ -62,7 +60,7 @@ def upload_index_files():
|
|||||||
job_name = request.form["name"]
|
job_name = request.form["name"]
|
||||||
tokens = request.form["tokens"]
|
tokens = request.form["tokens"]
|
||||||
retriever = request.form["retriever"]
|
retriever = request.form["retriever"]
|
||||||
id = request.form["id"]
|
source_id = request.form["id"]
|
||||||
type = request.form["type"]
|
type = request.form["type"]
|
||||||
remote_data = request.form["remote_data"] if "remote_data" in request.form else None
|
remote_data = request.form["remote_data"] if "remote_data" in request.form else None
|
||||||
sync_frequency = request.form["sync_frequency"] if "sync_frequency" in request.form else None
|
sync_frequency = request.form["sync_frequency"] if "sync_frequency" in request.form else None
|
||||||
@@ -89,7 +87,7 @@ def upload_index_files():
|
|||||||
file_name_map = None
|
file_name_map = None
|
||||||
|
|
||||||
storage = StorageCreator.get_storage()
|
storage = StorageCreator.get_storage()
|
||||||
index_base_path = f"indexes/{id}"
|
index_base_path = f"indexes/{source_id}"
|
||||||
|
|
||||||
if settings.VECTOR_STORE == "faiss":
|
if settings.VECTOR_STORE == "faiss":
|
||||||
if "file_faiss" not in request.files:
|
if "file_faiss" not in request.files:
|
||||||
@@ -111,46 +109,48 @@ def upload_index_files():
|
|||||||
storage.save_file(file_faiss, faiss_storage_path)
|
storage.save_file(file_faiss, faiss_storage_path)
|
||||||
storage.save_file(file_pkl, pkl_storage_path)
|
storage.save_file(file_pkl, pkl_storage_path)
|
||||||
|
|
||||||
|
now = datetime.datetime.now(datetime.timezone.utc)
|
||||||
|
update_fields = {
|
||||||
|
"name": job_name,
|
||||||
|
"type": type,
|
||||||
|
"language": job_name,
|
||||||
|
"date": now,
|
||||||
|
"model": settings.EMBEDDINGS_NAME,
|
||||||
|
"tokens": tokens,
|
||||||
|
"retriever": retriever,
|
||||||
|
"remote_data": remote_data,
|
||||||
|
"sync_frequency": sync_frequency,
|
||||||
|
"file_path": file_path,
|
||||||
|
"directory_structure": directory_structure,
|
||||||
|
}
|
||||||
|
if file_name_map is not None:
|
||||||
|
update_fields["file_name_map"] = file_name_map
|
||||||
|
|
||||||
existing_entry = sources_collection.find_one({"_id": ObjectId(id)})
|
with db_session() as conn:
|
||||||
if existing_entry:
|
repo = SourcesRepository(conn)
|
||||||
update_fields = {
|
existing = None
|
||||||
"user": user,
|
if looks_like_uuid(source_id):
|
||||||
"name": job_name,
|
existing = repo.get(source_id, user)
|
||||||
"language": job_name,
|
if existing is None:
|
||||||
"date": datetime.datetime.now(),
|
existing = repo.get_by_legacy_id(source_id, user)
|
||||||
"model": settings.EMBEDDINGS_NAME,
|
if existing is not None:
|
||||||
"type": type,
|
repo.update(str(existing["id"]), user, update_fields)
|
||||||
"tokens": tokens,
|
else:
|
||||||
"retriever": retriever,
|
repo.create(
|
||||||
"remote_data": remote_data,
|
job_name,
|
||||||
"sync_frequency": sync_frequency,
|
source_id=source_id if looks_like_uuid(source_id) else None,
|
||||||
"file_path": file_path,
|
user_id=user,
|
||||||
"directory_structure": directory_structure,
|
type=type,
|
||||||
}
|
tokens=tokens,
|
||||||
if file_name_map is not None:
|
retriever=retriever,
|
||||||
update_fields["file_name_map"] = file_name_map
|
remote_data=remote_data,
|
||||||
sources_collection.update_one(
|
sync_frequency=sync_frequency,
|
||||||
{"_id": ObjectId(id)},
|
file_path=file_path,
|
||||||
{"$set": update_fields},
|
directory_structure=directory_structure,
|
||||||
)
|
file_name_map=file_name_map,
|
||||||
else:
|
language=job_name,
|
||||||
insert_doc = {
|
model=settings.EMBEDDINGS_NAME,
|
||||||
"_id": ObjectId(id),
|
date=now,
|
||||||
"user": user,
|
legacy_mongo_id=None if looks_like_uuid(source_id) else str(source_id),
|
||||||
"name": job_name,
|
)
|
||||||
"language": job_name,
|
|
||||||
"date": datetime.datetime.now(),
|
|
||||||
"model": settings.EMBEDDINGS_NAME,
|
|
||||||
"type": type,
|
|
||||||
"tokens": tokens,
|
|
||||||
"retriever": retriever,
|
|
||||||
"remote_data": remote_data,
|
|
||||||
"sync_frequency": sync_frequency,
|
|
||||||
"file_path": file_path,
|
|
||||||
"directory_structure": directory_structure,
|
|
||||||
}
|
|
||||||
if file_name_map is not None:
|
|
||||||
insert_doc["file_name_map"] = file_name_map
|
|
||||||
sources_collection.insert_one(insert_doc)
|
|
||||||
return {"status": "ok"}
|
return {"status": "ok"}
|
||||||
|
|||||||
@@ -3,27 +3,50 @@ Agent folders management routes.
|
|||||||
Provides virtual folder organization for agents (Google Drive-like structure).
|
Provides virtual folder organization for agents (Google Drive-like structure).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import datetime
|
|
||||||
from bson.objectid import ObjectId
|
|
||||||
from flask import current_app, jsonify, make_response, request
|
from flask import current_app, jsonify, make_response, request
|
||||||
from flask_restx import Namespace, Resource, fields
|
from flask_restx import Namespace, Resource, fields
|
||||||
|
from sqlalchemy import text as _sql_text
|
||||||
|
|
||||||
from application.api import api
|
from application.api import api
|
||||||
from application.api.user.base import (
|
from application.storage.db.base_repository import looks_like_uuid
|
||||||
agent_folders_collection,
|
from application.storage.db.repositories.agent_folders import AgentFoldersRepository
|
||||||
agents_collection,
|
from application.storage.db.repositories.agents import AgentsRepository
|
||||||
)
|
from application.storage.db.session import db_readonly, db_session
|
||||||
|
|
||||||
|
|
||||||
agents_folders_ns = Namespace(
|
agents_folders_ns = Namespace(
|
||||||
"agents_folders", description="Agent folder management", path="/api/agents/folders"
|
"agents_folders", description="Agent folder management", path="/api/agents/folders"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_folder_id(repo: AgentFoldersRepository, folder_id: str, user: str):
|
||||||
|
"""Resolve a folder id that may be either a UUID or legacy Mongo ObjectId."""
|
||||||
|
if not folder_id:
|
||||||
|
return None
|
||||||
|
if looks_like_uuid(folder_id):
|
||||||
|
row = repo.get(folder_id, user)
|
||||||
|
if row is not None:
|
||||||
|
return row
|
||||||
|
return repo.get_by_legacy_id(folder_id, user)
|
||||||
|
|
||||||
|
|
||||||
def _folder_error_response(message: str, err: Exception):
|
def _folder_error_response(message: str, err: Exception):
|
||||||
current_app.logger.error(f"{message}: {err}", exc_info=True)
|
current_app.logger.error(f"{message}: {err}", exc_info=True)
|
||||||
return make_response(jsonify({"success": False, "message": message}), 400)
|
return make_response(jsonify({"success": False, "message": message}), 400)
|
||||||
|
|
||||||
|
|
||||||
|
def _serialize_folder(f: dict) -> dict:
|
||||||
|
created_at = f.get("created_at")
|
||||||
|
updated_at = f.get("updated_at")
|
||||||
|
return {
|
||||||
|
"id": str(f["id"]),
|
||||||
|
"name": f.get("name"),
|
||||||
|
"parent_id": str(f["parent_id"]) if f.get("parent_id") else None,
|
||||||
|
"created_at": created_at.isoformat() if hasattr(created_at, "isoformat") else created_at,
|
||||||
|
"updated_at": updated_at.isoformat() if hasattr(updated_at, "isoformat") else updated_at,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@agents_folders_ns.route("/")
|
@agents_folders_ns.route("/")
|
||||||
class AgentFolders(Resource):
|
class AgentFolders(Resource):
|
||||||
@api.doc(description="Get all folders for the user")
|
@api.doc(description="Get all folders for the user")
|
||||||
@@ -33,17 +56,9 @@ class AgentFolders(Resource):
|
|||||||
return make_response(jsonify({"success": False}), 401)
|
return make_response(jsonify({"success": False}), 401)
|
||||||
user = decoded_token.get("sub")
|
user = decoded_token.get("sub")
|
||||||
try:
|
try:
|
||||||
folders = list(agent_folders_collection.find({"user": user}))
|
with db_readonly() as conn:
|
||||||
result = [
|
folders = AgentFoldersRepository(conn).list_for_user(user)
|
||||||
{
|
result = [_serialize_folder(f) for f in folders]
|
||||||
"id": str(f["_id"]),
|
|
||||||
"name": f["name"],
|
|
||||||
"parent_id": f.get("parent_id"),
|
|
||||||
"created_at": f.get("created_at", "").isoformat() if f.get("created_at") else None,
|
|
||||||
"updated_at": f.get("updated_at", "").isoformat() if f.get("updated_at") else None,
|
|
||||||
}
|
|
||||||
for f in folders
|
|
||||||
]
|
|
||||||
return make_response(jsonify({"folders": result}), 200)
|
return make_response(jsonify({"folders": result}), 200)
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
return _folder_error_response("Failed to fetch folders", err)
|
return _folder_error_response("Failed to fetch folders", err)
|
||||||
@@ -67,24 +82,34 @@ class AgentFolders(Resource):
|
|||||||
if not data or not data.get("name"):
|
if not data or not data.get("name"):
|
||||||
return make_response(jsonify({"success": False, "message": "Folder name is required"}), 400)
|
return make_response(jsonify({"success": False, "message": "Folder name is required"}), 400)
|
||||||
|
|
||||||
parent_id = data.get("parent_id")
|
parent_id_input = data.get("parent_id")
|
||||||
if parent_id:
|
description = data.get("description")
|
||||||
parent = agent_folders_collection.find_one({"_id": ObjectId(parent_id), "user": user})
|
|
||||||
if not parent:
|
|
||||||
return make_response(jsonify({"success": False, "message": "Parent folder not found"}), 404)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
now = datetime.datetime.now(datetime.timezone.utc)
|
with db_session() as conn:
|
||||||
folder = {
|
repo = AgentFoldersRepository(conn)
|
||||||
"user": user,
|
pg_parent_id = None
|
||||||
"name": data["name"],
|
if parent_id_input:
|
||||||
"parent_id": parent_id,
|
parent = _resolve_folder_id(repo, parent_id_input, user)
|
||||||
"created_at": now,
|
if not parent:
|
||||||
"updated_at": now,
|
return make_response(
|
||||||
}
|
jsonify({"success": False, "message": "Parent folder not found"}),
|
||||||
result = agent_folders_collection.insert_one(folder)
|
404,
|
||||||
|
)
|
||||||
|
pg_parent_id = str(parent["id"])
|
||||||
|
folder = repo.create(
|
||||||
|
user, data["name"],
|
||||||
|
description=description,
|
||||||
|
parent_id=pg_parent_id,
|
||||||
|
)
|
||||||
return make_response(
|
return make_response(
|
||||||
jsonify({"id": str(result.inserted_id), "name": data["name"], "parent_id": parent_id}),
|
jsonify(
|
||||||
|
{
|
||||||
|
"id": str(folder["id"]),
|
||||||
|
"name": folder["name"],
|
||||||
|
"parent_id": pg_parent_id,
|
||||||
|
}
|
||||||
|
),
|
||||||
201,
|
201,
|
||||||
)
|
)
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
@@ -100,26 +125,51 @@ class AgentFolder(Resource):
|
|||||||
return make_response(jsonify({"success": False}), 401)
|
return make_response(jsonify({"success": False}), 401)
|
||||||
user = decoded_token.get("sub")
|
user = decoded_token.get("sub")
|
||||||
try:
|
try:
|
||||||
folder = agent_folders_collection.find_one({"_id": ObjectId(folder_id), "user": user})
|
with db_readonly() as conn:
|
||||||
if not folder:
|
folders_repo = AgentFoldersRepository(conn)
|
||||||
return make_response(jsonify({"success": False, "message": "Folder not found"}), 404)
|
folder = _resolve_folder_id(folders_repo, folder_id, user)
|
||||||
|
if not folder:
|
||||||
|
return make_response(
|
||||||
|
jsonify({"success": False, "message": "Folder not found"}),
|
||||||
|
404,
|
||||||
|
)
|
||||||
|
pg_folder_id = str(folder["id"])
|
||||||
|
|
||||||
agents = list(agents_collection.find({"user": user, "folder_id": folder_id}))
|
agents_rows = conn.execute(
|
||||||
agents_list = [
|
_sql_text(
|
||||||
{"id": str(a["_id"]), "name": a["name"], "description": a.get("description", "")}
|
"SELECT id, name, description FROM agents "
|
||||||
for a in agents
|
"WHERE user_id = :user_id AND folder_id = CAST(:fid AS uuid) "
|
||||||
]
|
"ORDER BY created_at DESC"
|
||||||
subfolders = list(agent_folders_collection.find({"user": user, "parent_id": folder_id}))
|
),
|
||||||
subfolders_list = [{"id": str(sf["_id"]), "name": sf["name"]} for sf in subfolders]
|
{"user_id": user, "fid": pg_folder_id},
|
||||||
|
).fetchall()
|
||||||
|
agents_list = [
|
||||||
|
{
|
||||||
|
"id": str(row._mapping["id"]),
|
||||||
|
"name": row._mapping["name"],
|
||||||
|
"description": row._mapping.get("description", "") or "",
|
||||||
|
}
|
||||||
|
for row in agents_rows
|
||||||
|
]
|
||||||
|
|
||||||
|
subfolders = folders_repo.list_children(pg_folder_id, user)
|
||||||
|
subfolders_list = [
|
||||||
|
{"id": str(sf["id"]), "name": sf["name"]}
|
||||||
|
for sf in subfolders
|
||||||
|
]
|
||||||
|
|
||||||
return make_response(
|
return make_response(
|
||||||
jsonify({
|
jsonify(
|
||||||
"id": str(folder["_id"]),
|
{
|
||||||
"name": folder["name"],
|
"id": pg_folder_id,
|
||||||
"parent_id": folder.get("parent_id"),
|
"name": folder["name"],
|
||||||
"agents": agents_list,
|
"parent_id": (
|
||||||
"subfolders": subfolders_list,
|
str(folder["parent_id"]) if folder.get("parent_id") else None
|
||||||
}),
|
),
|
||||||
|
"agents": agents_list,
|
||||||
|
"subfolders": subfolders_list,
|
||||||
|
}
|
||||||
|
),
|
||||||
200,
|
200,
|
||||||
)
|
)
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
@@ -136,19 +186,57 @@ class AgentFolder(Resource):
|
|||||||
return make_response(jsonify({"success": False, "message": "No data provided"}), 400)
|
return make_response(jsonify({"success": False, "message": "No data provided"}), 400)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
update_fields = {"updated_at": datetime.datetime.now(datetime.timezone.utc)}
|
with db_session() as conn:
|
||||||
if "name" in data:
|
repo = AgentFoldersRepository(conn)
|
||||||
update_fields["name"] = data["name"]
|
folder = _resolve_folder_id(repo, folder_id, user)
|
||||||
if "parent_id" in data:
|
if not folder:
|
||||||
if data["parent_id"] == folder_id:
|
return make_response(
|
||||||
return make_response(jsonify({"success": False, "message": "Cannot set folder as its own parent"}), 400)
|
jsonify({"success": False, "message": "Folder not found"}),
|
||||||
update_fields["parent_id"] = data["parent_id"]
|
404,
|
||||||
|
)
|
||||||
|
pg_folder_id = str(folder["id"])
|
||||||
|
|
||||||
|
update_fields: dict = {}
|
||||||
|
if "name" in data:
|
||||||
|
update_fields["name"] = data["name"]
|
||||||
|
if "description" in data:
|
||||||
|
update_fields["description"] = data["description"]
|
||||||
|
if "parent_id" in data:
|
||||||
|
parent_input = data.get("parent_id")
|
||||||
|
if parent_input:
|
||||||
|
if parent_input == folder_id or parent_input == pg_folder_id:
|
||||||
|
return make_response(
|
||||||
|
jsonify(
|
||||||
|
{
|
||||||
|
"success": False,
|
||||||
|
"message": "Cannot set folder as its own parent",
|
||||||
|
}
|
||||||
|
),
|
||||||
|
400,
|
||||||
|
)
|
||||||
|
parent = _resolve_folder_id(repo, parent_input, user)
|
||||||
|
if not parent:
|
||||||
|
return make_response(
|
||||||
|
jsonify({"success": False, "message": "Parent folder not found"}),
|
||||||
|
404,
|
||||||
|
)
|
||||||
|
if str(parent["id"]) == pg_folder_id:
|
||||||
|
return make_response(
|
||||||
|
jsonify(
|
||||||
|
{
|
||||||
|
"success": False,
|
||||||
|
"message": "Cannot set folder as its own parent",
|
||||||
|
}
|
||||||
|
),
|
||||||
|
400,
|
||||||
|
)
|
||||||
|
update_fields["parent_id"] = str(parent["id"])
|
||||||
|
else:
|
||||||
|
update_fields["parent_id"] = None
|
||||||
|
|
||||||
|
if update_fields:
|
||||||
|
repo.update(pg_folder_id, user, update_fields)
|
||||||
|
|
||||||
result = agent_folders_collection.update_one(
|
|
||||||
{"_id": ObjectId(folder_id), "user": user}, {"$set": update_fields}
|
|
||||||
)
|
|
||||||
if result.matched_count == 0:
|
|
||||||
return make_response(jsonify({"success": False, "message": "Folder not found"}), 404)
|
|
||||||
return make_response(jsonify({"success": True}), 200)
|
return make_response(jsonify({"success": True}), 200)
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
return _folder_error_response("Failed to update folder", err)
|
return _folder_error_response("Failed to update folder", err)
|
||||||
@@ -160,15 +248,24 @@ class AgentFolder(Resource):
|
|||||||
return make_response(jsonify({"success": False}), 401)
|
return make_response(jsonify({"success": False}), 401)
|
||||||
user = decoded_token.get("sub")
|
user = decoded_token.get("sub")
|
||||||
try:
|
try:
|
||||||
agents_collection.update_many(
|
with db_session() as conn:
|
||||||
{"user": user, "folder_id": folder_id}, {"$unset": {"folder_id": ""}}
|
repo = AgentFoldersRepository(conn)
|
||||||
)
|
folder = _resolve_folder_id(repo, folder_id, user)
|
||||||
agent_folders_collection.update_many(
|
if not folder:
|
||||||
{"user": user, "parent_id": folder_id}, {"$unset": {"parent_id": ""}}
|
return make_response(
|
||||||
)
|
jsonify({"success": False, "message": "Folder not found"}),
|
||||||
result = agent_folders_collection.delete_one({"_id": ObjectId(folder_id), "user": user})
|
404,
|
||||||
if result.deleted_count == 0:
|
)
|
||||||
return make_response(jsonify({"success": False, "message": "Folder not found"}), 404)
|
pg_folder_id = str(folder["id"])
|
||||||
|
# Clear folder assignments from agents; self-FK
|
||||||
|
# ``ON DELETE SET NULL`` handles child folders.
|
||||||
|
AgentsRepository(conn).clear_folder_for_all(pg_folder_id, user)
|
||||||
|
deleted = repo.delete(pg_folder_id, user)
|
||||||
|
if not deleted:
|
||||||
|
return make_response(
|
||||||
|
jsonify({"success": False, "message": "Folder not found"}),
|
||||||
|
404,
|
||||||
|
)
|
||||||
return make_response(jsonify({"success": True}), 200)
|
return make_response(jsonify({"success": True}), 200)
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
return _folder_error_response("Failed to delete folder", err)
|
return _folder_error_response("Failed to delete folder", err)
|
||||||
@@ -195,26 +292,29 @@ class MoveAgentToFolder(Resource):
|
|||||||
if not data or not data.get("agent_id"):
|
if not data or not data.get("agent_id"):
|
||||||
return make_response(jsonify({"success": False, "message": "Agent ID is required"}), 400)
|
return make_response(jsonify({"success": False, "message": "Agent ID is required"}), 400)
|
||||||
|
|
||||||
agent_id = data["agent_id"]
|
agent_id_input = data["agent_id"]
|
||||||
folder_id = data.get("folder_id")
|
folder_id_input = data.get("folder_id")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
agent = agents_collection.find_one({"_id": ObjectId(agent_id), "user": user})
|
with db_session() as conn:
|
||||||
if not agent:
|
agents_repo = AgentsRepository(conn)
|
||||||
return make_response(jsonify({"success": False, "message": "Agent not found"}), 404)
|
agent = agents_repo.get_any(agent_id_input, user)
|
||||||
|
if not agent:
|
||||||
if folder_id:
|
return make_response(
|
||||||
folder = agent_folders_collection.find_one({"_id": ObjectId(folder_id), "user": user})
|
jsonify({"success": False, "message": "Agent not found"}),
|
||||||
if not folder:
|
404,
|
||||||
return make_response(jsonify({"success": False, "message": "Folder not found"}), 404)
|
)
|
||||||
agents_collection.update_one(
|
pg_folder_id = None
|
||||||
{"_id": ObjectId(agent_id)}, {"$set": {"folder_id": folder_id}}
|
if folder_id_input:
|
||||||
)
|
folders_repo = AgentFoldersRepository(conn)
|
||||||
else:
|
folder = _resolve_folder_id(folders_repo, folder_id_input, user)
|
||||||
agents_collection.update_one(
|
if not folder:
|
||||||
{"_id": ObjectId(agent_id)}, {"$unset": {"folder_id": ""}}
|
return make_response(
|
||||||
)
|
jsonify({"success": False, "message": "Folder not found"}),
|
||||||
|
404,
|
||||||
|
)
|
||||||
|
pg_folder_id = str(folder["id"])
|
||||||
|
agents_repo.set_folder(str(agent["id"]), user, pg_folder_id)
|
||||||
return make_response(jsonify({"success": True}), 200)
|
return make_response(jsonify({"success": True}), 200)
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
return _folder_error_response("Failed to move agent", err)
|
return _folder_error_response("Failed to move agent", err)
|
||||||
@@ -242,25 +342,25 @@ class BulkMoveAgents(Resource):
|
|||||||
return make_response(jsonify({"success": False, "message": "Agent IDs are required"}), 400)
|
return make_response(jsonify({"success": False, "message": "Agent IDs are required"}), 400)
|
||||||
|
|
||||||
agent_ids = data["agent_ids"]
|
agent_ids = data["agent_ids"]
|
||||||
folder_id = data.get("folder_id")
|
folder_id_input = data.get("folder_id")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if folder_id:
|
with db_session() as conn:
|
||||||
folder = agent_folders_collection.find_one({"_id": ObjectId(folder_id), "user": user})
|
agents_repo = AgentsRepository(conn)
|
||||||
if not folder:
|
pg_folder_id = None
|
||||||
return make_response(jsonify({"success": False, "message": "Folder not found"}), 404)
|
if folder_id_input:
|
||||||
|
folders_repo = AgentFoldersRepository(conn)
|
||||||
object_ids = [ObjectId(aid) for aid in agent_ids]
|
folder = _resolve_folder_id(folders_repo, folder_id_input, user)
|
||||||
if folder_id:
|
if not folder:
|
||||||
agents_collection.update_many(
|
return make_response(
|
||||||
{"_id": {"$in": object_ids}, "user": user},
|
jsonify({"success": False, "message": "Folder not found"}),
|
||||||
{"$set": {"folder_id": folder_id}},
|
404,
|
||||||
)
|
)
|
||||||
else:
|
pg_folder_id = str(folder["id"])
|
||||||
agents_collection.update_many(
|
for agent_id_input in agent_ids:
|
||||||
{"_id": {"$in": object_ids}, "user": user},
|
agent = agents_repo.get_any(agent_id_input, user)
|
||||||
{"$unset": {"folder_id": ""}},
|
if agent is not None:
|
||||||
)
|
agents_repo.set_folder(str(agent["id"]), user, pg_folder_id)
|
||||||
return make_response(jsonify({"success": True}), 200)
|
return make_response(jsonify({"success": True}), 200)
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
return _folder_error_response("Failed to move agents", err)
|
return _folder_error_response("Failed to move agents", err)
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -3,21 +3,17 @@
|
|||||||
import datetime
|
import datetime
|
||||||
import secrets
|
import secrets
|
||||||
|
|
||||||
from bson import DBRef
|
|
||||||
from bson.objectid import ObjectId
|
|
||||||
from flask import current_app, jsonify, make_response, request
|
from flask import current_app, jsonify, make_response, request
|
||||||
from flask_restx import fields, Namespace, Resource
|
from flask_restx import fields, Namespace, Resource
|
||||||
|
from sqlalchemy import text as _sql_text
|
||||||
|
|
||||||
from application.api import api
|
from application.api import api
|
||||||
from application.core.settings import settings
|
from application.core.settings import settings
|
||||||
from application.api.user.base import (
|
from application.api.user.base import resolve_tool_details
|
||||||
agents_collection,
|
from application.storage.db.base_repository import looks_like_uuid
|
||||||
db,
|
from application.storage.db.repositories.agents import AgentsRepository
|
||||||
ensure_user_doc,
|
from application.storage.db.repositories.users import UsersRepository
|
||||||
resolve_tool_details,
|
from application.storage.db.session import db_readonly, db_session
|
||||||
user_tools_collection,
|
|
||||||
users_collection,
|
|
||||||
)
|
|
||||||
from application.utils import generate_image_url
|
from application.utils import generate_image_url
|
||||||
|
|
||||||
agents_sharing_ns = Namespace(
|
agents_sharing_ns = Namespace(
|
||||||
@@ -25,6 +21,38 @@ agents_sharing_ns = Namespace(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _serialize_agent_basic(agent: dict) -> dict:
|
||||||
|
"""Shape a PG agent row into the API response dict."""
|
||||||
|
source_id = agent.get("source_id")
|
||||||
|
return {
|
||||||
|
"id": str(agent["id"]),
|
||||||
|
"user": agent.get("user_id", ""),
|
||||||
|
"name": agent.get("name", ""),
|
||||||
|
"image": (
|
||||||
|
generate_image_url(agent["image"]) if agent.get("image") else ""
|
||||||
|
),
|
||||||
|
"description": agent.get("description", ""),
|
||||||
|
"source": str(source_id) if source_id else "",
|
||||||
|
"chunks": str(agent["chunks"]) if agent.get("chunks") is not None else "0",
|
||||||
|
"retriever": agent.get("retriever", "classic") or "classic",
|
||||||
|
"prompt_id": str(agent["prompt_id"]) if agent.get("prompt_id") else "default",
|
||||||
|
"tools": agent.get("tools", []) or [],
|
||||||
|
"tool_details": resolve_tool_details(agent.get("tools", []) or []),
|
||||||
|
"agent_type": agent.get("agent_type", "") or "",
|
||||||
|
"status": agent.get("status", "") or "",
|
||||||
|
"json_schema": agent.get("json_schema"),
|
||||||
|
"limited_token_mode": agent.get("limited_token_mode", False),
|
||||||
|
"token_limit": agent.get("token_limit") or settings.DEFAULT_AGENT_LIMITS["token_limit"],
|
||||||
|
"limited_request_mode": agent.get("limited_request_mode", False),
|
||||||
|
"request_limit": agent.get("request_limit") or settings.DEFAULT_AGENT_LIMITS["request_limit"],
|
||||||
|
"created_at": agent.get("created_at", ""),
|
||||||
|
"updated_at": agent.get("updated_at", ""),
|
||||||
|
"shared": bool(agent.get("shared", False)),
|
||||||
|
"shared_token": agent.get("shared_token", "") or "",
|
||||||
|
"shared_metadata": agent.get("shared_metadata", {}) or {},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@agents_sharing_ns.route("/shared_agent")
|
@agents_sharing_ns.route("/shared_agent")
|
||||||
class SharedAgent(Resource):
|
class SharedAgent(Resource):
|
||||||
@api.doc(
|
@api.doc(
|
||||||
@@ -41,70 +69,33 @@ class SharedAgent(Resource):
|
|||||||
jsonify({"success": False, "message": "Token or ID is required"}), 400
|
jsonify({"success": False, "message": "Token or ID is required"}), 400
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
query = {
|
with db_readonly() as conn:
|
||||||
"shared_publicly": True,
|
shared_agent = AgentsRepository(conn).find_by_shared_token(
|
||||||
"shared_token": shared_token,
|
shared_token,
|
||||||
}
|
)
|
||||||
shared_agent = agents_collection.find_one(query)
|
|
||||||
if not shared_agent:
|
if not shared_agent:
|
||||||
return make_response(
|
return make_response(
|
||||||
jsonify({"success": False, "message": "Shared agent not found"}),
|
jsonify({"success": False, "message": "Shared agent not found"}),
|
||||||
404,
|
404,
|
||||||
)
|
)
|
||||||
agent_id = str(shared_agent["_id"])
|
agent_id = str(shared_agent["id"])
|
||||||
data = {
|
data = _serialize_agent_basic(shared_agent)
|
||||||
"id": agent_id,
|
|
||||||
"user": shared_agent.get("user", ""),
|
|
||||||
"name": shared_agent.get("name", ""),
|
|
||||||
"image": (
|
|
||||||
generate_image_url(shared_agent["image"])
|
|
||||||
if shared_agent.get("image")
|
|
||||||
else ""
|
|
||||||
),
|
|
||||||
"description": shared_agent.get("description", ""),
|
|
||||||
"source": (
|
|
||||||
str(source_doc["_id"])
|
|
||||||
if isinstance(shared_agent.get("source"), DBRef)
|
|
||||||
and (source_doc := db.dereference(shared_agent.get("source")))
|
|
||||||
else ""
|
|
||||||
),
|
|
||||||
"chunks": shared_agent.get("chunks", "0"),
|
|
||||||
"retriever": shared_agent.get("retriever", "classic"),
|
|
||||||
"prompt_id": shared_agent.get("prompt_id", "default"),
|
|
||||||
"tools": shared_agent.get("tools", []),
|
|
||||||
"tool_details": resolve_tool_details(shared_agent.get("tools", [])),
|
|
||||||
"agent_type": shared_agent.get("agent_type", ""),
|
|
||||||
"status": shared_agent.get("status", ""),
|
|
||||||
"json_schema": shared_agent.get("json_schema"),
|
|
||||||
"limited_token_mode": shared_agent.get("limited_token_mode", False),
|
|
||||||
"token_limit": shared_agent.get("token_limit", settings.DEFAULT_AGENT_LIMITS["token_limit"]),
|
|
||||||
"limited_request_mode": shared_agent.get("limited_request_mode", False),
|
|
||||||
"request_limit": shared_agent.get("request_limit", settings.DEFAULT_AGENT_LIMITS["request_limit"]),
|
|
||||||
"created_at": shared_agent.get("createdAt", ""),
|
|
||||||
"updated_at": shared_agent.get("updatedAt", ""),
|
|
||||||
"shared": shared_agent.get("shared_publicly", False),
|
|
||||||
"shared_token": shared_agent.get("shared_token", ""),
|
|
||||||
"shared_metadata": shared_agent.get("shared_metadata", {}),
|
|
||||||
}
|
|
||||||
|
|
||||||
if data["tools"]:
|
if data["tools"]:
|
||||||
enriched_tools = []
|
enriched_tools = []
|
||||||
for tool in data["tools"]:
|
for detail in data["tool_details"]:
|
||||||
tool_data = user_tools_collection.find_one({"_id": ObjectId(tool)})
|
enriched_tools.append(detail.get("name", ""))
|
||||||
if tool_data:
|
|
||||||
enriched_tools.append(tool_data.get("name", ""))
|
|
||||||
data["tools"] = enriched_tools
|
data["tools"] = enriched_tools
|
||||||
decoded_token = getattr(request, "decoded_token", None)
|
decoded_token = getattr(request, "decoded_token", None)
|
||||||
if decoded_token:
|
if decoded_token:
|
||||||
user_id = decoded_token.get("sub")
|
user_id = decoded_token.get("sub")
|
||||||
owner_id = shared_agent.get("user")
|
owner_id = shared_agent.get("user_id")
|
||||||
|
|
||||||
if user_id != owner_id:
|
if user_id != owner_id:
|
||||||
ensure_user_doc(user_id)
|
with db_session() as conn:
|
||||||
users_collection.update_one(
|
users_repo = UsersRepository(conn)
|
||||||
{"user_id": user_id},
|
users_repo.upsert(user_id)
|
||||||
{"$addToSet": {"agent_preferences.shared_with_me": agent_id}},
|
users_repo.add_shared(user_id, agent_id)
|
||||||
)
|
|
||||||
return make_response(jsonify(data), 200)
|
return make_response(jsonify(data), 200)
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
current_app.logger.error(f"Error retrieving shared agent: {err}")
|
current_app.logger.error(f"Error retrieving shared agent: {err}")
|
||||||
@@ -121,52 +112,73 @@ class SharedAgents(Resource):
|
|||||||
return make_response(jsonify({"success": False}), 401)
|
return make_response(jsonify({"success": False}), 401)
|
||||||
user_id = decoded_token.get("sub")
|
user_id = decoded_token.get("sub")
|
||||||
|
|
||||||
user_doc = ensure_user_doc(user_id)
|
with db_session() as conn:
|
||||||
shared_with_ids = user_doc.get("agent_preferences", {}).get(
|
users_repo = UsersRepository(conn)
|
||||||
"shared_with_me", []
|
user_doc = users_repo.upsert(user_id)
|
||||||
)
|
shared_with_ids = (
|
||||||
shared_object_ids = [ObjectId(id) for id in shared_with_ids]
|
user_doc.get("agent_preferences", {}).get("shared_with_me", [])
|
||||||
|
if isinstance(user_doc.get("agent_preferences"), dict)
|
||||||
shared_agents_cursor = agents_collection.find(
|
else []
|
||||||
{"_id": {"$in": shared_object_ids}, "shared_publicly": True}
|
|
||||||
)
|
|
||||||
shared_agents = list(shared_agents_cursor)
|
|
||||||
|
|
||||||
found_ids_set = {str(agent["_id"]) for agent in shared_agents}
|
|
||||||
stale_ids = [id for id in shared_with_ids if id not in found_ids_set]
|
|
||||||
if stale_ids:
|
|
||||||
users_collection.update_one(
|
|
||||||
{"user_id": user_id},
|
|
||||||
{"$pullAll": {"agent_preferences.shared_with_me": stale_ids}},
|
|
||||||
)
|
)
|
||||||
pinned_ids = set(user_doc.get("agent_preferences", {}).get("pinned", []))
|
# Keep only UUID-shaped ids; ObjectId leftovers are stripped below.
|
||||||
|
uuid_ids = [sid for sid in shared_with_ids if looks_like_uuid(sid)]
|
||||||
|
non_uuid_ids = [sid for sid in shared_with_ids if not looks_like_uuid(sid)]
|
||||||
|
|
||||||
list_shared_agents = [
|
if uuid_ids:
|
||||||
{
|
result = conn.execute(
|
||||||
"id": str(agent["_id"]),
|
_sql_text(
|
||||||
"name": agent.get("name", ""),
|
"SELECT * FROM agents "
|
||||||
"description": agent.get("description", ""),
|
"WHERE id = ANY(CAST(:ids AS uuid[])) "
|
||||||
"image": (
|
"AND shared = true"
|
||||||
generate_image_url(agent["image"]) if agent.get("image") else ""
|
),
|
||||||
),
|
{"ids": uuid_ids},
|
||||||
"tools": agent.get("tools", []),
|
)
|
||||||
"tool_details": resolve_tool_details(agent.get("tools", [])),
|
shared_agents = [dict(row._mapping) for row in result.fetchall()]
|
||||||
"agent_type": agent.get("agent_type", ""),
|
else:
|
||||||
"status": agent.get("status", ""),
|
shared_agents = []
|
||||||
"json_schema": agent.get("json_schema"),
|
|
||||||
"limited_token_mode": agent.get("limited_token_mode", False),
|
found_ids_set = {str(agent["id"]) for agent in shared_agents}
|
||||||
"token_limit": agent.get("token_limit", settings.DEFAULT_AGENT_LIMITS["token_limit"]),
|
stale_ids = [sid for sid in uuid_ids if sid not in found_ids_set]
|
||||||
"limited_request_mode": agent.get("limited_request_mode", False),
|
stale_ids.extend(non_uuid_ids)
|
||||||
"request_limit": agent.get("request_limit", settings.DEFAULT_AGENT_LIMITS["request_limit"]),
|
if stale_ids:
|
||||||
"created_at": agent.get("createdAt", ""),
|
users_repo.remove_shared_bulk(user_id, stale_ids)
|
||||||
"updated_at": agent.get("updatedAt", ""),
|
|
||||||
"pinned": str(agent["_id"]) in pinned_ids,
|
pinned_ids = set(
|
||||||
"shared": agent.get("shared_publicly", False),
|
user_doc.get("agent_preferences", {}).get("pinned", [])
|
||||||
"shared_token": agent.get("shared_token", ""),
|
if isinstance(user_doc.get("agent_preferences"), dict)
|
||||||
"shared_metadata": agent.get("shared_metadata", {}),
|
else []
|
||||||
}
|
)
|
||||||
for agent in shared_agents
|
|
||||||
]
|
list_shared_agents = []
|
||||||
|
for agent in shared_agents:
|
||||||
|
agent_id_str = str(agent["id"])
|
||||||
|
list_shared_agents.append(
|
||||||
|
{
|
||||||
|
"id": agent_id_str,
|
||||||
|
"name": agent.get("name", ""),
|
||||||
|
"description": agent.get("description", ""),
|
||||||
|
"image": (
|
||||||
|
generate_image_url(agent["image"]) if agent.get("image") else ""
|
||||||
|
),
|
||||||
|
"tools": agent.get("tools", []) or [],
|
||||||
|
"tool_details": resolve_tool_details(
|
||||||
|
agent.get("tools", []) or []
|
||||||
|
),
|
||||||
|
"agent_type": agent.get("agent_type", "") or "",
|
||||||
|
"status": agent.get("status", "") or "",
|
||||||
|
"json_schema": agent.get("json_schema"),
|
||||||
|
"limited_token_mode": agent.get("limited_token_mode", False),
|
||||||
|
"token_limit": agent.get("token_limit") or settings.DEFAULT_AGENT_LIMITS["token_limit"],
|
||||||
|
"limited_request_mode": agent.get("limited_request_mode", False),
|
||||||
|
"request_limit": agent.get("request_limit") or settings.DEFAULT_AGENT_LIMITS["request_limit"],
|
||||||
|
"created_at": agent.get("created_at", ""),
|
||||||
|
"updated_at": agent.get("updated_at", ""),
|
||||||
|
"pinned": agent_id_str in pinned_ids,
|
||||||
|
"shared": bool(agent.get("shared", False)),
|
||||||
|
"shared_token": agent.get("shared_token", "") or "",
|
||||||
|
"shared_metadata": agent.get("shared_metadata", {}) or {},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
return make_response(jsonify(list_shared_agents), 200)
|
return make_response(jsonify(list_shared_agents), 200)
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
@@ -220,44 +232,43 @@ class ShareAgent(Resource):
|
|||||||
),
|
),
|
||||||
400,
|
400,
|
||||||
)
|
)
|
||||||
|
shared_token = None
|
||||||
try:
|
try:
|
||||||
try:
|
with db_session() as conn:
|
||||||
agent_oid = ObjectId(agent_id)
|
repo = AgentsRepository(conn)
|
||||||
except Exception:
|
agent = repo.get_any(agent_id, user)
|
||||||
return make_response(
|
if not agent:
|
||||||
jsonify({"success": False, "message": "Invalid agent ID"}), 400
|
return make_response(
|
||||||
)
|
jsonify({"success": False, "message": "Agent not found"}), 404
|
||||||
agent = agents_collection.find_one({"_id": agent_oid, "user": user})
|
)
|
||||||
if not agent:
|
if shared:
|
||||||
return make_response(
|
shared_metadata = {
|
||||||
jsonify({"success": False, "message": "Agent not found"}), 404
|
"shared_by": username,
|
||||||
)
|
"shared_at": datetime.datetime.now(
|
||||||
if shared:
|
datetime.timezone.utc
|
||||||
shared_metadata = {
|
).isoformat(),
|
||||||
"shared_by": username,
|
}
|
||||||
"shared_at": datetime.datetime.now(datetime.timezone.utc),
|
shared_token = secrets.token_urlsafe(32)
|
||||||
}
|
repo.update(
|
||||||
shared_token = secrets.token_urlsafe(32)
|
str(agent["id"]), user,
|
||||||
agents_collection.update_one(
|
{
|
||||||
{"_id": agent_oid, "user": user},
|
"shared": True,
|
||||||
{
|
|
||||||
"$set": {
|
|
||||||
"shared_publicly": shared,
|
|
||||||
"shared_metadata": shared_metadata,
|
|
||||||
"shared_token": shared_token,
|
"shared_token": shared_token,
|
||||||
}
|
"shared_metadata": shared_metadata,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
agents_collection.update_one(
|
repo.update(
|
||||||
{"_id": agent_oid, "user": user},
|
str(agent["id"]), user,
|
||||||
{"$set": {"shared_publicly": shared, "shared_token": None}},
|
{
|
||||||
{"$unset": {"shared_metadata": ""}},
|
"shared": False,
|
||||||
)
|
"shared_token": None,
|
||||||
|
"shared_metadata": None,
|
||||||
|
},
|
||||||
|
)
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
current_app.logger.error(f"Error sharing/unsharing agent: {err}", exc_info=True)
|
current_app.logger.error(f"Error sharing/unsharing agent: {err}", exc_info=True)
|
||||||
return make_response(jsonify({"success": False, "error": "Failed to update agent sharing status"}), 400)
|
return make_response(jsonify({"success": False, "error": "Failed to update agent sharing status"}), 400)
|
||||||
shared_token = shared_token if shared else None
|
|
||||||
return make_response(
|
return make_response(
|
||||||
jsonify({"success": True, "shared_token": shared_token}), 200
|
jsonify({"success": True, "shared_token": shared_token}), 200
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -2,14 +2,15 @@
|
|||||||
|
|
||||||
import secrets
|
import secrets
|
||||||
|
|
||||||
from bson.objectid import ObjectId
|
|
||||||
from flask import current_app, jsonify, make_response, request
|
from flask import current_app, jsonify, make_response, request
|
||||||
from flask_restx import Namespace, Resource
|
from flask_restx import Namespace, Resource
|
||||||
|
|
||||||
from application.api import api
|
from application.api import api
|
||||||
from application.api.user.base import agents_collection, require_agent
|
from application.api.user.base import require_agent
|
||||||
from application.api.user.tasks import process_agent_webhook
|
from application.api.user.tasks import process_agent_webhook
|
||||||
from application.core.settings import settings
|
from application.core.settings import settings
|
||||||
|
from application.storage.db.repositories.agents import AgentsRepository
|
||||||
|
from application.storage.db.session import db_readonly, db_session
|
||||||
|
|
||||||
|
|
||||||
agents_webhooks_ns = Namespace(
|
agents_webhooks_ns = Namespace(
|
||||||
@@ -34,9 +35,8 @@ class AgentWebhook(Resource):
|
|||||||
jsonify({"success": False, "message": "ID is required"}), 400
|
jsonify({"success": False, "message": "ID is required"}), 400
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
agent = agents_collection.find_one(
|
with db_readonly() as conn:
|
||||||
{"_id": ObjectId(agent_id), "user": user}
|
agent = AgentsRepository(conn).get_any(agent_id, user)
|
||||||
)
|
|
||||||
if not agent:
|
if not agent:
|
||||||
return make_response(
|
return make_response(
|
||||||
jsonify({"success": False, "message": "Agent not found"}), 404
|
jsonify({"success": False, "message": "Agent not found"}), 404
|
||||||
@@ -44,10 +44,11 @@ class AgentWebhook(Resource):
|
|||||||
webhook_token = agent.get("incoming_webhook_token")
|
webhook_token = agent.get("incoming_webhook_token")
|
||||||
if not webhook_token:
|
if not webhook_token:
|
||||||
webhook_token = secrets.token_urlsafe(32)
|
webhook_token = secrets.token_urlsafe(32)
|
||||||
agents_collection.update_one(
|
with db_session() as conn:
|
||||||
{"_id": ObjectId(agent_id), "user": user},
|
AgentsRepository(conn).update(
|
||||||
{"$set": {"incoming_webhook_token": webhook_token}},
|
str(agent["id"]), user,
|
||||||
)
|
{"incoming_webhook_token": webhook_token},
|
||||||
|
)
|
||||||
base_url = settings.API_URL.rstrip("/")
|
base_url = settings.API_URL.rstrip("/")
|
||||||
full_webhook_url = f"{base_url}/api/webhooks/agents/{webhook_token}"
|
full_webhook_url = f"{base_url}/api/webhooks/agents/{webhook_token}"
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
|
|||||||
@@ -2,26 +2,84 @@
|
|||||||
|
|
||||||
import datetime
|
import datetime
|
||||||
|
|
||||||
from bson.objectid import ObjectId
|
|
||||||
from flask import current_app, jsonify, make_response, request
|
from flask import current_app, jsonify, make_response, request
|
||||||
from flask_restx import fields, Namespace, Resource
|
from flask_restx import fields, Namespace, Resource
|
||||||
|
from sqlalchemy import text as _sql_text
|
||||||
|
|
||||||
from application.api import api
|
from application.api import api
|
||||||
from application.api.user.base import (
|
from application.api.user.base import (
|
||||||
agents_collection,
|
|
||||||
conversations_collection,
|
|
||||||
generate_date_range,
|
generate_date_range,
|
||||||
generate_hourly_range,
|
generate_hourly_range,
|
||||||
generate_minute_range,
|
generate_minute_range,
|
||||||
token_usage_collection,
|
|
||||||
user_logs_collection,
|
|
||||||
)
|
)
|
||||||
|
from application.storage.db.repositories.agents import AgentsRepository
|
||||||
|
from application.storage.db.repositories.token_usage import TokenUsageRepository
|
||||||
|
from application.storage.db.repositories.user_logs import UserLogsRepository
|
||||||
|
from application.storage.db.session import db_readonly
|
||||||
|
|
||||||
|
|
||||||
analytics_ns = Namespace(
|
analytics_ns = Namespace(
|
||||||
"analytics", description="Analytics and reporting operations", path="/api"
|
"analytics", description="Analytics and reporting operations", path="/api"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_FILTER_BUCKETS = {
|
||||||
|
"last_hour": ("minute", "%Y-%m-%d %H:%M:00", "YYYY-MM-DD HH24:MI:00"),
|
||||||
|
"last_24_hour": ("hour", "%Y-%m-%d %H:00", "YYYY-MM-DD HH24:00"),
|
||||||
|
"last_7_days": ("day", "%Y-%m-%d", "YYYY-MM-DD"),
|
||||||
|
"last_15_days": ("day", "%Y-%m-%d", "YYYY-MM-DD"),
|
||||||
|
"last_30_days": ("day", "%Y-%m-%d", "YYYY-MM-DD"),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _range_for_filter(filter_option: str):
|
||||||
|
"""Return ``(start_date, end_date, bucket_unit, pg_fmt)`` for the filter.
|
||||||
|
|
||||||
|
Returns ``None`` on invalid filter.
|
||||||
|
"""
|
||||||
|
if filter_option not in _FILTER_BUCKETS:
|
||||||
|
return None
|
||||||
|
end_date = datetime.datetime.now(datetime.timezone.utc)
|
||||||
|
bucket_unit, _py_fmt, pg_fmt = _FILTER_BUCKETS[filter_option]
|
||||||
|
|
||||||
|
if filter_option == "last_hour":
|
||||||
|
start_date = end_date - datetime.timedelta(hours=1)
|
||||||
|
elif filter_option == "last_24_hour":
|
||||||
|
start_date = end_date - datetime.timedelta(hours=24)
|
||||||
|
else:
|
||||||
|
days = {
|
||||||
|
"last_7_days": 6,
|
||||||
|
"last_15_days": 14,
|
||||||
|
"last_30_days": 29,
|
||||||
|
}[filter_option]
|
||||||
|
start_date = end_date - datetime.timedelta(days=days)
|
||||||
|
start_date = start_date.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||||
|
end_date = end_date.replace(
|
||||||
|
hour=23, minute=59, second=59, microsecond=999999
|
||||||
|
)
|
||||||
|
return start_date, end_date, bucket_unit, pg_fmt
|
||||||
|
|
||||||
|
|
||||||
|
def _intervals_for_filter(filter_option, start_date, end_date):
|
||||||
|
if filter_option == "last_hour":
|
||||||
|
return generate_minute_range(start_date, end_date)
|
||||||
|
if filter_option == "last_24_hour":
|
||||||
|
return generate_hourly_range(start_date, end_date)
|
||||||
|
return generate_date_range(start_date, end_date)
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_api_key(conn, api_key_id, user_id):
|
||||||
|
"""Look up the ``agents.key`` value for a given agent id.
|
||||||
|
|
||||||
|
Scoped by ``user_id`` so an authenticated caller can't probe another
|
||||||
|
user's agents. Accepts either UUID or legacy Mongo ObjectId shape.
|
||||||
|
"""
|
||||||
|
if not api_key_id:
|
||||||
|
return None
|
||||||
|
agent = AgentsRepository(conn).get_any(api_key_id, user_id)
|
||||||
|
return (agent or {}).get("key") if agent else None
|
||||||
|
|
||||||
|
|
||||||
@analytics_ns.route("/get_message_analytics")
|
@analytics_ns.route("/get_message_analytics")
|
||||||
class GetMessageAnalytics(Resource):
|
class GetMessageAnalytics(Resource):
|
||||||
get_message_analytics_model = api.model(
|
get_message_analytics_model = api.model(
|
||||||
@@ -32,13 +90,7 @@ class GetMessageAnalytics(Resource):
|
|||||||
required=False,
|
required=False,
|
||||||
description="Filter option for analytics",
|
description="Filter option for analytics",
|
||||||
default="last_30_days",
|
default="last_30_days",
|
||||||
enum=[
|
enum=list(_FILTER_BUCKETS.keys()),
|
||||||
"last_hour",
|
|
||||||
"last_24_hour",
|
|
||||||
"last_7_days",
|
|
||||||
"last_15_days",
|
|
||||||
"last_30_days",
|
|
||||||
],
|
|
||||||
),
|
),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
@@ -50,88 +102,54 @@ class GetMessageAnalytics(Resource):
|
|||||||
if not decoded_token:
|
if not decoded_token:
|
||||||
return make_response(jsonify({"success": False}), 401)
|
return make_response(jsonify({"success": False}), 401)
|
||||||
user = decoded_token.get("sub")
|
user = decoded_token.get("sub")
|
||||||
data = request.get_json()
|
data = request.get_json() or {}
|
||||||
api_key_id = data.get("api_key_id")
|
api_key_id = data.get("api_key_id")
|
||||||
filter_option = data.get("filter_option", "last_30_days")
|
filter_option = data.get("filter_option", "last_30_days")
|
||||||
|
|
||||||
|
window = _range_for_filter(filter_option)
|
||||||
|
if window is None:
|
||||||
|
return make_response(
|
||||||
|
jsonify({"success": False, "message": "Invalid option"}), 400
|
||||||
|
)
|
||||||
|
start_date, end_date, _bucket_unit, pg_fmt = window
|
||||||
|
|
||||||
try:
|
try:
|
||||||
api_key = (
|
with db_readonly() as conn:
|
||||||
agents_collection.find_one({"_id": ObjectId(api_key_id), "user": user})[
|
api_key = _resolve_api_key(conn, api_key_id, user)
|
||||||
"key"
|
|
||||||
|
# Count messages per bucket, filtered by the conversation's
|
||||||
|
# owner (user_id) and optionally the agent api_key. The
|
||||||
|
# ``user_id`` filter is always applied post-cutover to
|
||||||
|
# prevent cross-tenant leakage on admin dashboards.
|
||||||
|
clauses = [
|
||||||
|
"c.user_id = :user_id",
|
||||||
|
"m.timestamp >= :start",
|
||||||
|
"m.timestamp <= :end",
|
||||||
]
|
]
|
||||||
if api_key_id
|
params: dict = {
|
||||||
else None
|
"user_id": user,
|
||||||
)
|
"start": start_date,
|
||||||
except Exception as err:
|
"end": end_date,
|
||||||
current_app.logger.error(f"Error getting API key: {err}", exc_info=True)
|
"fmt": pg_fmt,
|
||||||
return make_response(jsonify({"success": False}), 400)
|
|
||||||
end_date = datetime.datetime.now(datetime.timezone.utc)
|
|
||||||
|
|
||||||
if filter_option == "last_hour":
|
|
||||||
start_date = end_date - datetime.timedelta(hours=1)
|
|
||||||
group_format = "%Y-%m-%d %H:%M:00"
|
|
||||||
elif filter_option == "last_24_hour":
|
|
||||||
start_date = end_date - datetime.timedelta(hours=24)
|
|
||||||
group_format = "%Y-%m-%d %H:00"
|
|
||||||
else:
|
|
||||||
if filter_option in ["last_7_days", "last_15_days", "last_30_days"]:
|
|
||||||
filter_days = (
|
|
||||||
6
|
|
||||||
if filter_option == "last_7_days"
|
|
||||||
else 14 if filter_option == "last_15_days" else 29
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return make_response(
|
|
||||||
jsonify({"success": False, "message": "Invalid option"}), 400
|
|
||||||
)
|
|
||||||
start_date = end_date - datetime.timedelta(days=filter_days)
|
|
||||||
start_date = start_date.replace(hour=0, minute=0, second=0, microsecond=0)
|
|
||||||
end_date = end_date.replace(
|
|
||||||
hour=23, minute=59, second=59, microsecond=999999
|
|
||||||
)
|
|
||||||
group_format = "%Y-%m-%d"
|
|
||||||
try:
|
|
||||||
match_stage = {
|
|
||||||
"$match": {
|
|
||||||
"user": user,
|
|
||||||
}
|
}
|
||||||
}
|
if api_key:
|
||||||
if api_key:
|
clauses.append("c.api_key = :api_key")
|
||||||
match_stage["$match"]["api_key"] = api_key
|
params["api_key"] = api_key
|
||||||
pipeline = [
|
where = " AND ".join(clauses)
|
||||||
match_stage,
|
sql = (
|
||||||
{"$unwind": "$queries"},
|
"SELECT to_char(m.timestamp AT TIME ZONE 'UTC', :fmt) AS bucket, "
|
||||||
{
|
"COUNT(*) AS count "
|
||||||
"$match": {
|
"FROM conversation_messages m "
|
||||||
"queries.timestamp": {"$gte": start_date, "$lte": end_date}
|
"JOIN conversations c ON c.id = m.conversation_id "
|
||||||
}
|
f"WHERE {where} "
|
||||||
},
|
"GROUP BY bucket ORDER BY bucket ASC"
|
||||||
{
|
)
|
||||||
"$group": {
|
rows = conn.execute(_sql_text(sql), params).fetchall()
|
||||||
"_id": {
|
|
||||||
"$dateToString": {
|
|
||||||
"format": group_format,
|
|
||||||
"date": "$queries.timestamp",
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"count": {"$sum": 1},
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{"$sort": {"_id": 1}},
|
|
||||||
]
|
|
||||||
|
|
||||||
message_data = conversations_collection.aggregate(pipeline)
|
intervals = _intervals_for_filter(filter_option, start_date, end_date)
|
||||||
|
|
||||||
if filter_option == "last_hour":
|
|
||||||
intervals = generate_minute_range(start_date, end_date)
|
|
||||||
elif filter_option == "last_24_hour":
|
|
||||||
intervals = generate_hourly_range(start_date, end_date)
|
|
||||||
else:
|
|
||||||
intervals = generate_date_range(start_date, end_date)
|
|
||||||
daily_messages = {interval: 0 for interval in intervals}
|
daily_messages = {interval: 0 for interval in intervals}
|
||||||
|
for row in rows:
|
||||||
for entry in message_data:
|
daily_messages[row._mapping["bucket"]] = int(row._mapping["count"])
|
||||||
daily_messages[entry["_id"]] = entry["count"]
|
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
current_app.logger.error(
|
current_app.logger.error(
|
||||||
f"Error getting message analytics: {err}", exc_info=True
|
f"Error getting message analytics: {err}", exc_info=True
|
||||||
@@ -152,13 +170,7 @@ class GetTokenAnalytics(Resource):
|
|||||||
required=False,
|
required=False,
|
||||||
description="Filter option for analytics",
|
description="Filter option for analytics",
|
||||||
default="last_30_days",
|
default="last_30_days",
|
||||||
enum=[
|
enum=list(_FILTER_BUCKETS.keys()),
|
||||||
"last_hour",
|
|
||||||
"last_24_hour",
|
|
||||||
"last_7_days",
|
|
||||||
"last_15_days",
|
|
||||||
"last_30_days",
|
|
||||||
],
|
|
||||||
),
|
),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
@@ -170,123 +182,36 @@ class GetTokenAnalytics(Resource):
|
|||||||
if not decoded_token:
|
if not decoded_token:
|
||||||
return make_response(jsonify({"success": False}), 401)
|
return make_response(jsonify({"success": False}), 401)
|
||||||
user = decoded_token.get("sub")
|
user = decoded_token.get("sub")
|
||||||
data = request.get_json()
|
data = request.get_json() or {}
|
||||||
api_key_id = data.get("api_key_id")
|
api_key_id = data.get("api_key_id")
|
||||||
filter_option = data.get("filter_option", "last_30_days")
|
filter_option = data.get("filter_option", "last_30_days")
|
||||||
|
|
||||||
try:
|
window = _range_for_filter(filter_option)
|
||||||
api_key = (
|
if window is None:
|
||||||
agents_collection.find_one({"_id": ObjectId(api_key_id), "user": user})[
|
return make_response(
|
||||||
"key"
|
jsonify({"success": False, "message": "Invalid option"}), 400
|
||||||
]
|
|
||||||
if api_key_id
|
|
||||||
else None
|
|
||||||
)
|
)
|
||||||
except Exception as err:
|
start_date, end_date, bucket_unit, _pg_fmt = window
|
||||||
current_app.logger.error(f"Error getting API key: {err}", exc_info=True)
|
|
||||||
return make_response(jsonify({"success": False}), 400)
|
|
||||||
end_date = datetime.datetime.now(datetime.timezone.utc)
|
|
||||||
|
|
||||||
if filter_option == "last_hour":
|
|
||||||
start_date = end_date - datetime.timedelta(hours=1)
|
|
||||||
group_format = "%Y-%m-%d %H:%M:00"
|
|
||||||
group_stage = {
|
|
||||||
"$group": {
|
|
||||||
"_id": {
|
|
||||||
"minute": {
|
|
||||||
"$dateToString": {
|
|
||||||
"format": group_format,
|
|
||||||
"date": "$timestamp",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"total_tokens": {
|
|
||||||
"$sum": {"$add": ["$prompt_tokens", "$generated_tokens"]}
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
elif filter_option == "last_24_hour":
|
|
||||||
start_date = end_date - datetime.timedelta(hours=24)
|
|
||||||
group_format = "%Y-%m-%d %H:00"
|
|
||||||
group_stage = {
|
|
||||||
"$group": {
|
|
||||||
"_id": {
|
|
||||||
"hour": {
|
|
||||||
"$dateToString": {
|
|
||||||
"format": group_format,
|
|
||||||
"date": "$timestamp",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"total_tokens": {
|
|
||||||
"$sum": {"$add": ["$prompt_tokens", "$generated_tokens"]}
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
if filter_option in ["last_7_days", "last_15_days", "last_30_days"]:
|
|
||||||
filter_days = (
|
|
||||||
6
|
|
||||||
if filter_option == "last_7_days"
|
|
||||||
else (14 if filter_option == "last_15_days" else 29)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return make_response(
|
|
||||||
jsonify({"success": False, "message": "Invalid option"}), 400
|
|
||||||
)
|
|
||||||
start_date = end_date - datetime.timedelta(days=filter_days)
|
|
||||||
start_date = start_date.replace(hour=0, minute=0, second=0, microsecond=0)
|
|
||||||
end_date = end_date.replace(
|
|
||||||
hour=23, minute=59, second=59, microsecond=999999
|
|
||||||
)
|
|
||||||
group_format = "%Y-%m-%d"
|
|
||||||
group_stage = {
|
|
||||||
"$group": {
|
|
||||||
"_id": {
|
|
||||||
"day": {
|
|
||||||
"$dateToString": {
|
|
||||||
"format": group_format,
|
|
||||||
"date": "$timestamp",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"total_tokens": {
|
|
||||||
"$sum": {"$add": ["$prompt_tokens", "$generated_tokens"]}
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
try:
|
try:
|
||||||
match_stage = {
|
with db_readonly() as conn:
|
||||||
"$match": {
|
api_key = _resolve_api_key(conn, api_key_id, user)
|
||||||
"user_id": user,
|
# ``bucketed_totals`` applies user_id / api_key filters
|
||||||
"timestamp": {"$gte": start_date, "$lte": end_date},
|
# directly — no need to reshape a Mongo pipeline.
|
||||||
}
|
rows = TokenUsageRepository(conn).bucketed_totals(
|
||||||
}
|
bucket_unit=bucket_unit,
|
||||||
if api_key:
|
user_id=user,
|
||||||
match_stage["$match"]["api_key"] = api_key
|
api_key=api_key,
|
||||||
token_usage_data = token_usage_collection.aggregate(
|
timestamp_gte=start_date,
|
||||||
[
|
timestamp_lt=end_date,
|
||||||
match_stage,
|
)
|
||||||
group_stage,
|
|
||||||
{"$sort": {"_id": 1}},
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
if filter_option == "last_hour":
|
intervals = _intervals_for_filter(filter_option, start_date, end_date)
|
||||||
intervals = generate_minute_range(start_date, end_date)
|
|
||||||
elif filter_option == "last_24_hour":
|
|
||||||
intervals = generate_hourly_range(start_date, end_date)
|
|
||||||
else:
|
|
||||||
intervals = generate_date_range(start_date, end_date)
|
|
||||||
daily_token_usage = {interval: 0 for interval in intervals}
|
daily_token_usage = {interval: 0 for interval in intervals}
|
||||||
|
for entry in rows:
|
||||||
for entry in token_usage_data:
|
daily_token_usage[entry["bucket"]] = int(
|
||||||
if filter_option == "last_hour":
|
entry["prompt_tokens"] + entry["generated_tokens"]
|
||||||
daily_token_usage[entry["_id"]["minute"]] = entry["total_tokens"]
|
)
|
||||||
elif filter_option == "last_24_hour":
|
|
||||||
daily_token_usage[entry["_id"]["hour"]] = entry["total_tokens"]
|
|
||||||
else:
|
|
||||||
daily_token_usage[entry["_id"]["day"]] = entry["total_tokens"]
|
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
current_app.logger.error(
|
current_app.logger.error(
|
||||||
f"Error getting token analytics: {err}", exc_info=True
|
f"Error getting token analytics: {err}", exc_info=True
|
||||||
@@ -307,13 +232,7 @@ class GetFeedbackAnalytics(Resource):
|
|||||||
required=False,
|
required=False,
|
||||||
description="Filter option for analytics",
|
description="Filter option for analytics",
|
||||||
default="last_30_days",
|
default="last_30_days",
|
||||||
enum=[
|
enum=list(_FILTER_BUCKETS.keys()),
|
||||||
"last_hour",
|
|
||||||
"last_24_hour",
|
|
||||||
"last_7_days",
|
|
||||||
"last_15_days",
|
|
||||||
"last_30_days",
|
|
||||||
],
|
|
||||||
),
|
),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
@@ -325,128 +244,64 @@ class GetFeedbackAnalytics(Resource):
|
|||||||
if not decoded_token:
|
if not decoded_token:
|
||||||
return make_response(jsonify({"success": False}), 401)
|
return make_response(jsonify({"success": False}), 401)
|
||||||
user = decoded_token.get("sub")
|
user = decoded_token.get("sub")
|
||||||
data = request.get_json()
|
data = request.get_json() or {}
|
||||||
api_key_id = data.get("api_key_id")
|
api_key_id = data.get("api_key_id")
|
||||||
filter_option = data.get("filter_option", "last_30_days")
|
filter_option = data.get("filter_option", "last_30_days")
|
||||||
|
|
||||||
|
window = _range_for_filter(filter_option)
|
||||||
|
if window is None:
|
||||||
|
return make_response(
|
||||||
|
jsonify({"success": False, "message": "Invalid option"}), 400
|
||||||
|
)
|
||||||
|
start_date, end_date, _bucket_unit, pg_fmt = window
|
||||||
|
|
||||||
try:
|
try:
|
||||||
api_key = (
|
with db_readonly() as conn:
|
||||||
agents_collection.find_one({"_id": ObjectId(api_key_id), "user": user})[
|
api_key = _resolve_api_key(conn, api_key_id, user)
|
||||||
"key"
|
|
||||||
|
# Feedback lives inside the ``conversation_messages.feedback``
|
||||||
|
# JSONB as ``{"text": "like"|"dislike", "timestamp": "..."}``.
|
||||||
|
# There is no scalar ``feedback_timestamp`` column — extract
|
||||||
|
# the timestamp from the JSONB and cast it to timestamptz for
|
||||||
|
# the range filter + bucket grouping.
|
||||||
|
clauses = [
|
||||||
|
"c.user_id = :user_id",
|
||||||
|
"m.feedback IS NOT NULL",
|
||||||
|
"(m.feedback->>'timestamp')::timestamptz >= :start",
|
||||||
|
"(m.feedback->>'timestamp')::timestamptz <= :end",
|
||||||
]
|
]
|
||||||
if api_key_id
|
params: dict = {
|
||||||
else None
|
"user_id": user,
|
||||||
)
|
"start": start_date,
|
||||||
except Exception as err:
|
"end": end_date,
|
||||||
current_app.logger.error(f"Error getting API key: {err}", exc_info=True)
|
"fmt": pg_fmt,
|
||||||
return make_response(jsonify({"success": False}), 400)
|
|
||||||
end_date = datetime.datetime.now(datetime.timezone.utc)
|
|
||||||
|
|
||||||
if filter_option == "last_hour":
|
|
||||||
start_date = end_date - datetime.timedelta(hours=1)
|
|
||||||
group_format = "%Y-%m-%d %H:%M:00"
|
|
||||||
date_field = {
|
|
||||||
"$dateToString": {
|
|
||||||
"format": group_format,
|
|
||||||
"date": "$queries.feedback_timestamp",
|
|
||||||
}
|
}
|
||||||
}
|
if api_key:
|
||||||
elif filter_option == "last_24_hour":
|
clauses.append("c.api_key = :api_key")
|
||||||
start_date = end_date - datetime.timedelta(hours=24)
|
params["api_key"] = api_key
|
||||||
group_format = "%Y-%m-%d %H:00"
|
where = " AND ".join(clauses)
|
||||||
date_field = {
|
sql = (
|
||||||
"$dateToString": {
|
"SELECT to_char("
|
||||||
"format": group_format,
|
"(m.feedback->>'timestamp')::timestamptz AT TIME ZONE 'UTC', :fmt"
|
||||||
"date": "$queries.feedback_timestamp",
|
") AS bucket, "
|
||||||
}
|
"SUM(CASE WHEN m.feedback->>'text' = 'like' THEN 1 ELSE 0 END) AS positive, "
|
||||||
}
|
"SUM(CASE WHEN m.feedback->>'text' = 'dislike' THEN 1 ELSE 0 END) AS negative "
|
||||||
else:
|
"FROM conversation_messages m "
|
||||||
if filter_option in ["last_7_days", "last_15_days", "last_30_days"]:
|
"JOIN conversations c ON c.id = m.conversation_id "
|
||||||
filter_days = (
|
f"WHERE {where} "
|
||||||
6
|
"GROUP BY bucket ORDER BY bucket ASC"
|
||||||
if filter_option == "last_7_days"
|
|
||||||
else (14 if filter_option == "last_15_days" else 29)
|
|
||||||
)
|
)
|
||||||
else:
|
rows = conn.execute(_sql_text(sql), params).fetchall()
|
||||||
return make_response(
|
|
||||||
jsonify({"success": False, "message": "Invalid option"}), 400
|
|
||||||
)
|
|
||||||
start_date = end_date - datetime.timedelta(days=filter_days)
|
|
||||||
start_date = start_date.replace(hour=0, minute=0, second=0, microsecond=0)
|
|
||||||
end_date = end_date.replace(
|
|
||||||
hour=23, minute=59, second=59, microsecond=999999
|
|
||||||
)
|
|
||||||
group_format = "%Y-%m-%d"
|
|
||||||
date_field = {
|
|
||||||
"$dateToString": {
|
|
||||||
"format": group_format,
|
|
||||||
"date": "$queries.feedback_timestamp",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
try:
|
|
||||||
match_stage = {
|
|
||||||
"$match": {
|
|
||||||
"queries.feedback_timestamp": {
|
|
||||||
"$gte": start_date,
|
|
||||||
"$lte": end_date,
|
|
||||||
},
|
|
||||||
"queries.feedback": {"$exists": True},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if api_key:
|
|
||||||
match_stage["$match"]["api_key"] = api_key
|
|
||||||
pipeline = [
|
|
||||||
match_stage,
|
|
||||||
{"$unwind": "$queries"},
|
|
||||||
{"$match": {"queries.feedback": {"$exists": True}}},
|
|
||||||
{
|
|
||||||
"$group": {
|
|
||||||
"_id": {"time": date_field, "feedback": "$queries.feedback"},
|
|
||||||
"count": {"$sum": 1},
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"$group": {
|
|
||||||
"_id": "$_id.time",
|
|
||||||
"positive": {
|
|
||||||
"$sum": {
|
|
||||||
"$cond": [
|
|
||||||
{"$eq": ["$_id.feedback", "LIKE"]},
|
|
||||||
"$count",
|
|
||||||
0,
|
|
||||||
]
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"negative": {
|
|
||||||
"$sum": {
|
|
||||||
"$cond": [
|
|
||||||
{"$eq": ["$_id.feedback", "DISLIKE"]},
|
|
||||||
"$count",
|
|
||||||
0,
|
|
||||||
]
|
|
||||||
}
|
|
||||||
},
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{"$sort": {"_id": 1}},
|
|
||||||
]
|
|
||||||
|
|
||||||
feedback_data = conversations_collection.aggregate(pipeline)
|
intervals = _intervals_for_filter(filter_option, start_date, end_date)
|
||||||
|
|
||||||
if filter_option == "last_hour":
|
|
||||||
intervals = generate_minute_range(start_date, end_date)
|
|
||||||
elif filter_option == "last_24_hour":
|
|
||||||
intervals = generate_hourly_range(start_date, end_date)
|
|
||||||
else:
|
|
||||||
intervals = generate_date_range(start_date, end_date)
|
|
||||||
daily_feedback = {
|
daily_feedback = {
|
||||||
interval: {"positive": 0, "negative": 0} for interval in intervals
|
interval: {"positive": 0, "negative": 0} for interval in intervals
|
||||||
}
|
}
|
||||||
|
for row in rows:
|
||||||
for entry in feedback_data:
|
bucket = row._mapping["bucket"]
|
||||||
daily_feedback[entry["_id"]] = {
|
daily_feedback[bucket] = {
|
||||||
"positive": entry["positive"],
|
"positive": int(row._mapping["positive"] or 0),
|
||||||
"negative": entry["negative"],
|
"negative": int(row._mapping["negative"] or 0),
|
||||||
}
|
}
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
current_app.logger.error(
|
current_app.logger.error(
|
||||||
@@ -484,47 +339,89 @@ class GetUserLogs(Resource):
|
|||||||
if not decoded_token:
|
if not decoded_token:
|
||||||
return make_response(jsonify({"success": False}), 401)
|
return make_response(jsonify({"success": False}), 401)
|
||||||
user = decoded_token.get("sub")
|
user = decoded_token.get("sub")
|
||||||
data = request.get_json()
|
data = request.get_json() or {}
|
||||||
page = int(data.get("page", 1))
|
page = int(data.get("page", 1))
|
||||||
api_key_id = data.get("api_key_id")
|
api_key_id = data.get("api_key_id")
|
||||||
page_size = int(data.get("page_size", 10))
|
page_size = int(data.get("page_size", 10))
|
||||||
skip = (page - 1) * page_size
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
api_key = (
|
with db_readonly() as conn:
|
||||||
agents_collection.find_one({"_id": ObjectId(api_key_id)})["key"]
|
api_key = _resolve_api_key(conn, api_key_id, user)
|
||||||
if api_key_id
|
logs_repo = UserLogsRepository(conn)
|
||||||
else None
|
if api_key:
|
||||||
)
|
# ``find_by_api_key`` filters on ``data->>'api_key'``
|
||||||
|
# — the PG shape of the legacy top-level ``api_key``
|
||||||
|
# filter. Paginate client-side using offset/limit.
|
||||||
|
all_rows = logs_repo.find_by_api_key(api_key)
|
||||||
|
offset = (page - 1) * page_size
|
||||||
|
window = all_rows[offset: offset + page_size + 1]
|
||||||
|
items = window
|
||||||
|
else:
|
||||||
|
items, has_more_flag = logs_repo.list_paginated(
|
||||||
|
user_id=user,
|
||||||
|
page=page,
|
||||||
|
page_size=page_size,
|
||||||
|
)
|
||||||
|
# list_paginated already trims to page_size and
|
||||||
|
# returns has_more separately.
|
||||||
|
results = [
|
||||||
|
{
|
||||||
|
"id": str(item.get("id") or item.get("_id")),
|
||||||
|
"action": (item.get("data") or {}).get("action"),
|
||||||
|
"level": (item.get("data") or {}).get("level"),
|
||||||
|
"user": item.get("user_id"),
|
||||||
|
"question": (item.get("data") or {}).get("question"),
|
||||||
|
"sources": (item.get("data") or {}).get("sources"),
|
||||||
|
"retriever_params": (item.get("data") or {}).get(
|
||||||
|
"retriever_params"
|
||||||
|
),
|
||||||
|
"timestamp": (
|
||||||
|
item["timestamp"].isoformat()
|
||||||
|
if hasattr(item.get("timestamp"), "isoformat")
|
||||||
|
else item.get("timestamp")
|
||||||
|
),
|
||||||
|
}
|
||||||
|
for item in items
|
||||||
|
]
|
||||||
|
return make_response(
|
||||||
|
jsonify(
|
||||||
|
{
|
||||||
|
"success": True,
|
||||||
|
"logs": results,
|
||||||
|
"page": page,
|
||||||
|
"page_size": page_size,
|
||||||
|
"has_more": has_more_flag,
|
||||||
|
}
|
||||||
|
),
|
||||||
|
200,
|
||||||
|
)
|
||||||
|
|
||||||
|
has_more = len(items) > page_size
|
||||||
|
items = items[:page_size]
|
||||||
|
results = [
|
||||||
|
{
|
||||||
|
"id": str(item.get("id") or item.get("_id")),
|
||||||
|
"action": (item.get("data") or {}).get("action"),
|
||||||
|
"level": (item.get("data") or {}).get("level"),
|
||||||
|
"user": item.get("user_id"),
|
||||||
|
"question": (item.get("data") or {}).get("question"),
|
||||||
|
"sources": (item.get("data") or {}).get("sources"),
|
||||||
|
"retriever_params": (item.get("data") or {}).get(
|
||||||
|
"retriever_params"
|
||||||
|
),
|
||||||
|
"timestamp": (
|
||||||
|
item["timestamp"].isoformat()
|
||||||
|
if hasattr(item.get("timestamp"), "isoformat")
|
||||||
|
else item.get("timestamp")
|
||||||
|
),
|
||||||
|
}
|
||||||
|
for item in items
|
||||||
|
]
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
current_app.logger.error(f"Error getting API key: {err}", exc_info=True)
|
current_app.logger.error(
|
||||||
|
f"Error getting user logs: {err}", exc_info=True
|
||||||
|
)
|
||||||
return make_response(jsonify({"success": False}), 400)
|
return make_response(jsonify({"success": False}), 400)
|
||||||
query = {"user": user}
|
|
||||||
if api_key:
|
|
||||||
query = {"api_key": api_key}
|
|
||||||
items_cursor = (
|
|
||||||
user_logs_collection.find(query)
|
|
||||||
.sort("timestamp", -1)
|
|
||||||
.skip(skip)
|
|
||||||
.limit(page_size + 1)
|
|
||||||
)
|
|
||||||
items = list(items_cursor)
|
|
||||||
|
|
||||||
results = [
|
|
||||||
{
|
|
||||||
"id": str(item.get("_id")),
|
|
||||||
"action": item.get("action"),
|
|
||||||
"level": item.get("level"),
|
|
||||||
"user": item.get("user"),
|
|
||||||
"question": item.get("question"),
|
|
||||||
"sources": item.get("sources"),
|
|
||||||
"retriever_params": item.get("retriever_params"),
|
|
||||||
"timestamp": item.get("timestamp"),
|
|
||||||
}
|
|
||||||
for item in items[:page_size]
|
|
||||||
]
|
|
||||||
|
|
||||||
has_more = len(items) > page_size
|
|
||||||
|
|
||||||
return make_response(
|
return make_response(
|
||||||
jsonify(
|
jsonify(
|
||||||
|
|||||||
@@ -4,13 +4,16 @@ import os
|
|||||||
import tempfile
|
import tempfile
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from bson.objectid import ObjectId
|
import uuid
|
||||||
|
|
||||||
from flask import current_app, jsonify, make_response, request
|
from flask import current_app, jsonify, make_response, request
|
||||||
from flask_restx import fields, Namespace, Resource
|
from flask_restx import fields, Namespace, Resource
|
||||||
|
|
||||||
from application.api import api
|
from application.api import api
|
||||||
from application.cache import get_redis_instance
|
from application.cache import get_redis_instance
|
||||||
from application.core.settings import settings
|
from application.core.settings import settings
|
||||||
|
from application.storage.db.repositories.agents import AgentsRepository
|
||||||
|
from application.storage.db.session import db_readonly
|
||||||
from application.stt.constants import (
|
from application.stt.constants import (
|
||||||
SUPPORTED_AUDIO_EXTENSIONS,
|
SUPPORTED_AUDIO_EXTENSIONS,
|
||||||
SUPPORTED_AUDIO_MIME_TYPES,
|
SUPPORTED_AUDIO_MIME_TYPES,
|
||||||
@@ -48,14 +51,13 @@ def _resolve_authenticated_user():
|
|||||||
return safe_filename(decoded_token.get("sub"))
|
return safe_filename(decoded_token.get("sub"))
|
||||||
|
|
||||||
if api_key:
|
if api_key:
|
||||||
from application.api.user.base import agents_collection
|
with db_readonly() as conn:
|
||||||
|
agent = AgentsRepository(conn).find_by_key(api_key)
|
||||||
agent = agents_collection.find_one({"key": api_key})
|
|
||||||
if not agent:
|
if not agent:
|
||||||
return make_response(
|
return make_response(
|
||||||
jsonify({"success": False, "message": "Invalid API key"}), 401
|
jsonify({"success": False, "message": "Invalid API key"}), 401
|
||||||
)
|
)
|
||||||
return safe_filename(agent.get("user"))
|
return safe_filename(agent.get("user_id"))
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@@ -157,7 +159,7 @@ class StoreAttachment(Resource):
|
|||||||
|
|
||||||
for idx, file in enumerate(files):
|
for idx, file in enumerate(files):
|
||||||
try:
|
try:
|
||||||
attachment_id = ObjectId()
|
attachment_id = uuid.uuid4()
|
||||||
original_filename = safe_filename(os.path.basename(file.filename))
|
original_filename = safe_filename(os.path.basename(file.filename))
|
||||||
_enforce_uploaded_audio_size_limit(file, original_filename)
|
_enforce_uploaded_audio_size_limit(file, original_filename)
|
||||||
relative_path = f"{settings.UPLOAD_FOLDER}/{user}/attachments/{str(attachment_id)}/{original_filename}"
|
relative_path = f"{settings.UPLOAD_FOLDER}/{user}/attachments/{str(attachment_id)}/{original_filename}"
|
||||||
|
|||||||
@@ -8,13 +8,15 @@ import uuid
|
|||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
from bson.objectid import ObjectId
|
|
||||||
from flask import current_app, jsonify, make_response, Response
|
from flask import current_app, jsonify, make_response, Response
|
||||||
from pymongo import ReturnDocument
|
|
||||||
from werkzeug.utils import secure_filename
|
from werkzeug.utils import secure_filename
|
||||||
|
|
||||||
from application.core.mongo_db import MongoDB
|
from sqlalchemy import text as _sql_text
|
||||||
|
|
||||||
from application.core.settings import settings
|
from application.core.settings import settings
|
||||||
|
from application.storage.db.base_repository import looks_like_uuid, row_to_dict
|
||||||
|
from application.storage.db.repositories.users import UsersRepository
|
||||||
|
from application.storage.db.session import db_readonly, db_session
|
||||||
from application.storage.storage_creator import StorageCreator
|
from application.storage.storage_creator import StorageCreator
|
||||||
from application.vectorstore.vector_creator import VectorCreator
|
from application.vectorstore.vector_creator import VectorCreator
|
||||||
|
|
||||||
@@ -22,56 +24,6 @@ from application.vectorstore.vector_creator import VectorCreator
|
|||||||
storage = StorageCreator.get_storage()
|
storage = StorageCreator.get_storage()
|
||||||
|
|
||||||
|
|
||||||
mongo = MongoDB.get_client()
|
|
||||||
db = mongo[settings.MONGO_DB_NAME]
|
|
||||||
|
|
||||||
|
|
||||||
conversations_collection = db["conversations"]
|
|
||||||
sources_collection = db["sources"]
|
|
||||||
prompts_collection = db["prompts"]
|
|
||||||
feedback_collection = db["feedback"]
|
|
||||||
agents_collection = db["agents"]
|
|
||||||
agent_folders_collection = db["agent_folders"]
|
|
||||||
token_usage_collection = db["token_usage"]
|
|
||||||
shared_conversations_collections = db["shared_conversations"]
|
|
||||||
users_collection = db["users"]
|
|
||||||
user_logs_collection = db["user_logs"]
|
|
||||||
user_tools_collection = db["user_tools"]
|
|
||||||
attachments_collection = db["attachments"]
|
|
||||||
workflow_runs_collection = db["workflow_runs"]
|
|
||||||
workflows_collection = db["workflows"]
|
|
||||||
workflow_nodes_collection = db["workflow_nodes"]
|
|
||||||
workflow_edges_collection = db["workflow_edges"]
|
|
||||||
|
|
||||||
|
|
||||||
try:
|
|
||||||
agents_collection.create_index(
|
|
||||||
[("shared", 1)],
|
|
||||||
name="shared_index",
|
|
||||||
background=True,
|
|
||||||
)
|
|
||||||
users_collection.create_index("user_id", unique=True)
|
|
||||||
workflows_collection.create_index(
|
|
||||||
[("user", 1)], name="workflow_user_index", background=True
|
|
||||||
)
|
|
||||||
workflow_nodes_collection.create_index(
|
|
||||||
[("workflow_id", 1)], name="node_workflow_index", background=True
|
|
||||||
)
|
|
||||||
workflow_nodes_collection.create_index(
|
|
||||||
[("workflow_id", 1), ("graph_version", 1)],
|
|
||||||
name="node_workflow_graph_version_index",
|
|
||||||
background=True,
|
|
||||||
)
|
|
||||||
workflow_edges_collection.create_index(
|
|
||||||
[("workflow_id", 1)], name="edge_workflow_index", background=True
|
|
||||||
)
|
|
||||||
workflow_edges_collection.create_index(
|
|
||||||
[("workflow_id", 1), ("graph_version", 1)],
|
|
||||||
name="edge_workflow_graph_version_index",
|
|
||||||
background=True,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
print("Error creating indexes:", e)
|
|
||||||
current_dir = os.path.dirname(
|
current_dir = os.path.dirname(
|
||||||
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||||
)
|
)
|
||||||
@@ -103,66 +55,95 @@ def generate_date_range(start_date, end_date):
|
|||||||
|
|
||||||
def ensure_user_doc(user_id):
|
def ensure_user_doc(user_id):
|
||||||
"""
|
"""
|
||||||
Ensure user document exists with proper agent preferences structure.
|
Ensure a Postgres ``users`` row exists for ``user_id``.
|
||||||
|
|
||||||
|
Returns the row as a dict with the shape legacy callers expect — in
|
||||||
|
particular ``user_id`` and ``agent_preferences`` (with ``pinned`` and
|
||||||
|
``shared_with_me`` list keys always present).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id: The user ID to ensure
|
user_id: The user ID to ensure
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The user document
|
The user document as a dict.
|
||||||
"""
|
"""
|
||||||
default_prefs = {
|
with db_session() as conn:
|
||||||
"pinned": [],
|
user_doc = UsersRepository(conn).upsert(user_id)
|
||||||
"shared_with_me": [],
|
|
||||||
}
|
|
||||||
|
|
||||||
user_doc = users_collection.find_one_and_update(
|
prefs = user_doc.get("agent_preferences") or {}
|
||||||
{"user_id": user_id},
|
if not isinstance(prefs, dict):
|
||||||
{"$setOnInsert": {"agent_preferences": default_prefs}},
|
prefs = {}
|
||||||
upsert=True,
|
prefs.setdefault("pinned", [])
|
||||||
return_document=ReturnDocument.AFTER,
|
prefs.setdefault("shared_with_me", [])
|
||||||
)
|
user_doc["agent_preferences"] = prefs
|
||||||
|
|
||||||
prefs = user_doc.get("agent_preferences", {})
|
|
||||||
updates = {}
|
|
||||||
if "pinned" not in prefs:
|
|
||||||
updates["agent_preferences.pinned"] = []
|
|
||||||
if "shared_with_me" not in prefs:
|
|
||||||
updates["agent_preferences.shared_with_me"] = []
|
|
||||||
if updates:
|
|
||||||
users_collection.update_one({"user_id": user_id}, {"$set": updates})
|
|
||||||
user_doc = users_collection.find_one({"user_id": user_id})
|
|
||||||
return user_doc
|
return user_doc
|
||||||
|
|
||||||
|
|
||||||
def resolve_tool_details(tool_ids):
|
def resolve_tool_details(tool_ids):
|
||||||
"""
|
"""
|
||||||
Resolve tool IDs to their details.
|
Resolve tool IDs to their display details.
|
||||||
|
|
||||||
|
Accepts either Postgres UUIDs or legacy Mongo ObjectId strings (mixed
|
||||||
|
lists are supported — each id is looked up via ``get_any``, which
|
||||||
|
resolves to whichever column matches). Unknown ids are silently
|
||||||
|
skipped.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
tool_ids: List of tool IDs
|
tool_ids: List of tool IDs (UUIDs or legacy Mongo ObjectId strings).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of tool details with id, name, and display_name
|
List of tool details with ``id``, ``name``, and ``display_name``.
|
||||||
"""
|
"""
|
||||||
valid_ids = []
|
if not tool_ids:
|
||||||
|
return []
|
||||||
|
|
||||||
|
uuid_ids: list[str] = []
|
||||||
|
legacy_ids: list[str] = []
|
||||||
for tid in tool_ids:
|
for tid in tool_ids:
|
||||||
try:
|
if not tid:
|
||||||
valid_ids.append(ObjectId(tid))
|
|
||||||
except Exception:
|
|
||||||
continue
|
continue
|
||||||
tools = user_tools_collection.find(
|
tid_str = str(tid)
|
||||||
{"_id": {"$in": valid_ids}}
|
if looks_like_uuid(tid_str):
|
||||||
) if valid_ids else []
|
uuid_ids.append(tid_str)
|
||||||
|
else:
|
||||||
|
legacy_ids.append(tid_str)
|
||||||
|
|
||||||
|
if not uuid_ids and not legacy_ids:
|
||||||
|
return []
|
||||||
|
|
||||||
|
rows: list[dict] = []
|
||||||
|
with db_readonly() as conn:
|
||||||
|
if uuid_ids:
|
||||||
|
result = conn.execute(
|
||||||
|
_sql_text(
|
||||||
|
"SELECT * FROM user_tools "
|
||||||
|
"WHERE id = ANY(CAST(:ids AS uuid[]))"
|
||||||
|
),
|
||||||
|
{"ids": uuid_ids},
|
||||||
|
)
|
||||||
|
rows.extend(row_to_dict(r) for r in result.fetchall())
|
||||||
|
if legacy_ids:
|
||||||
|
result = conn.execute(
|
||||||
|
_sql_text(
|
||||||
|
"SELECT * FROM user_tools "
|
||||||
|
"WHERE legacy_mongo_id = ANY(:ids)"
|
||||||
|
),
|
||||||
|
{"ids": legacy_ids},
|
||||||
|
)
|
||||||
|
rows.extend(row_to_dict(r) for r in result.fetchall())
|
||||||
|
|
||||||
return [
|
return [
|
||||||
{
|
{
|
||||||
"id": str(tool["_id"]),
|
"id": str(tool.get("id") or tool.get("legacy_mongo_id") or ""),
|
||||||
"name": tool.get("name", ""),
|
"name": tool.get("name", "") or "",
|
||||||
"display_name": tool.get("customName")
|
"display_name": (
|
||||||
or tool.get("displayName")
|
tool.get("custom_name")
|
||||||
or tool.get("name", ""),
|
or tool.get("display_name")
|
||||||
|
or tool.get("name", "")
|
||||||
|
or ""
|
||||||
|
),
|
||||||
}
|
}
|
||||||
for tool in tools
|
for tool in rows
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@@ -232,14 +213,15 @@ def require_agent(func):
|
|||||||
|
|
||||||
@wraps(func)
|
@wraps(func)
|
||||||
def wrapper(*args, **kwargs):
|
def wrapper(*args, **kwargs):
|
||||||
|
from application.storage.db.repositories.agents import AgentsRepository
|
||||||
|
|
||||||
webhook_token = kwargs.get("webhook_token")
|
webhook_token = kwargs.get("webhook_token")
|
||||||
if not webhook_token:
|
if not webhook_token:
|
||||||
return make_response(
|
return make_response(
|
||||||
jsonify({"success": False, "message": "Webhook token missing"}), 400
|
jsonify({"success": False, "message": "Webhook token missing"}), 400
|
||||||
)
|
)
|
||||||
agent = agents_collection.find_one(
|
with db_readonly() as conn:
|
||||||
{"incoming_webhook_token": webhook_token}, {"_id": 1}
|
agent = AgentsRepository(conn).find_by_webhook_token(webhook_token)
|
||||||
)
|
|
||||||
if not agent:
|
if not agent:
|
||||||
current_app.logger.warning(
|
current_app.logger.warning(
|
||||||
f"Webhook attempt with invalid token: {webhook_token}"
|
f"Webhook attempt with invalid token: {webhook_token}"
|
||||||
@@ -248,7 +230,7 @@ def require_agent(func):
|
|||||||
jsonify({"success": False, "message": "Agent not found"}), 404
|
jsonify({"success": False, "message": "Agent not found"}), 404
|
||||||
)
|
)
|
||||||
kwargs["agent"] = agent
|
kwargs["agent"] = agent
|
||||||
kwargs["agent_id_str"] = str(agent["_id"])
|
kwargs["agent_id_str"] = str(agent["id"])
|
||||||
return func(*args, **kwargs)
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|||||||
@@ -2,12 +2,13 @@
|
|||||||
|
|
||||||
import datetime
|
import datetime
|
||||||
|
|
||||||
from bson.objectid import ObjectId
|
|
||||||
from flask import current_app, jsonify, make_response, request
|
from flask import current_app, jsonify, make_response, request
|
||||||
from flask_restx import fields, Namespace, Resource
|
from flask_restx import fields, Namespace, Resource
|
||||||
|
|
||||||
from application.api import api
|
from application.api import api
|
||||||
from application.api.user.base import attachments_collection, conversations_collection
|
from application.storage.db.repositories.attachments import AttachmentsRepository
|
||||||
|
from application.storage.db.repositories.conversations import ConversationsRepository
|
||||||
|
from application.storage.db.session import db_readonly, db_session
|
||||||
from application.utils import check_required_fields
|
from application.utils import check_required_fields
|
||||||
|
|
||||||
conversations_ns = Namespace(
|
conversations_ns = Namespace(
|
||||||
@@ -30,10 +31,13 @@ class DeleteConversation(Resource):
|
|||||||
return make_response(
|
return make_response(
|
||||||
jsonify({"success": False, "message": "ID is required"}), 400
|
jsonify({"success": False, "message": "ID is required"}), 400
|
||||||
)
|
)
|
||||||
|
user_id = decoded_token["sub"]
|
||||||
try:
|
try:
|
||||||
conversations_collection.delete_one(
|
with db_session() as conn:
|
||||||
{"_id": ObjectId(conversation_id), "user": decoded_token["sub"]}
|
repo = ConversationsRepository(conn)
|
||||||
)
|
conv = repo.get_any(conversation_id, user_id)
|
||||||
|
if conv is not None:
|
||||||
|
repo.delete(str(conv["id"]), user_id)
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
current_app.logger.error(
|
current_app.logger.error(
|
||||||
f"Error deleting conversation: {err}", exc_info=True
|
f"Error deleting conversation: {err}", exc_info=True
|
||||||
@@ -53,7 +57,8 @@ class DeleteAllConversations(Resource):
|
|||||||
return make_response(jsonify({"success": False}), 401)
|
return make_response(jsonify({"success": False}), 401)
|
||||||
user_id = decoded_token.get("sub")
|
user_id = decoded_token.get("sub")
|
||||||
try:
|
try:
|
||||||
conversations_collection.delete_many({"user": user_id})
|
with db_session() as conn:
|
||||||
|
ConversationsRepository(conn).delete_all_for_user(user_id)
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
current_app.logger.error(
|
current_app.logger.error(
|
||||||
f"Error deleting all conversations: {err}", exc_info=True
|
f"Error deleting all conversations: {err}", exc_info=True
|
||||||
@@ -71,26 +76,21 @@ class GetConversations(Resource):
|
|||||||
decoded_token = request.decoded_token
|
decoded_token = request.decoded_token
|
||||||
if not decoded_token:
|
if not decoded_token:
|
||||||
return make_response(jsonify({"success": False}), 401)
|
return make_response(jsonify({"success": False}), 401)
|
||||||
|
user_id = decoded_token.get("sub")
|
||||||
try:
|
try:
|
||||||
conversations = (
|
with db_readonly() as conn:
|
||||||
conversations_collection.find(
|
conversations = ConversationsRepository(conn).list_for_user(
|
||||||
{
|
user_id, limit=30
|
||||||
"$or": [
|
|
||||||
{"api_key": {"$exists": False}},
|
|
||||||
{"agent_id": {"$exists": True}},
|
|
||||||
],
|
|
||||||
"user": decoded_token.get("sub"),
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
.sort("date", -1)
|
|
||||||
.limit(30)
|
|
||||||
)
|
|
||||||
|
|
||||||
list_conversations = [
|
list_conversations = [
|
||||||
{
|
{
|
||||||
"id": str(conversation["_id"]),
|
"id": str(conversation["id"]),
|
||||||
"name": conversation["name"],
|
"name": conversation["name"],
|
||||||
"agent_id": conversation.get("agent_id", None),
|
"agent_id": (
|
||||||
|
str(conversation["agent_id"])
|
||||||
|
if conversation.get("agent_id")
|
||||||
|
else None
|
||||||
|
),
|
||||||
"is_shared_usage": conversation.get("is_shared_usage", False),
|
"is_shared_usage": conversation.get("is_shared_usage", False),
|
||||||
"shared_token": conversation.get("shared_token", None),
|
"shared_token": conversation.get("shared_token", None),
|
||||||
}
|
}
|
||||||
@@ -119,38 +119,67 @@ class GetSingleConversation(Resource):
|
|||||||
return make_response(
|
return make_response(
|
||||||
jsonify({"success": False, "message": "ID is required"}), 400
|
jsonify({"success": False, "message": "ID is required"}), 400
|
||||||
)
|
)
|
||||||
|
user_id = decoded_token.get("sub")
|
||||||
try:
|
try:
|
||||||
conversation = conversations_collection.find_one(
|
with db_readonly() as conn:
|
||||||
{"_id": ObjectId(conversation_id), "user": decoded_token.get("sub")}
|
repo = ConversationsRepository(conn)
|
||||||
)
|
conversation = repo.get_any(conversation_id, user_id)
|
||||||
if not conversation:
|
if not conversation:
|
||||||
return make_response(jsonify({"status": "not found"}), 404)
|
return make_response(jsonify({"status": "not found"}), 404)
|
||||||
# Process queries to include attachment names
|
conv_pg_id = str(conversation["id"])
|
||||||
|
messages = repo.get_messages(conv_pg_id)
|
||||||
|
|
||||||
queries = conversation["queries"]
|
# Resolve attachment details (id, fileName) for each message.
|
||||||
for query in queries:
|
attachments_repo = AttachmentsRepository(conn)
|
||||||
if "attachments" in query and query["attachments"]:
|
queries = []
|
||||||
attachment_details = []
|
for msg in messages:
|
||||||
for attachment_id in query["attachments"]:
|
query = {
|
||||||
try:
|
"prompt": msg.get("prompt"),
|
||||||
attachment = attachments_collection.find_one(
|
"response": msg.get("response"),
|
||||||
{"_id": ObjectId(attachment_id)}
|
"thought": msg.get("thought"),
|
||||||
)
|
"sources": msg.get("sources") or [],
|
||||||
if attachment:
|
"tool_calls": msg.get("tool_calls") or [],
|
||||||
attachment_details.append(
|
"timestamp": msg.get("timestamp"),
|
||||||
{
|
"model_id": msg.get("model_id"),
|
||||||
"id": str(attachment["_id"]),
|
}
|
||||||
"fileName": attachment.get(
|
if msg.get("metadata"):
|
||||||
"filename", "Unknown file"
|
query["metadata"] = msg["metadata"]
|
||||||
),
|
# Feedback on conversation_messages is a JSONB blob with
|
||||||
}
|
# shape {"text": <str>, "timestamp": <iso>}. The legacy
|
||||||
|
# frontend consumed a flat scalar feedback string, so
|
||||||
|
# unwrap the ``text`` field for compat.
|
||||||
|
feedback = msg.get("feedback")
|
||||||
|
if feedback is not None:
|
||||||
|
if isinstance(feedback, dict):
|
||||||
|
query["feedback"] = feedback.get("text")
|
||||||
|
if feedback.get("timestamp"):
|
||||||
|
query["feedback_timestamp"] = feedback["timestamp"]
|
||||||
|
else:
|
||||||
|
query["feedback"] = feedback
|
||||||
|
attachments = msg.get("attachments") or []
|
||||||
|
if attachments:
|
||||||
|
attachment_details = []
|
||||||
|
for attachment_id in attachments:
|
||||||
|
try:
|
||||||
|
att = attachments_repo.get_any(
|
||||||
|
str(attachment_id), user_id
|
||||||
)
|
)
|
||||||
except Exception as e:
|
if att:
|
||||||
current_app.logger.error(
|
attachment_details.append(
|
||||||
f"Error retrieving attachment {attachment_id}: {e}",
|
{
|
||||||
exc_info=True,
|
"id": str(att["id"]),
|
||||||
)
|
"fileName": att.get(
|
||||||
query["attachments"] = attachment_details
|
"filename", "Unknown file"
|
||||||
|
),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
current_app.logger.error(
|
||||||
|
f"Error retrieving attachment {attachment_id}: {e}",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
query["attachments"] = attachment_details
|
||||||
|
queries.append(query)
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
current_app.logger.error(
|
current_app.logger.error(
|
||||||
f"Error retrieving conversation: {err}", exc_info=True
|
f"Error retrieving conversation: {err}", exc_info=True
|
||||||
@@ -158,7 +187,9 @@ class GetSingleConversation(Resource):
|
|||||||
return make_response(jsonify({"success": False}), 400)
|
return make_response(jsonify({"success": False}), 400)
|
||||||
data = {
|
data = {
|
||||||
"queries": queries,
|
"queries": queries,
|
||||||
"agent_id": conversation.get("agent_id"),
|
"agent_id": (
|
||||||
|
str(conversation["agent_id"]) if conversation.get("agent_id") else None
|
||||||
|
),
|
||||||
"is_shared_usage": conversation.get("is_shared_usage", False),
|
"is_shared_usage": conversation.get("is_shared_usage", False),
|
||||||
"shared_token": conversation.get("shared_token", None),
|
"shared_token": conversation.get("shared_token", None),
|
||||||
}
|
}
|
||||||
@@ -190,11 +221,13 @@ class UpdateConversationName(Resource):
|
|||||||
missing_fields = check_required_fields(data, required_fields)
|
missing_fields = check_required_fields(data, required_fields)
|
||||||
if missing_fields:
|
if missing_fields:
|
||||||
return missing_fields
|
return missing_fields
|
||||||
|
user_id = decoded_token.get("sub")
|
||||||
try:
|
try:
|
||||||
conversations_collection.update_one(
|
with db_session() as conn:
|
||||||
{"_id": ObjectId(data["id"]), "user": decoded_token.get("sub")},
|
repo = ConversationsRepository(conn)
|
||||||
{"$set": {"name": data["name"]}},
|
conv = repo.get_any(data["id"], user_id)
|
||||||
)
|
if conv is not None:
|
||||||
|
repo.rename(str(conv["id"]), user_id, data["name"])
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
current_app.logger.error(
|
current_app.logger.error(
|
||||||
f"Error updating conversation name: {err}", exc_info=True
|
f"Error updating conversation name: {err}", exc_info=True
|
||||||
@@ -237,43 +270,33 @@ class SubmitFeedback(Resource):
|
|||||||
missing_fields = check_required_fields(data, required_fields)
|
missing_fields = check_required_fields(data, required_fields)
|
||||||
if missing_fields:
|
if missing_fields:
|
||||||
return missing_fields
|
return missing_fields
|
||||||
|
user_id = decoded_token.get("sub")
|
||||||
|
feedback_value = data["feedback"]
|
||||||
|
question_index = int(data["question_index"])
|
||||||
|
# Normalize string feedback to lowercase so analytics queries
|
||||||
|
# (which match 'like'/'dislike') count rows correctly. Tolerate
|
||||||
|
# legacy uppercase clients on ingest. Non-string values pass through.
|
||||||
|
if isinstance(feedback_value, str):
|
||||||
|
feedback_value = feedback_value.lower()
|
||||||
|
feedback_payload = (
|
||||||
|
None
|
||||||
|
if feedback_value is None
|
||||||
|
else {
|
||||||
|
"text": feedback_value,
|
||||||
|
"timestamp": datetime.datetime.now(
|
||||||
|
datetime.timezone.utc
|
||||||
|
).isoformat(),
|
||||||
|
}
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
if data["feedback"] is None:
|
with db_session() as conn:
|
||||||
# Remove feedback and feedback_timestamp if feedback is null
|
repo = ConversationsRepository(conn)
|
||||||
|
conv = repo.get_any(data["conversation_id"], user_id)
|
||||||
conversations_collection.update_one(
|
if conv is None:
|
||||||
{
|
return make_response(
|
||||||
"_id": ObjectId(data["conversation_id"]),
|
jsonify({"success": False, "message": "Not found"}), 404
|
||||||
"user": decoded_token.get("sub"),
|
)
|
||||||
f"queries.{data['question_index']}": {"$exists": True},
|
repo.set_feedback(str(conv["id"]), question_index, feedback_payload)
|
||||||
},
|
|
||||||
{
|
|
||||||
"$unset": {
|
|
||||||
f"queries.{data['question_index']}.feedback": "",
|
|
||||||
f"queries.{data['question_index']}.feedback_timestamp": "",
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Set feedback and feedback_timestamp if feedback has a value
|
|
||||||
|
|
||||||
conversations_collection.update_one(
|
|
||||||
{
|
|
||||||
"_id": ObjectId(data["conversation_id"]),
|
|
||||||
"user": decoded_token.get("sub"),
|
|
||||||
f"queries.{data['question_index']}": {"$exists": True},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"$set": {
|
|
||||||
f"queries.{data['question_index']}.feedback": data[
|
|
||||||
"feedback"
|
|
||||||
],
|
|
||||||
f"queries.{data['question_index']}.feedback_timestamp": datetime.datetime.now(
|
|
||||||
datetime.timezone.utc
|
|
||||||
),
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
current_app.logger.error(f"Error submitting feedback: {err}", exc_info=True)
|
current_app.logger.error(f"Error submitting feedback: {err}", exc_info=True)
|
||||||
return make_response(jsonify({"success": False}), 400)
|
return make_response(jsonify({"success": False}), 400)
|
||||||
|
|||||||
@@ -2,12 +2,13 @@
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from bson.objectid import ObjectId
|
|
||||||
from flask import current_app, jsonify, make_response, request
|
from flask import current_app, jsonify, make_response, request
|
||||||
from flask_restx import fields, Namespace, Resource
|
from flask_restx import fields, Namespace, Resource
|
||||||
|
|
||||||
from application.api import api
|
from application.api import api
|
||||||
from application.api.user.base import current_dir, prompts_collection
|
from application.api.user.base import current_dir
|
||||||
|
from application.storage.db.repositories.prompts import PromptsRepository
|
||||||
|
from application.storage.db.session import db_readonly, db_session
|
||||||
from application.utils import check_required_fields
|
from application.utils import check_required_fields
|
||||||
|
|
||||||
prompts_ns = Namespace(
|
prompts_ns = Namespace(
|
||||||
@@ -40,15 +41,9 @@ class CreatePrompt(Resource):
|
|||||||
return missing_fields
|
return missing_fields
|
||||||
user = decoded_token.get("sub")
|
user = decoded_token.get("sub")
|
||||||
try:
|
try:
|
||||||
|
with db_session() as conn:
|
||||||
resp = prompts_collection.insert_one(
|
prompt = PromptsRepository(conn).create(user, data["name"], data["content"])
|
||||||
{
|
new_id = str(prompt["id"])
|
||||||
"name": data["name"],
|
|
||||||
"content": data["content"],
|
|
||||||
"user": user,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
new_id = str(resp.inserted_id)
|
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
current_app.logger.error(f"Error creating prompt: {err}", exc_info=True)
|
current_app.logger.error(f"Error creating prompt: {err}", exc_info=True)
|
||||||
return make_response(jsonify({"success": False}), 400)
|
return make_response(jsonify({"success": False}), 400)
|
||||||
@@ -64,17 +59,17 @@ class GetPrompts(Resource):
|
|||||||
return make_response(jsonify({"success": False}), 401)
|
return make_response(jsonify({"success": False}), 401)
|
||||||
user = decoded_token.get("sub")
|
user = decoded_token.get("sub")
|
||||||
try:
|
try:
|
||||||
prompts = prompts_collection.find({"user": user})
|
with db_readonly() as conn:
|
||||||
|
prompts = PromptsRepository(conn).list_for_user(user)
|
||||||
list_prompts = [
|
list_prompts = [
|
||||||
{"id": "default", "name": "default", "type": "public"},
|
{"id": "default", "name": "default", "type": "public"},
|
||||||
{"id": "creative", "name": "creative", "type": "public"},
|
{"id": "creative", "name": "creative", "type": "public"},
|
||||||
{"id": "strict", "name": "strict", "type": "public"},
|
{"id": "strict", "name": "strict", "type": "public"},
|
||||||
]
|
]
|
||||||
|
|
||||||
for prompt in prompts:
|
for prompt in prompts:
|
||||||
list_prompts.append(
|
list_prompts.append(
|
||||||
{
|
{
|
||||||
"id": str(prompt["_id"]),
|
"id": str(prompt["id"]),
|
||||||
"name": prompt["name"],
|
"name": prompt["name"],
|
||||||
"type": "private",
|
"type": "private",
|
||||||
}
|
}
|
||||||
@@ -119,9 +114,12 @@ class GetSinglePrompt(Resource):
|
|||||||
) as f:
|
) as f:
|
||||||
chat_reduce_strict = f.read()
|
chat_reduce_strict = f.read()
|
||||||
return make_response(jsonify({"content": chat_reduce_strict}), 200)
|
return make_response(jsonify({"content": chat_reduce_strict}), 200)
|
||||||
prompt = prompts_collection.find_one(
|
with db_readonly() as conn:
|
||||||
{"_id": ObjectId(prompt_id), "user": user}
|
prompt = PromptsRepository(conn).get_any(prompt_id, user)
|
||||||
)
|
if not prompt:
|
||||||
|
return make_response(
|
||||||
|
jsonify({"success": False, "message": "Prompt not found"}), 404
|
||||||
|
)
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
current_app.logger.error(f"Error retrieving prompt: {err}", exc_info=True)
|
current_app.logger.error(f"Error retrieving prompt: {err}", exc_info=True)
|
||||||
return make_response(jsonify({"success": False}), 400)
|
return make_response(jsonify({"success": False}), 400)
|
||||||
@@ -148,7 +146,15 @@ class DeletePrompt(Resource):
|
|||||||
if missing_fields:
|
if missing_fields:
|
||||||
return missing_fields
|
return missing_fields
|
||||||
try:
|
try:
|
||||||
prompts_collection.delete_one({"_id": ObjectId(data["id"]), "user": user})
|
with db_session() as conn:
|
||||||
|
repo = PromptsRepository(conn)
|
||||||
|
prompt = repo.get_any(data["id"], user)
|
||||||
|
if not prompt:
|
||||||
|
return make_response(
|
||||||
|
jsonify({"success": False, "message": "Prompt not found"}),
|
||||||
|
404,
|
||||||
|
)
|
||||||
|
repo.delete(str(prompt["id"]), user)
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
current_app.logger.error(f"Error deleting prompt: {err}", exc_info=True)
|
current_app.logger.error(f"Error deleting prompt: {err}", exc_info=True)
|
||||||
return make_response(jsonify({"success": False}), 400)
|
return make_response(jsonify({"success": False}), 400)
|
||||||
@@ -181,10 +187,15 @@ class UpdatePrompt(Resource):
|
|||||||
if missing_fields:
|
if missing_fields:
|
||||||
return missing_fields
|
return missing_fields
|
||||||
try:
|
try:
|
||||||
prompts_collection.update_one(
|
with db_session() as conn:
|
||||||
{"_id": ObjectId(data["id"]), "user": user},
|
repo = PromptsRepository(conn)
|
||||||
{"$set": {"name": data["name"], "content": data["content"]}},
|
prompt = repo.get_any(data["id"], user)
|
||||||
)
|
if not prompt:
|
||||||
|
return make_response(
|
||||||
|
jsonify({"success": False, "message": "Prompt not found"}),
|
||||||
|
404,
|
||||||
|
)
|
||||||
|
repo.update(str(prompt["id"]), user, data["name"], data["content"])
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
current_app.logger.error(f"Error updating prompt: {err}", exc_info=True)
|
current_app.logger.error(f"Error updating prompt: {err}", exc_info=True)
|
||||||
return make_response(jsonify({"success": False}), 400)
|
return make_response(jsonify({"success": False}), 400)
|
||||||
|
|||||||
@@ -2,26 +2,126 @@
|
|||||||
|
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
from bson.binary import Binary, UuidRepresentation
|
|
||||||
from bson.dbref import DBRef
|
|
||||||
from bson.objectid import ObjectId
|
|
||||||
from flask import current_app, jsonify, make_response, request
|
from flask import current_app, jsonify, make_response, request
|
||||||
from flask_restx import fields, inputs, Namespace, Resource
|
from flask_restx import fields, inputs, Namespace, Resource
|
||||||
|
from sqlalchemy import text as _sql_text
|
||||||
|
|
||||||
from application.api import api
|
from application.api import api
|
||||||
from application.api.user.base import (
|
from application.storage.db.base_repository import looks_like_uuid
|
||||||
agents_collection,
|
from application.storage.db.repositories.agents import AgentsRepository
|
||||||
attachments_collection,
|
from application.storage.db.repositories.attachments import AttachmentsRepository
|
||||||
conversations_collection,
|
from application.storage.db.repositories.conversations import ConversationsRepository
|
||||||
shared_conversations_collections,
|
from application.storage.db.repositories.shared_conversations import (
|
||||||
|
SharedConversationsRepository,
|
||||||
)
|
)
|
||||||
|
from application.storage.db.session import db_readonly, db_session
|
||||||
from application.utils import check_required_fields
|
from application.utils import check_required_fields
|
||||||
|
|
||||||
|
|
||||||
sharing_ns = Namespace(
|
sharing_ns = Namespace(
|
||||||
"sharing", description="Conversation sharing operations", path="/api"
|
"sharing", description="Conversation sharing operations", path="/api"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_prompt_pg_id(conn, prompt_id_raw, user_id):
|
||||||
|
"""Translate an incoming prompt id (UUID or legacy Mongo ObjectId) to a PG UUID.
|
||||||
|
|
||||||
|
Scoped by ``user_id`` so a caller can't link another user's prompt
|
||||||
|
into their share record. Returns ``None`` for sentinel values
|
||||||
|
(``"default"``) or unresolved ids.
|
||||||
|
"""
|
||||||
|
if not prompt_id_raw or prompt_id_raw == "default":
|
||||||
|
return None
|
||||||
|
value = str(prompt_id_raw)
|
||||||
|
# Already UUID — trust it but still require ownership. A shape-gate
|
||||||
|
# (rather than a loose ``len == 36 and '-' in value`` check) keeps
|
||||||
|
# non-UUID input out of ``CAST(:pid AS uuid)``; the cast would raise
|
||||||
|
# and poison the readonly transaction otherwise.
|
||||||
|
if looks_like_uuid(value):
|
||||||
|
row = conn.execute(
|
||||||
|
_sql_text(
|
||||||
|
"SELECT id FROM prompts WHERE id = CAST(:pid AS uuid) "
|
||||||
|
"AND user_id = :uid"
|
||||||
|
),
|
||||||
|
{"pid": value, "uid": user_id},
|
||||||
|
).fetchone()
|
||||||
|
return str(row[0]) if row else None
|
||||||
|
# Legacy Mongo ObjectId fallback.
|
||||||
|
row = conn.execute(
|
||||||
|
_sql_text(
|
||||||
|
"SELECT id FROM prompts WHERE legacy_mongo_id = :pid "
|
||||||
|
"AND user_id = :uid"
|
||||||
|
),
|
||||||
|
{"pid": value, "uid": user_id},
|
||||||
|
).fetchone()
|
||||||
|
return str(row[0]) if row else None
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_source_pg_id(conn, source_raw):
|
||||||
|
"""Translate a source id (UUID or legacy Mongo ObjectId) to a PG UUID."""
|
||||||
|
if not source_raw:
|
||||||
|
return None
|
||||||
|
value = str(source_raw)
|
||||||
|
# See ``_resolve_prompt_pg_id`` for the shape-gate rationale.
|
||||||
|
if looks_like_uuid(value):
|
||||||
|
row = conn.execute(
|
||||||
|
_sql_text(
|
||||||
|
"SELECT id FROM sources WHERE id = CAST(:sid AS uuid)"
|
||||||
|
),
|
||||||
|
{"sid": value},
|
||||||
|
).fetchone()
|
||||||
|
return str(row[0]) if row else None
|
||||||
|
row = conn.execute(
|
||||||
|
_sql_text("SELECT id FROM sources WHERE legacy_mongo_id = :sid"),
|
||||||
|
{"sid": value},
|
||||||
|
).fetchone()
|
||||||
|
return str(row[0]) if row else None
|
||||||
|
|
||||||
|
|
||||||
|
def _find_reusable_share_agent(
|
||||||
|
conn, user_id, *, prompt_pg_id, chunks, source_pg_id, retriever,
|
||||||
|
):
|
||||||
|
"""Find an existing share-as-agent key row matching these parameters.
|
||||||
|
|
||||||
|
Mirrors the legacy Mongo ``agents_collection.find_one`` pre-existence
|
||||||
|
check. Used to reuse an api key across repeated shares of the same
|
||||||
|
conversation with the same prompt/chunks/source/retriever.
|
||||||
|
"""
|
||||||
|
clauses = ["user_id = :uid", "key IS NOT NULL"]
|
||||||
|
params: dict = {"uid": user_id}
|
||||||
|
if prompt_pg_id is None:
|
||||||
|
clauses.append("prompt_id IS NULL")
|
||||||
|
else:
|
||||||
|
clauses.append("prompt_id = CAST(:pid AS uuid)")
|
||||||
|
params["pid"] = prompt_pg_id
|
||||||
|
if chunks is None:
|
||||||
|
clauses.append("chunks IS NULL")
|
||||||
|
else:
|
||||||
|
clauses.append("chunks = :chunks")
|
||||||
|
params["chunks"] = int(chunks)
|
||||||
|
if source_pg_id is None:
|
||||||
|
clauses.append("source_id IS NULL")
|
||||||
|
else:
|
||||||
|
clauses.append("source_id = CAST(:sid AS uuid)")
|
||||||
|
params["sid"] = source_pg_id
|
||||||
|
if retriever is None:
|
||||||
|
clauses.append("retriever IS NULL")
|
||||||
|
else:
|
||||||
|
clauses.append("retriever = :retr")
|
||||||
|
params["retr"] = retriever
|
||||||
|
sql = (
|
||||||
|
"SELECT * FROM agents WHERE "
|
||||||
|
+ " AND ".join(clauses)
|
||||||
|
+ " LIMIT 1"
|
||||||
|
)
|
||||||
|
row = conn.execute(_sql_text(sql), params).fetchone()
|
||||||
|
if row is None:
|
||||||
|
return None
|
||||||
|
mapping = dict(row._mapping)
|
||||||
|
mapping["id"] = str(mapping["id"]) if mapping.get("id") else None
|
||||||
|
return mapping
|
||||||
|
|
||||||
|
|
||||||
@sharing_ns.route("/share")
|
@sharing_ns.route("/share")
|
||||||
class ShareConversation(Resource):
|
class ShareConversation(Resource):
|
||||||
share_conversation_model = api.model(
|
share_conversation_model = api.model(
|
||||||
@@ -56,146 +156,94 @@ class ShareConversation(Resource):
|
|||||||
conversation_id = data["conversation_id"]
|
conversation_id = data["conversation_id"]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
conversation = conversations_collection.find_one(
|
with db_session() as conn:
|
||||||
{"_id": ObjectId(conversation_id), "user": user}
|
conv_repo = ConversationsRepository(conn)
|
||||||
)
|
shared_repo = SharedConversationsRepository(conn)
|
||||||
if conversation is None:
|
agents_repo = AgentsRepository(conn)
|
||||||
return make_response(
|
|
||||||
jsonify(
|
|
||||||
{
|
|
||||||
"status": "error",
|
|
||||||
"message": "Conversation does not exist",
|
|
||||||
}
|
|
||||||
),
|
|
||||||
404,
|
|
||||||
)
|
|
||||||
current_n_queries = len(conversation["queries"])
|
|
||||||
explicit_binary = Binary.from_uuid(
|
|
||||||
uuid.uuid4(), UuidRepresentation.STANDARD
|
|
||||||
)
|
|
||||||
|
|
||||||
if is_promptable:
|
conversation = conv_repo.get_any(conversation_id, user)
|
||||||
prompt_id = data.get("prompt_id", "default")
|
if conversation is None:
|
||||||
chunks = data.get("chunks", "2")
|
return make_response(
|
||||||
|
jsonify(
|
||||||
name = conversation["name"] + "(shared)"
|
|
||||||
new_api_key_data = {
|
|
||||||
"prompt_id": prompt_id,
|
|
||||||
"chunks": chunks,
|
|
||||||
"user": user,
|
|
||||||
}
|
|
||||||
|
|
||||||
if "source" in data and ObjectId.is_valid(data["source"]):
|
|
||||||
new_api_key_data["source"] = DBRef(
|
|
||||||
"sources", ObjectId(data["source"])
|
|
||||||
)
|
|
||||||
if "retriever" in data:
|
|
||||||
new_api_key_data["retriever"] = data["retriever"]
|
|
||||||
pre_existing_api_document = agents_collection.find_one(new_api_key_data)
|
|
||||||
if pre_existing_api_document:
|
|
||||||
api_uuid = pre_existing_api_document["key"]
|
|
||||||
pre_existing = shared_conversations_collections.find_one(
|
|
||||||
{
|
|
||||||
"conversation_id": ObjectId(conversation_id),
|
|
||||||
"isPromptable": is_promptable,
|
|
||||||
"first_n_queries": current_n_queries,
|
|
||||||
"user": user,
|
|
||||||
"api_key": api_uuid,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
if pre_existing is not None:
|
|
||||||
return make_response(
|
|
||||||
jsonify(
|
|
||||||
{
|
|
||||||
"success": True,
|
|
||||||
"identifier": str(pre_existing["uuid"].as_uuid()),
|
|
||||||
}
|
|
||||||
),
|
|
||||||
200,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
shared_conversations_collections.insert_one(
|
|
||||||
{
|
{
|
||||||
"uuid": explicit_binary,
|
"status": "error",
|
||||||
"conversation_id": ObjectId(conversation_id),
|
"message": "Conversation does not exist",
|
||||||
"isPromptable": is_promptable,
|
|
||||||
"first_n_queries": current_n_queries,
|
|
||||||
"user": user,
|
|
||||||
"api_key": api_uuid,
|
|
||||||
}
|
}
|
||||||
)
|
),
|
||||||
return make_response(
|
404,
|
||||||
jsonify(
|
)
|
||||||
{
|
conv_pg_id = str(conversation["id"])
|
||||||
"success": True,
|
current_n_queries = conv_repo.message_count(conv_pg_id)
|
||||||
"identifier": str(explicit_binary.as_uuid()),
|
|
||||||
}
|
|
||||||
),
|
|
||||||
201,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
api_uuid = str(uuid.uuid4())
|
|
||||||
new_api_key_data["key"] = api_uuid
|
|
||||||
new_api_key_data["name"] = name
|
|
||||||
|
|
||||||
if "source" in data and ObjectId.is_valid(data["source"]):
|
if is_promptable:
|
||||||
new_api_key_data["source"] = DBRef(
|
prompt_id_raw = data.get("prompt_id", "default")
|
||||||
"sources", ObjectId(data["source"])
|
chunks_raw = data.get("chunks", "2")
|
||||||
|
try:
|
||||||
|
chunks_int = int(chunks_raw) if chunks_raw not in (None, "") else None
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
chunks_int = None
|
||||||
|
|
||||||
|
prompt_pg_id = _resolve_prompt_pg_id(conn, prompt_id_raw, user)
|
||||||
|
source_pg_id = _resolve_source_pg_id(conn, data.get("source"))
|
||||||
|
retriever = data.get("retriever")
|
||||||
|
|
||||||
|
reusable = _find_reusable_share_agent(
|
||||||
|
conn, user,
|
||||||
|
prompt_pg_id=prompt_pg_id,
|
||||||
|
chunks=chunks_int,
|
||||||
|
source_pg_id=source_pg_id,
|
||||||
|
retriever=retriever,
|
||||||
|
)
|
||||||
|
if reusable:
|
||||||
|
api_uuid = reusable.get("key")
|
||||||
|
else:
|
||||||
|
api_uuid = str(uuid.uuid4())
|
||||||
|
name = (conversation.get("name") or "") + "(shared)"
|
||||||
|
agents_repo.create(
|
||||||
|
user,
|
||||||
|
name,
|
||||||
|
"published",
|
||||||
|
key=api_uuid,
|
||||||
|
retriever=retriever,
|
||||||
|
chunks=chunks_int,
|
||||||
|
prompt_id=prompt_pg_id,
|
||||||
|
source_id=source_pg_id,
|
||||||
)
|
)
|
||||||
if "retriever" in data:
|
|
||||||
new_api_key_data["retriever"] = data["retriever"]
|
share = shared_repo.get_or_create(
|
||||||
agents_collection.insert_one(new_api_key_data)
|
conv_pg_id,
|
||||||
shared_conversations_collections.insert_one(
|
user,
|
||||||
{
|
is_promptable=True,
|
||||||
"uuid": explicit_binary,
|
first_n_queries=current_n_queries,
|
||||||
"conversation_id": ObjectId(conversation_id),
|
api_key=api_uuid,
|
||||||
"isPromptable": is_promptable,
|
prompt_id=prompt_pg_id,
|
||||||
"first_n_queries": current_n_queries,
|
chunks=chunks_int,
|
||||||
"user": user,
|
|
||||||
"api_key": api_uuid,
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
return make_response(
|
return make_response(
|
||||||
jsonify(
|
jsonify(
|
||||||
{
|
{
|
||||||
"success": True,
|
"success": True,
|
||||||
"identifier": str(explicit_binary.as_uuid()),
|
"identifier": str(share["uuid"]),
|
||||||
}
|
}
|
||||||
),
|
),
|
||||||
201,
|
201 if reusable is None else 200,
|
||||||
)
|
)
|
||||||
pre_existing = shared_conversations_collections.find_one(
|
|
||||||
{
|
# Non-promptable share path.
|
||||||
"conversation_id": ObjectId(conversation_id),
|
share = shared_repo.get_or_create(
|
||||||
"isPromptable": is_promptable,
|
conv_pg_id,
|
||||||
"first_n_queries": current_n_queries,
|
user,
|
||||||
"user": user,
|
is_promptable=False,
|
||||||
}
|
first_n_queries=current_n_queries,
|
||||||
)
|
api_key=None,
|
||||||
if pre_existing is not None:
|
)
|
||||||
return make_response(
|
return make_response(
|
||||||
jsonify(
|
jsonify(
|
||||||
{
|
{
|
||||||
"success": True,
|
"success": True,
|
||||||
"identifier": str(pre_existing["uuid"].as_uuid()),
|
"identifier": str(share["uuid"]),
|
||||||
}
|
}
|
||||||
),
|
),
|
||||||
200,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
shared_conversations_collections.insert_one(
|
|
||||||
{
|
|
||||||
"uuid": explicit_binary,
|
|
||||||
"conversation_id": ObjectId(conversation_id),
|
|
||||||
"isPromptable": is_promptable,
|
|
||||||
"first_n_queries": current_n_queries,
|
|
||||||
"user": user,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
return make_response(
|
|
||||||
jsonify(
|
|
||||||
{"success": True, "identifier": str(explicit_binary.as_uuid())}
|
|
||||||
),
|
|
||||||
201,
|
201,
|
||||||
)
|
)
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
@@ -210,37 +258,13 @@ class GetPubliclySharedConversations(Resource):
|
|||||||
@api.doc(description="Get publicly shared conversations by identifier")
|
@api.doc(description="Get publicly shared conversations by identifier")
|
||||||
def get(self, identifier: str):
|
def get(self, identifier: str):
|
||||||
try:
|
try:
|
||||||
query_uuid = Binary.from_uuid(
|
with db_readonly() as conn:
|
||||||
uuid.UUID(identifier), UuidRepresentation.STANDARD
|
shared_repo = SharedConversationsRepository(conn)
|
||||||
)
|
conv_repo = ConversationsRepository(conn)
|
||||||
shared = shared_conversations_collections.find_one({"uuid": query_uuid})
|
attach_repo = AttachmentsRepository(conn)
|
||||||
conversation_queries = []
|
|
||||||
|
|
||||||
if (
|
shared = shared_repo.find_by_uuid(identifier)
|
||||||
shared
|
if not shared or not shared.get("conversation_id"):
|
||||||
and "conversation_id" in shared
|
|
||||||
):
|
|
||||||
# Handle DBRef (legacy), ObjectId, dict, and string formats for conversation_id
|
|
||||||
conversation_id = shared["conversation_id"]
|
|
||||||
if isinstance(conversation_id, DBRef):
|
|
||||||
conversation_id = conversation_id.id
|
|
||||||
elif isinstance(conversation_id, dict):
|
|
||||||
# Handle dict representation of DBRef (e.g., {"$ref": "...", "$id": "..."})
|
|
||||||
if "$id" in conversation_id:
|
|
||||||
conv_id = conversation_id["$id"]
|
|
||||||
# $id might be a dict like {"$oid": "..."} or a string
|
|
||||||
if isinstance(conv_id, dict) and "$oid" in conv_id:
|
|
||||||
conversation_id = ObjectId(conv_id["$oid"])
|
|
||||||
else:
|
|
||||||
conversation_id = ObjectId(conv_id)
|
|
||||||
elif "_id" in conversation_id:
|
|
||||||
conversation_id = ObjectId(conversation_id["_id"])
|
|
||||||
elif isinstance(conversation_id, str):
|
|
||||||
conversation_id = ObjectId(conversation_id)
|
|
||||||
conversation = conversations_collection.find_one(
|
|
||||||
{"_id": conversation_id}
|
|
||||||
)
|
|
||||||
if conversation is None:
|
|
||||||
return make_response(
|
return make_response(
|
||||||
jsonify(
|
jsonify(
|
||||||
{
|
{
|
||||||
@@ -250,22 +274,60 @@ class GetPubliclySharedConversations(Resource):
|
|||||||
),
|
),
|
||||||
404,
|
404,
|
||||||
)
|
)
|
||||||
conversation_queries = conversation["queries"][
|
conv_pg_id = str(shared["conversation_id"])
|
||||||
: (shared["first_n_queries"])
|
owner_user = shared.get("user_id")
|
||||||
]
|
|
||||||
|
|
||||||
for query in conversation_queries:
|
conversation = conv_repo.get_owned(conv_pg_id, owner_user) if owner_user else None
|
||||||
if "attachments" in query and query["attachments"]:
|
if conversation is None:
|
||||||
|
# Fall back to any-user lookup in case shared row's
|
||||||
|
# user_id is missing — still keyed by PG UUID.
|
||||||
|
row = conn.execute(
|
||||||
|
_sql_text(
|
||||||
|
"SELECT * FROM conversations WHERE id = CAST(:id AS uuid)"
|
||||||
|
),
|
||||||
|
{"id": conv_pg_id},
|
||||||
|
).fetchone()
|
||||||
|
if row is None:
|
||||||
|
return make_response(
|
||||||
|
jsonify(
|
||||||
|
{
|
||||||
|
"success": False,
|
||||||
|
"error": "might have broken url or the conversation does not exist",
|
||||||
|
}
|
||||||
|
),
|
||||||
|
404,
|
||||||
|
)
|
||||||
|
conversation = dict(row._mapping)
|
||||||
|
|
||||||
|
messages = conv_repo.get_messages(conv_pg_id)
|
||||||
|
first_n = shared.get("first_n_queries") or 0
|
||||||
|
conversation_queries = []
|
||||||
|
for msg in messages[:first_n]:
|
||||||
|
query = {
|
||||||
|
"prompt": msg.get("prompt"),
|
||||||
|
"response": msg.get("response"),
|
||||||
|
"thought": msg.get("thought"),
|
||||||
|
"sources": msg.get("sources") or [],
|
||||||
|
"tool_calls": msg.get("tool_calls") or [],
|
||||||
|
"timestamp": (
|
||||||
|
msg["timestamp"].isoformat()
|
||||||
|
if hasattr(msg.get("timestamp"), "isoformat")
|
||||||
|
else msg.get("timestamp")
|
||||||
|
),
|
||||||
|
"feedback": msg.get("feedback"),
|
||||||
|
}
|
||||||
|
attachments = msg.get("attachments") or []
|
||||||
|
if attachments:
|
||||||
attachment_details = []
|
attachment_details = []
|
||||||
for attachment_id in query["attachments"]:
|
for attachment_id in attachments:
|
||||||
try:
|
try:
|
||||||
attachment = attachments_collection.find_one(
|
attachment = attach_repo.get_any(
|
||||||
{"_id": ObjectId(attachment_id)}
|
str(attachment_id), owner_user,
|
||||||
)
|
) if owner_user else None
|
||||||
if attachment:
|
if attachment:
|
||||||
attachment_details.append(
|
attachment_details.append(
|
||||||
{
|
{
|
||||||
"id": str(attachment["_id"]),
|
"id": str(attachment["id"]),
|
||||||
"fileName": attachment.get(
|
"fileName": attachment.get(
|
||||||
"filename", "Unknown file"
|
"filename", "Unknown file"
|
||||||
),
|
),
|
||||||
@@ -277,26 +339,23 @@ class GetPubliclySharedConversations(Resource):
|
|||||||
exc_info=True,
|
exc_info=True,
|
||||||
)
|
)
|
||||||
query["attachments"] = attachment_details
|
query["attachments"] = attachment_details
|
||||||
else:
|
conversation_queries.append(query)
|
||||||
return make_response(
|
|
||||||
jsonify(
|
created = conversation.get("created_at") or conversation.get("date")
|
||||||
{
|
date_iso = (
|
||||||
"success": False,
|
created.isoformat()
|
||||||
"error": "might have broken url or the conversation does not exist",
|
if hasattr(created, "isoformat")
|
||||||
}
|
else (str(created) if created is not None else None)
|
||||||
),
|
|
||||||
404,
|
|
||||||
)
|
)
|
||||||
date = conversation["_id"].generation_time.isoformat()
|
res = {
|
||||||
res = {
|
"success": True,
|
||||||
"success": True,
|
"queries": conversation_queries,
|
||||||
"queries": conversation_queries,
|
"title": conversation.get("name"),
|
||||||
"title": conversation["name"],
|
"timestamp": date_iso,
|
||||||
"timestamp": date,
|
}
|
||||||
}
|
if shared.get("is_promptable") and shared.get("api_key"):
|
||||||
if shared["isPromptable"] and "api_key" in shared:
|
res["api_key"] = shared["api_key"]
|
||||||
res["api_key"] = shared["api_key"]
|
return make_response(jsonify(res), 200)
|
||||||
return make_response(jsonify(res), 200)
|
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
current_app.logger.error(
|
current_app.logger.error(
|
||||||
f"Error getting shared conversation: {err}", exc_info=True
|
f"Error getting shared conversation: {err}", exc_info=True
|
||||||
|
|||||||
@@ -1,11 +1,12 @@
|
|||||||
"""Source document management chunk management."""
|
"""Source document management chunk management."""
|
||||||
|
|
||||||
from bson.objectid import ObjectId
|
|
||||||
from flask import current_app, jsonify, make_response, request
|
from flask import current_app, jsonify, make_response, request
|
||||||
from flask_restx import fields, Namespace, Resource
|
from flask_restx import fields, Namespace, Resource
|
||||||
|
|
||||||
from application.api import api
|
from application.api import api
|
||||||
from application.api.user.base import get_vector_store, sources_collection
|
from application.api.user.base import get_vector_store
|
||||||
|
from application.storage.db.repositories.sources import SourcesRepository
|
||||||
|
from application.storage.db.session import db_readonly
|
||||||
from application.utils import check_required_fields, num_tokens_from_string
|
from application.utils import check_required_fields, num_tokens_from_string
|
||||||
|
|
||||||
sources_chunks_ns = Namespace(
|
sources_chunks_ns = Namespace(
|
||||||
@@ -13,6 +14,15 @@ sources_chunks_ns = Namespace(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_source(doc_id: str, user: str):
|
||||||
|
"""Resolve a source (UUID or legacy ObjectId) for the caller.
|
||||||
|
|
||||||
|
Returns the row dict (with PG UUID in ``id``) or ``None`` if missing.
|
||||||
|
"""
|
||||||
|
with db_readonly() as conn:
|
||||||
|
return SourcesRepository(conn).get_any(doc_id, user)
|
||||||
|
|
||||||
|
|
||||||
@sources_chunks_ns.route("/get_chunks")
|
@sources_chunks_ns.route("/get_chunks")
|
||||||
class GetChunks(Resource):
|
class GetChunks(Resource):
|
||||||
@api.doc(
|
@api.doc(
|
||||||
@@ -36,36 +46,34 @@ class GetChunks(Resource):
|
|||||||
path = request.args.get("path")
|
path = request.args.get("path")
|
||||||
search_term = request.args.get("search", "").strip().lower()
|
search_term = request.args.get("search", "").strip().lower()
|
||||||
|
|
||||||
if not ObjectId.is_valid(doc_id):
|
if not doc_id:
|
||||||
|
return make_response(jsonify({"error": "Invalid doc_id"}), 400)
|
||||||
|
try:
|
||||||
|
doc = _resolve_source(doc_id, user)
|
||||||
|
except Exception as e:
|
||||||
|
current_app.logger.error(f"Error resolving source: {e}", exc_info=True)
|
||||||
return make_response(jsonify({"error": "Invalid doc_id"}), 400)
|
return make_response(jsonify({"error": "Invalid doc_id"}), 400)
|
||||||
doc = sources_collection.find_one({"_id": ObjectId(doc_id), "user": user})
|
|
||||||
if not doc:
|
if not doc:
|
||||||
return make_response(
|
return make_response(
|
||||||
jsonify({"error": "Document not found or access denied"}), 404
|
jsonify({"error": "Document not found or access denied"}), 404
|
||||||
)
|
)
|
||||||
|
resolved_id = str(doc["id"])
|
||||||
try:
|
try:
|
||||||
store = get_vector_store(doc_id)
|
store = get_vector_store(resolved_id)
|
||||||
chunks = store.get_chunks()
|
chunks = store.get_chunks()
|
||||||
|
|
||||||
filtered_chunks = []
|
filtered_chunks = []
|
||||||
for chunk in chunks:
|
for chunk in chunks:
|
||||||
metadata = chunk.get("metadata", {})
|
metadata = chunk.get("metadata", {})
|
||||||
|
|
||||||
# Filter by path if provided
|
|
||||||
|
|
||||||
if path:
|
if path:
|
||||||
chunk_source = metadata.get("source", "")
|
chunk_source = metadata.get("source", "")
|
||||||
chunk_file_path = metadata.get("file_path", "")
|
chunk_file_path = metadata.get("file_path", "")
|
||||||
# Check if the chunk matches the requested path
|
|
||||||
# For file uploads: source ends with path (e.g., "inputs/.../file.pdf" ends with "file.pdf")
|
|
||||||
# For crawlers: file_path ends with path (e.g., "guides/setup.md" ends with "setup.md")
|
|
||||||
source_match = chunk_source and chunk_source.endswith(path)
|
source_match = chunk_source and chunk_source.endswith(path)
|
||||||
file_path_match = chunk_file_path and chunk_file_path.endswith(path)
|
file_path_match = chunk_file_path and chunk_file_path.endswith(path)
|
||||||
|
|
||||||
if not (source_match or file_path_match):
|
if not (source_match or file_path_match):
|
||||||
continue
|
continue
|
||||||
# Filter by search term if provided
|
|
||||||
|
|
||||||
if search_term:
|
if search_term:
|
||||||
text_match = search_term in chunk.get("text", "").lower()
|
text_match = search_term in chunk.get("text", "").lower()
|
||||||
title_match = search_term in metadata.get("title", "").lower()
|
title_match = search_term in metadata.get("title", "").lower()
|
||||||
@@ -132,15 +140,17 @@ class AddChunk(Resource):
|
|||||||
token_count = num_tokens_from_string(text)
|
token_count = num_tokens_from_string(text)
|
||||||
metadata["token_count"] = token_count
|
metadata["token_count"] = token_count
|
||||||
|
|
||||||
if not ObjectId.is_valid(doc_id):
|
try:
|
||||||
|
doc = _resolve_source(doc_id, user)
|
||||||
|
except Exception as e:
|
||||||
|
current_app.logger.error(f"Error resolving source: {e}", exc_info=True)
|
||||||
return make_response(jsonify({"error": "Invalid doc_id"}), 400)
|
return make_response(jsonify({"error": "Invalid doc_id"}), 400)
|
||||||
doc = sources_collection.find_one({"_id": ObjectId(doc_id), "user": user})
|
|
||||||
if not doc:
|
if not doc:
|
||||||
return make_response(
|
return make_response(
|
||||||
jsonify({"error": "Document not found or access denied"}), 404
|
jsonify({"error": "Document not found or access denied"}), 404
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
store = get_vector_store(doc_id)
|
store = get_vector_store(str(doc["id"]))
|
||||||
chunk_id = store.add_chunk(text, metadata)
|
chunk_id = store.add_chunk(text, metadata)
|
||||||
return make_response(
|
return make_response(
|
||||||
jsonify({"message": "Chunk added successfully", "chunk_id": chunk_id}),
|
jsonify({"message": "Chunk added successfully", "chunk_id": chunk_id}),
|
||||||
@@ -165,15 +175,17 @@ class DeleteChunk(Resource):
|
|||||||
doc_id = request.args.get("id")
|
doc_id = request.args.get("id")
|
||||||
chunk_id = request.args.get("chunk_id")
|
chunk_id = request.args.get("chunk_id")
|
||||||
|
|
||||||
if not ObjectId.is_valid(doc_id):
|
try:
|
||||||
|
doc = _resolve_source(doc_id, user)
|
||||||
|
except Exception as e:
|
||||||
|
current_app.logger.error(f"Error resolving source: {e}", exc_info=True)
|
||||||
return make_response(jsonify({"error": "Invalid doc_id"}), 400)
|
return make_response(jsonify({"error": "Invalid doc_id"}), 400)
|
||||||
doc = sources_collection.find_one({"_id": ObjectId(doc_id), "user": user})
|
|
||||||
if not doc:
|
if not doc:
|
||||||
return make_response(
|
return make_response(
|
||||||
jsonify({"error": "Document not found or access denied"}), 404
|
jsonify({"error": "Document not found or access denied"}), 404
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
store = get_vector_store(doc_id)
|
store = get_vector_store(str(doc["id"]))
|
||||||
deleted = store.delete_chunk(chunk_id)
|
deleted = store.delete_chunk(chunk_id)
|
||||||
if deleted:
|
if deleted:
|
||||||
return make_response(
|
return make_response(
|
||||||
@@ -232,15 +244,17 @@ class UpdateChunk(Resource):
|
|||||||
if metadata is None:
|
if metadata is None:
|
||||||
metadata = {}
|
metadata = {}
|
||||||
metadata["token_count"] = token_count
|
metadata["token_count"] = token_count
|
||||||
if not ObjectId.is_valid(doc_id):
|
try:
|
||||||
|
doc = _resolve_source(doc_id, user)
|
||||||
|
except Exception as e:
|
||||||
|
current_app.logger.error(f"Error resolving source: {e}", exc_info=True)
|
||||||
return make_response(jsonify({"error": "Invalid doc_id"}), 400)
|
return make_response(jsonify({"error": "Invalid doc_id"}), 400)
|
||||||
doc = sources_collection.find_one({"_id": ObjectId(doc_id), "user": user})
|
|
||||||
if not doc:
|
if not doc:
|
||||||
return make_response(
|
return make_response(
|
||||||
jsonify({"error": "Document not found or access denied"}), 404
|
jsonify({"error": "Document not found or access denied"}), 404
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
store = get_vector_store(doc_id)
|
store = get_vector_store(str(doc["id"]))
|
||||||
|
|
||||||
chunks = store.get_chunks()
|
chunks = store.get_chunks()
|
||||||
existing_chunk = next((c for c in chunks if c["doc_id"] == chunk_id), None)
|
existing_chunk = next((c for c in chunks if c["doc_id"] == chunk_id), None)
|
||||||
|
|||||||
@@ -3,14 +3,14 @@
|
|||||||
import json
|
import json
|
||||||
import math
|
import math
|
||||||
|
|
||||||
from bson.objectid import ObjectId
|
|
||||||
from flask import current_app, jsonify, make_response, redirect, request
|
from flask import current_app, jsonify, make_response, redirect, request
|
||||||
from flask_restx import fields, Namespace, Resource
|
from flask_restx import fields, Namespace, Resource
|
||||||
|
|
||||||
from application.api import api
|
from application.api import api
|
||||||
from application.api.user.base import sources_collection
|
|
||||||
from application.api.user.tasks import sync_source
|
from application.api.user.tasks import sync_source
|
||||||
from application.core.settings import settings
|
from application.core.settings import settings
|
||||||
|
from application.storage.db.repositories.sources import SourcesRepository
|
||||||
|
from application.storage.db.session import db_readonly, db_session
|
||||||
from application.storage.storage_creator import StorageCreator
|
from application.storage.storage_creator import StorageCreator
|
||||||
from application.utils import check_required_fields
|
from application.utils import check_required_fields
|
||||||
from application.vectorstore.vector_creator import VectorCreator
|
from application.vectorstore.vector_creator import VectorCreator
|
||||||
@@ -56,11 +56,20 @@ class CombinedJson(Resource):
|
|||||||
]
|
]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
for index in sources_collection.find({"user": user}).sort("date", -1):
|
with db_readonly() as conn:
|
||||||
|
indexes = SourcesRepository(conn).list_for_user(user)
|
||||||
|
# list_for_user sorts by created_at DESC; legacy shape sorted by
|
||||||
|
# "date" DESC. Both are monotonic on creation so the ordering is
|
||||||
|
# equivalent for dev; re-sort defensively.
|
||||||
|
indexes = sorted(
|
||||||
|
indexes, key=lambda r: r.get("date") or r.get("created_at") or "",
|
||||||
|
reverse=True,
|
||||||
|
)
|
||||||
|
for index in indexes:
|
||||||
provider = _get_provider_from_remote_data(index.get("remote_data"))
|
provider = _get_provider_from_remote_data(index.get("remote_data"))
|
||||||
data.append(
|
data.append(
|
||||||
{
|
{
|
||||||
"id": str(index["_id"]),
|
"id": str(index["id"]),
|
||||||
"name": index.get("name"),
|
"name": index.get("name"),
|
||||||
"date": index.get("date"),
|
"date": index.get("date"),
|
||||||
"model": settings.EMBEDDINGS_NAME,
|
"model": settings.EMBEDDINGS_NAME,
|
||||||
@@ -70,9 +79,7 @@ class CombinedJson(Resource):
|
|||||||
"syncFrequency": index.get("sync_frequency", ""),
|
"syncFrequency": index.get("sync_frequency", ""),
|
||||||
"provider": provider,
|
"provider": provider,
|
||||||
"is_nested": bool(index.get("directory_structure")),
|
"is_nested": bool(index.get("directory_structure")),
|
||||||
"type": index.get(
|
"type": index.get("type", "file"),
|
||||||
"type", "file"
|
|
||||||
), # Add type field with default "file"
|
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
@@ -89,61 +96,55 @@ class PaginatedSources(Resource):
|
|||||||
if not decoded_token:
|
if not decoded_token:
|
||||||
return make_response(jsonify({"success": False}), 401)
|
return make_response(jsonify({"success": False}), 401)
|
||||||
user = decoded_token.get("sub")
|
user = decoded_token.get("sub")
|
||||||
sort_field = request.args.get("sort", "date") # Default to 'date'
|
sort_field = request.args.get("sort", "date")
|
||||||
sort_order = request.args.get("order", "desc") # Default to 'desc'
|
sort_order = request.args.get("order", "desc")
|
||||||
page = int(request.args.get("page", 1)) # Default to 1
|
page = max(1, int(request.args.get("page", 1)))
|
||||||
rows_per_page = int(request.args.get("rows", 10)) # Default to 10
|
rows_per_page = max(1, int(request.args.get("rows", 10)))
|
||||||
# add .strip() to remove leading and trailing whitespaces
|
search_term = request.args.get("search", "").strip() or None
|
||||||
|
|
||||||
search_term = request.args.get(
|
|
||||||
"search", ""
|
|
||||||
).strip() # add search for filter documents
|
|
||||||
|
|
||||||
# Prepare query for filtering
|
|
||||||
|
|
||||||
query = {"user": user}
|
|
||||||
if search_term:
|
|
||||||
query["name"] = {
|
|
||||||
"$regex": search_term,
|
|
||||||
"$options": "i", # using case-insensitive search
|
|
||||||
}
|
|
||||||
total_documents = sources_collection.count_documents(query)
|
|
||||||
total_pages = max(1, math.ceil(total_documents / rows_per_page))
|
|
||||||
page = min(
|
|
||||||
max(1, page), total_pages
|
|
||||||
) # add this to make sure page inbound is within the range
|
|
||||||
sort_order = 1 if sort_order == "asc" else -1
|
|
||||||
skip = (page - 1) * rows_per_page
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
documents = (
|
with db_readonly() as conn:
|
||||||
sources_collection.find(query)
|
repo = SourcesRepository(conn)
|
||||||
.sort(sort_field, sort_order)
|
total_documents = repo.count_for_user(
|
||||||
.skip(skip)
|
user, search_term=search_term,
|
||||||
.limit(rows_per_page)
|
)
|
||||||
)
|
# Prior in-Python implementation returned ``totalPages = 1``
|
||||||
|
# for empty result sets (``max(1, ceil(0/rows))``); we
|
||||||
|
# preserve that contract so the frontend pager stays stable.
|
||||||
|
total_pages = max(1, math.ceil(total_documents / rows_per_page))
|
||||||
|
effective_page = min(page, total_pages)
|
||||||
|
offset = (effective_page - 1) * rows_per_page
|
||||||
|
window = repo.list_for_user(
|
||||||
|
user,
|
||||||
|
limit=rows_per_page,
|
||||||
|
offset=offset,
|
||||||
|
search_term=search_term,
|
||||||
|
sort_field=sort_field,
|
||||||
|
sort_order=sort_order,
|
||||||
|
)
|
||||||
|
|
||||||
paginated_docs = []
|
paginated_docs = []
|
||||||
for doc in documents:
|
for doc in window:
|
||||||
provider = _get_provider_from_remote_data(doc.get("remote_data"))
|
provider = _get_provider_from_remote_data(doc.get("remote_data"))
|
||||||
doc_data = {
|
paginated_docs.append(
|
||||||
"id": str(doc["_id"]),
|
{
|
||||||
"name": doc.get("name", ""),
|
"id": str(doc["id"]),
|
||||||
"date": doc.get("date", ""),
|
"name": doc.get("name", ""),
|
||||||
"model": settings.EMBEDDINGS_NAME,
|
"date": doc.get("date", ""),
|
||||||
"location": "local",
|
"model": settings.EMBEDDINGS_NAME,
|
||||||
"tokens": doc.get("tokens", ""),
|
"location": "local",
|
||||||
"retriever": doc.get("retriever", "classic"),
|
"tokens": doc.get("tokens", ""),
|
||||||
"syncFrequency": doc.get("sync_frequency", ""),
|
"retriever": doc.get("retriever", "classic"),
|
||||||
"provider": provider,
|
"syncFrequency": doc.get("sync_frequency", ""),
|
||||||
"isNested": bool(doc.get("directory_structure")),
|
"provider": provider,
|
||||||
"type": doc.get("type", "file"),
|
"isNested": bool(doc.get("directory_structure")),
|
||||||
}
|
"type": doc.get("type", "file"),
|
||||||
paginated_docs.append(doc_data)
|
}
|
||||||
|
)
|
||||||
response = {
|
response = {
|
||||||
"total": total_documents,
|
"total": total_documents,
|
||||||
"totalPages": total_pages,
|
"totalPages": total_pages,
|
||||||
"currentPage": page,
|
"currentPage": effective_page,
|
||||||
"paginated": paginated_docs,
|
"paginated": paginated_docs,
|
||||||
}
|
}
|
||||||
return make_response(jsonify(response), 200)
|
return make_response(jsonify(response), 200)
|
||||||
@@ -154,28 +155,6 @@ class PaginatedSources(Resource):
|
|||||||
return make_response(jsonify({"success": False}), 400)
|
return make_response(jsonify({"success": False}), 400)
|
||||||
|
|
||||||
|
|
||||||
@sources_ns.route("/delete_by_ids")
|
|
||||||
class DeleteByIds(Resource):
|
|
||||||
@api.doc(
|
|
||||||
description="Deletes documents from the vector store by IDs",
|
|
||||||
params={"path": "Comma-separated list of IDs"},
|
|
||||||
)
|
|
||||||
def get(self):
|
|
||||||
ids = request.args.get("path")
|
|
||||||
if not ids:
|
|
||||||
return make_response(
|
|
||||||
jsonify({"success": False, "message": "Missing required fields"}), 400
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
result = sources_collection.delete_index(ids=ids)
|
|
||||||
if result:
|
|
||||||
return make_response(jsonify({"success": True}), 200)
|
|
||||||
except Exception as err:
|
|
||||||
current_app.logger.error(f"Error deleting indexes: {err}", exc_info=True)
|
|
||||||
return make_response(jsonify({"success": False}), 400)
|
|
||||||
return make_response(jsonify({"success": False}), 400)
|
|
||||||
|
|
||||||
|
|
||||||
@sources_ns.route("/delete_old")
|
@sources_ns.route("/delete_old")
|
||||||
class DeleteOldIndexes(Resource):
|
class DeleteOldIndexes(Resource):
|
||||||
@api.doc(
|
@api.doc(
|
||||||
@@ -186,30 +165,33 @@ class DeleteOldIndexes(Resource):
|
|||||||
decoded_token = request.decoded_token
|
decoded_token = request.decoded_token
|
||||||
if not decoded_token:
|
if not decoded_token:
|
||||||
return make_response(jsonify({"success": False}), 401)
|
return make_response(jsonify({"success": False}), 401)
|
||||||
|
user = decoded_token.get("sub")
|
||||||
source_id = request.args.get("source_id")
|
source_id = request.args.get("source_id")
|
||||||
if not source_id:
|
if not source_id:
|
||||||
return make_response(
|
return make_response(
|
||||||
jsonify({"success": False, "message": "Missing required fields"}), 400
|
jsonify({"success": False, "message": "Missing required fields"}), 400
|
||||||
)
|
)
|
||||||
doc = sources_collection.find_one(
|
try:
|
||||||
{"_id": ObjectId(source_id), "user": decoded_token.get("sub")}
|
with db_readonly() as conn:
|
||||||
)
|
doc = SourcesRepository(conn).get_any(source_id, user)
|
||||||
|
except Exception as err:
|
||||||
|
current_app.logger.error(f"Error looking up source: {err}", exc_info=True)
|
||||||
|
return make_response(jsonify({"success": False}), 400)
|
||||||
if not doc:
|
if not doc:
|
||||||
return make_response(jsonify({"status": "not found"}), 404)
|
return make_response(jsonify({"status": "not found"}), 404)
|
||||||
storage = StorageCreator.get_storage()
|
storage = StorageCreator.get_storage()
|
||||||
|
resolved_id = str(doc["id"])
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Delete vector index
|
|
||||||
|
|
||||||
if settings.VECTOR_STORE == "faiss":
|
if settings.VECTOR_STORE == "faiss":
|
||||||
index_path = f"indexes/{str(doc['_id'])}"
|
index_path = f"indexes/{resolved_id}"
|
||||||
if storage.file_exists(f"{index_path}/index.faiss"):
|
if storage.file_exists(f"{index_path}/index.faiss"):
|
||||||
storage.delete_file(f"{index_path}/index.faiss")
|
storage.delete_file(f"{index_path}/index.faiss")
|
||||||
if storage.file_exists(f"{index_path}/index.pkl"):
|
if storage.file_exists(f"{index_path}/index.pkl"):
|
||||||
storage.delete_file(f"{index_path}/index.pkl")
|
storage.delete_file(f"{index_path}/index.pkl")
|
||||||
else:
|
else:
|
||||||
vectorstore = VectorCreator.create_vectorstore(
|
vectorstore = VectorCreator.create_vectorstore(
|
||||||
settings.VECTOR_STORE, source_id=str(doc["_id"])
|
settings.VECTOR_STORE, source_id=resolved_id
|
||||||
)
|
)
|
||||||
vectorstore.delete_index()
|
vectorstore.delete_index()
|
||||||
if "file_path" in doc and doc["file_path"]:
|
if "file_path" in doc and doc["file_path"]:
|
||||||
@@ -227,7 +209,14 @@ class DeleteOldIndexes(Resource):
|
|||||||
f"Error deleting files and indexes: {err}", exc_info=True
|
f"Error deleting files and indexes: {err}", exc_info=True
|
||||||
)
|
)
|
||||||
return make_response(jsonify({"success": False}), 400)
|
return make_response(jsonify({"success": False}), 400)
|
||||||
sources_collection.delete_one({"_id": ObjectId(source_id)})
|
try:
|
||||||
|
with db_session() as conn:
|
||||||
|
SourcesRepository(conn).delete(resolved_id, user)
|
||||||
|
except Exception as err:
|
||||||
|
current_app.logger.error(
|
||||||
|
f"Error deleting source row: {err}", exc_info=True
|
||||||
|
)
|
||||||
|
return make_response(jsonify({"success": False}), 400)
|
||||||
return make_response(jsonify({"success": True}), 200)
|
return make_response(jsonify({"success": True}), 200)
|
||||||
|
|
||||||
|
|
||||||
@@ -272,15 +261,16 @@ class ManageSync(Resource):
|
|||||||
return make_response(
|
return make_response(
|
||||||
jsonify({"success": False, "message": "Invalid frequency"}), 400
|
jsonify({"success": False, "message": "Invalid frequency"}), 400
|
||||||
)
|
)
|
||||||
update_data = {"$set": {"sync_frequency": sync_frequency}}
|
|
||||||
try:
|
try:
|
||||||
sources_collection.update_one(
|
with db_session() as conn:
|
||||||
{
|
repo = SourcesRepository(conn)
|
||||||
"_id": ObjectId(source_id),
|
doc = repo.get_any(source_id, user)
|
||||||
"user": user,
|
if doc is None:
|
||||||
},
|
return make_response(
|
||||||
update_data,
|
jsonify({"success": False, "message": "Source not found"}),
|
||||||
)
|
404,
|
||||||
|
)
|
||||||
|
repo.update(str(doc["id"]), user, {"sync_frequency": sync_frequency})
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
current_app.logger.error(
|
current_app.logger.error(
|
||||||
f"Error updating sync frequency: {err}", exc_info=True
|
f"Error updating sync frequency: {err}", exc_info=True
|
||||||
@@ -309,19 +299,20 @@ class SyncSource(Resource):
|
|||||||
if missing_fields:
|
if missing_fields:
|
||||||
return missing_fields
|
return missing_fields
|
||||||
source_id = data["source_id"]
|
source_id = data["source_id"]
|
||||||
if not ObjectId.is_valid(source_id):
|
try:
|
||||||
|
with db_readonly() as conn:
|
||||||
|
doc = SourcesRepository(conn).get_any(source_id, user)
|
||||||
|
except Exception as err:
|
||||||
|
current_app.logger.error(f"Error looking up source: {err}", exc_info=True)
|
||||||
return make_response(
|
return make_response(
|
||||||
jsonify({"success": False, "message": "Invalid source ID"}), 400
|
jsonify({"success": False, "message": "Invalid source ID"}), 400
|
||||||
)
|
)
|
||||||
doc = sources_collection.find_one(
|
|
||||||
{"_id": ObjectId(source_id), "user": user}
|
|
||||||
)
|
|
||||||
if not doc:
|
if not doc:
|
||||||
return make_response(
|
return make_response(
|
||||||
jsonify({"success": False, "message": "Source not found"}), 404
|
jsonify({"success": False, "message": "Source not found"}), 404
|
||||||
)
|
)
|
||||||
source_type = doc.get("type", "")
|
source_type = doc.get("type", "")
|
||||||
if source_type.startswith("connector"):
|
if source_type and source_type.startswith("connector"):
|
||||||
return make_response(
|
return make_response(
|
||||||
jsonify(
|
jsonify(
|
||||||
{
|
{
|
||||||
@@ -344,7 +335,7 @@ class SyncSource(Resource):
|
|||||||
loader=source_type,
|
loader=source_type,
|
||||||
sync_frequency=doc.get("sync_frequency", "never"),
|
sync_frequency=doc.get("sync_frequency", "never"),
|
||||||
retriever=doc.get("retriever", "classic"),
|
retriever=doc.get("retriever", "classic"),
|
||||||
doc_id=source_id,
|
doc_id=str(doc["id"]),
|
||||||
)
|
)
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
current_app.logger.error(
|
current_app.logger.error(
|
||||||
@@ -370,10 +361,9 @@ class DirectoryStructure(Resource):
|
|||||||
|
|
||||||
if not doc_id:
|
if not doc_id:
|
||||||
return make_response(jsonify({"error": "Document ID is required"}), 400)
|
return make_response(jsonify({"error": "Document ID is required"}), 400)
|
||||||
if not ObjectId.is_valid(doc_id):
|
|
||||||
return make_response(jsonify({"error": "Invalid document ID"}), 400)
|
|
||||||
try:
|
try:
|
||||||
doc = sources_collection.find_one({"_id": ObjectId(doc_id), "user": user})
|
with db_readonly() as conn:
|
||||||
|
doc = SourcesRepository(conn).get_any(doc_id, user)
|
||||||
if not doc:
|
if not doc:
|
||||||
return make_response(
|
return make_response(
|
||||||
jsonify({"error": "Document not found or access denied"}), 404
|
jsonify({"error": "Document not found or access denied"}), 404
|
||||||
@@ -387,6 +377,8 @@ class DirectoryStructure(Resource):
|
|||||||
if isinstance(remote_data, str) and remote_data:
|
if isinstance(remote_data, str) and remote_data:
|
||||||
remote_data_obj = json.loads(remote_data)
|
remote_data_obj = json.loads(remote_data)
|
||||||
provider = remote_data_obj.get("provider")
|
provider = remote_data_obj.get("provider")
|
||||||
|
elif isinstance(remote_data, dict):
|
||||||
|
provider = remote_data.get("provider")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
current_app.logger.warning(
|
current_app.logger.warning(
|
||||||
f"Failed to parse remote_data for doc {doc_id}: {e}"
|
f"Failed to parse remote_data for doc {doc_id}: {e}"
|
||||||
@@ -406,4 +398,7 @@ class DirectoryStructure(Resource):
|
|||||||
current_app.logger.error(
|
current_app.logger.error(
|
||||||
f"Error retrieving directory structure: {e}", exc_info=True
|
f"Error retrieving directory structure: {e}", exc_info=True
|
||||||
)
|
)
|
||||||
return make_response(jsonify({"success": False, "error": "Failed to retrieve directory structure"}), 500)
|
return make_response(
|
||||||
|
jsonify({"success": False, "error": "Failed to retrieve directory structure"}),
|
||||||
|
500,
|
||||||
|
)
|
||||||
|
|||||||
@@ -5,16 +5,16 @@ import os
|
|||||||
import tempfile
|
import tempfile
|
||||||
import zipfile
|
import zipfile
|
||||||
|
|
||||||
from bson.objectid import ObjectId
|
|
||||||
from flask import current_app, jsonify, make_response, request
|
from flask import current_app, jsonify, make_response, request
|
||||||
from flask_restx import fields, Namespace, Resource
|
from flask_restx import fields, Namespace, Resource
|
||||||
|
|
||||||
from application.api import api
|
from application.api import api
|
||||||
from application.api.user.base import sources_collection
|
|
||||||
from application.api.user.tasks import ingest, ingest_connector_task, ingest_remote
|
from application.api.user.tasks import ingest, ingest_connector_task, ingest_remote
|
||||||
from application.core.settings import settings
|
from application.core.settings import settings
|
||||||
from application.parser.connectors.connector_creator import ConnectorCreator
|
from application.parser.connectors.connector_creator import ConnectorCreator
|
||||||
from application.parser.file.constants import SUPPORTED_SOURCE_EXTENSIONS
|
from application.parser.file.constants import SUPPORTED_SOURCE_EXTENSIONS
|
||||||
|
from application.storage.db.repositories.sources import SourcesRepository
|
||||||
|
from application.storage.db.session import db_readonly, db_session
|
||||||
from application.storage.storage_creator import StorageCreator
|
from application.storage.storage_creator import StorageCreator
|
||||||
from application.stt.upload_limits import (
|
from application.stt.upload_limits import (
|
||||||
AudioFileTooLargeError,
|
AudioFileTooLargeError,
|
||||||
@@ -329,15 +329,8 @@ class ManageSourceFiles(Resource):
|
|||||||
400,
|
400,
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
ObjectId(source_id)
|
with db_readonly() as conn:
|
||||||
except Exception:
|
source = SourcesRepository(conn).get_any(source_id, user)
|
||||||
return make_response(
|
|
||||||
jsonify({"success": False, "message": "Invalid source ID format"}), 400
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
source = sources_collection.find_one(
|
|
||||||
{"_id": ObjectId(source_id), "user": user}
|
|
||||||
)
|
|
||||||
if not source:
|
if not source:
|
||||||
return make_response(
|
return make_response(
|
||||||
jsonify(
|
jsonify(
|
||||||
@@ -353,6 +346,7 @@ class ManageSourceFiles(Resource):
|
|||||||
return make_response(
|
return make_response(
|
||||||
jsonify({"success": False, "message": "Database error"}), 500
|
jsonify({"success": False, "message": "Database error"}), 500
|
||||||
)
|
)
|
||||||
|
resolved_source_id = str(source["id"])
|
||||||
try:
|
try:
|
||||||
storage = StorageCreator.get_storage()
|
storage = StorageCreator.get_storage()
|
||||||
source_file_path = source.get("file_path", "")
|
source_file_path = source.get("file_path", "")
|
||||||
@@ -411,15 +405,18 @@ class ManageSourceFiles(Resource):
|
|||||||
map_updated = True
|
map_updated = True
|
||||||
|
|
||||||
if map_updated:
|
if map_updated:
|
||||||
sources_collection.update_one(
|
with db_session() as conn:
|
||||||
{"_id": ObjectId(source_id)},
|
SourcesRepository(conn).update(
|
||||||
{"$set": {"file_name_map": file_name_map}},
|
resolved_source_id, user,
|
||||||
)
|
{"file_name_map": dict(file_name_map)},
|
||||||
|
)
|
||||||
# Trigger re-ingestion pipeline
|
# Trigger re-ingestion pipeline
|
||||||
|
|
||||||
from application.api.user.tasks import reingest_source_task
|
from application.api.user.tasks import reingest_source_task
|
||||||
|
|
||||||
task = reingest_source_task.delay(source_id=source_id, user=user)
|
task = reingest_source_task.delay(
|
||||||
|
source_id=resolved_source_id, user=user
|
||||||
|
)
|
||||||
|
|
||||||
return make_response(
|
return make_response(
|
||||||
jsonify(
|
jsonify(
|
||||||
@@ -485,15 +482,18 @@ class ManageSourceFiles(Resource):
|
|||||||
map_updated = True
|
map_updated = True
|
||||||
|
|
||||||
if map_updated and isinstance(file_name_map, dict):
|
if map_updated and isinstance(file_name_map, dict):
|
||||||
sources_collection.update_one(
|
with db_session() as conn:
|
||||||
{"_id": ObjectId(source_id)},
|
SourcesRepository(conn).update(
|
||||||
{"$set": {"file_name_map": file_name_map}},
|
resolved_source_id, user,
|
||||||
)
|
{"file_name_map": dict(file_name_map)},
|
||||||
|
)
|
||||||
# Trigger re-ingestion pipeline
|
# Trigger re-ingestion pipeline
|
||||||
|
|
||||||
from application.api.user.tasks import reingest_source_task
|
from application.api.user.tasks import reingest_source_task
|
||||||
|
|
||||||
task = reingest_source_task.delay(source_id=source_id, user=user)
|
task = reingest_source_task.delay(
|
||||||
|
source_id=resolved_source_id, user=user
|
||||||
|
)
|
||||||
|
|
||||||
return make_response(
|
return make_response(
|
||||||
jsonify(
|
jsonify(
|
||||||
@@ -581,16 +581,19 @@ class ManageSourceFiles(Resource):
|
|||||||
if keys_to_remove:
|
if keys_to_remove:
|
||||||
for key in keys_to_remove:
|
for key in keys_to_remove:
|
||||||
file_name_map.pop(key, None)
|
file_name_map.pop(key, None)
|
||||||
sources_collection.update_one(
|
with db_session() as conn:
|
||||||
{"_id": ObjectId(source_id)},
|
SourcesRepository(conn).update(
|
||||||
{"$set": {"file_name_map": file_name_map}},
|
resolved_source_id, user,
|
||||||
)
|
{"file_name_map": dict(file_name_map)},
|
||||||
|
)
|
||||||
|
|
||||||
# Trigger re-ingestion pipeline
|
# Trigger re-ingestion pipeline
|
||||||
|
|
||||||
from application.api.user.tasks import reingest_source_task
|
from application.api.user.tasks import reingest_source_task
|
||||||
|
|
||||||
task = reingest_source_task.delay(source_id=source_id, user=user)
|
task = reingest_source_task.delay(
|
||||||
|
source_id=resolved_source_id, user=user
|
||||||
|
)
|
||||||
|
|
||||||
return make_response(
|
return make_response(
|
||||||
jsonify(
|
jsonify(
|
||||||
|
|||||||
@@ -134,6 +134,12 @@ def setup_periodic_tasks(sender, **kwargs):
|
|||||||
timedelta(days=30),
|
timedelta(days=30),
|
||||||
schedule_syncs.s("monthly"),
|
schedule_syncs.s("monthly"),
|
||||||
)
|
)
|
||||||
|
# Replaces Mongo's TTL index on pending_tool_state.expires_at.
|
||||||
|
sender.add_periodic_task(
|
||||||
|
timedelta(seconds=60),
|
||||||
|
cleanup_pending_tool_state.s(),
|
||||||
|
name="cleanup-pending-tool-state",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@celery.task(bind=True)
|
@celery.task(bind=True)
|
||||||
@@ -146,3 +152,27 @@ def mcp_oauth_task(self, config, user):
|
|||||||
def mcp_oauth_status_task(self, task_id):
|
def mcp_oauth_status_task(self, task_id):
|
||||||
resp = mcp_oauth_status(self, task_id)
|
resp = mcp_oauth_status(self, task_id)
|
||||||
return resp
|
return resp
|
||||||
|
|
||||||
|
|
||||||
|
@celery.task(bind=True)
|
||||||
|
def cleanup_pending_tool_state(self):
|
||||||
|
"""Delete pending_tool_state rows past their TTL.
|
||||||
|
|
||||||
|
Replaces Mongo's ``expireAfterSeconds=0`` TTL index — Postgres has
|
||||||
|
no native TTL, so this task runs every 60 seconds to keep
|
||||||
|
``pending_tool_state`` bounded. No-ops if ``POSTGRES_URI`` isn't
|
||||||
|
configured (keeps the task runnable in Mongo-only environments).
|
||||||
|
"""
|
||||||
|
from application.core.settings import settings
|
||||||
|
if not settings.POSTGRES_URI:
|
||||||
|
return {"deleted": 0, "skipped": "POSTGRES_URI not set"}
|
||||||
|
|
||||||
|
from application.storage.db.engine import get_engine
|
||||||
|
from application.storage.db.repositories.pending_tool_state import (
|
||||||
|
PendingToolStateRepository,
|
||||||
|
)
|
||||||
|
|
||||||
|
engine = get_engine()
|
||||||
|
with engine.begin() as conn:
|
||||||
|
deleted = PendingToolStateRepository(conn).cleanup_expired()
|
||||||
|
return {"deleted": deleted}
|
||||||
|
|||||||
@@ -3,27 +3,24 @@
|
|||||||
import json
|
import json
|
||||||
from urllib.parse import urlencode, urlparse
|
from urllib.parse import urlencode, urlparse
|
||||||
|
|
||||||
from bson.objectid import ObjectId
|
|
||||||
from flask import current_app, jsonify, make_response, redirect, request
|
from flask import current_app, jsonify, make_response, redirect, request
|
||||||
from flask_restx import Namespace, Resource, fields
|
from flask_restx import Namespace, Resource, fields
|
||||||
|
|
||||||
from application.agents.tools.mcp_tool import MCPOAuthManager, MCPTool
|
from application.agents.tools.mcp_tool import MCPOAuthManager, MCPTool
|
||||||
from application.api import api
|
from application.api import api
|
||||||
from application.api.user.base import user_tools_collection
|
|
||||||
from application.api.user.tools.routes import transform_actions
|
from application.api.user.tools.routes import transform_actions
|
||||||
from application.cache import get_redis_instance
|
from application.cache import get_redis_instance
|
||||||
from application.core.mongo_db import MongoDB
|
|
||||||
from application.core.settings import settings
|
|
||||||
from application.core.url_validation import SSRFError, validate_url
|
from application.core.url_validation import SSRFError, validate_url
|
||||||
from application.security.encryption import decrypt_credentials, encrypt_credentials
|
from application.security.encryption import decrypt_credentials, encrypt_credentials
|
||||||
|
from application.storage.db.repositories.connector_sessions import (
|
||||||
|
ConnectorSessionsRepository,
|
||||||
|
)
|
||||||
|
from application.storage.db.repositories.user_tools import UserToolsRepository
|
||||||
|
from application.storage.db.session import db_readonly, db_session
|
||||||
from application.utils import check_required_fields
|
from application.utils import check_required_fields
|
||||||
|
|
||||||
tools_mcp_ns = Namespace("tools", description="Tool management operations", path="/api")
|
tools_mcp_ns = Namespace("tools", description="Tool management operations", path="/api")
|
||||||
|
|
||||||
_mongo = MongoDB.get_client()
|
|
||||||
_db = _mongo[settings.MONGO_DB_NAME]
|
|
||||||
_connector_sessions = _db["connector_sessions"]
|
|
||||||
|
|
||||||
_ALLOWED_TRANSPORTS = {"auto", "sse", "http"}
|
_ALLOWED_TRANSPORTS = {"auto", "sse", "http"}
|
||||||
|
|
||||||
|
|
||||||
@@ -252,15 +249,18 @@ class MCPServerSave(Resource):
|
|||||||
storage_config = config.copy()
|
storage_config = config.copy()
|
||||||
|
|
||||||
tool_id = data.get("id")
|
tool_id = data.get("id")
|
||||||
|
existing_doc = None
|
||||||
existing_encrypted = None
|
existing_encrypted = None
|
||||||
if tool_id:
|
if tool_id:
|
||||||
existing_doc = user_tools_collection.find_one(
|
with db_readonly() as conn:
|
||||||
{"_id": ObjectId(tool_id), "user": user, "name": "mcp_tool"}
|
repo = UserToolsRepository(conn)
|
||||||
)
|
existing_doc = repo.get_any(tool_id, user)
|
||||||
if existing_doc:
|
if existing_doc and existing_doc.get("name") == "mcp_tool":
|
||||||
existing_encrypted = existing_doc.get("config", {}).get(
|
existing_encrypted = (existing_doc.get("config") or {}).get(
|
||||||
"encrypted_credentials"
|
"encrypted_credentials"
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
existing_doc = None
|
||||||
|
|
||||||
if auth_credentials:
|
if auth_credentials:
|
||||||
if existing_encrypted:
|
if existing_encrypted:
|
||||||
@@ -283,47 +283,88 @@ class MCPServerSave(Resource):
|
|||||||
]:
|
]:
|
||||||
storage_config.pop(field, None)
|
storage_config.pop(field, None)
|
||||||
transformed_actions = transform_actions(actions_metadata)
|
transformed_actions = transform_actions(actions_metadata)
|
||||||
tool_data = {
|
|
||||||
"name": "mcp_tool",
|
|
||||||
"displayName": data["displayName"],
|
|
||||||
"customName": data["displayName"],
|
|
||||||
"description": f"MCP Server: {storage_config.get('server_url', 'Unknown')}",
|
|
||||||
"config": storage_config,
|
|
||||||
"actions": transformed_actions,
|
|
||||||
"status": data.get("status", True),
|
|
||||||
"user": user,
|
|
||||||
}
|
|
||||||
|
|
||||||
if tool_id:
|
display_name = data["displayName"]
|
||||||
result = user_tools_collection.update_one(
|
description = f"MCP Server: {storage_config.get('server_url', 'Unknown')}"
|
||||||
{"_id": ObjectId(tool_id), "user": user, "name": "mcp_tool"},
|
status_bool = bool(data.get("status", True))
|
||||||
{"$set": {k: v for k, v in tool_data.items() if k != "user"}},
|
|
||||||
)
|
with db_session() as conn:
|
||||||
if result.matched_count == 0:
|
repo = UserToolsRepository(conn)
|
||||||
return make_response(
|
if existing_doc:
|
||||||
jsonify(
|
repo.update(
|
||||||
{
|
str(existing_doc["id"]), user,
|
||||||
"success": False,
|
{
|
||||||
"error": "Tool not found or access denied",
|
"display_name": display_name,
|
||||||
}
|
"custom_name": display_name,
|
||||||
),
|
"description": description,
|
||||||
404,
|
"config": storage_config,
|
||||||
|
"actions": transformed_actions,
|
||||||
|
"status": status_bool,
|
||||||
|
},
|
||||||
)
|
)
|
||||||
response_data = {
|
saved_id = str(existing_doc["id"])
|
||||||
"success": True,
|
response_data = {
|
||||||
"id": tool_id,
|
"success": True,
|
||||||
"message": f"MCP server updated successfully! Discovered {len(transformed_actions)} tools.",
|
"id": saved_id,
|
||||||
"tools_count": len(transformed_actions),
|
"message": f"MCP server updated successfully! Discovered {len(transformed_actions)} tools.",
|
||||||
}
|
"tools_count": len(transformed_actions),
|
||||||
else:
|
}
|
||||||
result = user_tools_collection.insert_one(tool_data)
|
else:
|
||||||
tool_id = str(result.inserted_id)
|
# Fall back to find_by_user_and_name — the original
|
||||||
response_data = {
|
# dual-write path also ran an existence check before
|
||||||
"success": True,
|
# deciding between insert and update.
|
||||||
"id": tool_id,
|
existing_by_name = repo.find_by_user_and_name(user, "mcp_tool")
|
||||||
"message": f"MCP server created successfully! Discovered {len(transformed_actions)} tools.",
|
if tool_id is None and existing_by_name and (
|
||||||
"tools_count": len(transformed_actions),
|
(existing_by_name.get("config") or {}).get("server_url")
|
||||||
}
|
== storage_config.get("server_url")
|
||||||
|
):
|
||||||
|
repo.update(
|
||||||
|
str(existing_by_name["id"]), user,
|
||||||
|
{
|
||||||
|
"display_name": display_name,
|
||||||
|
"custom_name": display_name,
|
||||||
|
"description": description,
|
||||||
|
"config": storage_config,
|
||||||
|
"actions": transformed_actions,
|
||||||
|
"status": status_bool,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
saved_id = str(existing_by_name["id"])
|
||||||
|
response_data = {
|
||||||
|
"success": True,
|
||||||
|
"id": saved_id,
|
||||||
|
"message": f"MCP server updated successfully! Discovered {len(transformed_actions)} tools.",
|
||||||
|
"tools_count": len(transformed_actions),
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
created = repo.create(
|
||||||
|
user, "mcp_tool",
|
||||||
|
config=storage_config,
|
||||||
|
custom_name=display_name,
|
||||||
|
display_name=display_name,
|
||||||
|
description=description,
|
||||||
|
config_requirements={},
|
||||||
|
actions=transformed_actions,
|
||||||
|
status=status_bool,
|
||||||
|
)
|
||||||
|
saved_id = str(created["id"])
|
||||||
|
response_data = {
|
||||||
|
"success": True,
|
||||||
|
"id": saved_id,
|
||||||
|
"message": f"MCP server created successfully! Discovered {len(transformed_actions)} tools.",
|
||||||
|
"tools_count": len(transformed_actions),
|
||||||
|
}
|
||||||
|
if tool_id and existing_doc is None:
|
||||||
|
# Client requested update on a non-existent tool id.
|
||||||
|
return make_response(
|
||||||
|
jsonify(
|
||||||
|
{
|
||||||
|
"success": False,
|
||||||
|
"error": "Tool not found or access denied",
|
||||||
|
}
|
||||||
|
),
|
||||||
|
404,
|
||||||
|
)
|
||||||
return make_response(jsonify(response_data), 200)
|
return make_response(jsonify(response_data), 200)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
current_app.logger.warning(f"Invalid MCP server save request: {e}")
|
current_app.logger.warning(f"Invalid MCP server save request: {e}")
|
||||||
@@ -459,49 +500,59 @@ class MCPAuthStatus(Resource):
|
|||||||
return make_response(jsonify({"success": False}), 401)
|
return make_response(jsonify({"success": False}), 401)
|
||||||
user = decoded_token.get("sub")
|
user = decoded_token.get("sub")
|
||||||
try:
|
try:
|
||||||
mcp_tools = list(
|
with db_readonly() as conn:
|
||||||
user_tools_collection.find(
|
tools_repo = UserToolsRepository(conn)
|
||||||
{"user": user, "name": "mcp_tool"},
|
sessions_repo = ConnectorSessionsRepository(conn)
|
||||||
{"_id": 1, "config": 1},
|
all_tools = tools_repo.list_for_user(user)
|
||||||
)
|
mcp_tools = [t for t in all_tools if t.get("name") == "mcp_tool"]
|
||||||
)
|
if not mcp_tools:
|
||||||
if not mcp_tools:
|
return make_response(
|
||||||
return make_response(jsonify({"success": True, "statuses": {}}), 200)
|
jsonify({"success": True, "statuses": {}}), 200
|
||||||
|
|
||||||
oauth_server_urls = {}
|
|
||||||
statuses = {}
|
|
||||||
for tool in mcp_tools:
|
|
||||||
tool_id = str(tool["_id"])
|
|
||||||
config = tool.get("config", {})
|
|
||||||
auth_type = config.get("auth_type", "none")
|
|
||||||
if auth_type == "oauth":
|
|
||||||
server_url = config.get("server_url", "")
|
|
||||||
if server_url:
|
|
||||||
parsed = urlparse(server_url)
|
|
||||||
base_url = f"{parsed.scheme}://{parsed.netloc}"
|
|
||||||
oauth_server_urls[tool_id] = base_url
|
|
||||||
else:
|
|
||||||
statuses[tool_id] = "needs_auth"
|
|
||||||
else:
|
|
||||||
statuses[tool_id] = "configured"
|
|
||||||
|
|
||||||
if oauth_server_urls:
|
|
||||||
unique_urls = list(set(oauth_server_urls.values()))
|
|
||||||
sessions = list(
|
|
||||||
_connector_sessions.find(
|
|
||||||
{"user_id": user, "server_url": {"$in": unique_urls}},
|
|
||||||
{"server_url": 1, "tokens": 1},
|
|
||||||
)
|
)
|
||||||
)
|
|
||||||
url_has_tokens = {
|
oauth_server_urls: dict = {}
|
||||||
doc["server_url"]: bool(doc.get("tokens", {}).get("access_token"))
|
statuses: dict = {}
|
||||||
for doc in sessions
|
for tool in mcp_tools:
|
||||||
}
|
tool_id = str(tool["id"])
|
||||||
for tool_id, base_url in oauth_server_urls.items():
|
config = tool.get("config") or {}
|
||||||
if url_has_tokens.get(base_url):
|
auth_type = config.get("auth_type", "none")
|
||||||
statuses[tool_id] = "connected"
|
if auth_type == "oauth":
|
||||||
|
server_url = config.get("server_url", "")
|
||||||
|
if server_url:
|
||||||
|
parsed = urlparse(server_url)
|
||||||
|
base_url = f"{parsed.scheme}://{parsed.netloc}"
|
||||||
|
oauth_server_urls[tool_id] = base_url
|
||||||
|
else:
|
||||||
|
statuses[tool_id] = "needs_auth"
|
||||||
else:
|
else:
|
||||||
statuses[tool_id] = "needs_auth"
|
statuses[tool_id] = "configured"
|
||||||
|
|
||||||
|
if oauth_server_urls:
|
||||||
|
# Look up a session per distinct base URL. MCP sessions
|
||||||
|
# are stored with ``provider = "mcp:<server_url>"``
|
||||||
|
# and the URL in ``server_url``; reuse the repo's
|
||||||
|
# per-URL accessor rather than an ad-hoc $in query.
|
||||||
|
url_has_tokens: dict = {}
|
||||||
|
for base_url in set(oauth_server_urls.values()):
|
||||||
|
session = sessions_repo.get_by_user_and_server_url(
|
||||||
|
user, base_url,
|
||||||
|
)
|
||||||
|
tokens = (
|
||||||
|
(session or {}).get("session_data", {}) or {}
|
||||||
|
).get("tokens", {}) or {}
|
||||||
|
# MCP code also stashes tokens into token_info on
|
||||||
|
# the row; consider either present as "connected".
|
||||||
|
token_info = (session or {}).get("token_info") or {}
|
||||||
|
url_has_tokens[base_url] = bool(
|
||||||
|
tokens.get("access_token")
|
||||||
|
or token_info.get("access_token")
|
||||||
|
)
|
||||||
|
|
||||||
|
for tool_id, base_url in oauth_server_urls.items():
|
||||||
|
if url_has_tokens.get(base_url):
|
||||||
|
statuses[tool_id] = "connected"
|
||||||
|
else:
|
||||||
|
statuses[tool_id] = "needs_auth"
|
||||||
|
|
||||||
return make_response(jsonify({"success": True, "statuses": statuses}), 200)
|
return make_response(jsonify({"success": True, "statuses": statuses}), 200)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -1,21 +1,59 @@
|
|||||||
"""Tool management routes."""
|
"""Tool management routes."""
|
||||||
|
|
||||||
from bson.objectid import ObjectId
|
|
||||||
from flask import current_app, jsonify, make_response, request
|
from flask import current_app, jsonify, make_response, request
|
||||||
from flask_restx import fields, Namespace, Resource
|
from flask_restx import fields, Namespace, Resource
|
||||||
|
|
||||||
from application.agents.tools.spec_parser import parse_spec
|
from application.agents.tools.spec_parser import parse_spec
|
||||||
from application.agents.tools.tool_manager import ToolManager
|
from application.agents.tools.tool_manager import ToolManager
|
||||||
from application.api import api
|
from application.api import api
|
||||||
from application.api.user.base import user_tools_collection
|
|
||||||
from application.core.url_validation import SSRFError, validate_url
|
from application.core.url_validation import SSRFError, validate_url
|
||||||
from application.security.encryption import decrypt_credentials, encrypt_credentials
|
from application.security.encryption import decrypt_credentials, encrypt_credentials
|
||||||
|
from application.storage.db.repositories.notes import NotesRepository
|
||||||
|
from application.storage.db.repositories.todos import TodosRepository
|
||||||
|
from application.storage.db.repositories.user_tools import UserToolsRepository
|
||||||
|
from application.storage.db.session import db_readonly, db_session
|
||||||
from application.utils import check_required_fields, validate_function_name
|
from application.utils import check_required_fields, validate_function_name
|
||||||
|
|
||||||
tool_config = {}
|
tool_config = {}
|
||||||
tool_manager = ToolManager(config=tool_config)
|
tool_manager = ToolManager(config=tool_config)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Shape translation helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# The frontend speaks camelCase (``displayName`` / ``customName`` /
|
||||||
|
# ``configRequirements``). The PG ``user_tools`` table stores snake_case
|
||||||
|
# (``display_name`` / ``custom_name`` / ``config_requirements``). Keep the
|
||||||
|
# translation localized to this module so repositories stay pure.
|
||||||
|
|
||||||
|
_CAMEL_TO_SNAKE = {
|
||||||
|
"displayName": "display_name",
|
||||||
|
"customName": "custom_name",
|
||||||
|
"configRequirements": "config_requirements",
|
||||||
|
}
|
||||||
|
_SNAKE_TO_CAMEL = {v: k for k, v in _CAMEL_TO_SNAKE.items()}
|
||||||
|
|
||||||
|
|
||||||
|
def _row_to_api(row: dict) -> dict:
|
||||||
|
"""Rename DB-native snake_case keys to the camelCase shape the frontend expects."""
|
||||||
|
out = dict(row)
|
||||||
|
for snake, camel in _SNAKE_TO_CAMEL.items():
|
||||||
|
if snake in out:
|
||||||
|
out[camel] = out.pop(snake)
|
||||||
|
# ``user_id`` is exposed as ``user`` in the legacy API shape.
|
||||||
|
if "user_id" in out:
|
||||||
|
out["user"] = out.pop("user_id")
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def _api_to_update_fields(data: dict) -> dict:
|
||||||
|
"""Rename incoming camelCase update keys to the repo's snake_case columns."""
|
||||||
|
fields_out: dict = {}
|
||||||
|
for key, value in data.items():
|
||||||
|
fields_out[_CAMEL_TO_SNAKE.get(key, key)] = value
|
||||||
|
return fields_out
|
||||||
|
|
||||||
|
|
||||||
def _encrypt_secret_fields(config, config_requirements, user_id):
|
def _encrypt_secret_fields(config, config_requirements, user_id):
|
||||||
secret_keys = [
|
secret_keys = [
|
||||||
key for key, spec in config_requirements.items()
|
key for key, spec in config_requirements.items()
|
||||||
@@ -168,12 +206,11 @@ class GetTools(Resource):
|
|||||||
if not decoded_token:
|
if not decoded_token:
|
||||||
return make_response(jsonify({"success": False}), 401)
|
return make_response(jsonify({"success": False}), 401)
|
||||||
user = decoded_token.get("sub")
|
user = decoded_token.get("sub")
|
||||||
tools = user_tools_collection.find({"user": user})
|
with db_readonly() as conn:
|
||||||
|
rows = UserToolsRepository(conn).list_for_user(user)
|
||||||
user_tools = []
|
user_tools = []
|
||||||
for tool in tools:
|
for row in rows:
|
||||||
tool_copy = {**tool}
|
tool_copy = _row_to_api(row)
|
||||||
tool_copy["id"] = str(tool["_id"])
|
|
||||||
tool_copy.pop("_id", None)
|
|
||||||
|
|
||||||
config_req = tool_copy.get("configRequirements", {})
|
config_req = tool_copy.get("configRequirements", {})
|
||||||
if not config_req:
|
if not config_req:
|
||||||
@@ -281,19 +318,19 @@ class CreateTool(Resource):
|
|||||||
storage_config = _encrypt_secret_fields(
|
storage_config = _encrypt_secret_fields(
|
||||||
data["config"], config_requirements, user
|
data["config"], config_requirements, user
|
||||||
)
|
)
|
||||||
new_tool = {
|
with db_session() as conn:
|
||||||
"user": user,
|
created = UserToolsRepository(conn).create(
|
||||||
"name": data["name"],
|
user,
|
||||||
"displayName": data["displayName"],
|
data["name"],
|
||||||
"description": data["description"],
|
config=storage_config,
|
||||||
"customName": data.get("customName", ""),
|
custom_name=data.get("customName", ""),
|
||||||
"actions": transformed_actions,
|
display_name=data["displayName"],
|
||||||
"config": storage_config,
|
description=data["description"],
|
||||||
"configRequirements": config_requirements,
|
config_requirements=config_requirements,
|
||||||
"status": data["status"],
|
actions=transformed_actions,
|
||||||
}
|
status=bool(data.get("status", True)),
|
||||||
resp = user_tools_collection.insert_one(new_tool)
|
)
|
||||||
new_id = str(resp.inserted_id)
|
new_id = str(created["id"])
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
current_app.logger.error(f"Error creating tool: {err}", exc_info=True)
|
current_app.logger.error(f"Error creating tool: {err}", exc_info=True)
|
||||||
return make_response(jsonify({"success": False}), 400)
|
return make_response(jsonify({"success": False}), 400)
|
||||||
@@ -331,17 +368,10 @@ class UpdateTool(Resource):
|
|||||||
if missing_fields:
|
if missing_fields:
|
||||||
return missing_fields
|
return missing_fields
|
||||||
try:
|
try:
|
||||||
update_data = {}
|
update_data: dict = {}
|
||||||
if "name" in data:
|
for key in ("name", "displayName", "customName", "description", "actions"):
|
||||||
update_data["name"] = data["name"]
|
if key in data:
|
||||||
if "displayName" in data:
|
update_data[key] = data[key]
|
||||||
update_data["displayName"] = data["displayName"]
|
|
||||||
if "customName" in data:
|
|
||||||
update_data["customName"] = data["customName"]
|
|
||||||
if "description" in data:
|
|
||||||
update_data["description"] = data["description"]
|
|
||||||
if "actions" in data:
|
|
||||||
update_data["actions"] = data["actions"]
|
|
||||||
if "config" in data:
|
if "config" in data:
|
||||||
if "actions" in data["config"]:
|
if "actions" in data["config"]:
|
||||||
for action_name in list(data["config"]["actions"].keys()):
|
for action_name in list(data["config"]["actions"].keys()):
|
||||||
@@ -356,46 +386,61 @@ class UpdateTool(Resource):
|
|||||||
),
|
),
|
||||||
400,
|
400,
|
||||||
)
|
)
|
||||||
tool_doc = user_tools_collection.find_one(
|
with db_session() as conn:
|
||||||
{"_id": ObjectId(data["id"]), "user": user}
|
repo = UserToolsRepository(conn)
|
||||||
)
|
tool_doc = repo.get_any(data["id"], user)
|
||||||
if not tool_doc:
|
if not tool_doc:
|
||||||
return make_response(
|
|
||||||
jsonify({"success": False, "message": "Tool not found"}),
|
|
||||||
404,
|
|
||||||
)
|
|
||||||
tool_name = tool_doc.get("name", data.get("name"))
|
|
||||||
tool_instance = tool_manager.tools.get(tool_name)
|
|
||||||
config_requirements = (
|
|
||||||
tool_instance.get_config_requirements() if tool_instance else {}
|
|
||||||
)
|
|
||||||
existing_config = tool_doc.get("config", {})
|
|
||||||
has_existing_secrets = "encrypted_credentials" in existing_config
|
|
||||||
|
|
||||||
if config_requirements:
|
|
||||||
validation_errors = _validate_config(
|
|
||||||
data["config"], config_requirements,
|
|
||||||
has_existing_secrets=has_existing_secrets,
|
|
||||||
)
|
|
||||||
if validation_errors:
|
|
||||||
return make_response(
|
return make_response(
|
||||||
jsonify({
|
jsonify({"success": False, "message": "Tool not found"}),
|
||||||
"success": False,
|
404,
|
||||||
"message": "Validation failed",
|
|
||||||
"errors": validation_errors,
|
|
||||||
}),
|
|
||||||
400,
|
|
||||||
)
|
)
|
||||||
|
tool_name = tool_doc.get("name", data.get("name"))
|
||||||
|
tool_instance = tool_manager.tools.get(tool_name)
|
||||||
|
config_requirements = (
|
||||||
|
tool_instance.get_config_requirements()
|
||||||
|
if tool_instance
|
||||||
|
else {}
|
||||||
|
)
|
||||||
|
existing_config = tool_doc.get("config", {}) or {}
|
||||||
|
has_existing_secrets = "encrypted_credentials" in existing_config
|
||||||
|
|
||||||
update_data["config"] = _merge_secrets_on_update(
|
if config_requirements:
|
||||||
data["config"], existing_config, config_requirements, user
|
validation_errors = _validate_config(
|
||||||
)
|
data["config"], config_requirements,
|
||||||
if "status" in data:
|
has_existing_secrets=has_existing_secrets,
|
||||||
update_data["status"] = data["status"]
|
)
|
||||||
user_tools_collection.update_one(
|
if validation_errors:
|
||||||
{"_id": ObjectId(data["id"]), "user": user},
|
return make_response(
|
||||||
{"$set": update_data},
|
jsonify({
|
||||||
)
|
"success": False,
|
||||||
|
"message": "Validation failed",
|
||||||
|
"errors": validation_errors,
|
||||||
|
}),
|
||||||
|
400,
|
||||||
|
)
|
||||||
|
|
||||||
|
update_data["config"] = _merge_secrets_on_update(
|
||||||
|
data["config"], existing_config, config_requirements, user
|
||||||
|
)
|
||||||
|
if "status" in data:
|
||||||
|
update_data["status"] = bool(data["status"])
|
||||||
|
repo.update(
|
||||||
|
str(tool_doc["id"]), user, _api_to_update_fields(update_data),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if "status" in data:
|
||||||
|
update_data["status"] = bool(data["status"])
|
||||||
|
with db_session() as conn:
|
||||||
|
repo = UserToolsRepository(conn)
|
||||||
|
tool_doc = repo.get_any(data["id"], user)
|
||||||
|
if not tool_doc:
|
||||||
|
return make_response(
|
||||||
|
jsonify({"success": False, "message": "Tool not found"}),
|
||||||
|
404,
|
||||||
|
)
|
||||||
|
repo.update(
|
||||||
|
str(tool_doc["id"]), user, _api_to_update_fields(update_data),
|
||||||
|
)
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
current_app.logger.error(f"Error updating tool: {err}", exc_info=True)
|
current_app.logger.error(f"Error updating tool: {err}", exc_info=True)
|
||||||
return make_response(jsonify({"success": False}), 400)
|
return make_response(jsonify({"success": False}), 400)
|
||||||
@@ -427,53 +472,50 @@ class UpdateToolConfig(Resource):
|
|||||||
if missing_fields:
|
if missing_fields:
|
||||||
return missing_fields
|
return missing_fields
|
||||||
try:
|
try:
|
||||||
tool_doc = user_tools_collection.find_one(
|
with db_session() as conn:
|
||||||
{"_id": ObjectId(data["id"]), "user": user}
|
repo = UserToolsRepository(conn)
|
||||||
)
|
tool_doc = repo.get_any(data["id"], user)
|
||||||
if not tool_doc:
|
if not tool_doc:
|
||||||
return make_response(jsonify({"success": False}), 404)
|
return make_response(jsonify({"success": False}), 404)
|
||||||
|
|
||||||
tool_name = tool_doc.get("name")
|
tool_name = tool_doc.get("name")
|
||||||
if tool_name == "mcp_tool":
|
if tool_name == "mcp_tool":
|
||||||
server_url = (data["config"].get("server_url") or "").strip()
|
server_url = (data["config"].get("server_url") or "").strip()
|
||||||
if server_url:
|
if server_url:
|
||||||
try:
|
try:
|
||||||
validate_url(server_url)
|
validate_url(server_url)
|
||||||
except SSRFError:
|
except SSRFError:
|
||||||
|
return make_response(
|
||||||
|
jsonify({"success": False, "message": "Invalid server URL"}),
|
||||||
|
400,
|
||||||
|
)
|
||||||
|
tool_instance = tool_manager.tools.get(tool_name)
|
||||||
|
config_requirements = (
|
||||||
|
tool_instance.get_config_requirements() if tool_instance else {}
|
||||||
|
)
|
||||||
|
existing_config = tool_doc.get("config", {}) or {}
|
||||||
|
has_existing_secrets = "encrypted_credentials" in existing_config
|
||||||
|
|
||||||
|
if config_requirements:
|
||||||
|
validation_errors = _validate_config(
|
||||||
|
data["config"], config_requirements,
|
||||||
|
has_existing_secrets=has_existing_secrets,
|
||||||
|
)
|
||||||
|
if validation_errors:
|
||||||
return make_response(
|
return make_response(
|
||||||
jsonify({"success": False, "message": "Invalid server URL"}),
|
jsonify({
|
||||||
|
"success": False,
|
||||||
|
"message": "Validation failed",
|
||||||
|
"errors": validation_errors,
|
||||||
|
}),
|
||||||
400,
|
400,
|
||||||
)
|
)
|
||||||
tool_instance = tool_manager.tools.get(tool_name)
|
|
||||||
config_requirements = (
|
|
||||||
tool_instance.get_config_requirements() if tool_instance else {}
|
|
||||||
)
|
|
||||||
existing_config = tool_doc.get("config", {})
|
|
||||||
has_existing_secrets = "encrypted_credentials" in existing_config
|
|
||||||
|
|
||||||
if config_requirements:
|
final_config = _merge_secrets_on_update(
|
||||||
validation_errors = _validate_config(
|
data["config"], existing_config, config_requirements, user
|
||||||
data["config"], config_requirements,
|
|
||||||
has_existing_secrets=has_existing_secrets,
|
|
||||||
)
|
)
|
||||||
if validation_errors:
|
|
||||||
return make_response(
|
|
||||||
jsonify({
|
|
||||||
"success": False,
|
|
||||||
"message": "Validation failed",
|
|
||||||
"errors": validation_errors,
|
|
||||||
}),
|
|
||||||
400,
|
|
||||||
)
|
|
||||||
|
|
||||||
final_config = _merge_secrets_on_update(
|
repo.update(str(tool_doc["id"]), user, {"config": final_config})
|
||||||
data["config"], existing_config, config_requirements, user
|
|
||||||
)
|
|
||||||
|
|
||||||
user_tools_collection.update_one(
|
|
||||||
{"_id": ObjectId(data["id"]), "user": user},
|
|
||||||
{"$set": {"config": final_config}},
|
|
||||||
)
|
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
current_app.logger.error(
|
current_app.logger.error(
|
||||||
f"Error updating tool config: {err}", exc_info=True
|
f"Error updating tool config: {err}", exc_info=True
|
||||||
@@ -509,10 +551,17 @@ class UpdateToolActions(Resource):
|
|||||||
if missing_fields:
|
if missing_fields:
|
||||||
return missing_fields
|
return missing_fields
|
||||||
try:
|
try:
|
||||||
user_tools_collection.update_one(
|
with db_session() as conn:
|
||||||
{"_id": ObjectId(data["id"]), "user": user},
|
repo = UserToolsRepository(conn)
|
||||||
{"$set": {"actions": data["actions"]}},
|
tool_doc = repo.get_any(data["id"], user)
|
||||||
)
|
if not tool_doc:
|
||||||
|
return make_response(
|
||||||
|
jsonify({"success": False, "message": "Tool not found"}),
|
||||||
|
404,
|
||||||
|
)
|
||||||
|
repo.update(
|
||||||
|
str(tool_doc["id"]), user, {"actions": data["actions"]},
|
||||||
|
)
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
current_app.logger.error(
|
current_app.logger.error(
|
||||||
f"Error updating tool actions: {err}", exc_info=True
|
f"Error updating tool actions: {err}", exc_info=True
|
||||||
@@ -546,10 +595,17 @@ class UpdateToolStatus(Resource):
|
|||||||
if missing_fields:
|
if missing_fields:
|
||||||
return missing_fields
|
return missing_fields
|
||||||
try:
|
try:
|
||||||
user_tools_collection.update_one(
|
with db_session() as conn:
|
||||||
{"_id": ObjectId(data["id"]), "user": user},
|
repo = UserToolsRepository(conn)
|
||||||
{"$set": {"status": data["status"]}},
|
tool_doc = repo.get_any(data["id"], user)
|
||||||
)
|
if not tool_doc:
|
||||||
|
return make_response(
|
||||||
|
jsonify({"success": False, "message": "Tool not found"}),
|
||||||
|
404,
|
||||||
|
)
|
||||||
|
repo.update(
|
||||||
|
str(tool_doc["id"]), user, {"status": bool(data["status"])},
|
||||||
|
)
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
current_app.logger.error(
|
current_app.logger.error(
|
||||||
f"Error updating tool status: {err}", exc_info=True
|
f"Error updating tool status: {err}", exc_info=True
|
||||||
@@ -578,13 +634,14 @@ class DeleteTool(Resource):
|
|||||||
if missing_fields:
|
if missing_fields:
|
||||||
return missing_fields
|
return missing_fields
|
||||||
try:
|
try:
|
||||||
result = user_tools_collection.delete_one(
|
with db_session() as conn:
|
||||||
{"_id": ObjectId(data["id"]), "user": user}
|
repo = UserToolsRepository(conn)
|
||||||
)
|
tool_doc = repo.get_any(data["id"], user)
|
||||||
if result.deleted_count == 0:
|
if not tool_doc:
|
||||||
return make_response(
|
return make_response(
|
||||||
jsonify({"success": False, "message": "Tool not found"}), 404
|
jsonify({"success": False, "message": "Tool not found"}), 404
|
||||||
)
|
)
|
||||||
|
repo.delete(str(tool_doc["id"]), user)
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
current_app.logger.error(f"Error deleting tool: {err}", exc_info=True)
|
current_app.logger.error(f"Error deleting tool: {err}", exc_info=True)
|
||||||
return make_response(jsonify({"success": False}), 400)
|
return make_response(jsonify({"success": False}), 400)
|
||||||
@@ -653,70 +710,88 @@ class GetArtifact(Resource):
|
|||||||
user_id = decoded_token.get("sub")
|
user_id = decoded_token.get("sub")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
obj_id = ObjectId(artifact_id)
|
with db_readonly() as conn:
|
||||||
except Exception:
|
notes_repo = NotesRepository(conn)
|
||||||
return make_response(
|
todos_repo = TodosRepository(conn)
|
||||||
jsonify({"success": False, "message": "Invalid artifact ID"}), 400
|
|
||||||
|
# Artifact IDs may be PG UUIDs (post-cutover) or legacy
|
||||||
|
# Mongo ObjectIds embedded in older conversation history.
|
||||||
|
# Both repos' ``get_any`` handles the id-shape branching
|
||||||
|
# internally so a non-UUID input never reaches
|
||||||
|
# ``CAST(:id AS uuid)`` (which would poison the readonly
|
||||||
|
# transaction and break the fallback below).
|
||||||
|
note_doc = notes_repo.get_any(artifact_id, user_id)
|
||||||
|
|
||||||
|
if note_doc:
|
||||||
|
content = note_doc.get("note", "") or note_doc.get("content", "")
|
||||||
|
line_count = len(content.split("\n")) if content else 0
|
||||||
|
updated = note_doc.get("updated_at")
|
||||||
|
artifact = {
|
||||||
|
"artifact_type": "note",
|
||||||
|
"data": {
|
||||||
|
"content": content,
|
||||||
|
"line_count": line_count,
|
||||||
|
"updated_at": (
|
||||||
|
updated.isoformat()
|
||||||
|
if hasattr(updated, "isoformat")
|
||||||
|
else updated
|
||||||
|
),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return make_response(
|
||||||
|
jsonify({"success": True, "artifact": artifact}), 200
|
||||||
|
)
|
||||||
|
|
||||||
|
todo_doc = todos_repo.get_any(artifact_id, user_id)
|
||||||
|
if todo_doc:
|
||||||
|
tool_id = todo_doc.get("tool_id")
|
||||||
|
all_todos = todos_repo.list_for_tool(user_id, tool_id) if tool_id else []
|
||||||
|
items = []
|
||||||
|
open_count = 0
|
||||||
|
completed_count = 0
|
||||||
|
for t in all_todos:
|
||||||
|
# PG ``todos`` stores a ``completed BOOLEAN`` column;
|
||||||
|
# the legacy Mongo shape used a ``status`` string.
|
||||||
|
# Keep the response shape stable by translating here.
|
||||||
|
status = "completed" if t.get("completed") else "open"
|
||||||
|
if status == "open":
|
||||||
|
open_count += 1
|
||||||
|
else:
|
||||||
|
completed_count += 1
|
||||||
|
created = t.get("created_at")
|
||||||
|
updated = t.get("updated_at")
|
||||||
|
items.append({
|
||||||
|
"todo_id": t.get("todo_id"),
|
||||||
|
"title": t.get("title", ""),
|
||||||
|
"status": status,
|
||||||
|
"created_at": (
|
||||||
|
created.isoformat()
|
||||||
|
if hasattr(created, "isoformat")
|
||||||
|
else created
|
||||||
|
),
|
||||||
|
"updated_at": (
|
||||||
|
updated.isoformat()
|
||||||
|
if hasattr(updated, "isoformat")
|
||||||
|
else updated
|
||||||
|
),
|
||||||
|
})
|
||||||
|
artifact = {
|
||||||
|
"artifact_type": "todo_list",
|
||||||
|
"data": {
|
||||||
|
"items": items,
|
||||||
|
"total_count": len(items),
|
||||||
|
"open_count": open_count,
|
||||||
|
"completed_count": completed_count,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return make_response(
|
||||||
|
jsonify({"success": True, "artifact": artifact}), 200
|
||||||
|
)
|
||||||
|
except Exception as err:
|
||||||
|
current_app.logger.error(
|
||||||
|
f"Error retrieving artifact: {err}", exc_info=True
|
||||||
)
|
)
|
||||||
|
return make_response(jsonify({"success": False}), 400)
|
||||||
from application.core.mongo_db import MongoDB
|
|
||||||
from application.core.settings import settings
|
|
||||||
|
|
||||||
db = MongoDB.get_client()[settings.MONGO_DB_NAME]
|
|
||||||
|
|
||||||
note_doc = db["notes"].find_one({"_id": obj_id, "user_id": user_id})
|
|
||||||
if note_doc:
|
|
||||||
content = note_doc.get("note", "")
|
|
||||||
line_count = len(content.split("\n")) if content else 0
|
|
||||||
artifact = {
|
|
||||||
"artifact_type": "note",
|
|
||||||
"data": {
|
|
||||||
"content": content,
|
|
||||||
"line_count": line_count,
|
|
||||||
"updated_at": (
|
|
||||||
note_doc["updated_at"].isoformat()
|
|
||||||
if note_doc.get("updated_at")
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
return make_response(jsonify({"success": True, "artifact": artifact}), 200)
|
|
||||||
|
|
||||||
todo_doc = db["todos"].find_one({"_id": obj_id, "user_id": user_id})
|
|
||||||
if todo_doc:
|
|
||||||
tool_id = todo_doc.get("tool_id")
|
|
||||||
query = {"user_id": user_id, "tool_id": tool_id}
|
|
||||||
all_todos = list(db["todos"].find(query))
|
|
||||||
items = []
|
|
||||||
open_count = 0
|
|
||||||
completed_count = 0
|
|
||||||
for t in all_todos:
|
|
||||||
status = t.get("status", "open")
|
|
||||||
if status == "open":
|
|
||||||
open_count += 1
|
|
||||||
elif status == "completed":
|
|
||||||
completed_count += 1
|
|
||||||
items.append({
|
|
||||||
"todo_id": t.get("todo_id"),
|
|
||||||
"title": t.get("title", ""),
|
|
||||||
"status": status,
|
|
||||||
"created_at": (
|
|
||||||
t["created_at"].isoformat() if t.get("created_at") else None
|
|
||||||
),
|
|
||||||
"updated_at": (
|
|
||||||
t["updated_at"].isoformat() if t.get("updated_at") else None
|
|
||||||
),
|
|
||||||
})
|
|
||||||
artifact = {
|
|
||||||
"artifact_type": "todo_list",
|
|
||||||
"data": {
|
|
||||||
"items": items,
|
|
||||||
"total_count": len(items),
|
|
||||||
"open_count": open_count,
|
|
||||||
"completed_count": completed_count,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
return make_response(jsonify({"success": True, "artifact": artifact}), 200)
|
|
||||||
|
|
||||||
return make_response(
|
return make_response(
|
||||||
jsonify({"success": False, "message": "Artifact not found"}), 404
|
jsonify({"success": False, "message": "Artifact not found"}), 404
|
||||||
|
|||||||
@@ -1,290 +1,61 @@
|
|||||||
"""Centralized utilities for API routes."""
|
"""Centralized utilities for API routes.
|
||||||
|
|
||||||
|
Post-Mongo-cutover slim: the old Mongo-shaped helpers (``validate_object_id``,
|
||||||
|
``check_resource_ownership``, ``paginated_response``, ``serialize_object_id``,
|
||||||
|
``safe_db_operation``, ``validate_enum``, ``extract_sort_params``) have been
|
||||||
|
removed — they carried ``bson`` / ``pymongo`` imports and had zero callers.
|
||||||
|
"""
|
||||||
|
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
from typing import Callable, Optional
|
||||||
|
|
||||||
from bson.errors import InvalidId
|
|
||||||
from bson.objectid import ObjectId
|
|
||||||
from flask import (
|
from flask import (
|
||||||
Response,
|
Response,
|
||||||
current_app,
|
|
||||||
has_app_context,
|
|
||||||
jsonify,
|
jsonify,
|
||||||
make_response,
|
make_response,
|
||||||
request,
|
request,
|
||||||
)
|
)
|
||||||
from pymongo.collection import Collection
|
|
||||||
|
|
||||||
|
|
||||||
def get_user_id() -> Optional[str]:
|
def get_user_id() -> Optional[str]:
|
||||||
"""
|
"""Extract user ID from decoded JWT token, or None if unauthenticated."""
|
||||||
Extract user ID from decoded JWT token.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
User ID string or None if not authenticated
|
|
||||||
"""
|
|
||||||
decoded_token = getattr(request, "decoded_token", None)
|
decoded_token = getattr(request, "decoded_token", None)
|
||||||
return decoded_token.get("sub") if decoded_token else None
|
return decoded_token.get("sub") if decoded_token else None
|
||||||
|
|
||||||
|
|
||||||
def require_auth(func: Callable) -> Callable:
|
def require_auth(func: Callable) -> Callable:
|
||||||
"""
|
"""Decorator to require authentication. Returns 401 when absent."""
|
||||||
Decorator to require authentication for route handlers.
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
@require_auth
|
|
||||||
def get(self):
|
|
||||||
user_id = get_user_id()
|
|
||||||
...
|
|
||||||
"""
|
|
||||||
|
|
||||||
@wraps(func)
|
@wraps(func)
|
||||||
def wrapper(*args, **kwargs):
|
def wrapper(*args, **kwargs):
|
||||||
user_id = get_user_id()
|
user_id = get_user_id()
|
||||||
if not user_id:
|
if not user_id:
|
||||||
return error_response("Unauthorized", 401)
|
return make_response(jsonify({"success": False, "error": "Unauthorized"}), 401)
|
||||||
return func(*args, **kwargs)
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
def success_response(
|
def success_response(
|
||||||
data: Optional[Dict[str, Any]] = None, status: int = 200
|
data=None, message: Optional[str] = None, status: int = 200
|
||||||
) -> Response:
|
) -> Response:
|
||||||
"""
|
"""Shape a successful JSON response."""
|
||||||
Create a standardized success response.
|
body = {"success": True}
|
||||||
|
if data is not None:
|
||||||
Args:
|
body["data"] = data
|
||||||
data: Optional data dictionary to include in response
|
if message is not None:
|
||||||
status: HTTP status code (default: 200)
|
body["message"] = message
|
||||||
|
return make_response(jsonify(body), status)
|
||||||
Returns:
|
|
||||||
Flask Response object
|
|
||||||
|
|
||||||
Example:
|
|
||||||
return success_response({"users": [...], "total": 10})
|
|
||||||
"""
|
|
||||||
response = {"success": True}
|
|
||||||
if data:
|
|
||||||
response.update(data)
|
|
||||||
return make_response(jsonify(response), status)
|
|
||||||
|
|
||||||
|
|
||||||
def error_response(message: str, status: int = 400, **kwargs) -> Response:
|
def error_response(message: str, status: int = 400, **kwargs) -> Response:
|
||||||
"""
|
"""Shape an error JSON response; any kwargs are merged into the body."""
|
||||||
Create a standardized error response.
|
body = {"success": False, "error": message, **kwargs}
|
||||||
|
return make_response(jsonify(body), status)
|
||||||
Args:
|
|
||||||
message: Error message string
|
|
||||||
status: HTTP status code (default: 400)
|
|
||||||
**kwargs: Additional fields to include in response
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Flask Response object
|
|
||||||
|
|
||||||
Example:
|
|
||||||
return error_response("Resource not found", 404)
|
|
||||||
return error_response("Invalid input", 400, errors=["field1", "field2"])
|
|
||||||
"""
|
|
||||||
response = {"success": False, "message": message}
|
|
||||||
response.update(kwargs)
|
|
||||||
return make_response(jsonify(response), status)
|
|
||||||
|
|
||||||
|
|
||||||
def validate_object_id(
|
def require_fields(required: list) -> Callable:
|
||||||
id_string: str, resource_name: str = "Resource"
|
"""Decorator: return 400 if any listed field is missing/falsy in the JSON body."""
|
||||||
) -> Tuple[Optional[ObjectId], Optional[Response]]:
|
|
||||||
"""
|
|
||||||
Validate and convert string to ObjectId.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
id_string: String to convert
|
|
||||||
resource_name: Name of resource for error message
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple of (ObjectId or None, error_response or None)
|
|
||||||
|
|
||||||
Example:
|
|
||||||
obj_id, error = validate_object_id(workflow_id, "Workflow")
|
|
||||||
if error:
|
|
||||||
return error
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
return ObjectId(id_string), None
|
|
||||||
except (InvalidId, TypeError):
|
|
||||||
return None, error_response(f"Invalid {resource_name} ID format")
|
|
||||||
|
|
||||||
|
|
||||||
def validate_pagination(
|
|
||||||
default_limit: int = 20, max_limit: int = 100
|
|
||||||
) -> Tuple[int, int, Optional[Response]]:
|
|
||||||
"""
|
|
||||||
Extract and validate pagination parameters from request.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
default_limit: Default items per page
|
|
||||||
max_limit: Maximum allowed items per page
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple of (limit, skip, error_response or None)
|
|
||||||
|
|
||||||
Example:
|
|
||||||
limit, skip, error = validate_pagination()
|
|
||||||
if error:
|
|
||||||
return error
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
limit = min(int(request.args.get("limit", default_limit)), max_limit)
|
|
||||||
skip = int(request.args.get("skip", 0))
|
|
||||||
if limit < 1 or skip < 0:
|
|
||||||
return 0, 0, error_response("Invalid pagination parameters")
|
|
||||||
return limit, skip, None
|
|
||||||
except ValueError:
|
|
||||||
return 0, 0, error_response("Invalid pagination parameters")
|
|
||||||
|
|
||||||
|
|
||||||
def check_resource_ownership(
|
|
||||||
collection: Collection,
|
|
||||||
resource_id: ObjectId,
|
|
||||||
user_id: str,
|
|
||||||
resource_name: str = "Resource",
|
|
||||||
) -> Tuple[Optional[Dict], Optional[Response]]:
|
|
||||||
"""
|
|
||||||
Check if resource exists and belongs to user.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
collection: MongoDB collection
|
|
||||||
resource_id: Resource ObjectId
|
|
||||||
user_id: User ID string
|
|
||||||
resource_name: Name of resource for error messages
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple of (resource_dict or None, error_response or None)
|
|
||||||
|
|
||||||
Example:
|
|
||||||
workflow, error = check_resource_ownership(
|
|
||||||
workflows_collection,
|
|
||||||
workflow_id,
|
|
||||||
user_id,
|
|
||||||
"Workflow"
|
|
||||||
)
|
|
||||||
if error:
|
|
||||||
return error
|
|
||||||
"""
|
|
||||||
resource = collection.find_one({"_id": resource_id, "user": user_id})
|
|
||||||
if not resource:
|
|
||||||
return None, error_response(f"{resource_name} not found", 404)
|
|
||||||
return resource, None
|
|
||||||
|
|
||||||
|
|
||||||
def serialize_object_id(
|
|
||||||
obj: Dict[str, Any], id_field: str = "_id", new_field: str = "id"
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
"""
|
|
||||||
Convert ObjectId to string in a dictionary.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
obj: Dictionary containing ObjectId
|
|
||||||
id_field: Field name containing ObjectId
|
|
||||||
new_field: New field name for string ID
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Modified dictionary
|
|
||||||
|
|
||||||
Example:
|
|
||||||
user = serialize_object_id(user_doc)
|
|
||||||
# user["id"] = "507f1f77bcf86cd799439011"
|
|
||||||
"""
|
|
||||||
if id_field in obj:
|
|
||||||
obj[new_field] = str(obj[id_field])
|
|
||||||
if id_field != new_field:
|
|
||||||
obj.pop(id_field, None)
|
|
||||||
return obj
|
|
||||||
|
|
||||||
|
|
||||||
def serialize_list(items: List[Dict], serializer: Callable[[Dict], Dict]) -> List[Dict]:
|
|
||||||
"""
|
|
||||||
Apply serializer function to list of items.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
items: List of dictionaries
|
|
||||||
serializer: Function to apply to each item
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of serialized items
|
|
||||||
|
|
||||||
Example:
|
|
||||||
workflows = serialize_list(workflow_docs, serialize_workflow)
|
|
||||||
"""
|
|
||||||
return [serializer(item) for item in items]
|
|
||||||
|
|
||||||
|
|
||||||
def paginated_response(
|
|
||||||
collection: Collection,
|
|
||||||
query: Dict[str, Any],
|
|
||||||
serializer: Callable[[Dict], Dict],
|
|
||||||
limit: int,
|
|
||||||
skip: int,
|
|
||||||
sort_field: str = "created_at",
|
|
||||||
sort_order: int = -1,
|
|
||||||
response_key: str = "items",
|
|
||||||
) -> Response:
|
|
||||||
"""
|
|
||||||
Create paginated response for collection query.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
collection: MongoDB collection
|
|
||||||
query: Query dictionary
|
|
||||||
serializer: Function to serialize each item
|
|
||||||
limit: Items per page
|
|
||||||
skip: Number of items to skip
|
|
||||||
sort_field: Field to sort by
|
|
||||||
sort_order: Sort order (1=asc, -1=desc)
|
|
||||||
response_key: Key name for items in response
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Flask Response with paginated data
|
|
||||||
|
|
||||||
Example:
|
|
||||||
return paginated_response(
|
|
||||||
workflows_collection,
|
|
||||||
{"user": user_id},
|
|
||||||
serialize_workflow,
|
|
||||||
limit, skip,
|
|
||||||
response_key="workflows"
|
|
||||||
)
|
|
||||||
"""
|
|
||||||
items = list(
|
|
||||||
collection.find(query).sort(sort_field, sort_order).skip(skip).limit(limit)
|
|
||||||
)
|
|
||||||
total = collection.count_documents(query)
|
|
||||||
|
|
||||||
return success_response(
|
|
||||||
{
|
|
||||||
response_key: serialize_list(items, serializer),
|
|
||||||
"total": total,
|
|
||||||
"limit": limit,
|
|
||||||
"skip": skip,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def require_fields(required: List[str]) -> Callable:
|
|
||||||
"""
|
|
||||||
Decorator to validate required fields in request JSON.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
required: List of required field names
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Decorator function
|
|
||||||
|
|
||||||
Example:
|
|
||||||
@require_fields(["name", "description"])
|
|
||||||
def post(self):
|
|
||||||
data = request.get_json()
|
|
||||||
...
|
|
||||||
"""
|
|
||||||
|
|
||||||
def decorator(func: Callable) -> Callable:
|
def decorator(func: Callable) -> Callable:
|
||||||
@wraps(func)
|
@wraps(func)
|
||||||
@@ -294,94 +65,11 @@ def require_fields(required: List[str]) -> Callable:
|
|||||||
return error_response("Request body required")
|
return error_response("Request body required")
|
||||||
missing = [field for field in required if not data.get(field)]
|
missing = [field for field in required if not data.get(field)]
|
||||||
if missing:
|
if missing:
|
||||||
return error_response(f"Missing required fields: {', '.join(missing)}")
|
return error_response(
|
||||||
|
f"Missing required fields: {', '.join(missing)}"
|
||||||
|
)
|
||||||
return func(*args, **kwargs)
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
def safe_db_operation(
|
|
||||||
operation: Callable, error_message: str = "Database operation failed"
|
|
||||||
) -> Tuple[Any, Optional[Response]]:
|
|
||||||
"""
|
|
||||||
Safely execute database operation with error handling.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
operation: Function to execute
|
|
||||||
error_message: Error message if operation fails
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple of (result or None, error_response or None)
|
|
||||||
|
|
||||||
Example:
|
|
||||||
result, error = safe_db_operation(
|
|
||||||
lambda: collection.insert_one(doc),
|
|
||||||
"Failed to create resource"
|
|
||||||
)
|
|
||||||
if error:
|
|
||||||
return error
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
result = operation()
|
|
||||||
return result, None
|
|
||||||
except Exception as err:
|
|
||||||
if has_app_context():
|
|
||||||
current_app.logger.error(f"{error_message}: {err}", exc_info=True)
|
|
||||||
return None, error_response(error_message)
|
|
||||||
|
|
||||||
|
|
||||||
def validate_enum(
|
|
||||||
value: Any, allowed: List[Any], field_name: str
|
|
||||||
) -> Optional[Response]:
|
|
||||||
"""
|
|
||||||
Validate that value is in allowed list.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
value: Value to validate
|
|
||||||
allowed: List of allowed values
|
|
||||||
field_name: Field name for error message
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
error_response if invalid, None if valid
|
|
||||||
|
|
||||||
Example:
|
|
||||||
error = validate_enum(status, ["draft", "published"], "status")
|
|
||||||
if error:
|
|
||||||
return error
|
|
||||||
"""
|
|
||||||
if value not in allowed:
|
|
||||||
allowed_str = ", ".join(f"'{v}'" for v in allowed)
|
|
||||||
return error_response(f"Invalid {field_name}. Must be one of: {allowed_str}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def extract_sort_params(
|
|
||||||
default_field: str = "created_at",
|
|
||||||
default_order: str = "desc",
|
|
||||||
allowed_fields: Optional[List[str]] = None,
|
|
||||||
) -> Tuple[str, int]:
|
|
||||||
"""
|
|
||||||
Extract and validate sort parameters from request.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
default_field: Default sort field
|
|
||||||
default_order: Default sort order ("asc" or "desc")
|
|
||||||
allowed_fields: List of allowed sort fields (None = no validation)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple of (sort_field, sort_order)
|
|
||||||
|
|
||||||
Example:
|
|
||||||
sort_field, sort_order = extract_sort_params(
|
|
||||||
allowed_fields=["name", "date", "status"]
|
|
||||||
)
|
|
||||||
"""
|
|
||||||
sort_field = request.args.get("sort", default_field)
|
|
||||||
sort_order_str = request.args.get("order", default_order).lower()
|
|
||||||
|
|
||||||
if allowed_fields and sort_field not in allowed_fields:
|
|
||||||
sort_field = default_field
|
|
||||||
sort_order = -1 if sort_order_str == "desc" else 1
|
|
||||||
return sort_field, sort_order
|
|
||||||
|
|||||||
@@ -1,30 +1,26 @@
|
|||||||
"""Workflow management routes."""
|
"""Workflow management routes."""
|
||||||
|
|
||||||
from datetime import datetime, timezone
|
|
||||||
from typing import Any, Dict, List, Optional, Set
|
from typing import Any, Dict, List, Optional, Set
|
||||||
|
|
||||||
from flask import current_app, request
|
from flask import current_app, request
|
||||||
from flask_restx import Namespace, Resource
|
from flask_restx import Namespace, Resource
|
||||||
|
|
||||||
from application.api.user.base import (
|
from application.storage.db.base_repository import looks_like_uuid
|
||||||
workflow_edges_collection,
|
from application.storage.db.repositories.workflow_edges import WorkflowEdgesRepository
|
||||||
workflow_nodes_collection,
|
from application.storage.db.repositories.workflow_nodes import WorkflowNodesRepository
|
||||||
workflows_collection,
|
from application.storage.db.repositories.workflows import WorkflowsRepository
|
||||||
)
|
from application.storage.db.session import db_readonly, db_session
|
||||||
from application.core.json_schema_utils import (
|
from application.core.json_schema_utils import (
|
||||||
JsonSchemaValidationError,
|
JsonSchemaValidationError,
|
||||||
normalize_json_schema_payload,
|
normalize_json_schema_payload,
|
||||||
)
|
)
|
||||||
from application.core.model_utils import get_model_capabilities
|
from application.core.model_utils import get_model_capabilities
|
||||||
from application.api.user.utils import (
|
from application.api.user.utils import (
|
||||||
check_resource_ownership,
|
|
||||||
error_response,
|
error_response,
|
||||||
get_user_id,
|
get_user_id,
|
||||||
require_auth,
|
require_auth,
|
||||||
require_fields,
|
require_fields,
|
||||||
safe_db_operation,
|
|
||||||
success_response,
|
success_response,
|
||||||
validate_object_id,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
workflows_ns = Namespace("workflows", path="/api")
|
workflows_ns = Namespace("workflows", path="/api")
|
||||||
@@ -35,33 +31,112 @@ def _workflow_error_response(message: str, err: Exception):
|
|||||||
return error_response(message)
|
return error_response(message)
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_workflow(repo: WorkflowsRepository, workflow_id: str, user_id: str):
|
||||||
|
"""Resolve a workflow by UUID or legacy Mongo id, scoped to user."""
|
||||||
|
if not workflow_id:
|
||||||
|
return None
|
||||||
|
if looks_like_uuid(workflow_id):
|
||||||
|
row = repo.get(workflow_id, user_id)
|
||||||
|
if row is not None:
|
||||||
|
return row
|
||||||
|
return repo.get_by_legacy_id(workflow_id, user_id)
|
||||||
|
|
||||||
|
|
||||||
|
def _write_graph(
|
||||||
|
conn,
|
||||||
|
pg_workflow_id: str,
|
||||||
|
graph_version: int,
|
||||||
|
nodes_data: List[Dict],
|
||||||
|
edges_data: List[Dict],
|
||||||
|
) -> List[Dict]:
|
||||||
|
"""Bulk-create nodes + edges for one graph version. Uses ON CONFLICT upsert.
|
||||||
|
|
||||||
|
Edges arrive with source/target as user-provided node-id strings. We
|
||||||
|
insert nodes first, capture their ``node_id → UUID`` map, then
|
||||||
|
translate edges before insertion. Edges referencing missing nodes are
|
||||||
|
dropped with a warning.
|
||||||
|
"""
|
||||||
|
nodes_repo = WorkflowNodesRepository(conn)
|
||||||
|
edges_repo = WorkflowEdgesRepository(conn)
|
||||||
|
|
||||||
|
if nodes_data:
|
||||||
|
created_nodes = nodes_repo.bulk_create(
|
||||||
|
pg_workflow_id, graph_version,
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"node_id": n["id"],
|
||||||
|
"node_type": n["type"],
|
||||||
|
"title": n.get("title", ""),
|
||||||
|
"description": n.get("description", ""),
|
||||||
|
"position": n.get("position", {"x": 0, "y": 0}),
|
||||||
|
"config": n.get("data", {}),
|
||||||
|
}
|
||||||
|
for n in nodes_data
|
||||||
|
],
|
||||||
|
)
|
||||||
|
node_uuid_by_str = {n["node_id"]: n["id"] for n in created_nodes}
|
||||||
|
else:
|
||||||
|
created_nodes = []
|
||||||
|
node_uuid_by_str = {}
|
||||||
|
|
||||||
|
if edges_data:
|
||||||
|
translated_edges: List[Dict] = []
|
||||||
|
for e in edges_data:
|
||||||
|
src = e.get("source")
|
||||||
|
tgt = e.get("target")
|
||||||
|
from_uuid = node_uuid_by_str.get(src)
|
||||||
|
to_uuid = node_uuid_by_str.get(tgt)
|
||||||
|
if not from_uuid or not to_uuid:
|
||||||
|
current_app.logger.warning(
|
||||||
|
"Workflow graph write: dropping edge %s; node refs unresolved "
|
||||||
|
"(source=%s, target=%s)",
|
||||||
|
e.get("id"), src, tgt,
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
translated_edges.append({
|
||||||
|
"edge_id": e["id"],
|
||||||
|
"from_node_id": from_uuid,
|
||||||
|
"to_node_id": to_uuid,
|
||||||
|
"source_handle": e.get("sourceHandle"),
|
||||||
|
"target_handle": e.get("targetHandle"),
|
||||||
|
})
|
||||||
|
if translated_edges:
|
||||||
|
edges_repo.bulk_create(
|
||||||
|
pg_workflow_id, graph_version, translated_edges,
|
||||||
|
)
|
||||||
|
|
||||||
|
return created_nodes
|
||||||
|
|
||||||
|
|
||||||
def serialize_workflow(w: Dict) -> Dict:
|
def serialize_workflow(w: Dict) -> Dict:
|
||||||
"""Serialize workflow document to API response format."""
|
"""Serialize workflow row to API response format."""
|
||||||
|
created_at = w.get("created_at")
|
||||||
|
updated_at = w.get("updated_at")
|
||||||
return {
|
return {
|
||||||
"id": str(w["_id"]),
|
"id": str(w["id"]),
|
||||||
"name": w.get("name"),
|
"name": w.get("name"),
|
||||||
"description": w.get("description"),
|
"description": w.get("description"),
|
||||||
"created_at": w["created_at"].isoformat() if w.get("created_at") else None,
|
"created_at": created_at.isoformat() if hasattr(created_at, "isoformat") else created_at,
|
||||||
"updated_at": w["updated_at"].isoformat() if w.get("updated_at") else None,
|
"updated_at": updated_at.isoformat() if hasattr(updated_at, "isoformat") else updated_at,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def serialize_node(n: Dict) -> Dict:
|
def serialize_node(n: Dict) -> Dict:
|
||||||
"""Serialize workflow node document to API response format."""
|
"""Serialize workflow node row to API response format."""
|
||||||
return {
|
return {
|
||||||
"id": n["id"],
|
"id": n["node_id"],
|
||||||
"type": n["type"],
|
"type": n["node_type"],
|
||||||
"title": n.get("title"),
|
"title": n.get("title"),
|
||||||
"description": n.get("description"),
|
"description": n.get("description"),
|
||||||
"position": n.get("position"),
|
"position": n.get("position"),
|
||||||
"data": n.get("config", {}),
|
"data": n.get("config", {}) or {},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def serialize_edge(e: Dict) -> Dict:
|
def serialize_edge(e: Dict) -> Dict:
|
||||||
"""Serialize workflow edge document to API response format."""
|
"""Serialize workflow edge row to API response format."""
|
||||||
return {
|
return {
|
||||||
"id": e["id"],
|
"id": e["edge_id"],
|
||||||
"source": e.get("source_id"),
|
"source": e.get("source_id"),
|
||||||
"target": e.get("target_id"),
|
"target": e.get("target_id"),
|
||||||
"sourceHandle": e.get("source_handle"),
|
"sourceHandle": e.get("source_handle"),
|
||||||
@@ -70,7 +145,7 @@ def serialize_edge(e: Dict) -> Dict:
|
|||||||
|
|
||||||
|
|
||||||
def get_workflow_graph_version(workflow: Dict) -> int:
|
def get_workflow_graph_version(workflow: Dict) -> int:
|
||||||
"""Get current graph version with legacy fallback."""
|
"""Get current graph version with fallback."""
|
||||||
raw_version = workflow.get("current_graph_version", 1)
|
raw_version = workflow.get("current_graph_version", 1)
|
||||||
try:
|
try:
|
||||||
version = int(raw_version)
|
version = int(raw_version)
|
||||||
@@ -79,22 +154,6 @@ def get_workflow_graph_version(workflow: Dict) -> int:
|
|||||||
return 1
|
return 1
|
||||||
|
|
||||||
|
|
||||||
def fetch_graph_documents(collection, workflow_id: str, graph_version: int) -> List[Dict]:
|
|
||||||
"""Fetch graph docs for active version, with fallback for legacy unversioned data."""
|
|
||||||
docs = list(
|
|
||||||
collection.find({"workflow_id": workflow_id, "graph_version": graph_version})
|
|
||||||
)
|
|
||||||
if docs:
|
|
||||||
return docs
|
|
||||||
if graph_version == 1:
|
|
||||||
return list(
|
|
||||||
collection.find(
|
|
||||||
{"workflow_id": workflow_id, "graph_version": {"$exists": False}}
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return docs
|
|
||||||
|
|
||||||
|
|
||||||
def validate_json_schema_payload(
|
def validate_json_schema_payload(
|
||||||
json_schema: Any,
|
json_schema: Any,
|
||||||
) -> tuple[Optional[Dict[str, Any]], Optional[str]]:
|
) -> tuple[Optional[Dict[str, Any]], Optional[str]]:
|
||||||
@@ -315,49 +374,6 @@ def _can_reach_end(
|
|||||||
return any(_can_reach_end(t, edges, node_map, end_ids, visited) for t in outgoing if t)
|
return any(_can_reach_end(t, edges, node_map, end_ids, visited) for t in outgoing if t)
|
||||||
|
|
||||||
|
|
||||||
def create_workflow_nodes(
|
|
||||||
workflow_id: str, nodes_data: List[Dict], graph_version: int
|
|
||||||
) -> None:
|
|
||||||
"""Insert workflow nodes into database."""
|
|
||||||
if nodes_data:
|
|
||||||
workflow_nodes_collection.insert_many(
|
|
||||||
[
|
|
||||||
{
|
|
||||||
"id": n["id"],
|
|
||||||
"workflow_id": workflow_id,
|
|
||||||
"graph_version": graph_version,
|
|
||||||
"type": n["type"],
|
|
||||||
"title": n.get("title", ""),
|
|
||||||
"description": n.get("description", ""),
|
|
||||||
"position": n.get("position", {"x": 0, "y": 0}),
|
|
||||||
"config": n.get("data", {}),
|
|
||||||
}
|
|
||||||
for n in nodes_data
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def create_workflow_edges(
|
|
||||||
workflow_id: str, edges_data: List[Dict], graph_version: int
|
|
||||||
) -> None:
|
|
||||||
"""Insert workflow edges into database."""
|
|
||||||
if edges_data:
|
|
||||||
workflow_edges_collection.insert_many(
|
|
||||||
[
|
|
||||||
{
|
|
||||||
"id": e["id"],
|
|
||||||
"workflow_id": workflow_id,
|
|
||||||
"graph_version": graph_version,
|
|
||||||
"source_id": e.get("source"),
|
|
||||||
"target_id": e.get("target"),
|
|
||||||
"source_handle": e.get("sourceHandle"),
|
|
||||||
"target_handle": e.get("targetHandle"),
|
|
||||||
}
|
|
||||||
for e in edges_data
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@workflows_ns.route("/workflows")
|
@workflows_ns.route("/workflows")
|
||||||
class WorkflowList(Resource):
|
class WorkflowList(Resource):
|
||||||
|
|
||||||
@@ -369,6 +385,7 @@ class WorkflowList(Resource):
|
|||||||
data = request.get_json()
|
data = request.get_json()
|
||||||
|
|
||||||
name = data.get("name", "").strip()
|
name = data.get("name", "").strip()
|
||||||
|
description = data.get("description", "")
|
||||||
nodes_data = data.get("nodes", [])
|
nodes_data = data.get("nodes", [])
|
||||||
edges_data = data.get("edges", [])
|
edges_data = data.get("edges", [])
|
||||||
|
|
||||||
@@ -379,35 +396,16 @@ class WorkflowList(Resource):
|
|||||||
)
|
)
|
||||||
nodes_data = normalize_agent_node_json_schemas(nodes_data)
|
nodes_data = normalize_agent_node_json_schemas(nodes_data)
|
||||||
|
|
||||||
now = datetime.now(timezone.utc)
|
|
||||||
workflow_doc = {
|
|
||||||
"name": name,
|
|
||||||
"description": data.get("description", ""),
|
|
||||||
"user": user_id,
|
|
||||||
"created_at": now,
|
|
||||||
"updated_at": now,
|
|
||||||
"current_graph_version": 1,
|
|
||||||
}
|
|
||||||
|
|
||||||
result, error = safe_db_operation(
|
|
||||||
lambda: workflows_collection.insert_one(workflow_doc),
|
|
||||||
"Failed to create workflow",
|
|
||||||
)
|
|
||||||
if error:
|
|
||||||
return error
|
|
||||||
|
|
||||||
workflow_id = str(result.inserted_id)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
create_workflow_nodes(workflow_id, nodes_data, 1)
|
with db_session() as conn:
|
||||||
create_workflow_edges(workflow_id, edges_data, 1)
|
repo = WorkflowsRepository(conn)
|
||||||
|
workflow = repo.create(user_id, name, description=description)
|
||||||
|
pg_workflow_id = str(workflow["id"])
|
||||||
|
_write_graph(conn, pg_workflow_id, 1, nodes_data, edges_data)
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
workflow_nodes_collection.delete_many({"workflow_id": workflow_id})
|
return _workflow_error_response("Failed to create workflow", err)
|
||||||
workflow_edges_collection.delete_many({"workflow_id": workflow_id})
|
|
||||||
workflows_collection.delete_one({"_id": result.inserted_id})
|
|
||||||
return _workflow_error_response("Failed to create workflow structure", err)
|
|
||||||
|
|
||||||
return success_response({"id": workflow_id}, 201)
|
return success_response({"id": pg_workflow_id}, 201)
|
||||||
|
|
||||||
|
|
||||||
@workflows_ns.route("/workflows/<string:workflow_id>")
|
@workflows_ns.route("/workflows/<string:workflow_id>")
|
||||||
@@ -417,23 +415,22 @@ class WorkflowDetail(Resource):
|
|||||||
def get(self, workflow_id: str):
|
def get(self, workflow_id: str):
|
||||||
"""Get workflow details with nodes and edges."""
|
"""Get workflow details with nodes and edges."""
|
||||||
user_id = get_user_id()
|
user_id = get_user_id()
|
||||||
obj_id, error = validate_object_id(workflow_id, "Workflow")
|
try:
|
||||||
if error:
|
with db_readonly() as conn:
|
||||||
return error
|
repo = WorkflowsRepository(conn)
|
||||||
|
workflow = _resolve_workflow(repo, workflow_id, user_id)
|
||||||
workflow, error = check_resource_ownership(
|
if workflow is None:
|
||||||
workflows_collection, obj_id, user_id, "Workflow"
|
return error_response("Workflow not found", 404)
|
||||||
)
|
pg_workflow_id = str(workflow["id"])
|
||||||
if error:
|
graph_version = get_workflow_graph_version(workflow)
|
||||||
return error
|
nodes = WorkflowNodesRepository(conn).find_by_version(
|
||||||
|
pg_workflow_id, graph_version,
|
||||||
graph_version = get_workflow_graph_version(workflow)
|
)
|
||||||
nodes = fetch_graph_documents(
|
edges = WorkflowEdgesRepository(conn).find_by_version(
|
||||||
workflow_nodes_collection, workflow_id, graph_version
|
pg_workflow_id, graph_version,
|
||||||
)
|
)
|
||||||
edges = fetch_graph_documents(
|
except Exception as err:
|
||||||
workflow_edges_collection, workflow_id, graph_version
|
return _workflow_error_response("Failed to fetch workflow", err)
|
||||||
)
|
|
||||||
|
|
||||||
return success_response(
|
return success_response(
|
||||||
{
|
{
|
||||||
@@ -448,18 +445,9 @@ class WorkflowDetail(Resource):
|
|||||||
def put(self, workflow_id: str):
|
def put(self, workflow_id: str):
|
||||||
"""Update workflow and replace nodes/edges."""
|
"""Update workflow and replace nodes/edges."""
|
||||||
user_id = get_user_id()
|
user_id = get_user_id()
|
||||||
obj_id, error = validate_object_id(workflow_id, "Workflow")
|
|
||||||
if error:
|
|
||||||
return error
|
|
||||||
|
|
||||||
workflow, error = check_resource_ownership(
|
|
||||||
workflows_collection, obj_id, user_id, "Workflow"
|
|
||||||
)
|
|
||||||
if error:
|
|
||||||
return error
|
|
||||||
|
|
||||||
data = request.get_json()
|
data = request.get_json()
|
||||||
name = data.get("name", "").strip()
|
name = data.get("name", "").strip()
|
||||||
|
description = data.get("description", "")
|
||||||
nodes_data = data.get("nodes", [])
|
nodes_data = data.get("nodes", [])
|
||||||
edges_data = data.get("edges", [])
|
edges_data = data.get("edges", [])
|
||||||
|
|
||||||
@@ -470,55 +458,36 @@ class WorkflowDetail(Resource):
|
|||||||
)
|
)
|
||||||
nodes_data = normalize_agent_node_json_schemas(nodes_data)
|
nodes_data = normalize_agent_node_json_schemas(nodes_data)
|
||||||
|
|
||||||
current_graph_version = get_workflow_graph_version(workflow)
|
|
||||||
next_graph_version = current_graph_version + 1
|
|
||||||
try:
|
try:
|
||||||
create_workflow_nodes(workflow_id, nodes_data, next_graph_version)
|
with db_session() as conn:
|
||||||
create_workflow_edges(workflow_id, edges_data, next_graph_version)
|
repo = WorkflowsRepository(conn)
|
||||||
except Exception as err:
|
workflow = _resolve_workflow(repo, workflow_id, user_id)
|
||||||
workflow_nodes_collection.delete_many(
|
if workflow is None:
|
||||||
{"workflow_id": workflow_id, "graph_version": next_graph_version}
|
return error_response("Workflow not found", 404)
|
||||||
)
|
pg_workflow_id = str(workflow["id"])
|
||||||
workflow_edges_collection.delete_many(
|
current_graph_version = get_workflow_graph_version(workflow)
|
||||||
{"workflow_id": workflow_id, "graph_version": next_graph_version}
|
next_graph_version = current_graph_version + 1
|
||||||
)
|
|
||||||
return _workflow_error_response("Failed to update workflow structure", err)
|
|
||||||
|
|
||||||
now = datetime.now(timezone.utc)
|
_write_graph(
|
||||||
_, error = safe_db_operation(
|
conn, pg_workflow_id, next_graph_version,
|
||||||
lambda: workflows_collection.update_one(
|
nodes_data, edges_data,
|
||||||
{"_id": obj_id},
|
)
|
||||||
{
|
repo.update(
|
||||||
"$set": {
|
pg_workflow_id, user_id,
|
||||||
|
{
|
||||||
"name": name,
|
"name": name,
|
||||||
"description": data.get("description", ""),
|
"description": description,
|
||||||
"updated_at": now,
|
|
||||||
"current_graph_version": next_graph_version,
|
"current_graph_version": next_graph_version,
|
||||||
}
|
},
|
||||||
},
|
)
|
||||||
),
|
WorkflowNodesRepository(conn).delete_other_versions(
|
||||||
"Failed to update workflow",
|
pg_workflow_id, next_graph_version,
|
||||||
)
|
)
|
||||||
if error:
|
WorkflowEdgesRepository(conn).delete_other_versions(
|
||||||
workflow_nodes_collection.delete_many(
|
pg_workflow_id, next_graph_version,
|
||||||
{"workflow_id": workflow_id, "graph_version": next_graph_version}
|
)
|
||||||
)
|
except Exception as err:
|
||||||
workflow_edges_collection.delete_many(
|
return _workflow_error_response("Failed to update workflow", err)
|
||||||
{"workflow_id": workflow_id, "graph_version": next_graph_version}
|
|
||||||
)
|
|
||||||
return error
|
|
||||||
|
|
||||||
try:
|
|
||||||
workflow_nodes_collection.delete_many(
|
|
||||||
{"workflow_id": workflow_id, "graph_version": {"$ne": next_graph_version}}
|
|
||||||
)
|
|
||||||
workflow_edges_collection.delete_many(
|
|
||||||
{"workflow_id": workflow_id, "graph_version": {"$ne": next_graph_version}}
|
|
||||||
)
|
|
||||||
except Exception as cleanup_err:
|
|
||||||
current_app.logger.warning(
|
|
||||||
f"Failed to clean old workflow graph versions for {workflow_id}: {cleanup_err}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return success_response()
|
return success_response()
|
||||||
|
|
||||||
@@ -526,20 +495,14 @@ class WorkflowDetail(Resource):
|
|||||||
def delete(self, workflow_id: str):
|
def delete(self, workflow_id: str):
|
||||||
"""Delete workflow and its graph."""
|
"""Delete workflow and its graph."""
|
||||||
user_id = get_user_id()
|
user_id = get_user_id()
|
||||||
obj_id, error = validate_object_id(workflow_id, "Workflow")
|
|
||||||
if error:
|
|
||||||
return error
|
|
||||||
|
|
||||||
workflow, error = check_resource_ownership(
|
|
||||||
workflows_collection, obj_id, user_id, "Workflow"
|
|
||||||
)
|
|
||||||
if error:
|
|
||||||
return error
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
workflow_nodes_collection.delete_many({"workflow_id": workflow_id})
|
with db_session() as conn:
|
||||||
workflow_edges_collection.delete_many({"workflow_id": workflow_id})
|
repo = WorkflowsRepository(conn)
|
||||||
workflows_collection.delete_one({"_id": workflow["_id"], "user": user_id})
|
workflow = _resolve_workflow(repo, workflow_id, user_id)
|
||||||
|
if workflow is None:
|
||||||
|
return error_response("Workflow not found", 404)
|
||||||
|
# ON DELETE CASCADE on workflow_nodes/edges cleans children.
|
||||||
|
repo.delete(str(workflow["id"]), user_id)
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
return _workflow_error_response("Failed to delete workflow", err)
|
return _workflow_error_response("Failed to delete workflow", err)
|
||||||
|
|
||||||
|
|||||||
@@ -20,8 +20,8 @@ from application.api.v1.translator import (
|
|||||||
translate_response,
|
translate_response,
|
||||||
translate_stream_event,
|
translate_stream_event,
|
||||||
)
|
)
|
||||||
from application.core.mongo_db import MongoDB
|
from application.storage.db.repositories.agents import AgentsRepository
|
||||||
from application.core.settings import settings
|
from application.storage.db.session import db_readonly
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -39,9 +39,8 @@ def _extract_bearer_token() -> Optional[str]:
|
|||||||
def _lookup_agent(api_key: str) -> Optional[Dict]:
|
def _lookup_agent(api_key: str) -> Optional[Dict]:
|
||||||
"""Look up the agent document for this API key."""
|
"""Look up the agent document for this API key."""
|
||||||
try:
|
try:
|
||||||
mongo = MongoDB.get_client()
|
with db_readonly() as conn:
|
||||||
db = mongo[settings.MONGO_DB_NAME]
|
return AgentsRepository(conn).find_by_key(api_key)
|
||||||
return db["agents"].find_one({"key": api_key})
|
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.warning("Failed to look up agent for API key", exc_info=True)
|
logger.warning("Failed to look up agent for API key", exc_info=True)
|
||||||
return None
|
return None
|
||||||
@@ -90,8 +89,14 @@ def chat_completions():
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Link decoded_token to the agent's owner so continuation state,
|
# Link decoded_token to the agent's owner so continuation state,
|
||||||
# logs, and tool execution use the correct user identity.
|
# logs, and tool execution use the correct user identity. The PG
|
||||||
agent_user = agent_doc.get("user") if agent_doc else None
|
# ``agents`` row exposes the owner via ``user_id`` (``user`` is the
|
||||||
|
# legacy Mongo field name kept in ``row_to_dict`` only for the
|
||||||
|
# mapping ``id``/``_id``).
|
||||||
|
agent_user = (
|
||||||
|
(agent_doc.get("user_id") or agent_doc.get("user"))
|
||||||
|
if agent_doc else None
|
||||||
|
)
|
||||||
decoded_token = {"sub": agent_user or "api_key_user"}
|
decoded_token = {"sub": agent_user or "api_key_user"}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -138,10 +143,18 @@ def chat_completions():
|
|||||||
if usage_error:
|
if usage_error:
|
||||||
return usage_error
|
return usage_error
|
||||||
|
|
||||||
|
should_save_conversation = bool(internal_data.get("save_conversation", False))
|
||||||
|
|
||||||
if is_stream:
|
if is_stream:
|
||||||
return Response(
|
return Response(
|
||||||
_stream_response(
|
_stream_response(
|
||||||
helper, question, agent, processor, model_name, continuation
|
helper,
|
||||||
|
question,
|
||||||
|
agent,
|
||||||
|
processor,
|
||||||
|
model_name,
|
||||||
|
continuation,
|
||||||
|
should_save_conversation,
|
||||||
),
|
),
|
||||||
mimetype="text/event-stream",
|
mimetype="text/event-stream",
|
||||||
headers={
|
headers={
|
||||||
@@ -151,7 +164,13 @@ def chat_completions():
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return _non_stream_response(
|
return _non_stream_response(
|
||||||
helper, question, agent, processor, model_name, continuation
|
helper,
|
||||||
|
question,
|
||||||
|
agent,
|
||||||
|
processor,
|
||||||
|
model_name,
|
||||||
|
continuation,
|
||||||
|
should_save_conversation,
|
||||||
)
|
)
|
||||||
|
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
@@ -181,6 +200,7 @@ def _stream_response(
|
|||||||
processor: StreamProcessor,
|
processor: StreamProcessor,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
continuation: Optional[Dict],
|
continuation: Optional[Dict],
|
||||||
|
should_save_conversation: bool,
|
||||||
) -> Generator[str, None, None]:
|
) -> Generator[str, None, None]:
|
||||||
"""Generate translated SSE chunks for streaming response."""
|
"""Generate translated SSE chunks for streaming response."""
|
||||||
completion_id = f"chatcmpl-{int(time.time())}"
|
completion_id = f"chatcmpl-{int(time.time())}"
|
||||||
@@ -193,6 +213,7 @@ def _stream_response(
|
|||||||
decoded_token=processor.decoded_token,
|
decoded_token=processor.decoded_token,
|
||||||
agent_id=processor.agent_id,
|
agent_id=processor.agent_id,
|
||||||
model_id=processor.model_id,
|
model_id=processor.model_id,
|
||||||
|
should_save_conversation=should_save_conversation,
|
||||||
_continuation=continuation,
|
_continuation=continuation,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -225,6 +246,7 @@ def _non_stream_response(
|
|||||||
processor: StreamProcessor,
|
processor: StreamProcessor,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
continuation: Optional[Dict],
|
continuation: Optional[Dict],
|
||||||
|
should_save_conversation: bool,
|
||||||
) -> Response:
|
) -> Response:
|
||||||
"""Collect full response and return as single JSON."""
|
"""Collect full response and return as single JSON."""
|
||||||
stream = helper.complete_stream(
|
stream = helper.complete_stream(
|
||||||
@@ -235,6 +257,7 @@ def _non_stream_response(
|
|||||||
decoded_token=processor.decoded_token,
|
decoded_token=processor.decoded_token,
|
||||||
agent_id=processor.agent_id,
|
agent_id=processor.agent_id,
|
||||||
model_id=processor.model_id,
|
model_id=processor.model_id,
|
||||||
|
should_save_conversation=should_save_conversation,
|
||||||
_continuation=continuation,
|
_continuation=continuation,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -272,39 +295,32 @@ def list_models():
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
mongo = MongoDB.get_client()
|
with db_readonly() as conn:
|
||||||
db = mongo[settings.MONGO_DB_NAME]
|
agents_repo = AgentsRepository(conn)
|
||||||
agents_collection = db["agents"]
|
agent = agents_repo.find_by_key(api_key)
|
||||||
|
if not agent:
|
||||||
|
return make_response(
|
||||||
|
jsonify({"error": {"message": "Invalid API key", "type": "auth_error"}}),
|
||||||
|
401,
|
||||||
|
)
|
||||||
|
|
||||||
# Find the agent for this api_key
|
created = agent.get("created_at") or agent.get("createdAt")
|
||||||
agent = agents_collection.find_one({"key": api_key})
|
created_ts = (
|
||||||
if not agent:
|
int(created.timestamp()) if hasattr(created, "timestamp")
|
||||||
return make_response(
|
else int(time.time())
|
||||||
jsonify({"error": {"message": "Invalid API key", "type": "auth_error"}}),
|
)
|
||||||
401,
|
model_id = str(agent.get("id") or agent.get("_id") or "")
|
||||||
)
|
model = {
|
||||||
|
"id": model_id,
|
||||||
user = agent.get("user")
|
"object": "model",
|
||||||
|
"created": created_ts,
|
||||||
# Return all agents belonging to this user
|
"owned_by": "docsgpt",
|
||||||
user_agents = list(agents_collection.find({"user": user}))
|
"name": agent.get("name", ""),
|
||||||
|
"description": agent.get("description", ""),
|
||||||
models = []
|
}
|
||||||
for ag in user_agents:
|
|
||||||
created = ag.get("createdAt")
|
|
||||||
created_ts = int(created.timestamp()) if created else int(time.time())
|
|
||||||
model_id = str(ag.get("_id") or ag.get("id") or "")
|
|
||||||
models.append({
|
|
||||||
"id": model_id,
|
|
||||||
"object": "model",
|
|
||||||
"created": created_ts,
|
|
||||||
"owned_by": "docsgpt",
|
|
||||||
"name": ag.get("name", ""),
|
|
||||||
"description": ag.get("description", ""),
|
|
||||||
})
|
|
||||||
|
|
||||||
return make_response(
|
return make_response(
|
||||||
jsonify({"object": "list", "data": models}),
|
jsonify({"object": "list", "data": [model]}),
|
||||||
200,
|
200,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -80,6 +80,17 @@ def extract_conversation_id(messages: List[Dict]) -> Optional[str]:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def extract_system_prompt(messages: List[Dict]) -> Optional[str]:
|
||||||
|
"""Extract the first system message content from the messages array.
|
||||||
|
|
||||||
|
Returns None if no system message is present.
|
||||||
|
"""
|
||||||
|
for msg in messages:
|
||||||
|
if msg.get("role") == "system":
|
||||||
|
return msg.get("content", "")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def convert_history(messages: List[Dict]) -> List[Dict]:
|
def convert_history(messages: List[Dict]) -> List[Dict]:
|
||||||
"""Convert chat completions messages array to DocsGPT history format.
|
"""Convert chat completions messages array to DocsGPT history format.
|
||||||
|
|
||||||
@@ -148,20 +159,27 @@ def translate_request(
|
|||||||
break
|
break
|
||||||
|
|
||||||
history = convert_history(messages)
|
history = convert_history(messages)
|
||||||
|
system_prompt_override = extract_system_prompt(messages)
|
||||||
|
|
||||||
|
docsgpt = data.get("docsgpt", {})
|
||||||
|
|
||||||
result = {
|
result = {
|
||||||
"question": question,
|
"question": question,
|
||||||
"api_key": api_key,
|
"api_key": api_key,
|
||||||
"history": json.dumps(history),
|
"history": json.dumps(history),
|
||||||
"save_conversation": True,
|
# Conversations are NOT persisted by default on the v1 endpoint.
|
||||||
|
# Callers opt in via ``docsgpt.save_conversation: true``.
|
||||||
|
"save_conversation": bool(docsgpt.get("save_conversation", False)),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if system_prompt_override is not None:
|
||||||
|
result["system_prompt_override"] = system_prompt_override
|
||||||
|
|
||||||
# Client tools
|
# Client tools
|
||||||
if data.get("tools"):
|
if data.get("tools"):
|
||||||
result["client_tools"] = data["tools"]
|
result["client_tools"] = data["tools"]
|
||||||
|
|
||||||
# DocsGPT extensions
|
# DocsGPT extensions
|
||||||
docsgpt = data.get("docsgpt", {})
|
|
||||||
if docsgpt.get("attachments"):
|
if docsgpt.get("attachments"):
|
||||||
result["attachments"] = docsgpt["attachments"]
|
result["attachments"] = docsgpt["attachments"]
|
||||||
|
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
import platform
|
import platform
|
||||||
import uuid
|
import uuid
|
||||||
@@ -20,6 +21,7 @@ from application.api.connector.routes import connector # noqa: E402
|
|||||||
from application.api.v1 import v1_bp # noqa: E402
|
from application.api.v1 import v1_bp # noqa: E402
|
||||||
from application.celery_init import celery # noqa: E402
|
from application.celery_init import celery # noqa: E402
|
||||||
from application.core.settings import settings # noqa: E402
|
from application.core.settings import settings # noqa: E402
|
||||||
|
from application.storage.db.bootstrap import ensure_database_ready # noqa: E402
|
||||||
from application.stt.upload_limits import ( # noqa: E402
|
from application.stt.upload_limits import ( # noqa: E402
|
||||||
build_stt_file_size_limit_message,
|
build_stt_file_size_limit_message,
|
||||||
should_reject_stt_request,
|
should_reject_stt_request,
|
||||||
@@ -32,6 +34,17 @@ if platform.system() == "Windows":
|
|||||||
pathlib.PosixPath = pathlib.WindowsPath
|
pathlib.PosixPath = pathlib.WindowsPath
|
||||||
dotenv.load_dotenv()
|
dotenv.load_dotenv()
|
||||||
|
|
||||||
|
# Self-bootstrap the user-data Postgres DB. Runs before any blueprint or
|
||||||
|
# repository touches the engine, so the first request can't race the
|
||||||
|
# schema being created. Gated by AUTO_CREATE_DB / AUTO_MIGRATE settings
|
||||||
|
# (default ON for dev; disable in prod if schema is managed out-of-band).
|
||||||
|
ensure_database_ready(
|
||||||
|
settings.POSTGRES_URI,
|
||||||
|
create_db=settings.AUTO_CREATE_DB,
|
||||||
|
migrate=settings.AUTO_MIGRATE,
|
||||||
|
logger=logging.getLogger("application.app"),
|
||||||
|
)
|
||||||
|
|
||||||
app = Flask(__name__)
|
app = Flask(__name__)
|
||||||
app.register_blueprint(user)
|
app.register_blueprint(user)
|
||||||
app.register_blueprint(answer)
|
app.register_blueprint(answer)
|
||||||
@@ -120,6 +133,12 @@ def enforce_stt_request_size_limits():
|
|||||||
def authenticate_request():
|
def authenticate_request():
|
||||||
if request.method == "OPTIONS":
|
if request.method == "OPTIONS":
|
||||||
return "", 200
|
return "", 200
|
||||||
|
# OpenAI-compatible routes authenticate via opaque agent API keys in the
|
||||||
|
# Authorization header, which the JWT decoder below would reject. Defer
|
||||||
|
# auth to the route handlers (see application/api/v1/routes.py).
|
||||||
|
if request.path.startswith("/v1/"):
|
||||||
|
request.decoded_token = None
|
||||||
|
return None
|
||||||
decoded_token = handle_auth(request)
|
decoded_token = handle_auth(request)
|
||||||
if not decoded_token:
|
if not decoded_token:
|
||||||
request.decoded_token = None
|
request.decoded_token = None
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
from celery import Celery
|
from celery import Celery
|
||||||
from application.core.settings import settings
|
from application.core.settings import settings
|
||||||
from celery.signals import setup_logging
|
from celery.signals import setup_logging, worker_process_init
|
||||||
|
|
||||||
|
|
||||||
def make_celery(app_name=__name__):
|
def make_celery(app_name=__name__):
|
||||||
@@ -20,5 +20,24 @@ def config_loggers(*args, **kwargs):
|
|||||||
setup_logging()
|
setup_logging()
|
||||||
|
|
||||||
|
|
||||||
|
@worker_process_init.connect
|
||||||
|
def _dispose_db_engine_on_fork(*args, **kwargs):
|
||||||
|
"""Dispose the SQLAlchemy engine pool in each forked Celery worker.
|
||||||
|
|
||||||
|
SQLAlchemy connection pools are not fork-safe: file descriptors shared
|
||||||
|
between the parent and a forked worker will corrupt the pool. Disposing
|
||||||
|
on ``worker_process_init`` gives every worker its own fresh pool on
|
||||||
|
first use.
|
||||||
|
|
||||||
|
Imported lazily so Celery workers that don't touch Postgres (or where
|
||||||
|
``POSTGRES_URI`` is unset) don't fail at startup.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from application.storage.db.engine import dispose_engine
|
||||||
|
except Exception:
|
||||||
|
return
|
||||||
|
dispose_engine()
|
||||||
|
|
||||||
|
|
||||||
celery = make_celery()
|
celery = make_celery()
|
||||||
celery.config_from_object("application.celeryconfig")
|
celery.config_from_object("application.celeryconfig")
|
||||||
|
|||||||
89
application/core/db_uri.py
Normal file
89
application/core/db_uri.py
Normal file
@@ -0,0 +1,89 @@
|
|||||||
|
"""Normalize user-supplied Postgres URIs for different drivers.
|
||||||
|
|
||||||
|
DocsGPT has two Postgres connection strings pointing at potentially
|
||||||
|
different databases:
|
||||||
|
|
||||||
|
* ``POSTGRES_URI`` feeds SQLAlchemy, which needs the
|
||||||
|
``postgresql+psycopg://`` dialect prefix to pick the psycopg v3 driver.
|
||||||
|
* ``PGVECTOR_CONNECTION_STRING`` feeds ``psycopg.connect()`` directly
|
||||||
|
(via libpq) in ``application/vectorstore/pgvector.py``. libpq only
|
||||||
|
understands ``postgres://`` and ``postgresql://`` — the SQLAlchemy
|
||||||
|
dialect prefix is an invalid URI from its point of view.
|
||||||
|
|
||||||
|
The two fields therefore need opposite normalization so operators don't
|
||||||
|
have to know which driver a given field feeds. Each normalizer also
|
||||||
|
silently upgrades the legacy ``postgresql+psycopg2://`` prefix since
|
||||||
|
psycopg2 is no longer in the project.
|
||||||
|
|
||||||
|
This module is deliberately separate from ``application/core/settings.py``
|
||||||
|
so the Settings class stays focused on field declarations, and the
|
||||||
|
URI-rewriting logic can be unit-tested without triggering ``.env``
|
||||||
|
file loading from importing Settings.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
|
||||||
|
def _rewrite_uri_prefixes(v, rewrites):
|
||||||
|
"""Shared URI prefix rewriter used by both normalizers below.
|
||||||
|
|
||||||
|
Strips whitespace, returns ``None`` for empty / ``"none"`` values,
|
||||||
|
applies the first matching rewrite, and passes unrecognised input
|
||||||
|
through so downstream consumers (SQLAlchemy, libpq) can produce
|
||||||
|
their own error messages rather than us silently eating a
|
||||||
|
misconfiguration.
|
||||||
|
"""
|
||||||
|
if v is None:
|
||||||
|
return None
|
||||||
|
if not isinstance(v, str):
|
||||||
|
return v
|
||||||
|
v = v.strip()
|
||||||
|
if not v or v.lower() == "none":
|
||||||
|
return None
|
||||||
|
for prefix, target in rewrites:
|
||||||
|
if v.startswith(prefix):
|
||||||
|
return target + v[len(prefix):]
|
||||||
|
return v
|
||||||
|
|
||||||
|
|
||||||
|
# POSTGRES_URI feeds SQLAlchemy, which needs a ``postgresql+psycopg://``
|
||||||
|
# dialect prefix to select the psycopg v3 driver. Normalize the
|
||||||
|
# operator-friendly forms TOWARD that dialect.
|
||||||
|
_POSTGRES_URI_REWRITES = (
|
||||||
|
("postgresql+psycopg2://", "postgresql+psycopg://"),
|
||||||
|
("postgresql://", "postgresql+psycopg://"),
|
||||||
|
("postgres://", "postgresql+psycopg://"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# PGVECTOR_CONNECTION_STRING feeds ``psycopg.connect()`` directly in
|
||||||
|
# application/vectorstore/pgvector.py — NOT SQLAlchemy. libpq only
|
||||||
|
# understands ``postgres://`` and ``postgresql://``; the SQLAlchemy
|
||||||
|
# dialect prefix is an invalid URI from libpq's point of view. Strip it
|
||||||
|
# if the operator accidentally copied their POSTGRES_URI value here.
|
||||||
|
_PGVECTOR_CONNECTION_STRING_REWRITES = (
|
||||||
|
("postgresql+psycopg2://", "postgresql://"),
|
||||||
|
("postgresql+psycopg://", "postgresql://"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_postgres_uri(v):
|
||||||
|
"""Normalize a user-supplied POSTGRES_URI to the SQLAlchemy psycopg3 form.
|
||||||
|
|
||||||
|
Accepts the forms operators naturally write (``postgres://``,
|
||||||
|
``postgresql://``) and rewrites them to ``postgresql+psycopg://``.
|
||||||
|
Unknown schemes pass through unchanged so SQLAlchemy can produce its
|
||||||
|
own dialect-not-found error.
|
||||||
|
"""
|
||||||
|
return _rewrite_uri_prefixes(v, _POSTGRES_URI_REWRITES)
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_pgvector_connection_string(v):
|
||||||
|
"""Normalize a user-supplied PGVECTOR_CONNECTION_STRING for libpq.
|
||||||
|
|
||||||
|
Strips the SQLAlchemy dialect prefix if the operator accidentally
|
||||||
|
copied their POSTGRES_URI value here — libpq can't parse it.
|
||||||
|
User-friendly forms (``postgres://``, ``postgresql://``) pass
|
||||||
|
through unchanged since libpq accepts them natively.
|
||||||
|
"""
|
||||||
|
return _rewrite_uri_prefixes(v, _PGVECTOR_CONNECTION_STRING_REWRITES)
|
||||||
@@ -1,24 +0,0 @@
|
|||||||
from application.core.settings import settings
|
|
||||||
from pymongo import MongoClient
|
|
||||||
|
|
||||||
|
|
||||||
class MongoDB:
|
|
||||||
_client = None
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_client(cls):
|
|
||||||
"""
|
|
||||||
Get the MongoDB client instance, creating it if necessary.
|
|
||||||
"""
|
|
||||||
if cls._client is None:
|
|
||||||
cls._client = MongoClient(settings.MONGO_URI)
|
|
||||||
return cls._client
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def close_client(cls):
|
|
||||||
"""
|
|
||||||
Close the MongoDB client connection.
|
|
||||||
"""
|
|
||||||
if cls._client is not None:
|
|
||||||
cls._client.close()
|
|
||||||
cls._client = None
|
|
||||||
@@ -8,6 +8,12 @@ from pydantic_settings import BaseSettings, SettingsConfigDict
|
|||||||
current_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
current_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
|
|
||||||
|
|
||||||
|
from application.core.db_uri import ( # noqa: E402
|
||||||
|
normalize_pgvector_connection_string,
|
||||||
|
normalize_postgres_uri,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class Settings(BaseSettings):
|
class Settings(BaseSettings):
|
||||||
model_config = SettingsConfigDict(extra="ignore")
|
model_config = SettingsConfigDict(extra="ignore")
|
||||||
|
|
||||||
@@ -20,8 +26,14 @@ class Settings(BaseSettings):
|
|||||||
|
|
||||||
CELERY_BROKER_URL: str = "redis://localhost:6379/0"
|
CELERY_BROKER_URL: str = "redis://localhost:6379/0"
|
||||||
CELERY_RESULT_BACKEND: str = "redis://localhost:6379/1"
|
CELERY_RESULT_BACKEND: str = "redis://localhost:6379/1"
|
||||||
MONGO_URI: str = "mongodb://localhost:27017/docsgpt"
|
# Only consulted when VECTOR_STORE=mongodb or when running scripts/db/backfill.py; user data lives in Postgres.
|
||||||
MONGO_DB_NAME: str = "docsgpt"
|
MONGO_URI: Optional[str] = None
|
||||||
|
# User-data Postgres DB.
|
||||||
|
POSTGRES_URI: Optional[str] = None
|
||||||
|
# On app startup, apply pending Alembic migrations. Default ON for dev; disable in prod if you manage schema out-of-band.
|
||||||
|
AUTO_MIGRATE: bool = True
|
||||||
|
# On app startup, create the target Postgres database if it's missing (requires CREATEDB privilege). Dev-friendly default.
|
||||||
|
AUTO_CREATE_DB: bool = True
|
||||||
LLM_PATH: str = os.path.join(current_dir, "models/docsgpt-7b-f16.gguf")
|
LLM_PATH: str = os.path.join(current_dir, "models/docsgpt-7b-f16.gguf")
|
||||||
DEFAULT_MAX_HISTORY: int = 150
|
DEFAULT_MAX_HISTORY: int = 150
|
||||||
DEFAULT_LLM_TOKEN_LIMIT: int = 128000 # Fallback when model not found in registry
|
DEFAULT_LLM_TOKEN_LIMIT: int = 128000 # Fallback when model not found in registry
|
||||||
@@ -59,6 +71,10 @@ class Settings(BaseSettings):
|
|||||||
MICROSOFT_TENANT_ID: Optional[str] = "common" # Azure AD Tenant ID (or 'common' for multi-tenant)
|
MICROSOFT_TENANT_ID: Optional[str] = "common" # Azure AD Tenant ID (or 'common' for multi-tenant)
|
||||||
MICROSOFT_AUTHORITY: Optional[str] = None # e.g., "https://login.microsoftonline.com/{tenant_id}"
|
MICROSOFT_AUTHORITY: Optional[str] = None # e.g., "https://login.microsoftonline.com/{tenant_id}"
|
||||||
|
|
||||||
|
# Confluence Cloud integration
|
||||||
|
CONFLUENCE_CLIENT_ID: Optional[str] = None
|
||||||
|
CONFLUENCE_CLIENT_SECRET: Optional[str] = None
|
||||||
|
|
||||||
# GitHub source
|
# GitHub source
|
||||||
GITHUB_ACCESS_TOKEN: Optional[str] = None # PAT token with read repo access
|
GITHUB_ACCESS_TOKEN: Optional[str] = None # PAT token with read repo access
|
||||||
|
|
||||||
@@ -117,7 +133,10 @@ class Settings(BaseSettings):
|
|||||||
QDRANT_PATH: Optional[str] = None
|
QDRANT_PATH: Optional[str] = None
|
||||||
QDRANT_DISTANCE_FUNC: str = "Cosine"
|
QDRANT_DISTANCE_FUNC: str = "Cosine"
|
||||||
|
|
||||||
# PGVector vectorstore config
|
# PGVector vectorstore config. Write the URI in whichever form you
|
||||||
|
# prefer — ``postgres://``, ``postgresql://``, or even the SQLAlchemy
|
||||||
|
# dialect form (``postgresql+psycopg://``) are all accepted and
|
||||||
|
# normalized internally for ``psycopg.connect()``.
|
||||||
PGVECTOR_CONNECTION_STRING: Optional[str] = None
|
PGVECTOR_CONNECTION_STRING: Optional[str] = None
|
||||||
# Milvus vectorstore config
|
# Milvus vectorstore config
|
||||||
MILVUS_COLLECTION_NAME: Optional[str] = "docsgpt"
|
MILVUS_COLLECTION_NAME: Optional[str] = "docsgpt"
|
||||||
@@ -156,6 +175,16 @@ class Settings(BaseSettings):
|
|||||||
COMPRESSION_PROMPT_VERSION: str = "v1.0" # Track prompt iterations
|
COMPRESSION_PROMPT_VERSION: str = "v1.0" # Track prompt iterations
|
||||||
COMPRESSION_MAX_HISTORY_POINTS: int = 3 # Keep only last N compression points to prevent DB bloat
|
COMPRESSION_MAX_HISTORY_POINTS: int = 3 # Keep only last N compression points to prevent DB bloat
|
||||||
|
|
||||||
|
@field_validator("POSTGRES_URI", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def _normalize_postgres_uri_validator(cls, v):
|
||||||
|
return normalize_postgres_uri(v)
|
||||||
|
|
||||||
|
@field_validator("PGVECTOR_CONNECTION_STRING", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def _normalize_pgvector_connection_string_validator(cls, v):
|
||||||
|
return normalize_pgvector_connection_string(v)
|
||||||
|
|
||||||
@field_validator(
|
@field_validator(
|
||||||
"API_KEY",
|
"API_KEY",
|
||||||
"OPENAI_API_KEY",
|
"OPENAI_API_KEY",
|
||||||
|
|||||||
@@ -127,15 +127,33 @@ class GoogleLLM(BaseLLM):
|
|||||||
).uri,
|
).uri,
|
||||||
)
|
)
|
||||||
|
|
||||||
from application.core.mongo_db import MongoDB
|
# Cache the Google file URI on the attachment row so we don't
|
||||||
|
# re-upload on the next LLM call. Accept either a PG UUID
|
||||||
mongo = MongoDB.get_client()
|
# (``id``) or a legacy Mongo ObjectId (``_id``). Opened per
|
||||||
db = mongo[settings.MONGO_DB_NAME]
|
# write — this runs mid-LLM-call, so we don't wrap the
|
||||||
attachments_collection = db["attachments"]
|
# surrounding generator in a long-lived session.
|
||||||
if "_id" in attachment:
|
attachment_id = attachment.get("id") or attachment.get("_id")
|
||||||
attachments_collection.update_one(
|
if attachment_id:
|
||||||
{"_id": attachment["_id"]}, {"$set": {"google_file_uri": file_uri}}
|
user_id = None
|
||||||
|
decoded = getattr(self, "decoded_token", None)
|
||||||
|
if isinstance(decoded, dict):
|
||||||
|
user_id = decoded.get("sub")
|
||||||
|
from application.storage.db.repositories.attachments import (
|
||||||
|
AttachmentsRepository,
|
||||||
)
|
)
|
||||||
|
from application.storage.db.session import db_session
|
||||||
|
|
||||||
|
try:
|
||||||
|
with db_session() as conn:
|
||||||
|
AttachmentsRepository(conn).update_any(
|
||||||
|
str(attachment_id),
|
||||||
|
user_id,
|
||||||
|
{"google_file_uri": file_uri},
|
||||||
|
)
|
||||||
|
except Exception as cache_err:
|
||||||
|
logging.warning(
|
||||||
|
f"Failed to cache google_file_uri on attachment {attachment_id}: {cache_err}"
|
||||||
|
)
|
||||||
return file_uri
|
return file_uri
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Error uploading file to Google AI: {e}", exc_info=True)
|
logging.error(f"Error uploading file to Google AI: {e}", exc_info=True)
|
||||||
|
|||||||
@@ -527,15 +527,34 @@ class OpenAILLM(BaseLLM):
|
|||||||
).id,
|
).id,
|
||||||
)
|
)
|
||||||
|
|
||||||
from application.core.mongo_db import MongoDB
|
# Cache the OpenAI file id on the attachment row so we don't
|
||||||
|
# re-upload the same blob on the next LLM call. Prefer the PG
|
||||||
mongo = MongoDB.get_client()
|
# UUID (``id``) when present; fall back to the legacy Mongo
|
||||||
db = mongo[settings.MONGO_DB_NAME]
|
# ObjectId string (``_id``). Opened per-write — this runs
|
||||||
attachments_collection = db["attachments"]
|
# inside the hot LLM path, so we don't want a long-lived
|
||||||
if "_id" in attachment:
|
# session wrapping the generator.
|
||||||
attachments_collection.update_one(
|
attachment_id = attachment.get("id") or attachment.get("_id")
|
||||||
{"_id": attachment["_id"]}, {"$set": {"openai_file_id": file_id}}
|
if attachment_id:
|
||||||
|
user_id = None
|
||||||
|
decoded = getattr(self, "decoded_token", None)
|
||||||
|
if isinstance(decoded, dict):
|
||||||
|
user_id = decoded.get("sub")
|
||||||
|
from application.storage.db.repositories.attachments import (
|
||||||
|
AttachmentsRepository,
|
||||||
)
|
)
|
||||||
|
from application.storage.db.session import db_session
|
||||||
|
|
||||||
|
try:
|
||||||
|
with db_session() as conn:
|
||||||
|
AttachmentsRepository(conn).update_any(
|
||||||
|
str(attachment_id),
|
||||||
|
user_id,
|
||||||
|
{"openai_file_id": file_id},
|
||||||
|
)
|
||||||
|
except Exception as cache_err:
|
||||||
|
logging.warning(
|
||||||
|
f"Failed to cache openai_file_id on attachment {attachment_id}: {cache_err}"
|
||||||
|
)
|
||||||
return file_id
|
return file_id
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Error uploading file to OpenAI: {e}", exc_info=True)
|
logging.error(f"Error uploading file to OpenAI: {e}", exc_info=True)
|
||||||
|
|||||||
@@ -6,8 +6,8 @@ import logging
|
|||||||
import uuid
|
import uuid
|
||||||
from typing import Any, Callable, Dict, Generator, List
|
from typing import Any, Callable, Dict, Generator, List
|
||||||
|
|
||||||
from application.core.mongo_db import MongoDB
|
from application.storage.db.repositories.stack_logs import StackLogsRepository
|
||||||
from application.core.settings import settings
|
from application.storage.db.session import db_session
|
||||||
|
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
||||||
@@ -101,7 +101,7 @@ def _consume_and_log(generator: Generator, context: "LogContext"):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.exception(f"Error in {context.endpoint} - {context.activity_id}: {e}")
|
logging.exception(f"Error in {context.endpoint} - {context.activity_id}: {e}")
|
||||||
context.stacks.append({"component": "error", "data": {"message": str(e)}})
|
context.stacks.append({"component": "error", "data": {"message": str(e)}})
|
||||||
_log_to_mongodb(
|
_log_activity_to_db(
|
||||||
endpoint=context.endpoint,
|
endpoint=context.endpoint,
|
||||||
activity_id=context.activity_id,
|
activity_id=context.activity_id,
|
||||||
user=context.user,
|
user=context.user,
|
||||||
@@ -112,7 +112,7 @@ def _consume_and_log(generator: Generator, context: "LogContext"):
|
|||||||
)
|
)
|
||||||
raise
|
raise
|
||||||
finally:
|
finally:
|
||||||
_log_to_mongodb(
|
_log_activity_to_db(
|
||||||
endpoint=context.endpoint,
|
endpoint=context.endpoint,
|
||||||
activity_id=context.activity_id,
|
activity_id=context.activity_id,
|
||||||
user=context.user,
|
user=context.user,
|
||||||
@@ -123,7 +123,7 @@ def _consume_and_log(generator: Generator, context: "LogContext"):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _log_to_mongodb(
|
def _log_activity_to_db(
|
||||||
endpoint: str,
|
endpoint: str,
|
||||||
activity_id: str,
|
activity_id: str,
|
||||||
user: str,
|
user: str,
|
||||||
@@ -132,30 +132,26 @@ def _log_to_mongodb(
|
|||||||
stacks: List[Dict],
|
stacks: List[Dict],
|
||||||
level: str,
|
level: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
"""Append a per-request activity log row to Postgres (``stack_logs``)."""
|
||||||
try:
|
try:
|
||||||
mongo = MongoDB.get_client()
|
# Clean up text fields to be no longer than 10000 characters so a
|
||||||
db = mongo[settings.MONGO_DB_NAME]
|
# runaway payload can't blow up the insert.
|
||||||
user_logs_collection = db["stack_logs"]
|
def _truncate(val):
|
||||||
|
if isinstance(val, str) and len(val) > 10000:
|
||||||
|
return val[:10000]
|
||||||
|
return val
|
||||||
log_entry = {
|
|
||||||
"endpoint": endpoint,
|
|
||||||
"id": activity_id,
|
|
||||||
"level": level,
|
|
||||||
"user": user,
|
|
||||||
"api_key": api_key,
|
|
||||||
"query": query,
|
|
||||||
"stacks": stacks,
|
|
||||||
"timestamp": datetime.datetime.now(datetime.timezone.utc),
|
|
||||||
}
|
|
||||||
# clean up text fields to be no longer than 10000 characters
|
|
||||||
for key, value in log_entry.items():
|
|
||||||
if isinstance(value, str) and len(value) > 10000:
|
|
||||||
log_entry[key] = value[:10000]
|
|
||||||
|
|
||||||
user_logs_collection.insert_one(log_entry)
|
|
||||||
logging.debug(f"Logged activity to MongoDB: {activity_id}")
|
|
||||||
|
|
||||||
|
with db_session() as conn:
|
||||||
|
StackLogsRepository(conn).insert(
|
||||||
|
activity_id=activity_id,
|
||||||
|
endpoint=_truncate(endpoint),
|
||||||
|
level=_truncate(level),
|
||||||
|
user_id=_truncate(user),
|
||||||
|
api_key=_truncate(api_key),
|
||||||
|
query=_truncate(query),
|
||||||
|
stacks=stacks,
|
||||||
|
timestamp=datetime.datetime.now(datetime.timezone.utc),
|
||||||
|
)
|
||||||
|
logging.debug(f"Logged activity to Postgres: {activity_id}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Failed to log to MongoDB: {e}", exc_info=True)
|
logging.error(f"Failed to log activity to Postgres: {e}", exc_info=True)
|
||||||
|
|||||||
37
application/parser/connectors/_auth_utils.py
Normal file
37
application/parser/connectors/_auth_utils.py
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
"""Shared helpers for connector auth modules.
|
||||||
|
|
||||||
|
These helpers exist so that sensitive values (session tokens, bearer
|
||||||
|
credentials) never end up interpolated into exception messages or log
|
||||||
|
lines. Exception messages frequently flow into ``stack_logs`` (Postgres)
|
||||||
|
and Sentry via ``exc_info=True``, so the raw value must never be the
|
||||||
|
thing we format.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
|
||||||
|
|
||||||
|
def session_token_fingerprint(session_token: str) -> str:
|
||||||
|
"""Return a short, irreversible fingerprint for a session token.
|
||||||
|
|
||||||
|
The returned string is safe to embed in exception messages and log
|
||||||
|
lines: it is a prefix of a SHA-256 digest, clearly tagged so an
|
||||||
|
operator reading the log knows it is a hash and not the token
|
||||||
|
itself. It is stable for a given input, which lets operators
|
||||||
|
correlate "which token failed" across log lines without exposing
|
||||||
|
the credential.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_token: The raw session token. Accepts ``None`` or the
|
||||||
|
empty string for defensive callers; both yield a distinct
|
||||||
|
sentinel rather than raising.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A string of the form ``"sha256:<6 hex chars>"``, or
|
||||||
|
``"sha256:<empty>"`` when the input is falsy.
|
||||||
|
"""
|
||||||
|
if not session_token:
|
||||||
|
return "sha256:<empty>"
|
||||||
|
digest = hashlib.sha256(session_token.encode("utf-8")).hexdigest()
|
||||||
|
return f"sha256:{digest[:6]}"
|
||||||
4
application/parser/connectors/confluence/__init__.py
Normal file
4
application/parser/connectors/confluence/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
from .auth import ConfluenceAuth
|
||||||
|
from .loader import ConfluenceLoader
|
||||||
|
|
||||||
|
__all__ = ["ConfluenceAuth", "ConfluenceLoader"]
|
||||||
221
application/parser/connectors/confluence/auth.py
Normal file
221
application/parser/connectors/confluence/auth.py
Normal file
@@ -0,0 +1,221 @@
|
|||||||
|
import datetime
|
||||||
|
import logging
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
from urllib.parse import urlencode
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from application.core.settings import settings
|
||||||
|
from application.parser.connectors._auth_utils import session_token_fingerprint
|
||||||
|
from application.parser.connectors.base import BaseConnectorAuth
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ConfluenceAuth(BaseConnectorAuth):
|
||||||
|
|
||||||
|
SCOPES = [
|
||||||
|
"read:page:confluence",
|
||||||
|
"read:space:confluence",
|
||||||
|
"read:attachment:confluence",
|
||||||
|
"read:me",
|
||||||
|
"offline_access",
|
||||||
|
]
|
||||||
|
|
||||||
|
AUTH_URL = "https://auth.atlassian.com/authorize"
|
||||||
|
TOKEN_URL = "https://auth.atlassian.com/oauth/token"
|
||||||
|
RESOURCES_URL = "https://api.atlassian.com/oauth/token/accessible-resources"
|
||||||
|
ME_URL = "https://api.atlassian.com/me"
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.client_id = settings.CONFLUENCE_CLIENT_ID
|
||||||
|
self.client_secret = settings.CONFLUENCE_CLIENT_SECRET
|
||||||
|
self.redirect_uri = settings.CONNECTOR_REDIRECT_BASE_URI
|
||||||
|
|
||||||
|
if not self.client_id or not self.client_secret:
|
||||||
|
raise ValueError(
|
||||||
|
"Confluence OAuth credentials not configured. "
|
||||||
|
"Please set CONFLUENCE_CLIENT_ID and CONFLUENCE_CLIENT_SECRET in settings."
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_authorization_url(self, state: Optional[str] = None) -> str:
|
||||||
|
params = {
|
||||||
|
"audience": "api.atlassian.com",
|
||||||
|
"client_id": self.client_id,
|
||||||
|
"scope": " ".join(self.SCOPES),
|
||||||
|
"redirect_uri": self.redirect_uri,
|
||||||
|
"state": state,
|
||||||
|
"response_type": "code",
|
||||||
|
"prompt": "consent",
|
||||||
|
}
|
||||||
|
return f"{self.AUTH_URL}?{urlencode(params)}"
|
||||||
|
|
||||||
|
def exchange_code_for_tokens(self, authorization_code: str) -> Dict[str, Any]:
|
||||||
|
if not authorization_code:
|
||||||
|
raise ValueError("Authorization code is required")
|
||||||
|
|
||||||
|
response = requests.post(
|
||||||
|
self.TOKEN_URL,
|
||||||
|
json={
|
||||||
|
"grant_type": "authorization_code",
|
||||||
|
"client_id": self.client_id,
|
||||||
|
"client_secret": self.client_secret,
|
||||||
|
"code": authorization_code,
|
||||||
|
"redirect_uri": self.redirect_uri,
|
||||||
|
},
|
||||||
|
headers={"Content-Type": "application/json"},
|
||||||
|
timeout=30,
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
token_data = response.json()
|
||||||
|
|
||||||
|
access_token = token_data.get("access_token")
|
||||||
|
if not access_token:
|
||||||
|
raise ValueError("OAuth flow did not return an access token")
|
||||||
|
|
||||||
|
refresh_token = token_data.get("refresh_token")
|
||||||
|
if not refresh_token:
|
||||||
|
raise ValueError("OAuth flow did not return a refresh token")
|
||||||
|
|
||||||
|
expires_in = token_data.get("expires_in", 3600)
|
||||||
|
expiry = (
|
||||||
|
datetime.datetime.now(datetime.timezone.utc)
|
||||||
|
+ datetime.timedelta(seconds=expires_in)
|
||||||
|
).isoformat()
|
||||||
|
|
||||||
|
cloud_id = self._fetch_cloud_id(access_token)
|
||||||
|
user_info = self._fetch_user_info(access_token)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"access_token": access_token,
|
||||||
|
"refresh_token": refresh_token,
|
||||||
|
"token_uri": self.TOKEN_URL,
|
||||||
|
"scopes": self.SCOPES,
|
||||||
|
"expiry": expiry,
|
||||||
|
"cloud_id": cloud_id,
|
||||||
|
"user_info": {
|
||||||
|
"name": user_info.get("display_name", ""),
|
||||||
|
"email": user_info.get("email", ""),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
def refresh_access_token(self, refresh_token: str) -> Dict[str, Any]:
|
||||||
|
if not refresh_token:
|
||||||
|
raise ValueError("Refresh token is required")
|
||||||
|
|
||||||
|
response = requests.post(
|
||||||
|
self.TOKEN_URL,
|
||||||
|
json={
|
||||||
|
"grant_type": "refresh_token",
|
||||||
|
"client_id": self.client_id,
|
||||||
|
"client_secret": self.client_secret,
|
||||||
|
"refresh_token": refresh_token,
|
||||||
|
},
|
||||||
|
headers={"Content-Type": "application/json"},
|
||||||
|
timeout=30,
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
token_data = response.json()
|
||||||
|
|
||||||
|
access_token = token_data.get("access_token")
|
||||||
|
new_refresh_token = token_data.get("refresh_token", refresh_token)
|
||||||
|
|
||||||
|
expires_in = token_data.get("expires_in", 3600)
|
||||||
|
expiry = (
|
||||||
|
datetime.datetime.now(datetime.timezone.utc)
|
||||||
|
+ datetime.timedelta(seconds=expires_in)
|
||||||
|
).isoformat()
|
||||||
|
|
||||||
|
cloud_id = self._fetch_cloud_id(access_token)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"access_token": access_token,
|
||||||
|
"refresh_token": new_refresh_token,
|
||||||
|
"token_uri": self.TOKEN_URL,
|
||||||
|
"scopes": self.SCOPES,
|
||||||
|
"expiry": expiry,
|
||||||
|
"cloud_id": cloud_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
def is_token_expired(self, token_info: Dict[str, Any]) -> bool:
|
||||||
|
if not token_info:
|
||||||
|
return True
|
||||||
|
|
||||||
|
expiry = token_info.get("expiry")
|
||||||
|
if not expiry:
|
||||||
|
return bool(token_info.get("access_token"))
|
||||||
|
|
||||||
|
try:
|
||||||
|
expiry_dt = datetime.datetime.fromisoformat(expiry)
|
||||||
|
now = datetime.datetime.now(datetime.timezone.utc)
|
||||||
|
return now >= expiry_dt - datetime.timedelta(seconds=60)
|
||||||
|
except Exception:
|
||||||
|
return True
|
||||||
|
|
||||||
|
def get_token_info_from_session(self, session_token: str) -> Dict[str, Any]:
|
||||||
|
from application.storage.db.repositories.connector_sessions import (
|
||||||
|
ConnectorSessionsRepository,
|
||||||
|
)
|
||||||
|
from application.storage.db.session import db_readonly
|
||||||
|
|
||||||
|
with db_readonly() as conn:
|
||||||
|
session = ConnectorSessionsRepository(conn).get_by_session_token(
|
||||||
|
session_token
|
||||||
|
)
|
||||||
|
if not session:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid session token ({session_token_fingerprint(session_token)})"
|
||||||
|
)
|
||||||
|
|
||||||
|
token_info = session.get("token_info")
|
||||||
|
if not token_info:
|
||||||
|
raise ValueError("Session missing token information")
|
||||||
|
|
||||||
|
required = ["access_token", "refresh_token", "cloud_id"]
|
||||||
|
missing = [f for f in required if not token_info.get(f)]
|
||||||
|
if missing:
|
||||||
|
raise ValueError(f"Missing required token fields: {missing}")
|
||||||
|
|
||||||
|
return token_info
|
||||||
|
|
||||||
|
def sanitize_token_info(
|
||||||
|
self, token_info: Dict[str, Any], **extra_fields
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
return super().sanitize_token_info(
|
||||||
|
token_info,
|
||||||
|
cloud_id=token_info.get("cloud_id"),
|
||||||
|
**extra_fields,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _fetch_cloud_id(self, access_token: str) -> str:
|
||||||
|
response = requests.get(
|
||||||
|
self.RESOURCES_URL,
|
||||||
|
headers={
|
||||||
|
"Authorization": f"Bearer {access_token}",
|
||||||
|
"Accept": "application/json",
|
||||||
|
},
|
||||||
|
timeout=30,
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
resources = response.json()
|
||||||
|
|
||||||
|
if not resources:
|
||||||
|
raise ValueError("No accessible Confluence sites found for this account")
|
||||||
|
|
||||||
|
return resources[0]["id"]
|
||||||
|
|
||||||
|
def _fetch_user_info(self, access_token: str) -> Dict[str, Any]:
|
||||||
|
try:
|
||||||
|
response = requests.get(
|
||||||
|
self.ME_URL,
|
||||||
|
headers={
|
||||||
|
"Authorization": f"Bearer {access_token}",
|
||||||
|
"Accept": "application/json",
|
||||||
|
},
|
||||||
|
timeout=30,
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
return response.json()
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("Could not fetch user info: %s", e)
|
||||||
|
return {}
|
||||||
417
application/parser/connectors/confluence/loader.py
Normal file
417
application/parser/connectors/confluence/loader.py
Normal file
@@ -0,0 +1,417 @@
|
|||||||
|
import functools
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from application.parser.connectors.base import BaseConnectorLoader
|
||||||
|
from application.parser.connectors.confluence.auth import ConfluenceAuth
|
||||||
|
from application.parser.schema.base import Document
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
API_V2 = "https://api.atlassian.com/ex/confluence/{cloud_id}/wiki/api/v2"
|
||||||
|
DOWNLOAD_BASE = "https://api.atlassian.com/ex/confluence/{cloud_id}/wiki"
|
||||||
|
|
||||||
|
SUPPORTED_ATTACHMENT_TYPES = {
|
||||||
|
"application/pdf": ".pdf",
|
||||||
|
"application/vnd.openxmlformats-officedocument.wordprocessingml.document": ".docx",
|
||||||
|
"application/vnd.openxmlformats-officedocument.presentationml.presentation": ".pptx",
|
||||||
|
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet": ".xlsx",
|
||||||
|
"application/msword": ".doc",
|
||||||
|
"application/vnd.ms-powerpoint": ".ppt",
|
||||||
|
"application/vnd.ms-excel": ".xls",
|
||||||
|
"text/plain": ".txt",
|
||||||
|
"text/csv": ".csv",
|
||||||
|
"text/html": ".html",
|
||||||
|
"text/markdown": ".md",
|
||||||
|
"application/json": ".json",
|
||||||
|
"application/epub+zip": ".epub",
|
||||||
|
"image/jpeg": ".jpg",
|
||||||
|
"image/png": ".png",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _retry_on_auth_failure(func):
|
||||||
|
@functools.wraps(func)
|
||||||
|
def wrapper(self, *args, **kwargs):
|
||||||
|
try:
|
||||||
|
return func(self, *args, **kwargs)
|
||||||
|
except requests.exceptions.HTTPError as e:
|
||||||
|
if e.response is not None and e.response.status_code in (401, 403):
|
||||||
|
logger.info(
|
||||||
|
"Auth failure in %s, refreshing token and retrying", func.__name__
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
new_token_info = self.auth.refresh_access_token(self.refresh_token)
|
||||||
|
self.access_token = new_token_info["access_token"]
|
||||||
|
self.refresh_token = new_token_info.get(
|
||||||
|
"refresh_token", self.refresh_token
|
||||||
|
)
|
||||||
|
self._persist_refreshed_tokens(new_token_info)
|
||||||
|
except Exception as refresh_err:
|
||||||
|
raise ValueError(
|
||||||
|
f"Authentication failed and could not be refreshed: {refresh_err}"
|
||||||
|
) from e
|
||||||
|
return func(self, *args, **kwargs)
|
||||||
|
raise
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
class ConfluenceLoader(BaseConnectorLoader):
|
||||||
|
|
||||||
|
def __init__(self, session_token: str):
|
||||||
|
self.auth = ConfluenceAuth()
|
||||||
|
self.session_token = session_token
|
||||||
|
|
||||||
|
token_info = self.auth.get_token_info_from_session(session_token)
|
||||||
|
self.access_token = token_info["access_token"]
|
||||||
|
self.refresh_token = token_info["refresh_token"]
|
||||||
|
self.cloud_id = token_info["cloud_id"]
|
||||||
|
|
||||||
|
self.base_url = API_V2.format(cloud_id=self.cloud_id)
|
||||||
|
self.download_base = DOWNLOAD_BASE.format(cloud_id=self.cloud_id)
|
||||||
|
self.next_page_token = None
|
||||||
|
|
||||||
|
def _headers(self) -> Dict[str, str]:
|
||||||
|
return {
|
||||||
|
"Authorization": f"Bearer {self.access_token}",
|
||||||
|
"Accept": "application/json",
|
||||||
|
}
|
||||||
|
|
||||||
|
def _persist_refreshed_tokens(self, token_info: Dict[str, Any]) -> None:
|
||||||
|
try:
|
||||||
|
from application.storage.db.repositories.connector_sessions import (
|
||||||
|
ConnectorSessionsRepository,
|
||||||
|
)
|
||||||
|
from application.storage.db.session import db_session
|
||||||
|
|
||||||
|
sanitized = self.auth.sanitize_token_info(token_info)
|
||||||
|
with db_session() as conn:
|
||||||
|
repo = ConnectorSessionsRepository(conn)
|
||||||
|
session = repo.get_by_session_token(self.session_token)
|
||||||
|
if session:
|
||||||
|
repo.update(str(session["id"]), {"token_info": sanitized})
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("Failed to persist refreshed tokens: %s", e)
|
||||||
|
|
||||||
|
@_retry_on_auth_failure
|
||||||
|
def load_data(self, inputs: Dict[str, Any]) -> List[Document]:
|
||||||
|
folder_id = inputs.get("folder_id")
|
||||||
|
file_ids = inputs.get("file_ids", [])
|
||||||
|
limit = inputs.get("limit", 100)
|
||||||
|
list_only = inputs.get("list_only", False)
|
||||||
|
page_token = inputs.get("page_token")
|
||||||
|
search_query = inputs.get("search_query")
|
||||||
|
self.next_page_token = None
|
||||||
|
|
||||||
|
if file_ids:
|
||||||
|
return self._load_pages_by_ids(file_ids, list_only, search_query)
|
||||||
|
|
||||||
|
if folder_id:
|
||||||
|
return self._list_pages_in_space(
|
||||||
|
folder_id, limit, list_only, page_token, search_query
|
||||||
|
)
|
||||||
|
|
||||||
|
return self._list_spaces(limit, page_token, search_query)
|
||||||
|
|
||||||
|
@_retry_on_auth_failure
|
||||||
|
def download_to_directory(self, local_dir: str, source_config: dict = None) -> dict:
|
||||||
|
config = source_config or getattr(self, "config", {})
|
||||||
|
file_ids = config.get("file_ids", [])
|
||||||
|
folder_ids = config.get("folder_ids", [])
|
||||||
|
files_downloaded = 0
|
||||||
|
|
||||||
|
os.makedirs(local_dir, exist_ok=True)
|
||||||
|
|
||||||
|
if isinstance(file_ids, str):
|
||||||
|
file_ids = [file_ids]
|
||||||
|
if isinstance(folder_ids, str):
|
||||||
|
folder_ids = [folder_ids]
|
||||||
|
|
||||||
|
for page_id in file_ids:
|
||||||
|
if self._download_page(page_id, local_dir):
|
||||||
|
files_downloaded += 1
|
||||||
|
files_downloaded += self._download_page_attachments(page_id, local_dir)
|
||||||
|
|
||||||
|
for space_id in folder_ids:
|
||||||
|
files_downloaded += self._download_space(space_id, local_dir)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"files_downloaded": files_downloaded,
|
||||||
|
"directory_path": local_dir,
|
||||||
|
"empty_result": files_downloaded == 0,
|
||||||
|
"source_type": "confluence",
|
||||||
|
"config_used": config,
|
||||||
|
}
|
||||||
|
|
||||||
|
def _list_spaces(
|
||||||
|
self, limit: int, cursor: Optional[str], search_query: Optional[str]
|
||||||
|
) -> List[Document]:
|
||||||
|
documents: List[Document] = []
|
||||||
|
params: Dict[str, Any] = {"limit": min(limit, 250)}
|
||||||
|
if cursor:
|
||||||
|
params["cursor"] = cursor
|
||||||
|
|
||||||
|
response = requests.get(
|
||||||
|
f"{self.base_url}/spaces",
|
||||||
|
headers=self._headers(),
|
||||||
|
params=params,
|
||||||
|
timeout=30,
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
for space in data.get("results", []):
|
||||||
|
name = space.get("name", "")
|
||||||
|
if search_query and search_query.lower() not in name.lower():
|
||||||
|
continue
|
||||||
|
|
||||||
|
documents.append(
|
||||||
|
Document(
|
||||||
|
text="",
|
||||||
|
doc_id=space["id"],
|
||||||
|
extra_info={
|
||||||
|
"file_name": name,
|
||||||
|
"mime_type": "folder",
|
||||||
|
"size": None,
|
||||||
|
"created_time": space.get("createdAt"),
|
||||||
|
"modified_time": None,
|
||||||
|
"source": "confluence",
|
||||||
|
"is_folder": True,
|
||||||
|
"space_key": space.get("key"),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
next_link = data.get("_links", {}).get("next")
|
||||||
|
self.next_page_token = self._extract_cursor(next_link)
|
||||||
|
return documents
|
||||||
|
|
||||||
|
def _list_pages_in_space(
|
||||||
|
self,
|
||||||
|
space_id: str,
|
||||||
|
limit: int,
|
||||||
|
list_only: bool,
|
||||||
|
cursor: Optional[str],
|
||||||
|
search_query: Optional[str],
|
||||||
|
) -> List[Document]:
|
||||||
|
documents: List[Document] = []
|
||||||
|
params: Dict[str, Any] = {"limit": min(limit, 250)}
|
||||||
|
if cursor:
|
||||||
|
params["cursor"] = cursor
|
||||||
|
|
||||||
|
response = requests.get(
|
||||||
|
f"{self.base_url}/spaces/{space_id}/pages",
|
||||||
|
headers=self._headers(),
|
||||||
|
params=params,
|
||||||
|
timeout=30,
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
for page in data.get("results", []):
|
||||||
|
title = page.get("title", "")
|
||||||
|
if search_query and search_query.lower() not in title.lower():
|
||||||
|
continue
|
||||||
|
|
||||||
|
doc = self._page_to_document(
|
||||||
|
page, load_content=not list_only, space_id=space_id
|
||||||
|
)
|
||||||
|
if doc:
|
||||||
|
documents.append(doc)
|
||||||
|
|
||||||
|
next_link = data.get("_links", {}).get("next")
|
||||||
|
self.next_page_token = self._extract_cursor(next_link)
|
||||||
|
return documents
|
||||||
|
|
||||||
|
def _load_pages_by_ids(
|
||||||
|
self, page_ids: List[str], list_only: bool, search_query: Optional[str]
|
||||||
|
) -> List[Document]:
|
||||||
|
documents: List[Document] = []
|
||||||
|
for page_id in page_ids:
|
||||||
|
try:
|
||||||
|
params: Dict[str, str] = {}
|
||||||
|
if not list_only:
|
||||||
|
params["body-format"] = "storage"
|
||||||
|
|
||||||
|
response = requests.get(
|
||||||
|
f"{self.base_url}/pages/{page_id}",
|
||||||
|
headers=self._headers(),
|
||||||
|
params=params,
|
||||||
|
timeout=30,
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
page = response.json()
|
||||||
|
|
||||||
|
title = page.get("title", "")
|
||||||
|
if search_query and search_query.lower() not in title.lower():
|
||||||
|
continue
|
||||||
|
|
||||||
|
doc = self._page_to_document(page, load_content=not list_only)
|
||||||
|
if doc:
|
||||||
|
documents.append(doc)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Error loading page %s: %s", page_id, e)
|
||||||
|
return documents
|
||||||
|
|
||||||
|
def _page_to_document(
|
||||||
|
self,
|
||||||
|
page: Dict[str, Any],
|
||||||
|
load_content: bool = False,
|
||||||
|
space_id: Optional[str] = None,
|
||||||
|
) -> Optional[Document]:
|
||||||
|
page_id = page.get("id")
|
||||||
|
title = page.get("title", "Unknown")
|
||||||
|
version = page.get("version", {})
|
||||||
|
modified_time = version.get("createdAt") if isinstance(version, dict) else None
|
||||||
|
created_time = page.get("createdAt")
|
||||||
|
resolved_space_id = space_id or page.get("spaceId")
|
||||||
|
|
||||||
|
text = ""
|
||||||
|
if load_content:
|
||||||
|
body = page.get("body", {})
|
||||||
|
storage = body.get("storage", {}) if isinstance(body, dict) else {}
|
||||||
|
text = storage.get("value", "") if isinstance(storage, dict) else ""
|
||||||
|
|
||||||
|
return Document(
|
||||||
|
text=text,
|
||||||
|
doc_id=str(page_id),
|
||||||
|
extra_info={
|
||||||
|
"file_name": title,
|
||||||
|
"mime_type": "text/html",
|
||||||
|
"size": len(text) if text else None,
|
||||||
|
"created_time": created_time,
|
||||||
|
"modified_time": modified_time,
|
||||||
|
"source": "confluence",
|
||||||
|
"is_folder": False,
|
||||||
|
"page_id": str(page_id),
|
||||||
|
"space_id": resolved_space_id,
|
||||||
|
"cloud_id": self.cloud_id,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
def _download_page(self, page_id: str, local_dir: str) -> bool:
|
||||||
|
try:
|
||||||
|
response = requests.get(
|
||||||
|
f"{self.base_url}/pages/{page_id}",
|
||||||
|
headers=self._headers(),
|
||||||
|
params={"body-format": "storage"},
|
||||||
|
timeout=30,
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
page = response.json()
|
||||||
|
|
||||||
|
title = page.get("title", page_id)
|
||||||
|
safe_name = "".join(c if c.isalnum() or c in " -_" else "_" for c in title)
|
||||||
|
body = page.get("body", {}).get("storage", {}).get("value", "")
|
||||||
|
|
||||||
|
file_path = os.path.join(local_dir, f"{safe_name}.html")
|
||||||
|
with open(file_path, "w", encoding="utf-8") as f:
|
||||||
|
f.write(body)
|
||||||
|
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Error downloading page %s: %s", page_id, e)
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _download_page_attachments(self, page_id: str, local_dir: str) -> int:
|
||||||
|
downloaded = 0
|
||||||
|
try:
|
||||||
|
cursor = None
|
||||||
|
while True:
|
||||||
|
params: Dict[str, Any] = {"limit": 100}
|
||||||
|
if cursor:
|
||||||
|
params["cursor"] = cursor
|
||||||
|
|
||||||
|
response = requests.get(
|
||||||
|
f"{self.base_url}/pages/{page_id}/attachments",
|
||||||
|
headers=self._headers(),
|
||||||
|
params=params,
|
||||||
|
timeout=30,
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
for att in data.get("results", []):
|
||||||
|
media_type = att.get("mediaType", "")
|
||||||
|
if media_type not in SUPPORTED_ATTACHMENT_TYPES:
|
||||||
|
continue
|
||||||
|
|
||||||
|
download_link = att.get("_links", {}).get("download")
|
||||||
|
if not download_link:
|
||||||
|
continue
|
||||||
|
|
||||||
|
raw_name = att.get("title", att.get("id", "attachment"))
|
||||||
|
file_name = "".join(
|
||||||
|
c if c.isalnum() or c in " -_." else "_"
|
||||||
|
for c in os.path.basename(raw_name)
|
||||||
|
) or "attachment"
|
||||||
|
file_path = os.path.join(local_dir, file_name)
|
||||||
|
|
||||||
|
url = f"{self.download_base}{download_link}"
|
||||||
|
file_resp = requests.get(
|
||||||
|
url, headers=self._headers(), timeout=60, stream=True
|
||||||
|
)
|
||||||
|
file_resp.raise_for_status()
|
||||||
|
|
||||||
|
with open(file_path, "wb") as f:
|
||||||
|
for chunk in file_resp.iter_content(chunk_size=8192):
|
||||||
|
f.write(chunk)
|
||||||
|
|
||||||
|
downloaded += 1
|
||||||
|
|
||||||
|
next_link = data.get("_links", {}).get("next")
|
||||||
|
cursor = self._extract_cursor(next_link)
|
||||||
|
if not cursor:
|
||||||
|
break
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Error downloading attachments for page %s: %s", page_id, e)
|
||||||
|
return downloaded
|
||||||
|
|
||||||
|
def _download_space(self, space_id: str, local_dir: str) -> int:
|
||||||
|
downloaded = 0
|
||||||
|
cursor = None
|
||||||
|
while True:
|
||||||
|
params: Dict[str, Any] = {"limit": 250}
|
||||||
|
if cursor:
|
||||||
|
params["cursor"] = cursor
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = requests.get(
|
||||||
|
f"{self.base_url}/spaces/{space_id}/pages",
|
||||||
|
headers=self._headers(),
|
||||||
|
params=params,
|
||||||
|
timeout=30,
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
data = response.json()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Error listing pages in space %s: %s", space_id, e)
|
||||||
|
break
|
||||||
|
|
||||||
|
for page in data.get("results", []):
|
||||||
|
page_id = page.get("id")
|
||||||
|
if self._download_page(str(page_id), local_dir):
|
||||||
|
downloaded += 1
|
||||||
|
downloaded += self._download_page_attachments(str(page_id), local_dir)
|
||||||
|
|
||||||
|
next_link = data.get("_links", {}).get("next")
|
||||||
|
cursor = self._extract_cursor(next_link)
|
||||||
|
if not cursor:
|
||||||
|
break
|
||||||
|
|
||||||
|
return downloaded
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _extract_cursor(next_link: Optional[str]) -> Optional[str]:
|
||||||
|
if not next_link:
|
||||||
|
return None
|
||||||
|
from urllib.parse import parse_qs, urlparse
|
||||||
|
|
||||||
|
parsed = urlparse(next_link)
|
||||||
|
cursors = parse_qs(parsed.query).get("cursor")
|
||||||
|
return cursors[0] if cursors else None
|
||||||
@@ -1,5 +1,7 @@
|
|||||||
from application.parser.connectors.google_drive.loader import GoogleDriveLoader
|
from application.parser.connectors.confluence.auth import ConfluenceAuth
|
||||||
|
from application.parser.connectors.confluence.loader import ConfluenceLoader
|
||||||
from application.parser.connectors.google_drive.auth import GoogleDriveAuth
|
from application.parser.connectors.google_drive.auth import GoogleDriveAuth
|
||||||
|
from application.parser.connectors.google_drive.loader import GoogleDriveLoader
|
||||||
from application.parser.connectors.share_point.auth import SharePointAuth
|
from application.parser.connectors.share_point.auth import SharePointAuth
|
||||||
from application.parser.connectors.share_point.loader import SharePointLoader
|
from application.parser.connectors.share_point.loader import SharePointLoader
|
||||||
|
|
||||||
@@ -13,11 +15,13 @@ class ConnectorCreator:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
connectors = {
|
connectors = {
|
||||||
|
"confluence": ConfluenceLoader,
|
||||||
"google_drive": GoogleDriveLoader,
|
"google_drive": GoogleDriveLoader,
|
||||||
"share_point": SharePointLoader,
|
"share_point": SharePointLoader,
|
||||||
}
|
}
|
||||||
|
|
||||||
auth_providers = {
|
auth_providers = {
|
||||||
|
"confluence": ConfluenceAuth,
|
||||||
"google_drive": GoogleDriveAuth,
|
"google_drive": GoogleDriveAuth,
|
||||||
"share_point": SharePointAuth,
|
"share_point": SharePointAuth,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ from googleapiclient.discovery import build
|
|||||||
from googleapiclient.errors import HttpError
|
from googleapiclient.errors import HttpError
|
||||||
|
|
||||||
from application.core.settings import settings
|
from application.core.settings import settings
|
||||||
|
from application.parser.connectors._auth_utils import session_token_fingerprint
|
||||||
from application.parser.connectors.base import BaseConnectorAuth
|
from application.parser.connectors.base import BaseConnectorAuth
|
||||||
|
|
||||||
|
|
||||||
@@ -209,23 +210,23 @@ class GoogleDriveAuth(BaseConnectorAuth):
|
|||||||
|
|
||||||
def get_token_info_from_session(self, session_token: str) -> Dict[str, Any]:
|
def get_token_info_from_session(self, session_token: str) -> Dict[str, Any]:
|
||||||
try:
|
try:
|
||||||
from application.core.mongo_db import MongoDB
|
from application.storage.db.repositories.connector_sessions import (
|
||||||
from application.core.settings import settings
|
ConnectorSessionsRepository,
|
||||||
|
)
|
||||||
|
from application.storage.db.session import db_readonly
|
||||||
|
|
||||||
mongo = MongoDB.get_client()
|
with db_readonly() as conn:
|
||||||
db = mongo[settings.MONGO_DB_NAME]
|
session = ConnectorSessionsRepository(conn).get_by_session_token(
|
||||||
|
session_token
|
||||||
sessions_collection = db["connector_sessions"]
|
)
|
||||||
session = sessions_collection.find_one({"session_token": session_token})
|
|
||||||
if not session:
|
if not session:
|
||||||
raise ValueError(f"Invalid session token: {session_token}")
|
raise ValueError(
|
||||||
|
f"Invalid session token ({session_token_fingerprint(session_token)})"
|
||||||
|
)
|
||||||
|
|
||||||
if "token_info" not in session:
|
token_info = session.get("token_info")
|
||||||
raise ValueError("Session missing token information")
|
|
||||||
|
|
||||||
token_info = session["token_info"]
|
|
||||||
if not token_info:
|
if not token_info:
|
||||||
raise ValueError("Invalid token information")
|
raise ValueError("Session missing token information")
|
||||||
|
|
||||||
required_fields = ["access_token", "refresh_token"]
|
required_fields = ["access_token", "refresh_token"]
|
||||||
missing_fields = [field for field in required_fields if field not in token_info or not token_info.get(field)]
|
missing_fields = [field for field in required_fields if field not in token_info or not token_info.get(field)]
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ from typing import Optional, Dict, Any
|
|||||||
from msal import ConfidentialClientApplication
|
from msal import ConfidentialClientApplication
|
||||||
|
|
||||||
from application.core.settings import settings
|
from application.core.settings import settings
|
||||||
|
from application.parser.connectors._auth_utils import session_token_fingerprint
|
||||||
from application.parser.connectors.base import BaseConnectorAuth
|
from application.parser.connectors.base import BaseConnectorAuth
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -77,24 +78,24 @@ class SharePointAuth(BaseConnectorAuth):
|
|||||||
|
|
||||||
def get_token_info_from_session(self, session_token: str) -> Dict[str, Any]:
|
def get_token_info_from_session(self, session_token: str) -> Dict[str, Any]:
|
||||||
try:
|
try:
|
||||||
from application.core.mongo_db import MongoDB
|
from application.storage.db.repositories.connector_sessions import (
|
||||||
from application.core.settings import settings
|
ConnectorSessionsRepository,
|
||||||
|
)
|
||||||
|
from application.storage.db.session import db_readonly
|
||||||
|
|
||||||
mongo = MongoDB.get_client()
|
with db_readonly() as conn:
|
||||||
db = mongo[settings.MONGO_DB_NAME]
|
session = ConnectorSessionsRepository(conn).get_by_session_token(
|
||||||
|
session_token
|
||||||
sessions_collection = db["connector_sessions"]
|
)
|
||||||
session = sessions_collection.find_one({"session_token": session_token})
|
|
||||||
|
|
||||||
if not session:
|
if not session:
|
||||||
raise ValueError(f"Invalid session token: {session_token}")
|
raise ValueError(
|
||||||
|
f"Invalid session token ({session_token_fingerprint(session_token)})"
|
||||||
|
)
|
||||||
|
|
||||||
if "token_info" not in session:
|
token_info = session.get("token_info")
|
||||||
raise ValueError("Session missing token information")
|
|
||||||
|
|
||||||
token_info = session["token_info"]
|
|
||||||
if not token_info:
|
if not token_info:
|
||||||
raise ValueError("Invalid token information")
|
raise ValueError("Session missing token information")
|
||||||
|
|
||||||
required_fields = ["access_token", "refresh_token"]
|
required_fields = ["access_token", "refresh_token"]
|
||||||
missing_fields = [field for field in required_fields if field not in token_info or not token_info.get(field)]
|
missing_fields = [field for field in required_fields if field not in token_info or not token_info.get(field)]
|
||||||
|
|||||||
@@ -205,7 +205,7 @@ class SharePointLoader(BaseConnectorLoader):
|
|||||||
try:
|
try:
|
||||||
url = self._get_item_url(file_id)
|
url = self._get_item_url(file_id)
|
||||||
params = {'$select': 'id,name,file,createdDateTime,lastModifiedDateTime,size'}
|
params = {'$select': 'id,name,file,createdDateTime,lastModifiedDateTime,size'}
|
||||||
response = requests.get(url, headers=self._get_headers(), params=params)
|
response = requests.get(url, headers=self._get_headers(), params=params, timeout=100)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
|
|
||||||
file_metadata = response.json()
|
file_metadata = response.json()
|
||||||
@@ -236,9 +236,9 @@ class SharePointLoader(BaseConnectorLoader):
|
|||||||
search_url = f"{self.GRAPH_API_BASE}/drives/{drive_id}/root/search(q='{encoded_query}')"
|
search_url = f"{self.GRAPH_API_BASE}/drives/{drive_id}/root/search(q='{encoded_query}')"
|
||||||
else:
|
else:
|
||||||
search_url = f"{self.GRAPH_API_BASE}/me/drive/search(q='{encoded_query}')"
|
search_url = f"{self.GRAPH_API_BASE}/me/drive/search(q='{encoded_query}')"
|
||||||
response = requests.get(search_url, headers=self._get_headers(), params=params)
|
response = requests.get(search_url, headers=self._get_headers(), params=params, timeout=100)
|
||||||
else:
|
else:
|
||||||
response = requests.get(url, headers=self._get_headers(), params=params)
|
response = requests.get(url, headers=self._get_headers(), params=params, timeout=100)
|
||||||
|
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
|
|
||||||
@@ -307,7 +307,8 @@ class SharePointLoader(BaseConnectorLoader):
|
|||||||
response = requests.get(
|
response = requests.get(
|
||||||
f"{self.GRAPH_API_BASE}/me/drive",
|
f"{self.GRAPH_API_BASE}/me/drive",
|
||||||
headers=self._get_headers(),
|
headers=self._get_headers(),
|
||||||
params={'$select': 'webUrl'}
|
params={'$select': 'webUrl'},
|
||||||
|
timeout=100,
|
||||||
)
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
return response.json().get('webUrl')
|
return response.json().get('webUrl')
|
||||||
@@ -352,7 +353,7 @@ class SharePointLoader(BaseConnectorLoader):
|
|||||||
|
|
||||||
headers = self._get_headers()
|
headers = self._get_headers()
|
||||||
headers["Content-Type"] = "application/json"
|
headers["Content-Type"] = "application/json"
|
||||||
response = requests.post(url, headers=headers, json=body)
|
response = requests.post(url, headers=headers, json=body, timeout=100)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
results = response.json()
|
results = response.json()
|
||||||
|
|
||||||
@@ -472,7 +473,7 @@ class SharePointLoader(BaseConnectorLoader):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
url = f"{self._get_item_url(file_id)}/content"
|
url = f"{self._get_item_url(file_id)}/content"
|
||||||
response = requests.get(url, headers=self._get_headers())
|
response = requests.get(url, headers=self._get_headers(), timeout=100)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -491,7 +492,7 @@ class SharePointLoader(BaseConnectorLoader):
|
|||||||
try:
|
try:
|
||||||
url = self._get_item_url(file_id)
|
url = self._get_item_url(file_id)
|
||||||
params = {'$select': 'id,name,file'}
|
params = {'$select': 'id,name,file'}
|
||||||
response = requests.get(url, headers=self._get_headers(), params=params)
|
response = requests.get(url, headers=self._get_headers(), params=params, timeout=100)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
|
|
||||||
metadata = response.json()
|
metadata = response.json()
|
||||||
@@ -507,7 +508,7 @@ class SharePointLoader(BaseConnectorLoader):
|
|||||||
full_path = os.path.join(local_dir, file_name)
|
full_path = os.path.join(local_dir, file_name)
|
||||||
|
|
||||||
download_url = f"{self._get_item_url(file_id)}/content"
|
download_url = f"{self._get_item_url(file_id)}/content"
|
||||||
download_response = requests.get(download_url, headers=self._get_headers())
|
download_response = requests.get(download_url, headers=self._get_headers(), timeout=100)
|
||||||
download_response.raise_for_status()
|
download_response.raise_for_status()
|
||||||
|
|
||||||
with open(full_path, 'wb') as f:
|
with open(full_path, 'wb') as f:
|
||||||
@@ -527,7 +528,7 @@ class SharePointLoader(BaseConnectorLoader):
|
|||||||
params = {'$top': 1000}
|
params = {'$top': 1000}
|
||||||
|
|
||||||
while url:
|
while url:
|
||||||
response = requests.get(url, headers=self._get_headers(), params=params)
|
response = requests.get(url, headers=self._get_headers(), params=params, timeout=100)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
|
|
||||||
results = response.json()
|
results = response.json()
|
||||||
@@ -609,7 +610,7 @@ class SharePointLoader(BaseConnectorLoader):
|
|||||||
try:
|
try:
|
||||||
url = self._get_item_url(folder_id)
|
url = self._get_item_url(folder_id)
|
||||||
params = {'$select': 'id,name'}
|
params = {'$select': 'id,name'}
|
||||||
response = requests.get(url, headers=self._get_headers(), params=params)
|
response = requests.get(url, headers=self._get_headers(), params=params, timeout=100)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
|
|
||||||
folder_metadata = response.json()
|
folder_metadata = response.json()
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ class PDFParser(BaseParser):
|
|||||||
# alternatively you can use local vision capable LLM
|
# alternatively you can use local vision capable LLM
|
||||||
with open(file, "rb") as file_loaded:
|
with open(file, "rb") as file_loaded:
|
||||||
files = {'file': file_loaded}
|
files = {'file': file_loaded}
|
||||||
response = requests.post(doc2md_service, files=files)
|
response = requests.post(doc2md_service, files=files, timeout=100)
|
||||||
data = response.json()["markdown"]
|
data = response.json()["markdown"]
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|||||||
@@ -19,25 +19,10 @@ class EpubParser(BaseParser):
|
|||||||
def parse_file(self, file: Path, errors: str = "ignore") -> str:
|
def parse_file(self, file: Path, errors: str = "ignore") -> str:
|
||||||
"""Parse file."""
|
"""Parse file."""
|
||||||
try:
|
try:
|
||||||
import ebooklib
|
from fast_ebook import epub
|
||||||
from ebooklib import epub
|
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ValueError("`EbookLib` is required to read Epub files.")
|
raise ValueError("`fast-ebook` is required to read Epub files.")
|
||||||
try:
|
|
||||||
import html2text
|
|
||||||
except ImportError:
|
|
||||||
raise ValueError("`html2text` is required to parse Epub files.")
|
|
||||||
|
|
||||||
text_list = []
|
book = epub.read_epub(file)
|
||||||
book = epub.read_epub(file, options={"ignore_ncx": True})
|
text = book.to_markdown()
|
||||||
|
|
||||||
# Iterate through all chapters.
|
|
||||||
for item in book.get_items():
|
|
||||||
# Chapters are typically located in epub documents items.
|
|
||||||
if item.get_type() == ebooklib.ITEM_DOCUMENT:
|
|
||||||
text_list.append(
|
|
||||||
html2text.html2text(item.get_content().decode("utf-8"))
|
|
||||||
)
|
|
||||||
|
|
||||||
text = "\n".join(text_list)
|
|
||||||
return text
|
return text
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ class ImageParser(BaseParser):
|
|||||||
# alternatively you can use local vision capable LLM
|
# alternatively you can use local vision capable LLM
|
||||||
with open(file, "rb") as file_loaded:
|
with open(file, "rb") as file_loaded:
|
||||||
files = {'file': file_loaded}
|
files = {'file': file_loaded}
|
||||||
response = requests.post(doc2md_service, files=files)
|
response = requests.post(doc2md_service, files=files, timeout=100)
|
||||||
data = response.json()["markdown"]
|
data = response.json()["markdown"]
|
||||||
else:
|
else:
|
||||||
data = ""
|
data = ""
|
||||||
|
|||||||
@@ -77,7 +77,7 @@ class GitHubLoader(BaseRemote):
|
|||||||
def _make_request(self, url: str, max_retries: int = 3) -> requests.Response:
|
def _make_request(self, url: str, max_retries: int = 3) -> requests.Response:
|
||||||
"""Make a request with retry logic for rate limiting"""
|
"""Make a request with retry logic for rate limiting"""
|
||||||
for attempt in range(max_retries):
|
for attempt in range(max_retries):
|
||||||
response = requests.get(url, headers=self.headers)
|
response = requests.get(url, headers=self.headers, timeout=100)
|
||||||
|
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
return response
|
return response
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import os
|
|||||||
import tempfile
|
import tempfile
|
||||||
import mimetypes
|
import mimetypes
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
from application.core.url_validation import SSRFError, validate_url
|
||||||
from application.parser.remote.base import BaseRemote
|
from application.parser.remote.base import BaseRemote
|
||||||
from application.parser.schema.base import Document
|
from application.parser.schema.base import Document
|
||||||
|
|
||||||
@@ -108,6 +109,11 @@ class S3Loader(BaseRemote):
|
|||||||
logger.info(f"Normalized endpoint URL: {normalized_endpoint}")
|
logger.info(f"Normalized endpoint URL: {normalized_endpoint}")
|
||||||
logger.info(f"Bucket name: '{corrected_bucket}'")
|
logger.info(f"Bucket name: '{corrected_bucket}'")
|
||||||
|
|
||||||
|
try:
|
||||||
|
normalized_endpoint = validate_url(normalized_endpoint)
|
||||||
|
except SSRFError as e:
|
||||||
|
raise ValueError(f"Invalid S3 endpoint_url: {e}") from e
|
||||||
|
|
||||||
client_kwargs["endpoint_url"] = normalized_endpoint
|
client_kwargs["endpoint_url"] = normalized_endpoint
|
||||||
# Use path-style addressing for S3-compatible services
|
# Use path-style addressing for S3-compatible services
|
||||||
# (DigitalOcean Spaces, MinIO, etc.)
|
# (DigitalOcean Spaces, MinIO, etc.)
|
||||||
|
|||||||
@@ -36,6 +36,11 @@ class SitemapLoader(BaseRemote):
|
|||||||
if self.limit is not None and processed_urls >= self.limit:
|
if self.limit is not None and processed_urls >= self.limit:
|
||||||
break # Stop processing if the limit is reached
|
break # Stop processing if the limit is reached
|
||||||
|
|
||||||
|
try:
|
||||||
|
url = validate_url(url)
|
||||||
|
except SSRFError as e:
|
||||||
|
logging.error(f"URL validation failed for sitemap entry {url}: {e}")
|
||||||
|
continue
|
||||||
try:
|
try:
|
||||||
loader = self.loader([url])
|
loader = self.loader([url])
|
||||||
documents.extend(loader.load())
|
documents.extend(loader.load())
|
||||||
@@ -90,6 +95,15 @@ class SitemapLoader(BaseRemote):
|
|||||||
# Check for nested sitemaps
|
# Check for nested sitemaps
|
||||||
for sitemap in root.findall('.//sitemap/loc'):
|
for sitemap in root.findall('.//sitemap/loc'):
|
||||||
nested_sitemap_url = sitemap.text
|
nested_sitemap_url = sitemap.text
|
||||||
|
if not nested_sitemap_url:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
nested_sitemap_url = validate_url(nested_sitemap_url)
|
||||||
|
except SSRFError as e:
|
||||||
|
logging.error(
|
||||||
|
f"URL validation failed for nested sitemap {nested_sitemap_url}: {e}"
|
||||||
|
)
|
||||||
|
continue
|
||||||
urls.extend(self._extract_urls(nested_sitemap_url))
|
urls.extend(self._extract_urls(nested_sitemap_url))
|
||||||
|
|
||||||
return urls
|
return urls
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
import logging
|
import logging
|
||||||
|
from application.core.url_validation import SSRFError, validate_url
|
||||||
from application.parser.remote.base import BaseRemote
|
from application.parser.remote.base import BaseRemote
|
||||||
from application.parser.schema.base import Document
|
from application.parser.schema.base import Document
|
||||||
from langchain_community.document_loaders import WebBaseLoader
|
from langchain_community.document_loaders import WebBaseLoader
|
||||||
from urllib.parse import urlparse
|
|
||||||
|
|
||||||
headers = {
|
headers = {
|
||||||
"User-Agent": "Mozilla/5.0",
|
"User-Agent": "Mozilla/5.0",
|
||||||
@@ -26,9 +26,13 @@ class WebLoader(BaseRemote):
|
|||||||
urls = [urls]
|
urls = [urls]
|
||||||
documents = []
|
documents = []
|
||||||
for url in urls:
|
for url in urls:
|
||||||
# Check if the URL scheme is provided, if not, assume http
|
try:
|
||||||
if not urlparse(url).scheme:
|
url = validate_url(url)
|
||||||
url = "http://" + url
|
except SSRFError as e:
|
||||||
|
logging.warning(
|
||||||
|
f"Skipping URL due to SSRF validation failure: {url} - {e}"
|
||||||
|
)
|
||||||
|
continue
|
||||||
try:
|
try:
|
||||||
loader = self.loader([url], header_template=headers)
|
loader = self.loader([url], header_template=headers)
|
||||||
loaded_docs = loader.load()
|
loaded_docs = loader.load()
|
||||||
|
|||||||
@@ -1,9 +1,10 @@
|
|||||||
|
alembic>=1.13,<2
|
||||||
anthropic==0.88.0
|
anthropic==0.88.0
|
||||||
boto3==1.42.24
|
boto3==1.42.83
|
||||||
beautifulsoup4==4.14.3
|
beautifulsoup4==4.14.3
|
||||||
cel-python==0.5.0
|
cel-python==0.5.0
|
||||||
celery==5.6.3
|
celery==5.6.3
|
||||||
cryptography==46.0.6
|
cryptography==46.0.7
|
||||||
dataclasses-json==0.6.7
|
dataclasses-json==0.6.7
|
||||||
defusedxml==0.7.1
|
defusedxml==0.7.1
|
||||||
docling>=2.16.0
|
docling>=2.16.0
|
||||||
@@ -11,32 +12,31 @@ rapidocr>=1.4.0
|
|||||||
onnxruntime>=1.19.0
|
onnxruntime>=1.19.0
|
||||||
docx2txt==0.9
|
docx2txt==0.9
|
||||||
ddgs>=8.0.0
|
ddgs>=8.0.0
|
||||||
ebooklib==0.20
|
fast-ebook
|
||||||
elevenlabs==2.40.0
|
elevenlabs==2.43.0
|
||||||
Flask==3.1.3
|
Flask==3.1.3
|
||||||
faiss-cpu==1.13.2
|
faiss-cpu==1.13.2
|
||||||
fastmcp==3.2.0
|
fastmcp==3.2.4
|
||||||
flask-restx==1.3.2
|
flask-restx==1.3.2
|
||||||
google-genai==1.69.0
|
google-genai==1.73.1
|
||||||
google-api-python-client==2.193.0
|
google-api-python-client==2.193.0
|
||||||
google-auth-httplib2==0.3.1
|
google-auth-httplib2==0.3.1
|
||||||
google-auth-oauthlib==1.3.1
|
google-auth-oauthlib==1.3.1
|
||||||
gTTS==2.5.4
|
gTTS==2.5.4
|
||||||
gunicorn==25.3.0
|
gunicorn==25.3.0
|
||||||
html2text==2025.4.15
|
|
||||||
jinja2==3.1.6
|
jinja2==3.1.6
|
||||||
jiter==0.13.0
|
jiter==0.13.0
|
||||||
jmespath==1.0.1
|
jmespath==1.1.0
|
||||||
joblib==1.5.3
|
joblib==1.5.3
|
||||||
jsonpatch==1.33
|
jsonpatch==1.33
|
||||||
jsonpointer==3.0.0
|
jsonpointer==3.1.1
|
||||||
kombu==5.6.2
|
kombu==5.6.2
|
||||||
langchain==1.2.3
|
langchain==1.2.3
|
||||||
langchain-community==0.4.1
|
langchain-community==0.4.1
|
||||||
langchain-core==1.2.23
|
langchain-core==1.2.29
|
||||||
langchain-openai==1.1.12
|
langchain-openai==1.1.12
|
||||||
langchain-text-splitters==1.1.1
|
langchain-text-splitters==1.1.1
|
||||||
langsmith==0.7.23
|
langsmith==0.7.31
|
||||||
lazy-object-proxy==1.12.0
|
lazy-object-proxy==1.12.0
|
||||||
lxml==6.0.2
|
lxml==6.0.2
|
||||||
markupsafe==3.0.3
|
markupsafe==3.0.3
|
||||||
@@ -47,7 +47,7 @@ msal==1.35.1
|
|||||||
mypy-extensions==1.1.0
|
mypy-extensions==1.1.0
|
||||||
networkx==3.6.1
|
networkx==3.6.1
|
||||||
numpy==2.4.4
|
numpy==2.4.4
|
||||||
openai==2.30.0
|
openai==2.32.0
|
||||||
openapi3-parser==1.1.22
|
openapi3-parser==1.1.22
|
||||||
orjson==3.11.7
|
orjson==3.11.7
|
||||||
packaging==26.0
|
packaging==26.0
|
||||||
@@ -59,12 +59,11 @@ pillow
|
|||||||
portalocker>=2.7.0,<4.0.0
|
portalocker>=2.7.0,<4.0.0
|
||||||
prompt-toolkit==3.0.52
|
prompt-toolkit==3.0.52
|
||||||
protobuf==7.34.1
|
protobuf==7.34.1
|
||||||
psycopg2-binary==2.9.11
|
psycopg[binary,pool]>=3.1,<4
|
||||||
py==1.11.0
|
py==1.11.0
|
||||||
pydantic
|
pydantic
|
||||||
pydantic-core
|
pydantic-core
|
||||||
pydantic-settings
|
pydantic-settings
|
||||||
pymongo==4.16.0
|
|
||||||
pypdf==6.9.2
|
pypdf==6.9.2
|
||||||
python-dateutil==2.9.0.post0
|
python-dateutil==2.9.0.post0
|
||||||
python-dotenv
|
python-dotenv
|
||||||
@@ -72,10 +71,11 @@ python-jose==3.5.0
|
|||||||
python-pptx==1.0.2
|
python-pptx==1.0.2
|
||||||
redis==7.4.0
|
redis==7.4.0
|
||||||
referencing>=0.28.0,<0.38.0
|
referencing>=0.28.0,<0.38.0
|
||||||
regex==2026.3.32
|
regex==2026.4.4
|
||||||
requests==2.33.1
|
requests==2.33.1
|
||||||
retry==0.9.2
|
retry==0.9.2
|
||||||
sentence-transformers==5.3.0
|
sentence-transformers==5.3.0
|
||||||
|
sqlalchemy>=2.0,<3
|
||||||
tiktoken==0.12.0
|
tiktoken==0.12.0
|
||||||
tokenizers==0.22.2
|
tokenizers==0.22.2
|
||||||
torch==2.11.0
|
torch==2.11.0
|
||||||
@@ -83,7 +83,7 @@ tqdm==4.67.3
|
|||||||
transformers==5.4.0
|
transformers==5.4.0
|
||||||
typing-extensions==4.15.0
|
typing-extensions==4.15.0
|
||||||
typing-inspect==0.9.0
|
typing-inspect==0.9.0
|
||||||
tzdata==2025.3
|
tzdata==2026.1
|
||||||
urllib3==2.6.3
|
urllib3==2.6.3
|
||||||
vine==5.1.0
|
vine==5.1.0
|
||||||
wcwidth==0.6.0
|
wcwidth==0.6.0
|
||||||
|
|||||||
@@ -1,7 +1,5 @@
|
|||||||
import click
|
import click
|
||||||
|
|
||||||
from application.core.mongo_db import MongoDB
|
|
||||||
from application.core.settings import settings
|
|
||||||
from application.seed.seeder import DatabaseSeeder
|
from application.seed.seeder import DatabaseSeeder
|
||||||
|
|
||||||
|
|
||||||
@@ -15,10 +13,7 @@ def seed():
|
|||||||
@click.option("--force", is_flag=True, help="Force reseeding even if data exists")
|
@click.option("--force", is_flag=True, help="Force reseeding even if data exists")
|
||||||
def init(force):
|
def init(force):
|
||||||
"""Initialize database with seed data"""
|
"""Initialize database with seed data"""
|
||||||
mongo = MongoDB.get_client()
|
seeder = DatabaseSeeder()
|
||||||
db = mongo[settings.MONGO_DB_NAME]
|
|
||||||
|
|
||||||
seeder = DatabaseSeeder(db)
|
|
||||||
seeder.seed_initial_data(force=force)
|
seeder.seed_initial_data(force=force)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,35 +1,56 @@
|
|||||||
|
"""Database seeder — Postgres-native.
|
||||||
|
|
||||||
|
Post-Part-2 cutover: writes template prompts/tools/agents/sources directly
|
||||||
|
into Postgres via the repository layer. No MongoDB dependencies.
|
||||||
|
|
||||||
|
The seeder is invoked by the ``python -m application.seed.commands init``
|
||||||
|
CLI (not at Flask app startup). All template rows are owned by the
|
||||||
|
sentinel user id ``__system__`` — kept in sync with the migration
|
||||||
|
backfill/cleanup-trigger sentinel so template ownership is predictable.
|
||||||
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from datetime import datetime, timezone
|
from typing import Dict, List, Optional
|
||||||
from typing import Dict, List, Optional, Union
|
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
from bson import ObjectId
|
|
||||||
from bson.dbref import DBRef
|
|
||||||
|
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
from pymongo import MongoClient
|
|
||||||
|
|
||||||
from application.agents.tools.tool_manager import ToolManager
|
from application.agents.tools.tool_manager import ToolManager
|
||||||
from application.api.user.tasks import ingest_remote
|
from application.api.user.tasks import ingest_remote
|
||||||
|
from application.storage.db.repositories.agents import AgentsRepository
|
||||||
|
from application.storage.db.repositories.prompts import PromptsRepository
|
||||||
|
from application.storage.db.repositories.sources import SourcesRepository
|
||||||
|
from application.storage.db.repositories.user_tools import UserToolsRepository
|
||||||
|
from application.storage.db.session import db_readonly, db_session
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
tool_config = {}
|
tool_config = {}
|
||||||
tool_manager = ToolManager(config=tool_config)
|
tool_manager = ToolManager(config=tool_config)
|
||||||
|
|
||||||
|
|
||||||
|
# Sentinel user id for template rows (agents/prompts/sources/tools).
|
||||||
|
# Kept in sync with the Postgres backfill / cleanup-trigger sentinel so
|
||||||
|
# template ownership is predictable across the cutover.
|
||||||
|
SYSTEM_USER_ID = "__system__"
|
||||||
|
|
||||||
|
|
||||||
class DatabaseSeeder:
|
class DatabaseSeeder:
|
||||||
def __init__(self, db):
|
"""Postgres-backed seeder.
|
||||||
self.db = db
|
|
||||||
self.tools_collection = self.db["user_tools"]
|
The constructor accepts an optional positional argument for back
|
||||||
self.sources_collection = self.db["sources"]
|
compatibility with legacy callers that used to pass a Mongo ``db``
|
||||||
self.agents_collection = self.db["agents"]
|
handle. The value is ignored — all persistence goes through the
|
||||||
self.prompts_collection = self.db["prompts"]
|
Postgres repositories.
|
||||||
self.system_user_id = "system"
|
"""
|
||||||
|
|
||||||
|
def __init__(self, db=None):
|
||||||
|
self._legacy_db = db # unused; retained for call-site compatibility
|
||||||
|
self.system_user_id = SYSTEM_USER_ID
|
||||||
self.logger = logging.getLogger(__name__)
|
self.logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
def seed_initial_data(self, config_path: str = None, force=False):
|
def seed_initial_data(self, config_path: str = None, force=False):
|
||||||
"""Main entry point for seeding all initial data"""
|
"""Main entry point for seeding all initial data."""
|
||||||
if not force and self._is_already_seeded():
|
if not force and self._is_already_seeded():
|
||||||
self.logger.info("Database already seeded. Use force=True to reseed.")
|
self.logger.info("Database already seeded. Use force=True to reseed.")
|
||||||
return
|
return
|
||||||
@@ -46,20 +67,18 @@ class DatabaseSeeder:
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
def _seed_from_config(self, config: Dict):
|
def _seed_from_config(self, config: Dict):
|
||||||
"""Seed all data from configuration"""
|
"""Seed all data from configuration."""
|
||||||
self.logger.info("🌱 Starting seeding...")
|
self.logger.info("Starting seeding...")
|
||||||
|
|
||||||
if not config.get("agents"):
|
if not config.get("agents"):
|
||||||
self.logger.warning("No agents found in config")
|
self.logger.warning("No agents found in config")
|
||||||
return
|
return
|
||||||
used_tool_ids = set()
|
|
||||||
|
|
||||||
for agent_config in config["agents"]:
|
for agent_config in config["agents"]:
|
||||||
try:
|
try:
|
||||||
self.logger.info(f"Processing agent: {agent_config['name']}")
|
self.logger.info(f"Processing agent: {agent_config['name']}")
|
||||||
|
|
||||||
# 1. Handle Source
|
# 1. Handle Source
|
||||||
|
|
||||||
source_result = self._handle_source(agent_config)
|
source_result = self._handle_source(agent_config)
|
||||||
if source_result is False:
|
if source_result is False:
|
||||||
self.logger.error(
|
self.logger.error(
|
||||||
@@ -67,64 +86,100 @@ class DatabaseSeeder:
|
|||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
source_id = source_result
|
source_id = source_result
|
||||||
# 2. Handle Tools
|
|
||||||
|
|
||||||
|
# 2. Handle Tools
|
||||||
tool_ids = self._handle_tools(agent_config)
|
tool_ids = self._handle_tools(agent_config)
|
||||||
if len(tool_ids) == 0:
|
if len(tool_ids) == 0:
|
||||||
self.logger.warning(
|
self.logger.warning(
|
||||||
f"No valid tools for agent {agent_config['name']}"
|
f"No valid tools for agent {agent_config['name']}"
|
||||||
)
|
)
|
||||||
used_tool_ids.update(tool_ids)
|
|
||||||
|
|
||||||
# 3. Handle Prompt
|
# 3. Handle Prompt
|
||||||
|
|
||||||
prompt_id = self._handle_prompt(agent_config)
|
prompt_id = self._handle_prompt(agent_config)
|
||||||
|
|
||||||
# 4. Create Agent
|
# 4. Create or update Agent
|
||||||
|
self._upsert_agent(agent_config, source_id, tool_ids, prompt_id)
|
||||||
|
|
||||||
agent_data = {
|
|
||||||
"user": self.system_user_id,
|
|
||||||
"name": agent_config["name"],
|
|
||||||
"description": agent_config["description"],
|
|
||||||
"image": agent_config.get("image", ""),
|
|
||||||
"source": (
|
|
||||||
DBRef("sources", ObjectId(source_id)) if source_id else ""
|
|
||||||
),
|
|
||||||
"tools": [str(tid) for tid in tool_ids],
|
|
||||||
"agent_type": agent_config["agent_type"],
|
|
||||||
"prompt_id": prompt_id or agent_config.get("prompt_id", "default"),
|
|
||||||
"chunks": agent_config.get("chunks", "0"),
|
|
||||||
"retriever": agent_config.get("retriever", ""),
|
|
||||||
"status": "template",
|
|
||||||
"createdAt": datetime.now(timezone.utc),
|
|
||||||
"updatedAt": datetime.now(timezone.utc),
|
|
||||||
}
|
|
||||||
|
|
||||||
existing = self.agents_collection.find_one(
|
|
||||||
{"user": self.system_user_id, "name": agent_config["name"]}
|
|
||||||
)
|
|
||||||
if existing:
|
|
||||||
self.logger.info(f"Updating existing agent: {agent_config['name']}")
|
|
||||||
self.agents_collection.update_one(
|
|
||||||
{"_id": existing["_id"]}, {"$set": agent_data}
|
|
||||||
)
|
|
||||||
agent_id = existing["_id"]
|
|
||||||
else:
|
|
||||||
self.logger.info(f"Creating new agent: {agent_config['name']}")
|
|
||||||
result = self.agents_collection.insert_one(agent_data)
|
|
||||||
agent_id = result.inserted_id
|
|
||||||
self.logger.info(
|
|
||||||
f"Successfully processed agent: {agent_config['name']} (ID: {agent_id})"
|
|
||||||
)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.logger.error(
|
self.logger.error(
|
||||||
f"Error processing agent {agent_config['name']}: {str(e)}"
|
f"Error processing agent {agent_config['name']}: {str(e)}"
|
||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
self.logger.info("✅ Database seeding completed")
|
self.logger.info("Database seeding completed")
|
||||||
|
|
||||||
def _handle_source(self, agent_config: Dict) -> Union[ObjectId, None, bool]:
|
@staticmethod
|
||||||
"""Handle source ingestion and return source ID"""
|
def _coerce_uuid_fk(raw) -> Optional[str]:
|
||||||
|
"""Coerce sentinel/blank values to ``None`` for nullable UUID FK columns.
|
||||||
|
|
||||||
|
Mirrors the route-side handling in ``application/api/user/agents/routes.py``:
|
||||||
|
the literal string ``"default"``, empty string, and ``None`` all map
|
||||||
|
to ``None`` so the repository layer skips the column and Postgres
|
||||||
|
keeps the FK NULL (FKs are ``ON DELETE SET NULL``).
|
||||||
|
"""
|
||||||
|
if raw in (None, "", "default"):
|
||||||
|
return None
|
||||||
|
return str(raw)
|
||||||
|
|
||||||
|
def _upsert_agent(
|
||||||
|
self,
|
||||||
|
agent_config: Dict,
|
||||||
|
source_id: Optional[str],
|
||||||
|
tool_ids: List[str],
|
||||||
|
prompt_id: Optional[str],
|
||||||
|
) -> None:
|
||||||
|
"""Create or update a template agent owned by ``__system__``."""
|
||||||
|
name = agent_config["name"]
|
||||||
|
prompt_id_val = self._coerce_uuid_fk(
|
||||||
|
prompt_id if prompt_id is not None else agent_config.get("prompt_id")
|
||||||
|
)
|
||||||
|
folder_id_val = self._coerce_uuid_fk(agent_config.get("folder_id"))
|
||||||
|
workflow_id_val = self._coerce_uuid_fk(agent_config.get("workflow_id"))
|
||||||
|
source_id_val = self._coerce_uuid_fk(source_id)
|
||||||
|
agent_fields = {
|
||||||
|
"description": agent_config["description"],
|
||||||
|
"image": agent_config.get("image", ""),
|
||||||
|
"tools": [str(tid) for tid in tool_ids],
|
||||||
|
"agent_type": agent_config["agent_type"],
|
||||||
|
"prompt_id": prompt_id_val,
|
||||||
|
"chunks": agent_config.get("chunks", "0"),
|
||||||
|
"retriever": agent_config.get("retriever", ""),
|
||||||
|
}
|
||||||
|
if folder_id_val is not None:
|
||||||
|
agent_fields["folder_id"] = folder_id_val
|
||||||
|
if workflow_id_val is not None:
|
||||||
|
agent_fields["workflow_id"] = workflow_id_val
|
||||||
|
if source_id_val is not None:
|
||||||
|
agent_fields["source_id"] = source_id_val
|
||||||
|
|
||||||
|
with db_session() as conn:
|
||||||
|
repo = AgentsRepository(conn)
|
||||||
|
existing = self._find_system_agent_by_name(repo, name)
|
||||||
|
if existing:
|
||||||
|
self.logger.info(f"Updating existing agent: {name}")
|
||||||
|
repo.update(str(existing["id"]), self.system_user_id, agent_fields)
|
||||||
|
self.logger.info(f"Successfully updated agent: {name} (ID: {existing['id']})")
|
||||||
|
else:
|
||||||
|
self.logger.info(f"Creating new agent: {name}")
|
||||||
|
created = repo.create(
|
||||||
|
user_id=self.system_user_id,
|
||||||
|
name=name,
|
||||||
|
status="template",
|
||||||
|
**agent_fields,
|
||||||
|
)
|
||||||
|
self.logger.info(
|
||||||
|
f"Successfully created agent: {name} (ID: {created.get('id')})"
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _find_system_agent_by_name(repo: AgentsRepository, name: str) -> Optional[dict]:
|
||||||
|
"""Find a system-owned agent by name among the template rows."""
|
||||||
|
for row in repo.list_for_user(SYSTEM_USER_ID):
|
||||||
|
if row.get("name") == name:
|
||||||
|
return row
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _handle_source(self, agent_config: Dict):
|
||||||
|
"""Handle source ingestion and return a source id (UUID string) or ``None``/``False``."""
|
||||||
if not agent_config.get("source"):
|
if not agent_config.get("source"):
|
||||||
self.logger.info(
|
self.logger.info(
|
||||||
"No source provided for agent - will create agent without source"
|
"No source provided for agent - will create agent without source"
|
||||||
@@ -134,14 +189,15 @@ class DatabaseSeeder:
|
|||||||
self.logger.info(f"Ingesting source: {source_config['url']}")
|
self.logger.info(f"Ingesting source: {source_config['url']}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
existing = self.sources_collection.find_one(
|
with db_readonly() as conn:
|
||||||
{"user": self.system_user_id, "remote_data": source_config["url"]}
|
existing = self._find_system_source_by_remote_url(
|
||||||
)
|
SourcesRepository(conn), source_config["url"]
|
||||||
|
)
|
||||||
if existing:
|
if existing:
|
||||||
self.logger.info(f"Source already exists: {existing['_id']}")
|
self.logger.info(f"Source already exists: {existing['id']}")
|
||||||
return existing["_id"]
|
return existing["id"]
|
||||||
# Ingest new source using worker
|
|
||||||
|
|
||||||
|
# Ingest new source using worker
|
||||||
task = ingest_remote.delay(
|
task = ingest_remote.delay(
|
||||||
source_data=source_config["url"],
|
source_data=source_config["url"],
|
||||||
job_name=source_config["name"],
|
job_name=source_config["name"],
|
||||||
@@ -164,9 +220,29 @@ class DatabaseSeeder:
|
|||||||
self.logger.error(f"Failed to ingest source: {str(e)}")
|
self.logger.error(f"Failed to ingest source: {str(e)}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def _handle_tools(self, agent_config: Dict) -> List[ObjectId]:
|
@staticmethod
|
||||||
"""Handle tool creation and return list of tool IDs"""
|
def _find_system_source_by_remote_url(
|
||||||
tool_ids = []
|
repo: SourcesRepository, url: str
|
||||||
|
) -> Optional[dict]:
|
||||||
|
"""Scan system-owned sources for a row whose remote_data matches ``url``."""
|
||||||
|
# TODO(migration-postgres): push this into SourcesRepository once a
|
||||||
|
# remote_data search helper exists; today we keep the scan here to
|
||||||
|
# stay within this slice's boundaries.
|
||||||
|
try:
|
||||||
|
rows = repo.list_for_user(SYSTEM_USER_ID) # type: ignore[attr-defined]
|
||||||
|
except AttributeError:
|
||||||
|
return None
|
||||||
|
for row in rows:
|
||||||
|
remote = row.get("remote_data")
|
||||||
|
if remote == url:
|
||||||
|
return row
|
||||||
|
if isinstance(remote, dict) and remote.get("url") == url:
|
||||||
|
return row
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _handle_tools(self, agent_config: Dict) -> List[str]:
|
||||||
|
"""Handle tool creation and return list of tool ids (UUID strings)."""
|
||||||
|
tool_ids: List[str] = []
|
||||||
if not agent_config.get("tools"):
|
if not agent_config.get("tools"):
|
||||||
return tool_ids
|
return tool_ids
|
||||||
for tool_config in agent_config["tools"]:
|
for tool_config in agent_config["tools"]:
|
||||||
@@ -175,37 +251,43 @@ class DatabaseSeeder:
|
|||||||
processed_config = self._process_config(tool_config.get("config", {}))
|
processed_config = self._process_config(tool_config.get("config", {}))
|
||||||
self.logger.info(f"Processing tool: {tool_name}")
|
self.logger.info(f"Processing tool: {tool_name}")
|
||||||
|
|
||||||
existing = self.tools_collection.find_one(
|
with db_session() as conn:
|
||||||
{
|
repo = UserToolsRepository(conn)
|
||||||
"user": self.system_user_id,
|
existing = self._find_system_tool(
|
||||||
"name": tool_name,
|
repo, tool_name, processed_config
|
||||||
"config": processed_config,
|
)
|
||||||
}
|
if existing:
|
||||||
)
|
self.logger.info(f"Tool already exists: {existing['id']}")
|
||||||
if existing:
|
tool_ids.append(existing["id"])
|
||||||
self.logger.info(f"Tool already exists: {existing['_id']}")
|
continue
|
||||||
tool_ids.append(existing["_id"])
|
created = repo.create(
|
||||||
continue
|
user_id=self.system_user_id,
|
||||||
tool_data = {
|
name=tool_name,
|
||||||
"user": self.system_user_id,
|
display_name=tool_config.get("display_name", tool_name),
|
||||||
"name": tool_name,
|
description=tool_config.get("description", ""),
|
||||||
"displayName": tool_config.get("display_name", tool_name),
|
actions=tool_manager.tools[tool_name].get_actions_metadata(),
|
||||||
"description": tool_config.get("description", ""),
|
config=processed_config,
|
||||||
"actions": tool_manager.tools[tool_name].get_actions_metadata(),
|
status=True,
|
||||||
"config": processed_config,
|
)
|
||||||
"status": True,
|
tool_ids.append(created["id"])
|
||||||
}
|
self.logger.info(f"Created new tool: {created['id']}")
|
||||||
|
|
||||||
result = self.tools_collection.insert_one(tool_data)
|
|
||||||
tool_ids.append(result.inserted_id)
|
|
||||||
self.logger.info(f"Created new tool: {result.inserted_id}")
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.logger.error(f"Failed to process tool {tool_name}: {str(e)}")
|
self.logger.error(f"Failed to process tool {tool_name}: {str(e)}")
|
||||||
continue
|
continue
|
||||||
return tool_ids
|
return tool_ids
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _find_system_tool(
|
||||||
|
repo: UserToolsRepository, name: str, config: dict
|
||||||
|
) -> Optional[dict]:
|
||||||
|
"""Locate a system-owned tool by (name, config) among existing rows."""
|
||||||
|
existing = repo.find_by_user_and_name(SYSTEM_USER_ID, name)
|
||||||
|
if existing and existing.get("config") == config:
|
||||||
|
return existing
|
||||||
|
return None
|
||||||
|
|
||||||
def _handle_prompt(self, agent_config: Dict) -> Optional[str]:
|
def _handle_prompt(self, agent_config: Dict) -> Optional[str]:
|
||||||
"""Handle prompt creation and return prompt ID"""
|
"""Handle prompt creation and return prompt id (UUID string)."""
|
||||||
if not agent_config.get("prompt"):
|
if not agent_config.get("prompt"):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@@ -222,34 +304,20 @@ class DatabaseSeeder:
|
|||||||
self.logger.info(f"Processing prompt: {prompt_name}")
|
self.logger.info(f"Processing prompt: {prompt_name}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
existing = self.prompts_collection.find_one(
|
with db_session() as conn:
|
||||||
{
|
repo = PromptsRepository(conn)
|
||||||
"user": self.system_user_id,
|
row = repo.find_or_create(
|
||||||
"name": prompt_name,
|
self.system_user_id, prompt_name, prompt_content
|
||||||
"content": prompt_content,
|
)
|
||||||
}
|
prompt_id = str(row["id"])
|
||||||
)
|
self.logger.info(f"Prompt ready: {prompt_id}")
|
||||||
if existing:
|
return prompt_id
|
||||||
self.logger.info(f"Prompt already exists: {existing['_id']}")
|
|
||||||
return str(existing["_id"])
|
|
||||||
|
|
||||||
prompt_data = {
|
|
||||||
"name": prompt_name,
|
|
||||||
"content": prompt_content,
|
|
||||||
"user": self.system_user_id,
|
|
||||||
}
|
|
||||||
|
|
||||||
result = self.prompts_collection.insert_one(prompt_data)
|
|
||||||
prompt_id = str(result.inserted_id)
|
|
||||||
self.logger.info(f"Created new prompt: {prompt_id}")
|
|
||||||
return prompt_id
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.logger.error(f"Failed to process prompt {prompt_name}: {str(e)}")
|
self.logger.error(f"Failed to process prompt {prompt_name}: {str(e)}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _process_config(self, config: Dict) -> Dict:
|
def _process_config(self, config: Dict) -> Dict:
|
||||||
"""Process config values to replace environment variables"""
|
"""Process config values to replace environment variables."""
|
||||||
processed = {}
|
processed = {}
|
||||||
for key, value in config.items():
|
for key, value in config.items():
|
||||||
if (
|
if (
|
||||||
@@ -264,14 +332,18 @@ class DatabaseSeeder:
|
|||||||
return processed
|
return processed
|
||||||
|
|
||||||
def _is_already_seeded(self) -> bool:
|
def _is_already_seeded(self) -> bool:
|
||||||
"""Check if premade agents already exist"""
|
"""Check if premade (system-owned) agents already exist in Postgres."""
|
||||||
return self.agents_collection.count_documents({"user": self.system_user_id}) > 0
|
with db_readonly() as conn:
|
||||||
|
repo = AgentsRepository(conn)
|
||||||
|
return len(repo.list_for_user(SYSTEM_USER_ID)) > 0
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def initialize_from_env(cls, worker=None):
|
def initialize_from_env(cls, worker=None):
|
||||||
"""Factory method to create seeder from environment"""
|
"""Factory method to create seeder from environment.
|
||||||
mongo_uri = os.getenv("MONGO_URI", "mongodb://localhost:27017")
|
|
||||||
db_name = os.getenv("MONGO_DB_NAME", "docsgpt")
|
Retained for back compatibility with existing call sites. The
|
||||||
client = MongoClient(mongo_uri)
|
Postgres connection is resolved lazily via the repository layer
|
||||||
db = client[db_name]
|
(``application.storage.db.engine``), so no explicit wiring is
|
||||||
return cls(db)
|
required here.
|
||||||
|
"""
|
||||||
|
return cls()
|
||||||
|
|||||||
10
application/storage/db/__init__.py
Normal file
10
application/storage/db/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
"""PostgreSQL storage layer for user-level data.
|
||||||
|
|
||||||
|
This package holds the SQLAlchemy Core engine, metadata, repositories, and
|
||||||
|
migration infrastructure for the user-data Postgres database. It is separate
|
||||||
|
from ``application/vectorstore/pgvector.py`` — the two may point at the same
|
||||||
|
cluster or at different clusters depending on operator configuration.
|
||||||
|
|
||||||
|
Repository modules are added in later phases
|
||||||
|
as individual collections are ported.
|
||||||
|
"""
|
||||||
61
application/storage/db/base_repository.py
Normal file
61
application/storage/db/base_repository.py
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
"""Common helpers shared by all repositories.
|
||||||
|
|
||||||
|
Repositories are thin wrappers around SQLAlchemy Core query construction.
|
||||||
|
They take a ``Connection`` on call and return plain ``dict`` rows during the
|
||||||
|
Mongo→Postgres cutover so that call sites don't have to change shape. Once
|
||||||
|
cutover is complete, a follow-up phase may migrate repo return types to
|
||||||
|
Pydantic DTOs (tracked in the migration plan as a post-migration item).
|
||||||
|
"""
|
||||||
|
|
||||||
|
import re
|
||||||
|
from typing import Any, Mapping
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
|
||||||
|
_UUID_RE = re.compile(
|
||||||
|
r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$",
|
||||||
|
re.IGNORECASE,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def looks_like_uuid(value: Any) -> bool:
|
||||||
|
"""Return True if ``value`` is a canonical UUID (string or ``UUID`` instance).
|
||||||
|
|
||||||
|
Used by ``get_any`` accessors to pick the UUID lookup path vs. the
|
||||||
|
``legacy_mongo_id`` fallback during the Mongo→PG cutover window.
|
||||||
|
Accepting ``uuid.UUID`` directly matters for callers that receive an
|
||||||
|
id straight from a PG column (SQLAlchemy maps ``UUID`` columns to the
|
||||||
|
Python ``UUID`` type) — without this, the call falls through to the
|
||||||
|
legacy-text lookup and crashes on ``operator does not exist: text = uuid``.
|
||||||
|
"""
|
||||||
|
if isinstance(value, UUID):
|
||||||
|
return True
|
||||||
|
return isinstance(value, str) and bool(_UUID_RE.match(value))
|
||||||
|
|
||||||
|
|
||||||
|
def row_to_dict(row: Any) -> dict:
|
||||||
|
"""Convert a SQLAlchemy ``Row`` to a plain dict with Mongo-compatible ids.
|
||||||
|
|
||||||
|
During the migration window, API responses and downstream code still
|
||||||
|
expect a string ``_id`` field (matching the Mongo shape). This helper
|
||||||
|
normalizes UUID columns to strings and emits both ``id`` and ``_id`` so
|
||||||
|
existing serializers keep working unchanged.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
row: A SQLAlchemy ``Row`` object, or ``None``.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A plain dict, or an empty dict if ``row`` is ``None``.
|
||||||
|
"""
|
||||||
|
if row is None:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
# Row has a ``._mapping`` attribute exposing a MappingProxy view.
|
||||||
|
mapping: Mapping[str, Any] = row._mapping # type: ignore[attr-defined]
|
||||||
|
out = dict(mapping)
|
||||||
|
|
||||||
|
if "id" in out and out["id"] is not None:
|
||||||
|
out["id"] = str(out["id"]) if isinstance(out["id"], UUID) else out["id"]
|
||||||
|
out["_id"] = out["id"]
|
||||||
|
|
||||||
|
return out
|
||||||
320
application/storage/db/bootstrap.py
Normal file
320
application/storage/db/bootstrap.py
Normal file
@@ -0,0 +1,320 @@
|
|||||||
|
"""Self-bootstrapping database setup for the DocsGPT user-data Postgres DB.
|
||||||
|
|
||||||
|
On app startup the Flask factory (and Celery worker init) can call
|
||||||
|
:func:`ensure_database_ready` to:
|
||||||
|
|
||||||
|
1. Create the target database if it's missing (dev-friendly; requires the
|
||||||
|
configured role to have ``CREATEDB`` privilege).
|
||||||
|
2. Apply every pending Alembic migration up to ``head``.
|
||||||
|
|
||||||
|
Both steps are gated by settings that default ON for dev convenience and
|
||||||
|
can be turned off in prod (``AUTO_CREATE_DB`` / ``AUTO_MIGRATE``) where
|
||||||
|
schema is managed out-of-band by a deploy pipeline.
|
||||||
|
|
||||||
|
All heavy imports (alembic, psycopg, sqlalchemy.exc sub-symbols) are
|
||||||
|
deferred to inside the function so merely importing this module has no
|
||||||
|
side effects and is cheap for test collection.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
|
def ensure_database_ready(
|
||||||
|
uri: Optional[str],
|
||||||
|
*,
|
||||||
|
create_db: bool,
|
||||||
|
migrate: bool,
|
||||||
|
logger: Optional[logging.Logger] = None,
|
||||||
|
) -> None:
|
||||||
|
"""Make sure the target Postgres DB exists and is migrated to ``head``.
|
||||||
|
|
||||||
|
This is idempotent and safe to call once per process. Each step is
|
||||||
|
independently gated so prod deployments that manage schema externally
|
||||||
|
can disable the migrate step while still allowing the process to boot
|
||||||
|
against an already-provisioned database.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
uri: SQLAlchemy URI for the user-data Postgres database. If
|
||||||
|
``None`` or empty, the function logs and returns — the app
|
||||||
|
supports running without a configured URI for certain dev
|
||||||
|
flows that don't touch user data.
|
||||||
|
create_db: If ``True``, auto-create the database when it's
|
||||||
|
missing. Requires the configured role to have ``CREATEDB``.
|
||||||
|
migrate: If ``True``, run ``alembic upgrade head`` after the
|
||||||
|
database is reachable.
|
||||||
|
logger: Optional logger to use. Defaults to this module's logger.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
Exception: Any failure in an explicitly-enabled step is re-raised
|
||||||
|
so the app fails fast rather than booting into a broken state.
|
||||||
|
Missing-role / auth errors surface cleanly without a
|
||||||
|
mis-directed auto-create attempt.
|
||||||
|
"""
|
||||||
|
log = logger or logging.getLogger(__name__)
|
||||||
|
|
||||||
|
if not uri:
|
||||||
|
log.info(
|
||||||
|
"ensure_database_ready: POSTGRES_URI is not set; "
|
||||||
|
"skipping database bootstrap."
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
if create_db:
|
||||||
|
_ensure_database_exists(uri, log)
|
||||||
|
|
||||||
|
if migrate:
|
||||||
|
_run_migrations(log)
|
||||||
|
|
||||||
|
|
||||||
|
def _ensure_database_exists(uri: str, log: logging.Logger) -> None:
|
||||||
|
"""Create the target database if a connection reveals it's missing.
|
||||||
|
|
||||||
|
We probe with a lightweight ``connect().close()``. If Postgres
|
||||||
|
reports ``InvalidCatalogName`` (SQLSTATE ``3D000``), we reconnect to
|
||||||
|
the server's ``postgres`` maintenance DB and issue ``CREATE DATABASE``
|
||||||
|
in AUTOCOMMIT mode (required — CREATE DATABASE can't run in a
|
||||||
|
transaction). Any other connection failure (bad host, auth failure,
|
||||||
|
missing role) is re-raised untouched so the operator sees the true
|
||||||
|
cause instead of a mis-directed auto-create attempt.
|
||||||
|
"""
|
||||||
|
# Lazy imports keep module import side-effect free.
|
||||||
|
from sqlalchemy import create_engine
|
||||||
|
from sqlalchemy.engine import make_url
|
||||||
|
from sqlalchemy.exc import OperationalError
|
||||||
|
|
||||||
|
url = make_url(uri)
|
||||||
|
target_db = url.database
|
||||||
|
if not target_db:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"POSTGRES_URI is missing a database name: {uri!r}. "
|
||||||
|
"Expected something like "
|
||||||
|
"'postgresql+psycopg://user:pass@host:5432/docsgpt'."
|
||||||
|
)
|
||||||
|
|
||||||
|
probe_engine = create_engine(uri, pool_pre_ping=False)
|
||||||
|
try:
|
||||||
|
try:
|
||||||
|
conn = probe_engine.connect()
|
||||||
|
except OperationalError as exc:
|
||||||
|
if _is_missing_database(exc):
|
||||||
|
log.info(
|
||||||
|
"ensure_database_ready: database %r is missing; "
|
||||||
|
"creating it...",
|
||||||
|
target_db,
|
||||||
|
)
|
||||||
|
_create_database(url, target_db, log)
|
||||||
|
log.info("ensure_database_ready: database %r ready.", target_db)
|
||||||
|
return
|
||||||
|
# Not a missing-DB error — surface it as-is. This is the path
|
||||||
|
# for bad host/auth/role-missing, and auto-creating would be
|
||||||
|
# actively wrong there.
|
||||||
|
log.error(
|
||||||
|
"ensure_database_ready: cannot connect to Postgres for "
|
||||||
|
"database %r: %s",
|
||||||
|
target_db,
|
||||||
|
exc,
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
else:
|
||||||
|
conn.close()
|
||||||
|
log.info("ensure_database_ready: database %r ready.", target_db)
|
||||||
|
finally:
|
||||||
|
probe_engine.dispose()
|
||||||
|
|
||||||
|
|
||||||
|
def _create_database(url, target_db: str, log: logging.Logger) -> None:
|
||||||
|
"""Issue ``CREATE DATABASE`` against the server's ``postgres`` DB.
|
||||||
|
|
||||||
|
Uses AUTOCOMMIT (required by Postgres — ``CREATE DATABASE`` cannot run
|
||||||
|
inside a transaction). The database identifier is quoted via
|
||||||
|
``psycopg.sql.Identifier`` so unusual names (hyphens, reserved words)
|
||||||
|
are handled correctly.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
url: Parsed SQLAlchemy URL for the target DB; we reuse
|
||||||
|
host/port/credentials and swap the database to ``postgres``.
|
||||||
|
target_db: The target database name to create.
|
||||||
|
log: Logger for INFO/ERROR breadcrumbs.
|
||||||
|
"""
|
||||||
|
from sqlalchemy import create_engine
|
||||||
|
from sqlalchemy.exc import OperationalError, ProgrammingError
|
||||||
|
|
||||||
|
# psycopg is imported lazily — its error classes are the canonical
|
||||||
|
# cause markers Postgres hands us back.
|
||||||
|
import psycopg
|
||||||
|
from psycopg import sql as pg_sql
|
||||||
|
|
||||||
|
maintenance_url = url.set(database="postgres")
|
||||||
|
maintenance_engine = create_engine(
|
||||||
|
maintenance_url,
|
||||||
|
isolation_level="AUTOCOMMIT",
|
||||||
|
pool_pre_ping=False,
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
with maintenance_engine.connect() as conn:
|
||||||
|
# Use psycopg's Identifier to quote the DB name safely. The
|
||||||
|
# SQL object renders as a literal ``CREATE DATABASE "<name>"``
|
||||||
|
# which SQLAlchemy passes through to psycopg verbatim.
|
||||||
|
stmt = pg_sql.SQL("CREATE DATABASE {}").format(
|
||||||
|
pg_sql.Identifier(target_db)
|
||||||
|
)
|
||||||
|
raw = conn.connection.dbapi_connection # psycopg connection
|
||||||
|
with raw.cursor() as cur:
|
||||||
|
try:
|
||||||
|
cur.execute(stmt)
|
||||||
|
except psycopg.errors.DuplicateDatabase:
|
||||||
|
# Another worker won the race — benign.
|
||||||
|
log.info(
|
||||||
|
"ensure_database_ready: database %r already "
|
||||||
|
"created by a concurrent worker; continuing.",
|
||||||
|
target_db,
|
||||||
|
)
|
||||||
|
except psycopg.errors.InsufficientPrivilege as exc:
|
||||||
|
log.error(
|
||||||
|
"ensure_database_ready: role lacks CREATEDB "
|
||||||
|
"privilege to create %r. Either GRANT CREATEDB "
|
||||||
|
"to the role, create the database manually, or "
|
||||||
|
"set AUTO_CREATE_DB=False and provision it "
|
||||||
|
"out-of-band. See docs/Deploying/Postgres-"
|
||||||
|
"Migration for guidance. Underlying error: %s",
|
||||||
|
target_db,
|
||||||
|
exc,
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
except (OperationalError, ProgrammingError) as exc:
|
||||||
|
log.error(
|
||||||
|
"ensure_database_ready: failed to create database %r: %s. "
|
||||||
|
"See docs/Deploying/Postgres-Migration for manual setup.",
|
||||||
|
target_db,
|
||||||
|
exc,
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
maintenance_engine.dispose()
|
||||||
|
|
||||||
|
|
||||||
|
def _is_missing_database(exc: Exception) -> bool:
|
||||||
|
"""Return True if ``exc`` indicates the target database doesn't exist.
|
||||||
|
|
||||||
|
We check three signals in the cause chain:
|
||||||
|
|
||||||
|
1. ``psycopg.errors.InvalidCatalogName`` — the canonical class for
|
||||||
|
SQLSTATE ``3D000`` when raised during a query.
|
||||||
|
2. ``pgcode`` / ``diag.sqlstate`` equal to ``3D000`` — defensive, for
|
||||||
|
driver versions that surface the code on a generic class.
|
||||||
|
3. The canonical server message phrasing ``database "..." does not
|
||||||
|
exist`` — **required** for connection-time failures, because
|
||||||
|
psycopg 3's ``OperationalError`` raised by ``connect()`` does NOT
|
||||||
|
populate ``sqlstate`` (the connection never completed the protocol
|
||||||
|
handshake, so the attributes stay ``None``). The server's error
|
||||||
|
message itself is stable across Postgres versions, so this is a
|
||||||
|
reliable fallback for the only case that matters: DB missing at
|
||||||
|
boot.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
import psycopg
|
||||||
|
|
||||||
|
invalid_catalog = psycopg.errors.InvalidCatalogName
|
||||||
|
except Exception: # noqa: BLE001 — defensive; never break on import
|
||||||
|
invalid_catalog = None
|
||||||
|
|
||||||
|
seen: set[int] = set()
|
||||||
|
cursor: Optional[BaseException] = exc
|
||||||
|
while cursor is not None and id(cursor) not in seen:
|
||||||
|
seen.add(id(cursor))
|
||||||
|
if invalid_catalog is not None and isinstance(cursor, invalid_catalog):
|
||||||
|
return True
|
||||||
|
pgcode = getattr(cursor, "pgcode", None) or getattr(
|
||||||
|
getattr(cursor, "diag", None), "sqlstate", None
|
||||||
|
)
|
||||||
|
if pgcode == "3D000":
|
||||||
|
return True
|
||||||
|
msg = str(cursor)
|
||||||
|
if 'database "' in msg and "does not exist" in msg:
|
||||||
|
return True
|
||||||
|
cursor = cursor.__cause__ or cursor.__context__
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _run_migrations(log: logging.Logger) -> None:
|
||||||
|
"""Run ``alembic upgrade head`` against ``POSTGRES_URI``.
|
||||||
|
|
||||||
|
Alembic serializes concurrent workers via its ``alembic_version``
|
||||||
|
table, so no extra application-level locking is needed. Failures are
|
||||||
|
logged and re-raised so the app fails fast.
|
||||||
|
"""
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# Lazy imports — alembic pulls in a fair amount of code.
|
||||||
|
from alembic import command
|
||||||
|
from alembic.config import Config
|
||||||
|
from alembic.runtime.migration import MigrationContext
|
||||||
|
from alembic.script import ScriptDirectory
|
||||||
|
from sqlalchemy import create_engine
|
||||||
|
|
||||||
|
# Mirror the discovery path used by scripts/db/init_postgres.py so
|
||||||
|
# both entry points resolve the same alembic.ini regardless of cwd.
|
||||||
|
alembic_ini = Path(__file__).resolve().parents[2] / "alembic.ini"
|
||||||
|
if not alembic_ini.exists():
|
||||||
|
raise RuntimeError(f"alembic.ini not found at {alembic_ini}")
|
||||||
|
|
||||||
|
cfg = Config(str(alembic_ini))
|
||||||
|
cfg.set_main_option("script_location", str(alembic_ini.parent / "alembic"))
|
||||||
|
|
||||||
|
# Cheap pre-check: if we're already at head, say so explicitly.
|
||||||
|
try:
|
||||||
|
script = ScriptDirectory.from_config(cfg)
|
||||||
|
head_rev = script.get_current_head()
|
||||||
|
url = cfg.get_main_option("sqlalchemy.url")
|
||||||
|
# env.py populates sqlalchemy.url from settings.POSTGRES_URI when
|
||||||
|
# it's imported, but our Config instance hasn't loaded env.py
|
||||||
|
# yet. Fall back to reading settings directly for the precheck.
|
||||||
|
if not url:
|
||||||
|
from application.core.settings import settings as _settings
|
||||||
|
|
||||||
|
url = _settings.POSTGRES_URI
|
||||||
|
current_rev: Optional[str] = None
|
||||||
|
if url:
|
||||||
|
precheck_engine = create_engine(url, pool_pre_ping=False)
|
||||||
|
try:
|
||||||
|
with precheck_engine.connect() as conn:
|
||||||
|
ctx = MigrationContext.configure(conn)
|
||||||
|
current_rev = ctx.get_current_revision()
|
||||||
|
finally:
|
||||||
|
precheck_engine.dispose()
|
||||||
|
if current_rev is not None and current_rev == head_rev:
|
||||||
|
log.info(
|
||||||
|
"ensure_database_ready: migrations already at head (%s); "
|
||||||
|
"nothing to do.",
|
||||||
|
head_rev,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
log.info(
|
||||||
|
"ensure_database_ready: applying Alembic migrations "
|
||||||
|
"(current=%s, target=%s)...",
|
||||||
|
current_rev,
|
||||||
|
head_rev,
|
||||||
|
)
|
||||||
|
except Exception as exc: # noqa: BLE001 — precheck is best-effort
|
||||||
|
# If the precheck itself fails we still want to try the upgrade;
|
||||||
|
# alembic will give a more actionable error if something's off.
|
||||||
|
log.info(
|
||||||
|
"ensure_database_ready: revision precheck failed (%s); "
|
||||||
|
"proceeding with upgrade anyway.",
|
||||||
|
exc,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
command.upgrade(cfg, "head")
|
||||||
|
except Exception as exc: # noqa: BLE001 — surface everything
|
||||||
|
log.error(
|
||||||
|
"ensure_database_ready: alembic upgrade failed: %s. "
|
||||||
|
"Check migration logs and DB connectivity; the app will not "
|
||||||
|
"boot until this is resolved (or AUTO_MIGRATE is disabled).",
|
||||||
|
exc,
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
log.info("ensure_database_ready: migrations applied.")
|
||||||
98
application/storage/db/engine.py
Normal file
98
application/storage/db/engine.py
Normal file
@@ -0,0 +1,98 @@
|
|||||||
|
"""SQLAlchemy Core engine factory for the user-data Postgres database.
|
||||||
|
|
||||||
|
The engine is lazily constructed on first use and cached as a module-level
|
||||||
|
singleton. Repositories and the Alembic env module both obtain connections
|
||||||
|
through this factory, so pool tuning lives in one place.
|
||||||
|
|
||||||
|
``POSTGRES_URI`` can be written in any of the common Postgres URI forms::
|
||||||
|
|
||||||
|
postgres://user:pass@host:5432/docsgpt
|
||||||
|
postgresql://user:pass@host:5432/docsgpt
|
||||||
|
|
||||||
|
Both are accepted and normalized internally to the psycopg3 dialect
|
||||||
|
(``postgresql+psycopg://``) by ``application.core.settings``. Operators
|
||||||
|
don't need to know about SQLAlchemy dialect prefixes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from sqlalchemy import Engine, create_engine, event
|
||||||
|
|
||||||
|
from application.core.settings import settings
|
||||||
|
|
||||||
|
_engine: Optional[Engine] = None
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_uri() -> str:
|
||||||
|
"""Return the Postgres URI for user-data tables.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If ``settings.POSTGRES_URI`` is unset. Callers that
|
||||||
|
reach this path without a configured URI have a setup bug — the
|
||||||
|
error message points them at the right setting.
|
||||||
|
"""
|
||||||
|
if not settings.POSTGRES_URI:
|
||||||
|
raise RuntimeError(
|
||||||
|
"POSTGRES_URI is not configured. Set it in your .env to a "
|
||||||
|
"psycopg3 URI such as "
|
||||||
|
"'postgresql+psycopg://user:pass@host:5432/docsgpt'."
|
||||||
|
)
|
||||||
|
return settings.POSTGRES_URI
|
||||||
|
|
||||||
|
|
||||||
|
#: Per-statement wall-clock cap applied to every connection handed out by
|
||||||
|
#: the engine. 30s is generous for interactive hot paths (reads under a few
|
||||||
|
#: hundred ms are normal) but still catches a runaway query before it
|
||||||
|
#: stacks up on PgBouncer or holds locks indefinitely.
|
||||||
|
STATEMENT_TIMEOUT_MS = 30_000
|
||||||
|
|
||||||
|
|
||||||
|
def get_engine() -> Engine:
|
||||||
|
"""Return the process-wide SQLAlchemy Engine, creating it if needed.
|
||||||
|
|
||||||
|
The engine applies a server-side ``statement_timeout`` to every
|
||||||
|
connection it hands out via a ``connect`` event, so both
|
||||||
|
:func:`db_session` and :func:`db_readonly` inherit the same
|
||||||
|
guardrail.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A SQLAlchemy ``Engine`` configured with a pooled connection to
|
||||||
|
Postgres via psycopg3.
|
||||||
|
"""
|
||||||
|
global _engine
|
||||||
|
if _engine is None:
|
||||||
|
_engine = create_engine(
|
||||||
|
_resolve_uri(),
|
||||||
|
pool_size=10,
|
||||||
|
max_overflow=20,
|
||||||
|
pool_pre_ping=True, # survive PgBouncer / idle-disconnect recycles
|
||||||
|
pool_recycle=1800,
|
||||||
|
future=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
@event.listens_for(_engine, "connect")
|
||||||
|
def _apply_session_guardrails(dbapi_conn, _record):
|
||||||
|
# Apply as a SQL ``SET`` (not a libpq ``options=-c ...``
|
||||||
|
# startup parameter) so the engine works behind
|
||||||
|
# PgBouncer-style poolers — notably Neon's ``-pooler``
|
||||||
|
# endpoint, which rejects startup options. Explicit
|
||||||
|
# ``commit()`` so the session-level SET survives SA's
|
||||||
|
# transaction resets on pool return.
|
||||||
|
with dbapi_conn.cursor() as cur:
|
||||||
|
cur.execute(f"SET statement_timeout = {STATEMENT_TIMEOUT_MS}")
|
||||||
|
dbapi_conn.commit()
|
||||||
|
|
||||||
|
return _engine
|
||||||
|
|
||||||
|
|
||||||
|
def dispose_engine() -> None:
|
||||||
|
"""Dispose the pooled connections and reset the singleton.
|
||||||
|
|
||||||
|
Called from the Celery ``worker_process_init`` signal so each forked
|
||||||
|
worker gets a fresh pool instead of sharing file descriptors with the
|
||||||
|
parent process (which corrupts the pool on fork).
|
||||||
|
"""
|
||||||
|
global _engine
|
||||||
|
if _engine is not None:
|
||||||
|
_engine.dispose()
|
||||||
|
_engine = None
|
||||||
432
application/storage/db/models.py
Normal file
432
application/storage/db/models.py
Normal file
@@ -0,0 +1,432 @@
|
|||||||
|
"""SQLAlchemy Core metadata for the user-data Postgres database.
|
||||||
|
|
||||||
|
Tables are added here one at a time as repositories are built during the
|
||||||
|
MongoDB→Postgres migration. The baseline schema in the Alembic migration
|
||||||
|
(``application/alembic/versions/0001_initial.py``) is the source of truth
|
||||||
|
for DDL; the ``Table`` definitions below must match it column-for-column.
|
||||||
|
If the two drift, migrations win — update this file to match.
|
||||||
|
|
||||||
|
Cross-table invariant not expressed in the Core ``Table`` definitions
|
||||||
|
below: every ``user_id`` column is FK-enforced against
|
||||||
|
``users(user_id)`` with ``ON DELETE RESTRICT``, and a
|
||||||
|
``BEFORE INSERT OR UPDATE OF user_id`` trigger on each child table
|
||||||
|
auto-creates the ``users`` row if it does not yet exist. See migration
|
||||||
|
``0015_user_id_fk``. The FKs are intentionally omitted from the Core
|
||||||
|
declarations to keep this file readable; the DB is the authority.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from sqlalchemy import (
|
||||||
|
BigInteger,
|
||||||
|
Boolean,
|
||||||
|
Column,
|
||||||
|
DateTime,
|
||||||
|
ForeignKey,
|
||||||
|
ForeignKeyConstraint,
|
||||||
|
Integer,
|
||||||
|
MetaData,
|
||||||
|
UniqueConstraint,
|
||||||
|
Table,
|
||||||
|
Text,
|
||||||
|
func,
|
||||||
|
)
|
||||||
|
from sqlalchemy.dialects.postgresql import ARRAY, CITEXT, JSONB, UUID
|
||||||
|
|
||||||
|
metadata = MetaData()
|
||||||
|
|
||||||
|
|
||||||
|
# --- Phase 1, Tier 1 --------------------------------------------------------
|
||||||
|
|
||||||
|
users_table = Table(
|
||||||
|
"users",
|
||||||
|
metadata,
|
||||||
|
Column("id", UUID(as_uuid=True), primary_key=True, server_default=func.gen_random_uuid()),
|
||||||
|
Column("user_id", Text, nullable=False, unique=True),
|
||||||
|
Column(
|
||||||
|
"agent_preferences",
|
||||||
|
JSONB,
|
||||||
|
nullable=False,
|
||||||
|
server_default='{"pinned": [], "shared_with_me": []}',
|
||||||
|
),
|
||||||
|
Column("created_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||||
|
Column("updated_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||||
|
)
|
||||||
|
|
||||||
|
prompts_table = Table(
|
||||||
|
"prompts",
|
||||||
|
metadata,
|
||||||
|
Column("id", UUID(as_uuid=True), primary_key=True, server_default=func.gen_random_uuid()),
|
||||||
|
Column("user_id", Text, nullable=False),
|
||||||
|
Column("name", Text, nullable=False),
|
||||||
|
Column("content", Text, nullable=False),
|
||||||
|
Column("created_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||||
|
Column("updated_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||||
|
Column("legacy_mongo_id", Text),
|
||||||
|
)
|
||||||
|
|
||||||
|
user_tools_table = Table(
|
||||||
|
"user_tools",
|
||||||
|
metadata,
|
||||||
|
Column("id", UUID(as_uuid=True), primary_key=True, server_default=func.gen_random_uuid()),
|
||||||
|
Column("user_id", Text, nullable=False),
|
||||||
|
Column("name", Text, nullable=False),
|
||||||
|
Column("custom_name", Text),
|
||||||
|
Column("display_name", Text),
|
||||||
|
Column("description", Text),
|
||||||
|
Column("config", JSONB, nullable=False, server_default="{}"),
|
||||||
|
Column("config_requirements", JSONB, nullable=False, server_default="{}"),
|
||||||
|
Column("actions", JSONB, nullable=False, server_default="[]"),
|
||||||
|
Column("status", Boolean, nullable=False, server_default="true"),
|
||||||
|
Column("created_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||||
|
Column("updated_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||||
|
Column("legacy_mongo_id", Text),
|
||||||
|
)
|
||||||
|
|
||||||
|
token_usage_table = Table(
|
||||||
|
"token_usage",
|
||||||
|
metadata,
|
||||||
|
Column("id", BigInteger, primary_key=True, autoincrement=True),
|
||||||
|
Column("user_id", Text),
|
||||||
|
Column("api_key", Text),
|
||||||
|
Column("agent_id", UUID(as_uuid=True)),
|
||||||
|
Column("prompt_tokens", Integer, nullable=False, server_default="0"),
|
||||||
|
Column("generated_tokens", Integer, nullable=False, server_default="0"),
|
||||||
|
Column("timestamp", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||||
|
)
|
||||||
|
|
||||||
|
user_logs_table = Table(
|
||||||
|
"user_logs",
|
||||||
|
metadata,
|
||||||
|
Column("id", BigInteger, primary_key=True, autoincrement=True),
|
||||||
|
Column("user_id", Text),
|
||||||
|
Column("endpoint", Text),
|
||||||
|
Column("timestamp", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||||
|
Column("data", JSONB),
|
||||||
|
)
|
||||||
|
|
||||||
|
stack_logs_table = Table(
|
||||||
|
"stack_logs",
|
||||||
|
metadata,
|
||||||
|
Column("id", BigInteger, primary_key=True, autoincrement=True),
|
||||||
|
Column("activity_id", Text, nullable=False),
|
||||||
|
Column("endpoint", Text),
|
||||||
|
Column("level", Text),
|
||||||
|
Column("user_id", Text),
|
||||||
|
Column("api_key", Text),
|
||||||
|
Column("query", Text),
|
||||||
|
Column("stacks", JSONB, nullable=False, server_default="[]"),
|
||||||
|
Column("timestamp", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# --- Phase 2, Tier 2 --------------------------------------------------------
|
||||||
|
|
||||||
|
agent_folders_table = Table(
|
||||||
|
"agent_folders",
|
||||||
|
metadata,
|
||||||
|
Column("id", UUID(as_uuid=True), primary_key=True, server_default=func.gen_random_uuid()),
|
||||||
|
Column("user_id", Text, nullable=False),
|
||||||
|
Column("name", Text, nullable=False),
|
||||||
|
Column("description", Text),
|
||||||
|
Column("parent_id", UUID(as_uuid=True), ForeignKey("agent_folders.id", ondelete="SET NULL")),
|
||||||
|
Column("created_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||||
|
Column("updated_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||||
|
Column("legacy_mongo_id", Text),
|
||||||
|
)
|
||||||
|
|
||||||
|
sources_table = Table(
|
||||||
|
"sources",
|
||||||
|
metadata,
|
||||||
|
Column("id", UUID(as_uuid=True), primary_key=True, server_default=func.gen_random_uuid()),
|
||||||
|
Column("user_id", Text, nullable=False),
|
||||||
|
Column("name", Text, nullable=False),
|
||||||
|
Column("language", Text),
|
||||||
|
Column("date", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||||
|
Column("model", Text),
|
||||||
|
Column("type", Text),
|
||||||
|
Column("metadata", JSONB, nullable=False, server_default="{}"),
|
||||||
|
Column("retriever", Text),
|
||||||
|
Column("sync_frequency", Text),
|
||||||
|
Column("tokens", Text),
|
||||||
|
Column("file_path", Text),
|
||||||
|
Column("remote_data", JSONB),
|
||||||
|
Column("directory_structure", JSONB),
|
||||||
|
Column("file_name_map", JSONB),
|
||||||
|
Column("created_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||||
|
Column("updated_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||||
|
Column("legacy_mongo_id", Text),
|
||||||
|
)
|
||||||
|
|
||||||
|
agents_table = Table(
|
||||||
|
"agents",
|
||||||
|
metadata,
|
||||||
|
Column("id", UUID(as_uuid=True), primary_key=True, server_default=func.gen_random_uuid()),
|
||||||
|
Column("user_id", Text, nullable=False),
|
||||||
|
Column("name", Text, nullable=False),
|
||||||
|
Column("description", Text),
|
||||||
|
Column("agent_type", Text),
|
||||||
|
Column("status", Text, nullable=False),
|
||||||
|
Column("key", CITEXT, unique=True),
|
||||||
|
Column("image", Text),
|
||||||
|
Column("source_id", UUID(as_uuid=True), ForeignKey("sources.id", ondelete="SET NULL")),
|
||||||
|
Column("extra_source_ids", ARRAY(UUID(as_uuid=True)), nullable=False, server_default="{}"),
|
||||||
|
Column("chunks", Integer),
|
||||||
|
Column("retriever", Text),
|
||||||
|
Column("prompt_id", UUID(as_uuid=True), ForeignKey("prompts.id", ondelete="SET NULL")),
|
||||||
|
Column("tools", JSONB, nullable=False, server_default="[]"),
|
||||||
|
Column("json_schema", JSONB),
|
||||||
|
Column("models", JSONB),
|
||||||
|
Column("default_model_id", Text),
|
||||||
|
Column("folder_id", UUID(as_uuid=True), ForeignKey("agent_folders.id", ondelete="SET NULL")),
|
||||||
|
Column("workflow_id", UUID(as_uuid=True), ForeignKey("workflows.id", ondelete="SET NULL")),
|
||||||
|
Column("limited_token_mode", Boolean, nullable=False, server_default="false"),
|
||||||
|
Column("token_limit", Integer),
|
||||||
|
Column("limited_request_mode", Boolean, nullable=False, server_default="false"),
|
||||||
|
Column("request_limit", Integer),
|
||||||
|
Column("allow_system_prompt_override", Boolean, nullable=False, server_default="false"),
|
||||||
|
Column("shared", Boolean, nullable=False, server_default="false"),
|
||||||
|
Column("shared_token", CITEXT, unique=True),
|
||||||
|
Column("shared_metadata", JSONB),
|
||||||
|
Column("incoming_webhook_token", CITEXT, unique=True),
|
||||||
|
Column("created_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||||
|
Column("updated_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||||
|
Column("last_used_at", DateTime(timezone=True)),
|
||||||
|
Column("legacy_mongo_id", Text),
|
||||||
|
)
|
||||||
|
|
||||||
|
attachments_table = Table(
|
||||||
|
"attachments",
|
||||||
|
metadata,
|
||||||
|
Column("id", UUID(as_uuid=True), primary_key=True, server_default=func.gen_random_uuid()),
|
||||||
|
Column("user_id", Text, nullable=False),
|
||||||
|
Column("filename", Text, nullable=False),
|
||||||
|
Column("upload_path", Text, nullable=False),
|
||||||
|
Column("mime_type", Text),
|
||||||
|
Column("size", BigInteger),
|
||||||
|
Column("content", Text),
|
||||||
|
Column("token_count", Integer),
|
||||||
|
Column("openai_file_id", Text),
|
||||||
|
Column("google_file_uri", Text),
|
||||||
|
Column("metadata", JSONB),
|
||||||
|
Column("created_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||||
|
Column("legacy_mongo_id", Text),
|
||||||
|
)
|
||||||
|
|
||||||
|
memories_table = Table(
|
||||||
|
"memories",
|
||||||
|
metadata,
|
||||||
|
Column("id", UUID(as_uuid=True), primary_key=True, server_default=func.gen_random_uuid()),
|
||||||
|
Column("user_id", Text, nullable=False),
|
||||||
|
Column("tool_id", UUID(as_uuid=True), ForeignKey("user_tools.id", ondelete="CASCADE")),
|
||||||
|
Column("path", Text, nullable=False),
|
||||||
|
Column("content", Text, nullable=False),
|
||||||
|
Column("created_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||||
|
Column("updated_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||||
|
UniqueConstraint("user_id", "tool_id", "path", name="memories_user_tool_path_uidx"),
|
||||||
|
)
|
||||||
|
|
||||||
|
todos_table = Table(
|
||||||
|
"todos",
|
||||||
|
metadata,
|
||||||
|
Column("id", UUID(as_uuid=True), primary_key=True, server_default=func.gen_random_uuid()),
|
||||||
|
Column("user_id", Text, nullable=False),
|
||||||
|
Column("tool_id", UUID(as_uuid=True), ForeignKey("user_tools.id", ondelete="CASCADE")),
|
||||||
|
Column("todo_id", Integer),
|
||||||
|
Column("title", Text, nullable=False),
|
||||||
|
Column("completed", Boolean, nullable=False, server_default="false"),
|
||||||
|
Column("created_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||||
|
Column("updated_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||||
|
Column("legacy_mongo_id", Text),
|
||||||
|
)
|
||||||
|
|
||||||
|
notes_table = Table(
|
||||||
|
"notes",
|
||||||
|
metadata,
|
||||||
|
Column("id", UUID(as_uuid=True), primary_key=True, server_default=func.gen_random_uuid()),
|
||||||
|
Column("user_id", Text, nullable=False),
|
||||||
|
Column("tool_id", UUID(as_uuid=True), ForeignKey("user_tools.id", ondelete="CASCADE")),
|
||||||
|
Column("title", Text, nullable=False),
|
||||||
|
Column("content", Text, nullable=False),
|
||||||
|
Column("created_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||||
|
Column("updated_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||||
|
UniqueConstraint("user_id", "tool_id", name="notes_user_tool_uidx"),
|
||||||
|
)
|
||||||
|
|
||||||
|
connector_sessions_table = Table(
|
||||||
|
"connector_sessions",
|
||||||
|
metadata,
|
||||||
|
Column("id", UUID(as_uuid=True), primary_key=True, server_default=func.gen_random_uuid()),
|
||||||
|
Column("user_id", Text, nullable=False),
|
||||||
|
Column("provider", Text, nullable=False),
|
||||||
|
Column("server_url", Text),
|
||||||
|
Column("session_token", Text, unique=True),
|
||||||
|
Column("user_email", Text),
|
||||||
|
Column("status", Text),
|
||||||
|
Column("token_info", JSONB),
|
||||||
|
Column("session_data", JSONB, nullable=False, server_default="{}"),
|
||||||
|
Column("expires_at", DateTime(timezone=True)),
|
||||||
|
Column("created_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||||
|
Column("legacy_mongo_id", Text),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# --- Phase 3, Tier 3 --------------------------------------------------------
|
||||||
|
|
||||||
|
conversations_table = Table(
|
||||||
|
"conversations",
|
||||||
|
metadata,
|
||||||
|
Column("id", UUID(as_uuid=True), primary_key=True, server_default=func.gen_random_uuid()),
|
||||||
|
Column("user_id", Text, nullable=False),
|
||||||
|
Column("agent_id", UUID(as_uuid=True), ForeignKey("agents.id", ondelete="SET NULL")),
|
||||||
|
Column("name", Text),
|
||||||
|
Column("api_key", Text),
|
||||||
|
Column("is_shared_usage", Boolean, nullable=False, server_default="false"),
|
||||||
|
Column("shared_token", Text),
|
||||||
|
Column("shared_with", ARRAY(Text), nullable=False, server_default="{}"),
|
||||||
|
Column("compression_metadata", JSONB),
|
||||||
|
Column("date", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||||
|
Column("created_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||||
|
Column("updated_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||||
|
Column("legacy_mongo_id", Text),
|
||||||
|
)
|
||||||
|
|
||||||
|
conversation_messages_table = Table(
|
||||||
|
"conversation_messages",
|
||||||
|
metadata,
|
||||||
|
Column("id", UUID(as_uuid=True), primary_key=True, server_default=func.gen_random_uuid()),
|
||||||
|
Column("conversation_id", UUID(as_uuid=True), ForeignKey("conversations.id", ondelete="CASCADE"), nullable=False),
|
||||||
|
# Denormalised from conversations.user_id. Auto-filled on insert by a
|
||||||
|
# BEFORE INSERT trigger when the caller omits it. See migration 0020.
|
||||||
|
Column("user_id", Text, nullable=False),
|
||||||
|
Column("position", Integer, nullable=False),
|
||||||
|
Column("prompt", Text),
|
||||||
|
Column("response", Text),
|
||||||
|
Column("thought", Text),
|
||||||
|
Column("sources", JSONB, nullable=False, server_default="[]"),
|
||||||
|
Column("tool_calls", JSONB, nullable=False, server_default="[]"),
|
||||||
|
# Postgres cannot FK-enforce array elements, so the referential
|
||||||
|
# invariant is kept by an AFTER DELETE trigger on ``attachments``
|
||||||
|
# that array_removes the id from every row that references it.
|
||||||
|
# See migration 0017_cleanup_dangling_refs.
|
||||||
|
Column("attachments", ARRAY(UUID(as_uuid=True)), nullable=False, server_default="{}"),
|
||||||
|
Column("model_id", Text),
|
||||||
|
# Renamed from ``metadata`` in migration 0016 to avoid SQLAlchemy's
|
||||||
|
# reserved attribute collision on declarative models. The repository
|
||||||
|
# translates this ↔ API dict key ``metadata`` so external callers
|
||||||
|
# still see ``metadata``.
|
||||||
|
Column("message_metadata", JSONB, nullable=False, server_default="{}"),
|
||||||
|
Column("feedback", JSONB),
|
||||||
|
Column("timestamp", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||||
|
Column("updated_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||||
|
UniqueConstraint("conversation_id", "position", name="conversation_messages_conv_pos_uidx"),
|
||||||
|
)
|
||||||
|
|
||||||
|
shared_conversations_table = Table(
|
||||||
|
"shared_conversations",
|
||||||
|
metadata,
|
||||||
|
Column("id", UUID(as_uuid=True), primary_key=True, server_default=func.gen_random_uuid()),
|
||||||
|
Column("uuid", UUID(as_uuid=True), nullable=False, unique=True),
|
||||||
|
Column("conversation_id", UUID(as_uuid=True), ForeignKey("conversations.id", ondelete="CASCADE"), nullable=False),
|
||||||
|
Column("user_id", Text, nullable=False),
|
||||||
|
Column("prompt_id", UUID(as_uuid=True), ForeignKey("prompts.id", ondelete="SET NULL")),
|
||||||
|
Column("chunks", Integer),
|
||||||
|
Column("is_promptable", Boolean, nullable=False, server_default="false"),
|
||||||
|
Column("first_n_queries", Integer, nullable=False, server_default="0"),
|
||||||
|
Column("api_key", Text),
|
||||||
|
Column("created_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||||
|
)
|
||||||
|
|
||||||
|
pending_tool_state_table = Table(
|
||||||
|
"pending_tool_state",
|
||||||
|
metadata,
|
||||||
|
Column("id", UUID(as_uuid=True), primary_key=True, server_default=func.gen_random_uuid()),
|
||||||
|
Column("conversation_id", UUID(as_uuid=True), ForeignKey("conversations.id", ondelete="CASCADE"), nullable=False),
|
||||||
|
Column("user_id", Text, nullable=False),
|
||||||
|
Column("messages", JSONB, nullable=False),
|
||||||
|
Column("pending_tool_calls", JSONB, nullable=False),
|
||||||
|
Column("tools_dict", JSONB, nullable=False),
|
||||||
|
Column("tool_schemas", JSONB, nullable=False),
|
||||||
|
Column("agent_config", JSONB, nullable=False),
|
||||||
|
Column("client_tools", JSONB),
|
||||||
|
Column("created_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||||
|
Column("expires_at", DateTime(timezone=True), nullable=False),
|
||||||
|
UniqueConstraint("conversation_id", "user_id", name="pending_tool_state_conv_user_uidx"),
|
||||||
|
)
|
||||||
|
|
||||||
|
workflows_table = Table(
|
||||||
|
"workflows",
|
||||||
|
metadata,
|
||||||
|
Column("id", UUID(as_uuid=True), primary_key=True, server_default=func.gen_random_uuid()),
|
||||||
|
Column("user_id", Text, nullable=False),
|
||||||
|
Column("name", Text, nullable=False),
|
||||||
|
Column("description", Text),
|
||||||
|
Column("current_graph_version", Integer, nullable=False, server_default="1"),
|
||||||
|
Column("created_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||||
|
Column("updated_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||||
|
Column("legacy_mongo_id", Text),
|
||||||
|
)
|
||||||
|
|
||||||
|
workflow_nodes_table = Table(
|
||||||
|
"workflow_nodes",
|
||||||
|
metadata,
|
||||||
|
Column("id", UUID(as_uuid=True), primary_key=True, server_default=func.gen_random_uuid()),
|
||||||
|
Column("workflow_id", UUID(as_uuid=True), ForeignKey("workflows.id", ondelete="CASCADE"), nullable=False),
|
||||||
|
Column("graph_version", Integer, nullable=False),
|
||||||
|
Column("node_id", Text, nullable=False),
|
||||||
|
Column("node_type", Text, nullable=False),
|
||||||
|
Column("title", Text),
|
||||||
|
Column("description", Text),
|
||||||
|
Column("position", JSONB, nullable=False, server_default='{"x": 0, "y": 0}'),
|
||||||
|
Column("config", JSONB, nullable=False, server_default="{}"),
|
||||||
|
Column("legacy_mongo_id", Text),
|
||||||
|
# Composite UNIQUE so workflow_edges can use a composite FK that
|
||||||
|
# enforces endpoint nodes belong to the same (workflow, version) as
|
||||||
|
# the edge itself. See migration 0008.
|
||||||
|
UniqueConstraint(
|
||||||
|
"id", "workflow_id", "graph_version",
|
||||||
|
name="workflow_nodes_id_wf_ver_key",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
workflow_edges_table = Table(
|
||||||
|
"workflow_edges",
|
||||||
|
metadata,
|
||||||
|
Column("id", UUID(as_uuid=True), primary_key=True, server_default=func.gen_random_uuid()),
|
||||||
|
Column("workflow_id", UUID(as_uuid=True), ForeignKey("workflows.id", ondelete="CASCADE"), nullable=False),
|
||||||
|
Column("graph_version", Integer, nullable=False),
|
||||||
|
Column("edge_id", Text, nullable=False),
|
||||||
|
Column("from_node_id", UUID(as_uuid=True), nullable=False),
|
||||||
|
Column("to_node_id", UUID(as_uuid=True), nullable=False),
|
||||||
|
Column("source_handle", Text),
|
||||||
|
Column("target_handle", Text),
|
||||||
|
Column("config", JSONB, nullable=False, server_default="{}"),
|
||||||
|
# Composite FKs: endpoints must belong to the same (workflow, version)
|
||||||
|
# as the edge. Prevents cross-workflow / cross-version edges that the
|
||||||
|
# single-column FKs couldn't catch. See migration 0008.
|
||||||
|
ForeignKeyConstraint(
|
||||||
|
["from_node_id", "workflow_id", "graph_version"],
|
||||||
|
["workflow_nodes.id", "workflow_nodes.workflow_id", "workflow_nodes.graph_version"],
|
||||||
|
ondelete="CASCADE",
|
||||||
|
name="workflow_edges_from_node_fk",
|
||||||
|
),
|
||||||
|
ForeignKeyConstraint(
|
||||||
|
["to_node_id", "workflow_id", "graph_version"],
|
||||||
|
["workflow_nodes.id", "workflow_nodes.workflow_id", "workflow_nodes.graph_version"],
|
||||||
|
ondelete="CASCADE",
|
||||||
|
name="workflow_edges_to_node_fk",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
workflow_runs_table = Table(
|
||||||
|
"workflow_runs",
|
||||||
|
metadata,
|
||||||
|
Column("id", UUID(as_uuid=True), primary_key=True, server_default=func.gen_random_uuid()),
|
||||||
|
Column("workflow_id", UUID(as_uuid=True), ForeignKey("workflows.id", ondelete="CASCADE"), nullable=False),
|
||||||
|
Column("user_id", Text, nullable=False),
|
||||||
|
Column("status", Text, nullable=False),
|
||||||
|
Column("inputs", JSONB),
|
||||||
|
Column("result", JSONB),
|
||||||
|
Column("steps", JSONB, nullable=False, server_default="[]"),
|
||||||
|
Column("started_at", DateTime(timezone=True), nullable=False, server_default=func.now()),
|
||||||
|
Column("ended_at", DateTime(timezone=True)),
|
||||||
|
Column("legacy_mongo_id", Text),
|
||||||
|
)
|
||||||
11
application/storage/db/repositories/__init__.py
Normal file
11
application/storage/db/repositories/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
"""Repositories for the user-data Postgres database.
|
||||||
|
|
||||||
|
Each module in this package exposes exactly one repository class. Repository
|
||||||
|
methods take a ``Connection`` (either as a constructor argument or as a
|
||||||
|
method argument) and return plain ``dict`` rows via
|
||||||
|
``application.storage.db.base_repository.row_to_dict`` during the
|
||||||
|
MongoDB→Postgres cutover, so call sites don't have to change shape.
|
||||||
|
|
||||||
|
Repositories are added one collection at a time, matching the phased
|
||||||
|
rollout in ``migration-postgres.md``.
|
||||||
|
"""
|
||||||
140
application/storage/db/repositories/agent_folders.py
Normal file
140
application/storage/db/repositories/agent_folders.py
Normal file
@@ -0,0 +1,140 @@
|
|||||||
|
"""Repository for the ``agent_folders`` table.
|
||||||
|
|
||||||
|
Folders are self-referential via ``parent_id`` to model nested folder
|
||||||
|
hierarchies — a folder can sit inside another folder, and on delete the
|
||||||
|
DB sets each child's ``parent_id`` to NULL (no cascade) so children
|
||||||
|
survive their parent's removal but flatten to the top level. The legacy
|
||||||
|
Mongo route used ``$unset: {parent_id: ""}`` against children before
|
||||||
|
deleting the parent; that pre-step is no longer needed because the FK
|
||||||
|
``ON DELETE SET NULL`` action does it automatically.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
from sqlalchemy import Connection, func, text
|
||||||
|
|
||||||
|
from application.storage.db.base_repository import row_to_dict
|
||||||
|
from application.storage.db.models import agent_folders_table
|
||||||
|
|
||||||
|
|
||||||
|
_ALLOWED_UPDATE_COLUMNS = {"name", "description", "parent_id"}
|
||||||
|
|
||||||
|
|
||||||
|
class AgentFoldersRepository:
|
||||||
|
def __init__(self, conn: Connection) -> None:
|
||||||
|
self._conn = conn
|
||||||
|
|
||||||
|
def create(
|
||||||
|
self,
|
||||||
|
user_id: str,
|
||||||
|
name: str,
|
||||||
|
*,
|
||||||
|
description: Optional[str] = None,
|
||||||
|
parent_id: Optional[str] = None,
|
||||||
|
legacy_mongo_id: Optional[str] = None,
|
||||||
|
) -> dict:
|
||||||
|
result = self._conn.execute(
|
||||||
|
text(
|
||||||
|
"""
|
||||||
|
INSERT INTO agent_folders (
|
||||||
|
user_id, name, description, parent_id, legacy_mongo_id
|
||||||
|
)
|
||||||
|
VALUES (
|
||||||
|
:user_id, :name, :description,
|
||||||
|
CAST(:parent_id AS uuid), :legacy_mongo_id
|
||||||
|
)
|
||||||
|
RETURNING *
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
{
|
||||||
|
"user_id": user_id,
|
||||||
|
"name": name,
|
||||||
|
"description": description,
|
||||||
|
"parent_id": str(parent_id) if parent_id else None,
|
||||||
|
"legacy_mongo_id": legacy_mongo_id,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return row_to_dict(result.fetchone())
|
||||||
|
|
||||||
|
def get(self, folder_id: str, user_id: str) -> Optional[dict]:
|
||||||
|
result = self._conn.execute(
|
||||||
|
text("SELECT * FROM agent_folders WHERE id = CAST(:id AS uuid) AND user_id = :user_id"),
|
||||||
|
{"id": folder_id, "user_id": user_id},
|
||||||
|
)
|
||||||
|
row = result.fetchone()
|
||||||
|
return row_to_dict(row) if row is not None else None
|
||||||
|
|
||||||
|
def get_by_legacy_id(
|
||||||
|
self, legacy_mongo_id: str, user_id: Optional[str] = None
|
||||||
|
) -> Optional[dict]:
|
||||||
|
legacy_mongo_id = str(legacy_mongo_id) if legacy_mongo_id is not None else None
|
||||||
|
sql = "SELECT * FROM agent_folders WHERE legacy_mongo_id = :legacy_id"
|
||||||
|
params: dict[str, str] = {"legacy_id": legacy_mongo_id}
|
||||||
|
if user_id is not None:
|
||||||
|
sql += " AND user_id = :user_id"
|
||||||
|
params["user_id"] = user_id
|
||||||
|
result = self._conn.execute(text(sql), params)
|
||||||
|
row = result.fetchone()
|
||||||
|
return row_to_dict(row) if row is not None else None
|
||||||
|
|
||||||
|
def list_for_user(self, user_id: str) -> list[dict]:
|
||||||
|
result = self._conn.execute(
|
||||||
|
text("SELECT * FROM agent_folders WHERE user_id = :user_id ORDER BY created_at"),
|
||||||
|
{"user_id": user_id},
|
||||||
|
)
|
||||||
|
return [row_to_dict(r) for r in result.fetchall()]
|
||||||
|
|
||||||
|
def list_children(self, parent_id: str, user_id: str) -> list[dict]:
|
||||||
|
"""List immediate children of ``parent_id`` for nested-folder UIs."""
|
||||||
|
result = self._conn.execute(
|
||||||
|
text(
|
||||||
|
"SELECT * FROM agent_folders "
|
||||||
|
"WHERE parent_id = CAST(:parent_id AS uuid) AND user_id = :user_id "
|
||||||
|
"ORDER BY created_at"
|
||||||
|
),
|
||||||
|
{"parent_id": parent_id, "user_id": user_id},
|
||||||
|
)
|
||||||
|
return [row_to_dict(r) for r in result.fetchall()]
|
||||||
|
|
||||||
|
def update(self, folder_id: str, user_id: str, fields: dict[str, Any]) -> bool:
|
||||||
|
"""Partial update.
|
||||||
|
|
||||||
|
The route validates that ``parent_id != folder_id`` (no self-parenting)
|
||||||
|
before calling here; this layer does not re-check.
|
||||||
|
"""
|
||||||
|
filtered = {k: v for k, v in fields.items() if k in _ALLOWED_UPDATE_COLUMNS}
|
||||||
|
if not filtered:
|
||||||
|
return False
|
||||||
|
|
||||||
|
values: dict = {}
|
||||||
|
for col, val in filtered.items():
|
||||||
|
if col == "parent_id":
|
||||||
|
values[col] = str(val) if val else None
|
||||||
|
else:
|
||||||
|
values[col] = val
|
||||||
|
values["updated_at"] = func.now()
|
||||||
|
|
||||||
|
t = agent_folders_table
|
||||||
|
stmt = (
|
||||||
|
t.update()
|
||||||
|
.where(t.c.id == folder_id)
|
||||||
|
.where(t.c.user_id == user_id)
|
||||||
|
.values(**values)
|
||||||
|
)
|
||||||
|
result = self._conn.execute(stmt)
|
||||||
|
return result.rowcount > 0
|
||||||
|
|
||||||
|
def delete(self, folder_id: str, user_id: str) -> bool:
|
||||||
|
"""Delete a folder.
|
||||||
|
|
||||||
|
The schema's ``ON DELETE SET NULL`` on the self-FK takes care of
|
||||||
|
un-parenting any child folders, and the agents table's
|
||||||
|
``folder_id`` FK does the same for agents in the folder.
|
||||||
|
"""
|
||||||
|
result = self._conn.execute(
|
||||||
|
text("DELETE FROM agent_folders WHERE id = CAST(:id AS uuid) AND user_id = :user_id"),
|
||||||
|
{"id": folder_id, "user_id": user_id},
|
||||||
|
)
|
||||||
|
return result.rowcount > 0
|
||||||
250
application/storage/db/repositories/agents.py
Normal file
250
application/storage/db/repositories/agents.py
Normal file
@@ -0,0 +1,250 @@
|
|||||||
|
"""Repository for the ``agents`` table.
|
||||||
|
|
||||||
|
This is the most complex Phase 2 repository. Covers every write operation
|
||||||
|
the legacy Mongo code performs on ``agents_collection``:
|
||||||
|
|
||||||
|
- create, update, delete
|
||||||
|
- find by key (API key lookup)
|
||||||
|
- find by webhook token
|
||||||
|
- list for user, list templates
|
||||||
|
- folder assignment
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from sqlalchemy import Connection, func, text
|
||||||
|
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||||
|
|
||||||
|
from application.storage.db.base_repository import looks_like_uuid, row_to_dict
|
||||||
|
from application.storage.db.models import agents_table
|
||||||
|
|
||||||
|
|
||||||
|
class AgentsRepository:
|
||||||
|
def __init__(self, conn: Connection) -> None:
|
||||||
|
self._conn = conn
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _normalize_unique_text(col: str, val):
|
||||||
|
"""Coerce blank strings for nullable unique text columns to NULL."""
|
||||||
|
if col == "key" and val == "":
|
||||||
|
return None
|
||||||
|
return val
|
||||||
|
|
||||||
|
def create(self, user_id: str, name: str, status: str, **kwargs) -> dict:
|
||||||
|
values: dict = {"user_id": user_id, "name": name, "status": status}
|
||||||
|
|
||||||
|
_ALLOWED = {
|
||||||
|
"description", "agent_type", "key", "retriever",
|
||||||
|
"default_model_id", "incoming_webhook_token",
|
||||||
|
"source_id", "prompt_id", "folder_id", "workflow_id",
|
||||||
|
"extra_source_ids", "image",
|
||||||
|
"chunks", "token_limit", "request_limit",
|
||||||
|
"limited_token_mode", "limited_request_mode",
|
||||||
|
"allow_system_prompt_override",
|
||||||
|
"shared", "shared_token", "shared_metadata",
|
||||||
|
"tools", "json_schema", "models", "legacy_mongo_id",
|
||||||
|
"created_at", "updated_at", "last_used_at",
|
||||||
|
}
|
||||||
|
|
||||||
|
for col, val in kwargs.items():
|
||||||
|
if col not in _ALLOWED or val is None:
|
||||||
|
continue
|
||||||
|
if col in ("tools", "json_schema", "models", "shared_metadata"):
|
||||||
|
# JSONB columns: pass the Python object directly. SQLAlchemy
|
||||||
|
# Core's JSONB type processor json.dumps it once during
|
||||||
|
# bind; pre-serialising would double-encode and the value
|
||||||
|
# would round-trip as a JSON string instead of the dict.
|
||||||
|
values[col] = val
|
||||||
|
elif col in ("chunks", "token_limit", "request_limit"):
|
||||||
|
values[col] = int(val)
|
||||||
|
elif col in (
|
||||||
|
"limited_token_mode", "limited_request_mode",
|
||||||
|
"shared", "allow_system_prompt_override",
|
||||||
|
):
|
||||||
|
values[col] = bool(val)
|
||||||
|
elif col in ("source_id", "prompt_id", "folder_id", "workflow_id"):
|
||||||
|
values[col] = str(val)
|
||||||
|
elif col == "extra_source_ids":
|
||||||
|
# ARRAY(UUID) — pass list of strings; psycopg adapts it.
|
||||||
|
values[col] = [str(x) for x in val] if val else []
|
||||||
|
else:
|
||||||
|
values[col] = self._normalize_unique_text(col, val)
|
||||||
|
|
||||||
|
stmt = pg_insert(agents_table).values(**values).returning(agents_table)
|
||||||
|
result = self._conn.execute(stmt)
|
||||||
|
return row_to_dict(result.fetchone())
|
||||||
|
|
||||||
|
def get(self, agent_id: str, user_id: str) -> Optional[dict]:
|
||||||
|
result = self._conn.execute(
|
||||||
|
text("SELECT * FROM agents WHERE id = CAST(:id AS uuid) AND user_id = :user_id"),
|
||||||
|
{"id": agent_id, "user_id": user_id},
|
||||||
|
)
|
||||||
|
row = result.fetchone()
|
||||||
|
return row_to_dict(row) if row is not None else None
|
||||||
|
|
||||||
|
def get_any(self, agent_id: str, user_id: str) -> Optional[dict]:
|
||||||
|
"""Resolve an agent by either PG UUID or legacy Mongo ObjectId string.
|
||||||
|
|
||||||
|
Cutover helper: URLs / bookmarks / old client state may still hold
|
||||||
|
Mongo ObjectId-strings. Try the UUID path first (the post-cutover
|
||||||
|
shape) and fall back to ``legacy_mongo_id`` — both are scoped by
|
||||||
|
``user_id`` so cross-user access is impossible.
|
||||||
|
"""
|
||||||
|
if looks_like_uuid(agent_id):
|
||||||
|
row = self.get(agent_id, user_id)
|
||||||
|
if row is not None:
|
||||||
|
return row
|
||||||
|
return self.get_by_legacy_id(agent_id, user_id)
|
||||||
|
|
||||||
|
def get_by_legacy_id(self, legacy_mongo_id: str, user_id: str | None = None) -> Optional[dict]:
|
||||||
|
"""Fetch an agent by the original Mongo ObjectId string."""
|
||||||
|
legacy_mongo_id = str(legacy_mongo_id) if legacy_mongo_id is not None else None
|
||||||
|
sql = "SELECT * FROM agents WHERE legacy_mongo_id = :legacy_id"
|
||||||
|
params: dict[str, str] = {"legacy_id": legacy_mongo_id}
|
||||||
|
if user_id is not None:
|
||||||
|
sql += " AND user_id = :user_id"
|
||||||
|
params["user_id"] = user_id
|
||||||
|
result = self._conn.execute(text(sql), params)
|
||||||
|
row = result.fetchone()
|
||||||
|
return row_to_dict(row) if row is not None else None
|
||||||
|
|
||||||
|
def find_by_key(self, key: str) -> Optional[dict]:
|
||||||
|
result = self._conn.execute(
|
||||||
|
text("SELECT * FROM agents WHERE key = :key"),
|
||||||
|
{"key": key},
|
||||||
|
)
|
||||||
|
row = result.fetchone()
|
||||||
|
return row_to_dict(row) if row is not None else None
|
||||||
|
|
||||||
|
def find_by_shared_token(self, token: str) -> Optional[dict]:
|
||||||
|
"""Resolve a publicly-shared agent by its rotating share token.
|
||||||
|
|
||||||
|
Only returns rows with ``shared = true`` so revoking a share
|
||||||
|
(setting ``shared = false``) immediately stops token access even
|
||||||
|
if the token value itself is still in the row.
|
||||||
|
"""
|
||||||
|
result = self._conn.execute(
|
||||||
|
text(
|
||||||
|
"SELECT * FROM agents "
|
||||||
|
"WHERE shared_token = :token AND shared = true"
|
||||||
|
),
|
||||||
|
{"token": token},
|
||||||
|
)
|
||||||
|
row = result.fetchone()
|
||||||
|
return row_to_dict(row) if row is not None else None
|
||||||
|
|
||||||
|
def find_by_webhook_token(self, token: str) -> Optional[dict]:
|
||||||
|
result = self._conn.execute(
|
||||||
|
text("SELECT * FROM agents WHERE incoming_webhook_token = :token"),
|
||||||
|
{"token": token},
|
||||||
|
)
|
||||||
|
row = result.fetchone()
|
||||||
|
return row_to_dict(row) if row is not None else None
|
||||||
|
|
||||||
|
def list_for_user(self, user_id: str) -> list[dict]:
|
||||||
|
result = self._conn.execute(
|
||||||
|
text("SELECT * FROM agents WHERE user_id = :user_id ORDER BY created_at DESC"),
|
||||||
|
{"user_id": user_id},
|
||||||
|
)
|
||||||
|
return [row_to_dict(r) for r in result.fetchall()]
|
||||||
|
|
||||||
|
def list_templates(self) -> list[dict]:
|
||||||
|
result = self._conn.execute(
|
||||||
|
text("SELECT * FROM agents WHERE user_id = 'system' ORDER BY name"),
|
||||||
|
)
|
||||||
|
return [row_to_dict(r) for r in result.fetchall()]
|
||||||
|
|
||||||
|
def update(self, agent_id: str, user_id: str, fields: dict) -> bool:
|
||||||
|
allowed = {
|
||||||
|
"name", "description", "agent_type", "status", "key", "source_id",
|
||||||
|
"chunks", "retriever", "prompt_id", "tools", "json_schema", "models",
|
||||||
|
"default_model_id", "folder_id", "workflow_id",
|
||||||
|
"extra_source_ids", "image",
|
||||||
|
"limited_token_mode", "token_limit",
|
||||||
|
"limited_request_mode", "request_limit",
|
||||||
|
"allow_system_prompt_override",
|
||||||
|
"shared", "shared_token", "shared_metadata",
|
||||||
|
"incoming_webhook_token", "last_used_at",
|
||||||
|
}
|
||||||
|
filtered = {k: v for k, v in fields.items() if k in allowed}
|
||||||
|
if not filtered:
|
||||||
|
return False
|
||||||
|
|
||||||
|
values: dict = {}
|
||||||
|
for col, val in filtered.items():
|
||||||
|
if col in ("tools", "json_schema", "models", "shared_metadata"):
|
||||||
|
# See note in create(): JSONB columns receive Python
|
||||||
|
# objects, the type processor handles serialisation.
|
||||||
|
values[col] = val
|
||||||
|
elif col in ("source_id", "prompt_id", "folder_id", "workflow_id"):
|
||||||
|
values[col] = str(val) if val else None
|
||||||
|
elif col == "extra_source_ids":
|
||||||
|
values[col] = [str(x) for x in val] if val else []
|
||||||
|
elif col in (
|
||||||
|
"limited_token_mode", "limited_request_mode",
|
||||||
|
"shared", "allow_system_prompt_override",
|
||||||
|
):
|
||||||
|
values[col] = bool(val)
|
||||||
|
else:
|
||||||
|
values[col] = self._normalize_unique_text(col, val)
|
||||||
|
values["updated_at"] = func.now()
|
||||||
|
|
||||||
|
t = agents_table
|
||||||
|
stmt = (
|
||||||
|
t.update()
|
||||||
|
.where(t.c.id == agent_id)
|
||||||
|
.where(t.c.user_id == user_id)
|
||||||
|
.values(**values)
|
||||||
|
)
|
||||||
|
result = self._conn.execute(stmt)
|
||||||
|
return result.rowcount > 0
|
||||||
|
|
||||||
|
def update_by_legacy_id(self, legacy_mongo_id: str, user_id: str, fields: dict) -> bool:
|
||||||
|
"""Update an agent addressed by the Mongo ObjectId string."""
|
||||||
|
legacy_mongo_id = str(legacy_mongo_id) if legacy_mongo_id is not None else None
|
||||||
|
agent = self.get_by_legacy_id(legacy_mongo_id, user_id)
|
||||||
|
if agent is None:
|
||||||
|
return False
|
||||||
|
return self.update(agent["id"], user_id, fields)
|
||||||
|
|
||||||
|
def delete(self, agent_id: str, user_id: str) -> bool:
|
||||||
|
result = self._conn.execute(
|
||||||
|
text("DELETE FROM agents WHERE id = CAST(:id AS uuid) AND user_id = :user_id"),
|
||||||
|
{"id": agent_id, "user_id": user_id},
|
||||||
|
)
|
||||||
|
return result.rowcount > 0
|
||||||
|
|
||||||
|
def delete_by_legacy_id(self, legacy_mongo_id: str, user_id: str) -> bool:
|
||||||
|
"""Delete an agent addressed by the Mongo ObjectId string."""
|
||||||
|
legacy_mongo_id = str(legacy_mongo_id) if legacy_mongo_id is not None else None
|
||||||
|
result = self._conn.execute(
|
||||||
|
text(
|
||||||
|
"DELETE FROM agents "
|
||||||
|
"WHERE legacy_mongo_id = :legacy_id AND user_id = :user_id"
|
||||||
|
),
|
||||||
|
{"legacy_id": legacy_mongo_id, "user_id": user_id},
|
||||||
|
)
|
||||||
|
return result.rowcount > 0
|
||||||
|
|
||||||
|
def set_folder(self, agent_id: str, user_id: str, folder_id: Optional[str]) -> None:
|
||||||
|
self._conn.execute(
|
||||||
|
text(
|
||||||
|
"""
|
||||||
|
UPDATE agents SET folder_id = CAST(:folder_id AS uuid), updated_at = now()
|
||||||
|
WHERE id = CAST(:id AS uuid) AND user_id = :user_id
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
{"id": agent_id, "user_id": user_id, "folder_id": folder_id},
|
||||||
|
)
|
||||||
|
|
||||||
|
def clear_folder_for_all(self, folder_id: str, user_id: str) -> None:
|
||||||
|
"""Remove folder assignment from all agents in a folder (used on folder delete)."""
|
||||||
|
self._conn.execute(
|
||||||
|
text(
|
||||||
|
"UPDATE agents SET folder_id = NULL, updated_at = now() "
|
||||||
|
"WHERE folder_id = CAST(:folder_id AS uuid) AND user_id = :user_id"
|
||||||
|
),
|
||||||
|
{"folder_id": folder_id, "user_id": user_id},
|
||||||
|
)
|
||||||
248
application/storage/db/repositories/attachments.py
Normal file
248
application/storage/db/repositories/attachments.py
Normal file
@@ -0,0 +1,248 @@
|
|||||||
|
"""Repository for the ``attachments`` table."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
from sqlalchemy import Connection, text
|
||||||
|
|
||||||
|
from application.storage.db.base_repository import looks_like_uuid, row_to_dict
|
||||||
|
|
||||||
|
|
||||||
|
_UPDATABLE_SCALARS = {
|
||||||
|
"filename", "upload_path", "mime_type", "size",
|
||||||
|
"content", "token_count", "openai_file_id", "google_file_uri",
|
||||||
|
}
|
||||||
|
_UPDATABLE_JSONB = {"metadata"}
|
||||||
|
|
||||||
|
|
||||||
|
class AttachmentsRepository:
|
||||||
|
def __init__(self, conn: Connection) -> None:
|
||||||
|
self._conn = conn
|
||||||
|
|
||||||
|
def create(
|
||||||
|
self,
|
||||||
|
user_id: str,
|
||||||
|
filename: str,
|
||||||
|
upload_path: str,
|
||||||
|
*,
|
||||||
|
mime_type: Optional[str] = None,
|
||||||
|
size: Optional[int] = None,
|
||||||
|
content: Optional[str] = None,
|
||||||
|
token_count: Optional[int] = None,
|
||||||
|
openai_file_id: Optional[str] = None,
|
||||||
|
google_file_uri: Optional[str] = None,
|
||||||
|
metadata: Any = None,
|
||||||
|
legacy_mongo_id: Optional[str] = None,
|
||||||
|
) -> dict:
|
||||||
|
result = self._conn.execute(
|
||||||
|
text(
|
||||||
|
"""
|
||||||
|
INSERT INTO attachments (
|
||||||
|
user_id, filename, upload_path, mime_type, size,
|
||||||
|
content, token_count, openai_file_id, google_file_uri,
|
||||||
|
metadata, legacy_mongo_id
|
||||||
|
)
|
||||||
|
VALUES (
|
||||||
|
:user_id, :filename, :upload_path, :mime_type, :size,
|
||||||
|
:content, :token_count, :openai_file_id, :google_file_uri,
|
||||||
|
CAST(:metadata AS jsonb), :legacy_mongo_id
|
||||||
|
)
|
||||||
|
RETURNING *
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
{
|
||||||
|
"user_id": user_id,
|
||||||
|
"filename": filename,
|
||||||
|
"upload_path": upload_path,
|
||||||
|
"mime_type": mime_type,
|
||||||
|
"size": size,
|
||||||
|
"content": content,
|
||||||
|
"token_count": token_count,
|
||||||
|
"openai_file_id": openai_file_id,
|
||||||
|
"google_file_uri": google_file_uri,
|
||||||
|
"metadata": json.dumps(metadata) if metadata is not None else None,
|
||||||
|
"legacy_mongo_id": legacy_mongo_id,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return row_to_dict(result.fetchone())
|
||||||
|
|
||||||
|
def get(self, attachment_id: str, user_id: str) -> Optional[dict]:
|
||||||
|
result = self._conn.execute(
|
||||||
|
text(
|
||||||
|
"SELECT * FROM attachments WHERE id = CAST(:id AS uuid) AND user_id = :user_id"
|
||||||
|
),
|
||||||
|
{"id": attachment_id, "user_id": user_id},
|
||||||
|
)
|
||||||
|
row = result.fetchone()
|
||||||
|
return row_to_dict(row) if row is not None else None
|
||||||
|
|
||||||
|
def get_any(self, attachment_id: str, user_id: str) -> Optional[dict]:
|
||||||
|
"""Resolve an attachment by either PG UUID or legacy Mongo ObjectId string."""
|
||||||
|
if looks_like_uuid(attachment_id):
|
||||||
|
row = self.get(attachment_id, user_id)
|
||||||
|
if row is not None:
|
||||||
|
return row
|
||||||
|
return self.get_by_legacy_id(attachment_id, user_id)
|
||||||
|
|
||||||
|
def resolve_ids(self, ids: list[str]) -> dict[str, str]:
|
||||||
|
"""Batch-resolve a list of attachment ids (PG UUID *or* Mongo
|
||||||
|
ObjectId or post-cutover route-minted UUID stored only in
|
||||||
|
``legacy_mongo_id``) to their canonical PG ``attachments.id``.
|
||||||
|
|
||||||
|
Returns a ``{input_id: pg_uuid}`` map. Inputs that don't match
|
||||||
|
any row are simply absent from the map (caller decides whether
|
||||||
|
to drop or keep). Single round-trip via ``= ANY(:ids)`` to
|
||||||
|
avoid N+1.
|
||||||
|
|
||||||
|
Resolution prefers ``legacy_mongo_id`` matches first, since
|
||||||
|
the post-cutover ``/store_attachment`` route mints a UUID that
|
||||||
|
is UUID-shaped but only ever lives in ``legacy_mongo_id``
|
||||||
|
(the row's own ``id`` is a fresh PG-generated UUID). A
|
||||||
|
UUID-shaped input that is *also* a real ``attachments.id``
|
||||||
|
falls back to the direct PK match.
|
||||||
|
"""
|
||||||
|
if not ids:
|
||||||
|
return {}
|
||||||
|
# Deduplicate while preserving order for stable output mapping.
|
||||||
|
unique_ids: list[str] = []
|
||||||
|
seen: set[str] = set()
|
||||||
|
for raw in ids:
|
||||||
|
if raw is None:
|
||||||
|
continue
|
||||||
|
s = str(raw)
|
||||||
|
if s in seen:
|
||||||
|
continue
|
||||||
|
seen.add(s)
|
||||||
|
unique_ids.append(s)
|
||||||
|
if not unique_ids:
|
||||||
|
return {}
|
||||||
|
result = self._conn.execute(
|
||||||
|
text(
|
||||||
|
"SELECT id::text AS id, legacy_mongo_id "
|
||||||
|
"FROM attachments "
|
||||||
|
"WHERE legacy_mongo_id = ANY(:ids) "
|
||||||
|
"OR id::text = ANY(:ids)"
|
||||||
|
),
|
||||||
|
{"ids": unique_ids},
|
||||||
|
)
|
||||||
|
rows = result.fetchall()
|
||||||
|
# Build two indexes so we can apply the legacy-first preference.
|
||||||
|
by_legacy: dict[str, str] = {}
|
||||||
|
by_pk: dict[str, str] = {}
|
||||||
|
for row in rows:
|
||||||
|
pg_id = str(row[0])
|
||||||
|
legacy = row[1]
|
||||||
|
by_pk[pg_id] = pg_id
|
||||||
|
if legacy is not None:
|
||||||
|
by_legacy[str(legacy)] = pg_id
|
||||||
|
out: dict[str, str] = {}
|
||||||
|
for s in unique_ids:
|
||||||
|
if s in by_legacy:
|
||||||
|
out[s] = by_legacy[s]
|
||||||
|
elif s in by_pk:
|
||||||
|
out[s] = by_pk[s]
|
||||||
|
return out
|
||||||
|
|
||||||
|
def get_by_legacy_id(self, legacy_mongo_id: str, user_id: str | None = None) -> Optional[dict]:
|
||||||
|
"""Fetch an attachment by the original Mongo ObjectId string."""
|
||||||
|
legacy_mongo_id = str(legacy_mongo_id) if legacy_mongo_id is not None else None
|
||||||
|
sql = "SELECT * FROM attachments WHERE legacy_mongo_id = :legacy_id"
|
||||||
|
params: dict[str, str] = {"legacy_id": legacy_mongo_id}
|
||||||
|
if user_id is not None:
|
||||||
|
sql += " AND user_id = :user_id"
|
||||||
|
params["user_id"] = user_id
|
||||||
|
result = self._conn.execute(text(sql), params)
|
||||||
|
row = result.fetchone()
|
||||||
|
return row_to_dict(row) if row is not None else None
|
||||||
|
|
||||||
|
def list_for_user(self, user_id: str) -> list[dict]:
|
||||||
|
result = self._conn.execute(
|
||||||
|
text("SELECT * FROM attachments WHERE user_id = :user_id ORDER BY created_at DESC"),
|
||||||
|
{"user_id": user_id},
|
||||||
|
)
|
||||||
|
return [row_to_dict(r) for r in result.fetchall()]
|
||||||
|
|
||||||
|
def update(self, attachment_id: str, user_id: str, fields: dict) -> bool:
|
||||||
|
"""Partial update. Used by the LLM providers to cache their
|
||||||
|
uploaded file IDs (``openai_file_id`` / ``google_file_uri``) so we
|
||||||
|
don't re-upload the same blob every call.
|
||||||
|
"""
|
||||||
|
filtered = {
|
||||||
|
k: v for k, v in fields.items()
|
||||||
|
if k in _UPDATABLE_SCALARS | _UPDATABLE_JSONB
|
||||||
|
}
|
||||||
|
if not filtered:
|
||||||
|
return False
|
||||||
|
set_clauses: list[str] = []
|
||||||
|
params: dict = {"id": attachment_id, "user_id": user_id}
|
||||||
|
for col, val in filtered.items():
|
||||||
|
if col in _UPDATABLE_JSONB:
|
||||||
|
set_clauses.append(f"{col} = CAST(:{col} AS jsonb)")
|
||||||
|
params[col] = json.dumps(val) if val is not None else None
|
||||||
|
else:
|
||||||
|
set_clauses.append(f"{col} = :{col}")
|
||||||
|
params[col] = val
|
||||||
|
result = self._conn.execute(
|
||||||
|
text(
|
||||||
|
f"UPDATE attachments SET {', '.join(set_clauses)} "
|
||||||
|
"WHERE id = CAST(:id AS uuid) AND user_id = :user_id"
|
||||||
|
),
|
||||||
|
params,
|
||||||
|
)
|
||||||
|
return result.rowcount > 0
|
||||||
|
|
||||||
|
def update_any(self, attachment_id: str, user_id: str, fields: dict) -> bool:
|
||||||
|
"""Partial update addressed by either PG UUID or legacy Mongo ObjectId.
|
||||||
|
|
||||||
|
Cutover helper used by the LLM provider file-ID caching hot path:
|
||||||
|
the attachment dict in hand may carry a UUID (post-cutover shape)
|
||||||
|
or an ObjectId-string ``_id`` (legacy). Try the UUID path first
|
||||||
|
when the id looks like a UUID; otherwise fall back to the
|
||||||
|
``legacy_mongo_id`` update. Both branches are user-scoped: the
|
||||||
|
caller must pass the authenticated ``user_id`` so cross-tenant
|
||||||
|
writes are prevented even when the fallback legacy path fires.
|
||||||
|
"""
|
||||||
|
if looks_like_uuid(attachment_id):
|
||||||
|
if self.update(attachment_id, user_id, fields):
|
||||||
|
return True
|
||||||
|
return self.update_by_legacy_id(attachment_id, user_id, fields)
|
||||||
|
|
||||||
|
def update_by_legacy_id(
|
||||||
|
self, legacy_mongo_id: str, user_id: str, fields: dict
|
||||||
|
) -> bool:
|
||||||
|
"""Like ``update`` but addressed by the Mongo ObjectId string.
|
||||||
|
|
||||||
|
Used by the LLM file-ID caching path which, at dual-write time,
|
||||||
|
only has the Mongo ``_id`` in hand (the PG UUID hasn't been
|
||||||
|
looked up yet). Scoped by ``user_id`` so a caller that happens to
|
||||||
|
pass an id matching another user's ``legacy_mongo_id`` cannot
|
||||||
|
mutate the wrong row (IDOR).
|
||||||
|
"""
|
||||||
|
if user_id is None:
|
||||||
|
return False
|
||||||
|
legacy_mongo_id = str(legacy_mongo_id) if legacy_mongo_id is not None else None
|
||||||
|
filtered = {
|
||||||
|
k: v for k, v in fields.items()
|
||||||
|
if k in _UPDATABLE_SCALARS | _UPDATABLE_JSONB
|
||||||
|
}
|
||||||
|
if not filtered:
|
||||||
|
return False
|
||||||
|
set_clauses: list[str] = []
|
||||||
|
params: dict = {"legacy_id": legacy_mongo_id, "user_id": user_id}
|
||||||
|
for col, val in filtered.items():
|
||||||
|
if col in _UPDATABLE_JSONB:
|
||||||
|
set_clauses.append(f"{col} = CAST(:{col} AS jsonb)")
|
||||||
|
params[col] = json.dumps(val) if val is not None else None
|
||||||
|
else:
|
||||||
|
set_clauses.append(f"{col} = :{col}")
|
||||||
|
params[col] = val
|
||||||
|
result = self._conn.execute(
|
||||||
|
text(
|
||||||
|
f"UPDATE attachments SET {', '.join(set_clauses)} "
|
||||||
|
"WHERE legacy_mongo_id = :legacy_id AND user_id = :user_id"
|
||||||
|
),
|
||||||
|
params,
|
||||||
|
)
|
||||||
|
return result.rowcount > 0
|
||||||
316
application/storage/db/repositories/connector_sessions.py
Normal file
316
application/storage/db/repositories/connector_sessions.py
Normal file
@@ -0,0 +1,316 @@
|
|||||||
|
"""Repository for the ``connector_sessions`` table.
|
||||||
|
|
||||||
|
Shape notes:
|
||||||
|
|
||||||
|
* OAuth connectors (Google Drive, SharePoint, Confluence) write one row
|
||||||
|
per ``(user_id, provider)`` with ``server_url = NULL``. The primary
|
||||||
|
lookup key post-callback is ``session_token`` (see
|
||||||
|
``complete_oauth`` style routes), so the table has a standalone
|
||||||
|
unique constraint on ``session_token``.
|
||||||
|
* MCP sessions key off ``server_url`` instead — a single user may have
|
||||||
|
multiple MCP servers, one row each. The composite unique index
|
||||||
|
``(user_id, COALESCE(server_url, ''), provider)`` makes both patterns
|
||||||
|
coexist without collision.
|
||||||
|
* ``session_data`` remains a catch-all JSONB for driver-specific state
|
||||||
|
(tokens that don't fit anywhere else, per-provider scratch data).
|
||||||
|
Promoted columns (``session_token``, ``user_email``, ``status``,
|
||||||
|
``token_info``) are the ones route/auth code queries by.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
from sqlalchemy import Connection, text
|
||||||
|
|
||||||
|
from application.storage.db.base_repository import row_to_dict
|
||||||
|
|
||||||
|
|
||||||
|
_UPDATABLE_SCALARS = {
|
||||||
|
"server_url", "session_token", "user_email", "status", "expires_at",
|
||||||
|
}
|
||||||
|
_UPDATABLE_JSONB = {"session_data", "token_info"}
|
||||||
|
|
||||||
|
|
||||||
|
def _jsonb(value: Any) -> Any:
|
||||||
|
if value is None:
|
||||||
|
return None
|
||||||
|
return json.dumps(value, default=str)
|
||||||
|
|
||||||
|
|
||||||
|
class ConnectorSessionsRepository:
|
||||||
|
def __init__(self, conn: Connection) -> None:
|
||||||
|
self._conn = conn
|
||||||
|
|
||||||
|
def upsert(
|
||||||
|
self,
|
||||||
|
user_id: str,
|
||||||
|
provider: str,
|
||||||
|
session_data: Optional[dict] = None,
|
||||||
|
*,
|
||||||
|
server_url: Optional[str] = None,
|
||||||
|
session_token: Optional[str] = None,
|
||||||
|
user_email: Optional[str] = None,
|
||||||
|
status: Optional[str] = None,
|
||||||
|
token_info: Optional[dict] = None,
|
||||||
|
expires_at: Any = None,
|
||||||
|
legacy_mongo_id: Optional[str] = None,
|
||||||
|
) -> dict:
|
||||||
|
"""Insert or update a connector session row.
|
||||||
|
|
||||||
|
Conflict key is ``(user_id, COALESCE(server_url, ''), provider)``
|
||||||
|
so MCP rows (per-server) and OAuth rows (per-provider) both get
|
||||||
|
idempotent upsert semantics.
|
||||||
|
"""
|
||||||
|
result = self._conn.execute(
|
||||||
|
text(
|
||||||
|
"""
|
||||||
|
INSERT INTO connector_sessions (
|
||||||
|
user_id, provider, server_url, session_token, user_email,
|
||||||
|
status, token_info, session_data, expires_at, legacy_mongo_id
|
||||||
|
)
|
||||||
|
VALUES (
|
||||||
|
:user_id, :provider, :server_url, :session_token, :user_email,
|
||||||
|
:status, CAST(:token_info AS jsonb),
|
||||||
|
CAST(:session_data AS jsonb), :expires_at, :legacy_mongo_id
|
||||||
|
)
|
||||||
|
ON CONFLICT (user_id, COALESCE(server_url, ''), provider)
|
||||||
|
DO UPDATE SET
|
||||||
|
session_token = COALESCE(EXCLUDED.session_token, connector_sessions.session_token),
|
||||||
|
user_email = COALESCE(EXCLUDED.user_email, connector_sessions.user_email),
|
||||||
|
status = COALESCE(EXCLUDED.status, connector_sessions.status),
|
||||||
|
token_info = COALESCE(EXCLUDED.token_info, connector_sessions.token_info),
|
||||||
|
session_data = EXCLUDED.session_data,
|
||||||
|
expires_at = COALESCE(EXCLUDED.expires_at, connector_sessions.expires_at)
|
||||||
|
RETURNING *
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
{
|
||||||
|
"user_id": user_id,
|
||||||
|
"provider": provider,
|
||||||
|
"server_url": server_url,
|
||||||
|
"session_token": session_token,
|
||||||
|
"user_email": user_email,
|
||||||
|
"status": status,
|
||||||
|
"token_info": _jsonb(token_info),
|
||||||
|
"session_data": _jsonb(session_data or {}),
|
||||||
|
"expires_at": expires_at,
|
||||||
|
"legacy_mongo_id": legacy_mongo_id,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return row_to_dict(result.fetchone())
|
||||||
|
|
||||||
|
def get_by_user_provider(
|
||||||
|
self, user_id: str, provider: str, *, server_url: Optional[str] = None,
|
||||||
|
) -> Optional[dict]:
|
||||||
|
"""Legacy (user_id, provider) lookup, optionally scoped by server_url.
|
||||||
|
|
||||||
|
Kept for OAuth providers that only have one row per user — they
|
||||||
|
pass ``server_url=None`` and get the single OAuth row.
|
||||||
|
"""
|
||||||
|
sql = (
|
||||||
|
"SELECT * FROM connector_sessions "
|
||||||
|
"WHERE user_id = :user_id AND provider = :provider"
|
||||||
|
)
|
||||||
|
params: dict[str, Any] = {"user_id": user_id, "provider": provider}
|
||||||
|
if server_url is not None:
|
||||||
|
sql += " AND server_url = :server_url"
|
||||||
|
params["server_url"] = server_url
|
||||||
|
result = self._conn.execute(text(sql), params)
|
||||||
|
row = result.fetchone()
|
||||||
|
return row_to_dict(row) if row is not None else None
|
||||||
|
|
||||||
|
def get_by_session_token(self, session_token: str) -> Optional[dict]:
|
||||||
|
"""Post-OAuth-callback lookup.
|
||||||
|
|
||||||
|
Every OAuth flow (Google Drive, SharePoint, Confluence) redirects
|
||||||
|
back with the ``session_token`` as the only handle; the callback
|
||||||
|
route resolves it to the full session row.
|
||||||
|
"""
|
||||||
|
result = self._conn.execute(
|
||||||
|
text("SELECT * FROM connector_sessions WHERE session_token = :token"),
|
||||||
|
{"token": session_token},
|
||||||
|
)
|
||||||
|
row = result.fetchone()
|
||||||
|
return row_to_dict(row) if row is not None else None
|
||||||
|
|
||||||
|
def get_by_user_and_server_url(
|
||||||
|
self, user_id: str, server_url: str,
|
||||||
|
) -> Optional[dict]:
|
||||||
|
"""MCP-tool lookup: resolve a session by the MCP server URL."""
|
||||||
|
result = self._conn.execute(
|
||||||
|
text(
|
||||||
|
"SELECT * FROM connector_sessions "
|
||||||
|
"WHERE user_id = :user_id AND server_url = :server_url "
|
||||||
|
"LIMIT 1"
|
||||||
|
),
|
||||||
|
{"user_id": user_id, "server_url": server_url},
|
||||||
|
)
|
||||||
|
row = result.fetchone()
|
||||||
|
return row_to_dict(row) if row is not None else None
|
||||||
|
|
||||||
|
def get_by_legacy_id(
|
||||||
|
self, legacy_mongo_id: str, user_id: Optional[str] = None,
|
||||||
|
) -> Optional[dict]:
|
||||||
|
legacy_mongo_id = str(legacy_mongo_id) if legacy_mongo_id is not None else None
|
||||||
|
sql = "SELECT * FROM connector_sessions WHERE legacy_mongo_id = :legacy_id"
|
||||||
|
params: dict[str, str] = {"legacy_id": legacy_mongo_id}
|
||||||
|
if user_id is not None:
|
||||||
|
sql += " AND user_id = :user_id"
|
||||||
|
params["user_id"] = user_id
|
||||||
|
result = self._conn.execute(text(sql), params)
|
||||||
|
row = result.fetchone()
|
||||||
|
return row_to_dict(row) if row is not None else None
|
||||||
|
|
||||||
|
def list_for_user(self, user_id: str) -> list[dict]:
|
||||||
|
result = self._conn.execute(
|
||||||
|
text("SELECT * FROM connector_sessions WHERE user_id = :user_id"),
|
||||||
|
{"user_id": user_id},
|
||||||
|
)
|
||||||
|
return [row_to_dict(r) for r in result.fetchall()]
|
||||||
|
|
||||||
|
def update(self, session_id: str, fields: dict) -> bool:
|
||||||
|
"""Partial update by PG UUID."""
|
||||||
|
filtered = {
|
||||||
|
k: v for k, v in fields.items()
|
||||||
|
if k in _UPDATABLE_SCALARS | _UPDATABLE_JSONB
|
||||||
|
}
|
||||||
|
if not filtered:
|
||||||
|
return False
|
||||||
|
set_clauses: list[str] = []
|
||||||
|
params: dict = {"id": session_id}
|
||||||
|
for col, val in filtered.items():
|
||||||
|
if col in _UPDATABLE_JSONB:
|
||||||
|
set_clauses.append(f"{col} = CAST(:{col} AS jsonb)")
|
||||||
|
params[col] = _jsonb(val)
|
||||||
|
else:
|
||||||
|
set_clauses.append(f"{col} = :{col}")
|
||||||
|
params[col] = val
|
||||||
|
result = self._conn.execute(
|
||||||
|
text(
|
||||||
|
f"UPDATE connector_sessions SET {', '.join(set_clauses)} "
|
||||||
|
"WHERE id = CAST(:id AS uuid)"
|
||||||
|
),
|
||||||
|
params,
|
||||||
|
)
|
||||||
|
return result.rowcount > 0
|
||||||
|
|
||||||
|
def update_by_legacy_id(self, legacy_mongo_id: str, fields: dict) -> bool:
|
||||||
|
legacy_mongo_id = str(legacy_mongo_id) if legacy_mongo_id is not None else None
|
||||||
|
filtered = {
|
||||||
|
k: v for k, v in fields.items()
|
||||||
|
if k in _UPDATABLE_SCALARS | _UPDATABLE_JSONB
|
||||||
|
}
|
||||||
|
if not filtered:
|
||||||
|
return False
|
||||||
|
set_clauses: list[str] = []
|
||||||
|
params: dict = {"legacy_id": legacy_mongo_id}
|
||||||
|
for col, val in filtered.items():
|
||||||
|
if col in _UPDATABLE_JSONB:
|
||||||
|
set_clauses.append(f"{col} = CAST(:{col} AS jsonb)")
|
||||||
|
params[col] = _jsonb(val)
|
||||||
|
else:
|
||||||
|
set_clauses.append(f"{col} = :{col}")
|
||||||
|
params[col] = val
|
||||||
|
result = self._conn.execute(
|
||||||
|
text(
|
||||||
|
f"UPDATE connector_sessions SET {', '.join(set_clauses)} "
|
||||||
|
"WHERE legacy_mongo_id = :legacy_id"
|
||||||
|
),
|
||||||
|
params,
|
||||||
|
)
|
||||||
|
return result.rowcount > 0
|
||||||
|
|
||||||
|
def merge_session_data(
|
||||||
|
self,
|
||||||
|
user_id: str,
|
||||||
|
provider: str,
|
||||||
|
server_url: Optional[str],
|
||||||
|
patch: dict,
|
||||||
|
) -> dict:
|
||||||
|
"""Upsert by shallow-merging ``patch`` into ``session_data``.
|
||||||
|
|
||||||
|
Writes ``server_url`` to the scalar column so downstream
|
||||||
|
``get_by_user_and_server_url`` lookups can find the row. If
|
||||||
|
``patch`` still carries a ``"server_url"`` key (legacy callers)
|
||||||
|
it is stripped before merging so the scalar column stays the
|
||||||
|
single source of truth and we don't duplicate it inside the
|
||||||
|
JSONB blob.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: Owner of the session.
|
||||||
|
provider: Provider tag (e.g. ``"mcp:<base_url>"`` for MCP).
|
||||||
|
server_url: Endpoint to pin the row to. ``None`` is valid
|
||||||
|
for single-row-per-user OAuth providers.
|
||||||
|
patch: Shallow-merge payload for ``session_data``. Keys
|
||||||
|
mapped to ``None`` are *dropped* from the stored doc
|
||||||
|
(used by the redirect-URI-mismatch clear path).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The upserted row as a dict.
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
The conflict target matches the table's composite unique
|
||||||
|
constraint ``(user_id, COALESCE(server_url, ''), provider)``
|
||||||
|
so MCP's per-URL rows and OAuth's single-row-per-user rows
|
||||||
|
both upsert idempotently.
|
||||||
|
"""
|
||||||
|
# Defensively strip ``server_url`` from ``patch`` — the scalar
|
||||||
|
# column is authoritative now. Callers still pass it for
|
||||||
|
# backwards compatibility during the transition.
|
||||||
|
patch = {k: v for k, v in patch.items() if k != "server_url"}
|
||||||
|
set_entries = {k: v for k, v in patch.items() if v is not None}
|
||||||
|
drop_keys = [k for k, v in patch.items() if v is None]
|
||||||
|
|
||||||
|
result = self._conn.execute(
|
||||||
|
text(
|
||||||
|
"""
|
||||||
|
INSERT INTO connector_sessions (
|
||||||
|
user_id, provider, server_url, session_data
|
||||||
|
)
|
||||||
|
VALUES (
|
||||||
|
:user_id, :provider, :server_url,
|
||||||
|
CAST(:patch AS jsonb)
|
||||||
|
)
|
||||||
|
ON CONFLICT (user_id, COALESCE(server_url, ''), provider)
|
||||||
|
DO UPDATE SET
|
||||||
|
server_url = COALESCE(EXCLUDED.server_url, connector_sessions.server_url),
|
||||||
|
session_data =
|
||||||
|
(connector_sessions.session_data || EXCLUDED.session_data)
|
||||||
|
- CAST(:drop_keys AS text[])
|
||||||
|
RETURNING *
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
{
|
||||||
|
"user_id": user_id,
|
||||||
|
"provider": provider,
|
||||||
|
"server_url": server_url,
|
||||||
|
"patch": json.dumps(set_entries),
|
||||||
|
"drop_keys": "{" + ",".join(f'"{k}"' for k in drop_keys) + "}",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return row_to_dict(result.fetchone())
|
||||||
|
|
||||||
|
def delete(
|
||||||
|
self, user_id: str, provider: str, *, server_url: Optional[str] = None,
|
||||||
|
) -> bool:
|
||||||
|
sql = (
|
||||||
|
"DELETE FROM connector_sessions "
|
||||||
|
"WHERE user_id = :user_id AND provider = :provider"
|
||||||
|
)
|
||||||
|
params: dict[str, Any] = {"user_id": user_id, "provider": provider}
|
||||||
|
if server_url is not None:
|
||||||
|
sql += " AND server_url = :server_url"
|
||||||
|
params["server_url"] = server_url
|
||||||
|
result = self._conn.execute(text(sql), params)
|
||||||
|
return result.rowcount > 0
|
||||||
|
|
||||||
|
def delete_by_session_token(self, session_token: str) -> bool:
|
||||||
|
result = self._conn.execute(
|
||||||
|
text(
|
||||||
|
"DELETE FROM connector_sessions WHERE session_token = :token"
|
||||||
|
),
|
||||||
|
{"token": session_token},
|
||||||
|
)
|
||||||
|
return result.rowcount > 0
|
||||||
686
application/storage/db/repositories/conversations.py
Normal file
686
application/storage/db/repositories/conversations.py
Normal file
@@ -0,0 +1,686 @@
|
|||||||
|
"""Repository for the ``conversations`` and ``conversation_messages`` tables.
|
||||||
|
|
||||||
|
Covers every operation the legacy Mongo code performs on
|
||||||
|
``conversations_collection``:
|
||||||
|
|
||||||
|
- create / get / list / delete conversations
|
||||||
|
- append message (transactional position allocation)
|
||||||
|
- update message at index (overwrite + optional truncation)
|
||||||
|
- set / unset feedback on a message
|
||||||
|
- rename conversation
|
||||||
|
- update compression metadata
|
||||||
|
- shared_with access checks
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from sqlalchemy import Connection, text
|
||||||
|
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||||
|
|
||||||
|
from application.storage.db.base_repository import looks_like_uuid, row_to_dict
|
||||||
|
from application.storage.db.models import conversations_table, conversation_messages_table
|
||||||
|
|
||||||
|
|
||||||
|
def _message_row_to_dict(row) -> dict:
|
||||||
|
"""Like ``row_to_dict`` but renames the DB column ``message_metadata``
|
||||||
|
back to the public API key ``metadata`` so callers keep the Mongo-era
|
||||||
|
shape. See migration 0016 for the column rename rationale."""
|
||||||
|
out = row_to_dict(row)
|
||||||
|
if "message_metadata" in out:
|
||||||
|
out["metadata"] = out.pop("message_metadata")
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class ConversationsRepository:
|
||||||
|
def __init__(self, conn: Connection) -> None:
|
||||||
|
self._conn = conn
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Reference translation helpers
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
#
|
||||||
|
# During the Mongo→Postgres dual-write window, callers routinely
|
||||||
|
# hand us Mongo ObjectId strings (24-char hex) for fields that are
|
||||||
|
# UUID FKs in Postgres (``agent_id``, ``attachments`` entries, ...).
|
||||||
|
# Casting those straight to ``uuid`` raises and the outer dual-write
|
||||||
|
# shim swallows the exception, so the write silently drops. These
|
||||||
|
# helpers translate via the ``legacy_mongo_id`` columns we added
|
||||||
|
# precisely for this purpose.
|
||||||
|
|
||||||
|
def _resolve_agent_ref(self, agent_id_raw: str | None) -> str | None:
|
||||||
|
"""Translate ``agent_id_raw`` to a Postgres UUID string.
|
||||||
|
|
||||||
|
- ``None``/empty → ``None`` (no agent).
|
||||||
|
- Already-UUID-shaped → returned as-is.
|
||||||
|
- Otherwise treated as a Mongo ObjectId and looked up via
|
||||||
|
``agents.legacy_mongo_id``. Returns ``None`` if no PG row
|
||||||
|
exists yet (e.g. the agent was created before Phase 1
|
||||||
|
backfill).
|
||||||
|
"""
|
||||||
|
if not agent_id_raw:
|
||||||
|
return None
|
||||||
|
value = str(agent_id_raw)
|
||||||
|
if looks_like_uuid(value):
|
||||||
|
return value
|
||||||
|
result = self._conn.execute(
|
||||||
|
text("SELECT id FROM agents WHERE legacy_mongo_id = :lid LIMIT 1"),
|
||||||
|
{"lid": value},
|
||||||
|
)
|
||||||
|
row = result.fetchone()
|
||||||
|
return str(row[0]) if row is not None else None
|
||||||
|
|
||||||
|
def _resolve_attachment_refs(
|
||||||
|
self, ids: list[str] | None,
|
||||||
|
) -> list[str]:
|
||||||
|
"""Translate a list of attachment ids to canonical PG
|
||||||
|
``attachments.id`` UUIDs.
|
||||||
|
|
||||||
|
Inputs may be:
|
||||||
|
|
||||||
|
- A Mongo ObjectId string (24-hex), legacy dual-write era —
|
||||||
|
must be looked up via ``attachments.legacy_mongo_id``.
|
||||||
|
- A UUID string that is a real ``attachments.id`` PK.
|
||||||
|
- A UUID string that is *only* present as
|
||||||
|
``attachments.legacy_mongo_id`` — this is the post-cutover
|
||||||
|
shape: ``/store_attachment`` mints a UUID, hands it to the
|
||||||
|
worker, and the worker stashes it in ``legacy_mongo_id``
|
||||||
|
while the row gets a freshly-generated PK. Trusting the
|
||||||
|
input UUID as a PK here orphans the array entry: the column
|
||||||
|
is ``uuid[]`` (no FK), so PG accepts the bad value and all
|
||||||
|
downstream reads via ``AttachmentsRepository.get_any`` miss.
|
||||||
|
|
||||||
|
Resolution therefore tries ``legacy_mongo_id`` first for every
|
||||||
|
id (UUID-shaped or not), then falls back to the direct PK
|
||||||
|
match. Unknown ids are dropped — they'd have failed the
|
||||||
|
``uuid[]`` cast otherwise and the whole row would have vanished
|
||||||
|
via dual-write's exception swallow.
|
||||||
|
"""
|
||||||
|
if not ids:
|
||||||
|
return []
|
||||||
|
# Defer to AttachmentsRepository for the batched lookup so the
|
||||||
|
# legacy-first semantics live in one place.
|
||||||
|
from application.storage.db.repositories.attachments import (
|
||||||
|
AttachmentsRepository,
|
||||||
|
)
|
||||||
|
|
||||||
|
clean: list[str] = [str(raw) for raw in ids if raw is not None]
|
||||||
|
if not clean:
|
||||||
|
return []
|
||||||
|
repo = AttachmentsRepository(self._conn)
|
||||||
|
mapping = repo.resolve_ids(clean)
|
||||||
|
out: list[str] = []
|
||||||
|
for value in clean:
|
||||||
|
mapped = mapping.get(value)
|
||||||
|
if mapped is not None:
|
||||||
|
out.append(mapped)
|
||||||
|
return out
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Conversation CRUD
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def create(
|
||||||
|
self,
|
||||||
|
user_id: str,
|
||||||
|
name: str | None = None,
|
||||||
|
*,
|
||||||
|
agent_id: str | None = None,
|
||||||
|
api_key: str | None = None,
|
||||||
|
is_shared_usage: bool = False,
|
||||||
|
shared_token: str | None = None,
|
||||||
|
legacy_mongo_id: str | None = None,
|
||||||
|
) -> dict:
|
||||||
|
"""Create a new conversation.
|
||||||
|
|
||||||
|
``legacy_mongo_id`` is used by the dual-write shim so that a
|
||||||
|
Postgres row inserted *after* a successful Mongo insert carries
|
||||||
|
the Mongo ``_id`` as a lookup key. Subsequent appends/updates
|
||||||
|
can then resolve the PG row by that id via
|
||||||
|
:meth:`get_by_legacy_id`.
|
||||||
|
"""
|
||||||
|
values: dict = {
|
||||||
|
"user_id": user_id,
|
||||||
|
"name": name,
|
||||||
|
}
|
||||||
|
# ``agent_id`` may arrive as a Mongo ObjectId during the dual-write
|
||||||
|
# window; resolve to a UUID (or drop silently if not yet backfilled).
|
||||||
|
resolved_agent_id = self._resolve_agent_ref(agent_id)
|
||||||
|
if resolved_agent_id:
|
||||||
|
values["agent_id"] = resolved_agent_id
|
||||||
|
if api_key:
|
||||||
|
values["api_key"] = api_key
|
||||||
|
if is_shared_usage:
|
||||||
|
values["is_shared_usage"] = True
|
||||||
|
if shared_token:
|
||||||
|
values["shared_token"] = shared_token
|
||||||
|
if legacy_mongo_id:
|
||||||
|
values["legacy_mongo_id"] = legacy_mongo_id
|
||||||
|
|
||||||
|
stmt = pg_insert(conversations_table).values(**values).returning(conversations_table)
|
||||||
|
result = self._conn.execute(stmt)
|
||||||
|
return row_to_dict(result.fetchone())
|
||||||
|
|
||||||
|
def get_by_legacy_id(
|
||||||
|
self, legacy_mongo_id: str, user_id: str | None = None,
|
||||||
|
) -> Optional[dict]:
|
||||||
|
"""Look up a conversation by the original Mongo ObjectId string.
|
||||||
|
|
||||||
|
Used by the dual-write helpers to translate a Mongo ``_id`` into
|
||||||
|
the Postgres UUID for follow-up writes. When ``user_id`` is
|
||||||
|
provided, the lookup is scoped to rows owned by that user so
|
||||||
|
callers can't accidentally resolve another user's conversation.
|
||||||
|
"""
|
||||||
|
legacy_mongo_id = str(legacy_mongo_id) if legacy_mongo_id is not None else None
|
||||||
|
if user_id is not None:
|
||||||
|
result = self._conn.execute(
|
||||||
|
text(
|
||||||
|
"SELECT * FROM conversations "
|
||||||
|
"WHERE legacy_mongo_id = :legacy_id AND user_id = :user_id"
|
||||||
|
),
|
||||||
|
{"legacy_id": legacy_mongo_id, "user_id": user_id},
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
result = self._conn.execute(
|
||||||
|
text(
|
||||||
|
"SELECT * FROM conversations WHERE legacy_mongo_id = :legacy_id"
|
||||||
|
),
|
||||||
|
{"legacy_id": legacy_mongo_id},
|
||||||
|
)
|
||||||
|
row = result.fetchone()
|
||||||
|
return row_to_dict(row) if row is not None else None
|
||||||
|
|
||||||
|
def get(self, conversation_id: str, user_id: str) -> Optional[dict]:
|
||||||
|
"""Fetch a conversation the user owns or has shared access to."""
|
||||||
|
result = self._conn.execute(
|
||||||
|
text(
|
||||||
|
"SELECT * FROM conversations "
|
||||||
|
"WHERE id = CAST(:id AS uuid) "
|
||||||
|
"AND (user_id = :user_id OR :user_id = ANY(shared_with))"
|
||||||
|
),
|
||||||
|
{"id": conversation_id, "user_id": user_id},
|
||||||
|
)
|
||||||
|
row = result.fetchone()
|
||||||
|
return row_to_dict(row) if row is not None else None
|
||||||
|
|
||||||
|
def get_any(self, conversation_id: str, user_id: str) -> Optional[dict]:
|
||||||
|
"""Resolve a conversation by either PG UUID or legacy Mongo ObjectId string.
|
||||||
|
|
||||||
|
Returns a conversation the user owns or has shared access to.
|
||||||
|
"""
|
||||||
|
if looks_like_uuid(conversation_id):
|
||||||
|
row = self.get(conversation_id, user_id)
|
||||||
|
if row is not None:
|
||||||
|
return row
|
||||||
|
return self.get_by_legacy_id(conversation_id, user_id)
|
||||||
|
|
||||||
|
def get_owned(self, conversation_id: str, user_id: str) -> Optional[dict]:
|
||||||
|
"""Fetch a conversation owned by the user (no shared access)."""
|
||||||
|
result = self._conn.execute(
|
||||||
|
text(
|
||||||
|
"SELECT * FROM conversations "
|
||||||
|
"WHERE id = CAST(:id AS uuid) AND user_id = :user_id"
|
||||||
|
),
|
||||||
|
{"id": conversation_id, "user_id": user_id},
|
||||||
|
)
|
||||||
|
row = result.fetchone()
|
||||||
|
return row_to_dict(row) if row is not None else None
|
||||||
|
|
||||||
|
def list_for_user(self, user_id: str, limit: int = 30) -> list[dict]:
|
||||||
|
"""List conversations for a user, most recent first.
|
||||||
|
|
||||||
|
Mirrors the Mongo query: either no api_key or agent_id exists.
|
||||||
|
"""
|
||||||
|
result = self._conn.execute(
|
||||||
|
text(
|
||||||
|
"SELECT * FROM conversations "
|
||||||
|
"WHERE user_id = :user_id "
|
||||||
|
"AND (api_key IS NULL OR agent_id IS NOT NULL) "
|
||||||
|
"ORDER BY date DESC LIMIT :limit"
|
||||||
|
),
|
||||||
|
{"user_id": user_id, "limit": limit},
|
||||||
|
)
|
||||||
|
return [row_to_dict(r) for r in result.fetchall()]
|
||||||
|
|
||||||
|
def rename(self, conversation_id: str, user_id: str, name: str) -> bool:
|
||||||
|
# Shape-gate so a non-UUID id (legacy Mongo ObjectId still floating
|
||||||
|
# around in client-side state during the cutover) never reaches the
|
||||||
|
# ``CAST(:id AS uuid)`` — that cast raises on the server and poisons
|
||||||
|
# the enclosing transaction, making every subsequent query on the
|
||||||
|
# same connection fail.
|
||||||
|
if not looks_like_uuid(conversation_id):
|
||||||
|
return False
|
||||||
|
result = self._conn.execute(
|
||||||
|
text(
|
||||||
|
"UPDATE conversations SET name = :name, updated_at = now() "
|
||||||
|
"WHERE id = CAST(:id AS uuid) AND user_id = :user_id"
|
||||||
|
),
|
||||||
|
{"id": conversation_id, "user_id": user_id, "name": name},
|
||||||
|
)
|
||||||
|
return result.rowcount > 0
|
||||||
|
|
||||||
|
def add_shared_user(self, conversation_id: str, user_to_add: str) -> bool:
|
||||||
|
"""Idempotently append ``user_to_add`` to ``shared_with``.
|
||||||
|
|
||||||
|
Accepts either a PG UUID or a legacy Mongo ObjectId as the
|
||||||
|
conversation id. Mirrors Mongo ``$addToSet`` semantics via the
|
||||||
|
``NOT (:user = ANY(shared_with))`` guard.
|
||||||
|
"""
|
||||||
|
if not user_to_add:
|
||||||
|
return False
|
||||||
|
if looks_like_uuid(conversation_id):
|
||||||
|
sql = (
|
||||||
|
"UPDATE conversations "
|
||||||
|
"SET shared_with = array_append(shared_with, :user), "
|
||||||
|
" updated_at = now() "
|
||||||
|
"WHERE id = CAST(:id AS uuid) "
|
||||||
|
"AND NOT (:user = ANY(shared_with))"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
sql = (
|
||||||
|
"UPDATE conversations "
|
||||||
|
"SET shared_with = array_append(shared_with, :user), "
|
||||||
|
" updated_at = now() "
|
||||||
|
"WHERE legacy_mongo_id = :id "
|
||||||
|
"AND NOT (:user = ANY(shared_with))"
|
||||||
|
)
|
||||||
|
result = self._conn.execute(
|
||||||
|
text(sql), {"id": conversation_id, "user": user_to_add},
|
||||||
|
)
|
||||||
|
return result.rowcount > 0
|
||||||
|
|
||||||
|
def remove_shared_user(self, conversation_id: str, user_to_remove: str) -> bool:
|
||||||
|
"""Remove ``user_to_remove`` from ``shared_with``. Mirror of Mongo ``$pull``."""
|
||||||
|
if not user_to_remove:
|
||||||
|
return False
|
||||||
|
if looks_like_uuid(conversation_id):
|
||||||
|
sql = (
|
||||||
|
"UPDATE conversations "
|
||||||
|
"SET shared_with = array_remove(shared_with, :user), "
|
||||||
|
" updated_at = now() "
|
||||||
|
"WHERE id = CAST(:id AS uuid) "
|
||||||
|
"AND :user = ANY(shared_with)"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
sql = (
|
||||||
|
"UPDATE conversations "
|
||||||
|
"SET shared_with = array_remove(shared_with, :user), "
|
||||||
|
" updated_at = now() "
|
||||||
|
"WHERE legacy_mongo_id = :id "
|
||||||
|
"AND :user = ANY(shared_with)"
|
||||||
|
)
|
||||||
|
result = self._conn.execute(
|
||||||
|
text(sql), {"id": conversation_id, "user": user_to_remove},
|
||||||
|
)
|
||||||
|
return result.rowcount > 0
|
||||||
|
|
||||||
|
def set_shared_token(self, conversation_id: str, user_id: str, token: str) -> bool:
|
||||||
|
# Shape-gate: see ``rename`` — prevents transaction poisoning when
|
||||||
|
# a non-UUID id reaches this code path.
|
||||||
|
if not looks_like_uuid(conversation_id):
|
||||||
|
return False
|
||||||
|
result = self._conn.execute(
|
||||||
|
text(
|
||||||
|
"UPDATE conversations SET shared_token = :token, updated_at = now() "
|
||||||
|
"WHERE id = CAST(:id AS uuid) AND user_id = :user_id"
|
||||||
|
),
|
||||||
|
{"id": conversation_id, "user_id": user_id, "token": token},
|
||||||
|
)
|
||||||
|
return result.rowcount > 0
|
||||||
|
|
||||||
|
def update_compression_metadata(
|
||||||
|
self, conversation_id: str, user_id: str, metadata: dict,
|
||||||
|
) -> bool:
|
||||||
|
"""Replace the entire ``compression_metadata`` JSONB blob.
|
||||||
|
|
||||||
|
Prefer :meth:`append_compression_point` + :meth:`set_compression_flags`
|
||||||
|
to match the Mongo service semantics exactly (those two mirror
|
||||||
|
``$set`` + ``$push $slice``). This method is retained for callers
|
||||||
|
that already compute the full merged blob client-side.
|
||||||
|
"""
|
||||||
|
# Shape-gate: see ``rename`` — prevents transaction poisoning when
|
||||||
|
# a non-UUID id reaches this code path.
|
||||||
|
if not looks_like_uuid(conversation_id):
|
||||||
|
return False
|
||||||
|
result = self._conn.execute(
|
||||||
|
text(
|
||||||
|
"UPDATE conversations "
|
||||||
|
"SET compression_metadata = CAST(:meta AS jsonb), updated_at = now() "
|
||||||
|
"WHERE id = CAST(:id AS uuid) AND user_id = :user_id"
|
||||||
|
),
|
||||||
|
{"id": conversation_id, "user_id": user_id, "meta": json.dumps(metadata)},
|
||||||
|
)
|
||||||
|
return result.rowcount > 0
|
||||||
|
|
||||||
|
def set_compression_flags(
|
||||||
|
self,
|
||||||
|
conversation_id: str,
|
||||||
|
*,
|
||||||
|
is_compressed: bool,
|
||||||
|
last_compression_at,
|
||||||
|
) -> bool:
|
||||||
|
"""Update ``compression_metadata.is_compressed`` and
|
||||||
|
``compression_metadata.last_compression_at`` without touching
|
||||||
|
``compression_points``.
|
||||||
|
|
||||||
|
Mirrors the Mongo ``$set`` on those two subfields in
|
||||||
|
``ConversationService.update_compression_metadata``. Initialises
|
||||||
|
the surrounding object when the row has no ``compression_metadata``
|
||||||
|
yet.
|
||||||
|
"""
|
||||||
|
# Shape-gate: the streaming pipeline may pass through a legacy id
|
||||||
|
# that ``get_by_legacy_id`` couldn't resolve; in that case the id
|
||||||
|
# remains a non-UUID string and the CAST would poison the txn.
|
||||||
|
if not looks_like_uuid(conversation_id):
|
||||||
|
return False
|
||||||
|
result = self._conn.execute(
|
||||||
|
text(
|
||||||
|
"""
|
||||||
|
UPDATE conversations SET
|
||||||
|
compression_metadata = jsonb_set(
|
||||||
|
jsonb_set(
|
||||||
|
COALESCE(compression_metadata, '{}'::jsonb),
|
||||||
|
'{is_compressed}',
|
||||||
|
to_jsonb(CAST(:is_compressed AS boolean)), true
|
||||||
|
),
|
||||||
|
'{last_compression_at}',
|
||||||
|
to_jsonb(CAST(:last_compression_at AS text)), true
|
||||||
|
),
|
||||||
|
updated_at = now()
|
||||||
|
WHERE id = CAST(:id AS uuid)
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
{
|
||||||
|
"id": conversation_id,
|
||||||
|
"is_compressed": bool(is_compressed),
|
||||||
|
"last_compression_at": (
|
||||||
|
str(last_compression_at) if last_compression_at is not None else None
|
||||||
|
),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return result.rowcount > 0
|
||||||
|
|
||||||
|
def append_compression_point(
|
||||||
|
self,
|
||||||
|
conversation_id: str,
|
||||||
|
point: dict,
|
||||||
|
*,
|
||||||
|
max_points: int,
|
||||||
|
) -> bool:
|
||||||
|
"""Append one compression point, keeping at most ``max_points``.
|
||||||
|
|
||||||
|
Mirrors Mongo's ``$push {"$each": [point], "$slice": -max_points}``
|
||||||
|
on ``compression_metadata.compression_points``. Preserves the
|
||||||
|
other top-level keys in ``compression_metadata``.
|
||||||
|
"""
|
||||||
|
# Shape-gate: see ``set_compression_flags``.
|
||||||
|
if not looks_like_uuid(conversation_id):
|
||||||
|
return False
|
||||||
|
result = self._conn.execute(
|
||||||
|
text(
|
||||||
|
"""
|
||||||
|
UPDATE conversations SET
|
||||||
|
compression_metadata = jsonb_set(
|
||||||
|
COALESCE(compression_metadata, '{}'::jsonb),
|
||||||
|
'{compression_points}',
|
||||||
|
COALESCE(
|
||||||
|
(
|
||||||
|
SELECT jsonb_agg(elem ORDER BY rn)
|
||||||
|
FROM (
|
||||||
|
SELECT
|
||||||
|
elem,
|
||||||
|
row_number() OVER () AS rn,
|
||||||
|
count(*) OVER () AS cnt
|
||||||
|
FROM jsonb_array_elements(
|
||||||
|
COALESCE(
|
||||||
|
compression_metadata -> 'compression_points',
|
||||||
|
'[]'::jsonb
|
||||||
|
) || jsonb_build_array(CAST(:point AS jsonb))
|
||||||
|
) AS elem
|
||||||
|
) ranked
|
||||||
|
WHERE rn > cnt - :max_points
|
||||||
|
),
|
||||||
|
'[]'::jsonb
|
||||||
|
),
|
||||||
|
true
|
||||||
|
),
|
||||||
|
updated_at = now()
|
||||||
|
WHERE id = CAST(:id AS uuid)
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
{
|
||||||
|
"id": conversation_id,
|
||||||
|
"point": json.dumps(point, default=str),
|
||||||
|
"max_points": int(max_points),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return result.rowcount > 0
|
||||||
|
|
||||||
|
def delete(self, conversation_id: str, user_id: str) -> bool:
|
||||||
|
# Shape-gate: see ``rename`` — prevents transaction poisoning when
|
||||||
|
# a non-UUID id reaches this code path.
|
||||||
|
if not looks_like_uuid(conversation_id):
|
||||||
|
return False
|
||||||
|
result = self._conn.execute(
|
||||||
|
text(
|
||||||
|
"DELETE FROM conversations "
|
||||||
|
"WHERE id = CAST(:id AS uuid) AND user_id = :user_id"
|
||||||
|
),
|
||||||
|
{"id": conversation_id, "user_id": user_id},
|
||||||
|
)
|
||||||
|
return result.rowcount > 0
|
||||||
|
|
||||||
|
def delete_all_for_user(self, user_id: str) -> int:
|
||||||
|
result = self._conn.execute(
|
||||||
|
text("DELETE FROM conversations WHERE user_id = :user_id"),
|
||||||
|
{"user_id": user_id},
|
||||||
|
)
|
||||||
|
return result.rowcount
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Messages
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def get_messages(self, conversation_id: str) -> list[dict]:
|
||||||
|
result = self._conn.execute(
|
||||||
|
text(
|
||||||
|
"SELECT * FROM conversation_messages "
|
||||||
|
"WHERE conversation_id = CAST(:conv_id AS uuid) "
|
||||||
|
"ORDER BY position ASC"
|
||||||
|
),
|
||||||
|
{"conv_id": conversation_id},
|
||||||
|
)
|
||||||
|
return [_message_row_to_dict(r) for r in result.fetchall()]
|
||||||
|
|
||||||
|
def get_message_at(self, conversation_id: str, position: int) -> Optional[dict]:
|
||||||
|
# Shape-gate: see ``rename``. Callers today always pass a resolved
|
||||||
|
# UUID (via ``get_any`` first), but the guard costs nothing and
|
||||||
|
# keeps future callers safe from txn-poisoning.
|
||||||
|
if not looks_like_uuid(conversation_id):
|
||||||
|
return None
|
||||||
|
result = self._conn.execute(
|
||||||
|
text(
|
||||||
|
"SELECT * FROM conversation_messages "
|
||||||
|
"WHERE conversation_id = CAST(:conv_id AS uuid) "
|
||||||
|
"AND position = :pos"
|
||||||
|
),
|
||||||
|
{"conv_id": conversation_id, "pos": position},
|
||||||
|
)
|
||||||
|
row = result.fetchone()
|
||||||
|
return _message_row_to_dict(row) if row is not None else None
|
||||||
|
|
||||||
|
def append_message(self, conversation_id: str, message: dict) -> dict:
|
||||||
|
"""Append a message to a conversation.
|
||||||
|
|
||||||
|
Uses ``SELECT ... FOR UPDATE`` to allocate the next position
|
||||||
|
atomically. The caller must be inside a transaction.
|
||||||
|
|
||||||
|
Mirrors Mongo's ``$push`` on the ``queries`` array.
|
||||||
|
"""
|
||||||
|
# Lock the parent conversation row to serialize concurrent appends.
|
||||||
|
self._conn.execute(
|
||||||
|
text(
|
||||||
|
"SELECT id FROM conversations "
|
||||||
|
"WHERE id = CAST(:conv_id AS uuid) FOR UPDATE"
|
||||||
|
),
|
||||||
|
{"conv_id": conversation_id},
|
||||||
|
)
|
||||||
|
next_pos_result = self._conn.execute(
|
||||||
|
text(
|
||||||
|
"SELECT COALESCE(MAX(position), -1) + 1 AS next_pos "
|
||||||
|
"FROM conversation_messages "
|
||||||
|
"WHERE conversation_id = CAST(:conv_id AS uuid)"
|
||||||
|
),
|
||||||
|
{"conv_id": conversation_id},
|
||||||
|
)
|
||||||
|
next_pos = next_pos_result.scalar()
|
||||||
|
|
||||||
|
values = {
|
||||||
|
"conversation_id": conversation_id,
|
||||||
|
"position": next_pos,
|
||||||
|
"prompt": message.get("prompt"),
|
||||||
|
"response": message.get("response"),
|
||||||
|
"thought": message.get("thought"),
|
||||||
|
"sources": message.get("sources") or [],
|
||||||
|
"tool_calls": message.get("tool_calls") or [],
|
||||||
|
"model_id": message.get("model_id"),
|
||||||
|
"message_metadata": message.get("metadata") or {},
|
||||||
|
}
|
||||||
|
if message.get("timestamp") is not None:
|
||||||
|
values["timestamp"] = message["timestamp"]
|
||||||
|
|
||||||
|
attachments = message.get("attachments")
|
||||||
|
if attachments:
|
||||||
|
# Attachment ids may arrive as Mongo ObjectIds during the
|
||||||
|
# dual-write window — resolve each to a PG UUID or drop it.
|
||||||
|
resolved = self._resolve_attachment_refs(
|
||||||
|
[str(a) for a in attachments],
|
||||||
|
)
|
||||||
|
if resolved:
|
||||||
|
values["attachments"] = resolved
|
||||||
|
|
||||||
|
stmt = (
|
||||||
|
pg_insert(conversation_messages_table)
|
||||||
|
.values(**values)
|
||||||
|
.returning(conversation_messages_table)
|
||||||
|
)
|
||||||
|
result = self._conn.execute(stmt)
|
||||||
|
# Touch the parent conversation's updated_at.
|
||||||
|
self._conn.execute(
|
||||||
|
text(
|
||||||
|
"UPDATE conversations SET updated_at = now() "
|
||||||
|
"WHERE id = CAST(:id AS uuid)"
|
||||||
|
),
|
||||||
|
{"id": conversation_id},
|
||||||
|
)
|
||||||
|
return _message_row_to_dict(result.fetchone())
|
||||||
|
|
||||||
|
def update_message_at(
|
||||||
|
self, conversation_id: str, position: int, fields: dict,
|
||||||
|
) -> bool:
|
||||||
|
"""Update specific fields on a message at a given position.
|
||||||
|
|
||||||
|
Mirrors Mongo's ``$set`` on ``queries.{index}.*``.
|
||||||
|
"""
|
||||||
|
allowed = {
|
||||||
|
"prompt", "response", "thought", "sources", "tool_calls",
|
||||||
|
"attachments", "model_id", "metadata", "timestamp",
|
||||||
|
# Feedback can be re-set in rare continuation flows; without
|
||||||
|
# it in the whitelist an upstream re-append that happens to
|
||||||
|
# carry feedback would silently lose it. Mirrors
|
||||||
|
# ``set_feedback`` — column is JSONB.
|
||||||
|
"feedback", "feedback_timestamp",
|
||||||
|
}
|
||||||
|
filtered = {k: v for k, v in fields.items() if k in allowed}
|
||||||
|
if not filtered:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Map public API key ``metadata`` → DB column ``message_metadata``.
|
||||||
|
api_to_col = {"metadata": "message_metadata"}
|
||||||
|
|
||||||
|
set_parts = []
|
||||||
|
params: dict = {"conv_id": conversation_id, "pos": position}
|
||||||
|
for key, val in filtered.items():
|
||||||
|
col = api_to_col.get(key, key)
|
||||||
|
if key in ("sources", "tool_calls", "metadata", "feedback"):
|
||||||
|
set_parts.append(f"{col} = CAST(:{col} AS jsonb)")
|
||||||
|
if val is None:
|
||||||
|
params[col] = None
|
||||||
|
else:
|
||||||
|
params[col] = (
|
||||||
|
json.dumps(val) if not isinstance(val, str) else val
|
||||||
|
)
|
||||||
|
elif key == "attachments":
|
||||||
|
# Attachment ids may be Mongo ObjectIds during the
|
||||||
|
# dual-write window; translate via attachments.legacy_mongo_id.
|
||||||
|
set_parts.append(f"{col} = CAST(:{col} AS uuid[])")
|
||||||
|
params[col] = self._resolve_attachment_refs(
|
||||||
|
[str(a) for a in val] if val else [],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
set_parts.append(f"{col} = :{col}")
|
||||||
|
params[col] = val
|
||||||
|
|
||||||
|
if "timestamp" not in filtered:
|
||||||
|
set_parts.append("timestamp = now()")
|
||||||
|
sql = (
|
||||||
|
f"UPDATE conversation_messages SET {', '.join(set_parts)} "
|
||||||
|
"WHERE conversation_id = CAST(:conv_id AS uuid) AND position = :pos"
|
||||||
|
)
|
||||||
|
result = self._conn.execute(text(sql), params)
|
||||||
|
return result.rowcount > 0
|
||||||
|
|
||||||
|
def truncate_after(self, conversation_id: str, keep_up_to: int) -> int:
|
||||||
|
"""Delete messages with position > keep_up_to.
|
||||||
|
|
||||||
|
Mirrors Mongo's ``$push`` + ``$slice`` that trims queries after an
|
||||||
|
index-based update.
|
||||||
|
"""
|
||||||
|
# Shape-gate: see ``rename`` — prevents transaction poisoning when
|
||||||
|
# a non-UUID id reaches this code path.
|
||||||
|
if not looks_like_uuid(conversation_id):
|
||||||
|
return 0
|
||||||
|
result = self._conn.execute(
|
||||||
|
text(
|
||||||
|
"DELETE FROM conversation_messages "
|
||||||
|
"WHERE conversation_id = CAST(:conv_id AS uuid) "
|
||||||
|
"AND position > :pos"
|
||||||
|
),
|
||||||
|
{"conv_id": conversation_id, "pos": keep_up_to},
|
||||||
|
)
|
||||||
|
return result.rowcount
|
||||||
|
|
||||||
|
def set_feedback(
|
||||||
|
self, conversation_id: str, position: int, feedback: dict | None,
|
||||||
|
) -> bool:
|
||||||
|
"""Set or unset feedback on a message.
|
||||||
|
|
||||||
|
``feedback`` is a JSONB value, e.g. ``{"text": "thumbs_up",
|
||||||
|
"timestamp": "..."}`` or ``None`` to unset.
|
||||||
|
"""
|
||||||
|
# Shape-gate: see ``rename`` — prevents transaction poisoning when
|
||||||
|
# a non-UUID id reaches this code path.
|
||||||
|
if not looks_like_uuid(conversation_id):
|
||||||
|
return False
|
||||||
|
fb_json = json.dumps(feedback) if feedback is not None else None
|
||||||
|
result = self._conn.execute(
|
||||||
|
text(
|
||||||
|
"UPDATE conversation_messages "
|
||||||
|
"SET feedback = CAST(:fb AS jsonb) "
|
||||||
|
"WHERE conversation_id = CAST(:conv_id AS uuid) AND position = :pos"
|
||||||
|
),
|
||||||
|
{"conv_id": conversation_id, "pos": position, "fb": fb_json},
|
||||||
|
)
|
||||||
|
return result.rowcount > 0
|
||||||
|
|
||||||
|
def message_count(self, conversation_id: str) -> int:
|
||||||
|
result = self._conn.execute(
|
||||||
|
text(
|
||||||
|
"SELECT COUNT(*) FROM conversation_messages "
|
||||||
|
"WHERE conversation_id = CAST(:conv_id AS uuid)"
|
||||||
|
),
|
||||||
|
{"conv_id": conversation_id},
|
||||||
|
)
|
||||||
|
return result.scalar() or 0
|
||||||
97
application/storage/db/repositories/memories.py
Normal file
97
application/storage/db/repositories/memories.py
Normal file
@@ -0,0 +1,97 @@
|
|||||||
|
"""Repository for the ``memories`` table.
|
||||||
|
|
||||||
|
Covers the operations in ``application/agents/tools/memory.py``:
|
||||||
|
- upsert (create/overwrite file)
|
||||||
|
- find by path (view file)
|
||||||
|
- find by path prefix (view directory, regex scan)
|
||||||
|
- delete by path / path prefix
|
||||||
|
- rename (update path)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from sqlalchemy import Connection, text
|
||||||
|
|
||||||
|
from application.storage.db.base_repository import row_to_dict
|
||||||
|
|
||||||
|
|
||||||
|
class MemoriesRepository:
|
||||||
|
def __init__(self, conn: Connection) -> None:
|
||||||
|
self._conn = conn
|
||||||
|
|
||||||
|
def upsert(self, user_id: str, tool_id: str, path: str, content: str) -> dict:
|
||||||
|
result = self._conn.execute(
|
||||||
|
text(
|
||||||
|
"""
|
||||||
|
INSERT INTO memories (user_id, tool_id, path, content)
|
||||||
|
VALUES (:user_id, CAST(:tool_id AS uuid), :path, :content)
|
||||||
|
ON CONFLICT (user_id, tool_id, path)
|
||||||
|
DO UPDATE SET content = EXCLUDED.content, updated_at = now()
|
||||||
|
RETURNING *
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
{"user_id": user_id, "tool_id": tool_id, "path": path, "content": content},
|
||||||
|
)
|
||||||
|
return row_to_dict(result.fetchone())
|
||||||
|
|
||||||
|
def get_by_path(self, user_id: str, tool_id: str, path: str) -> Optional[dict]:
|
||||||
|
result = self._conn.execute(
|
||||||
|
text(
|
||||||
|
"SELECT * FROM memories WHERE user_id = :user_id "
|
||||||
|
"AND tool_id = CAST(:tool_id AS uuid) AND path = :path"
|
||||||
|
),
|
||||||
|
{"user_id": user_id, "tool_id": tool_id, "path": path},
|
||||||
|
)
|
||||||
|
row = result.fetchone()
|
||||||
|
return row_to_dict(row) if row is not None else None
|
||||||
|
|
||||||
|
def list_by_prefix(self, user_id: str, tool_id: str, prefix: str) -> list[dict]:
|
||||||
|
result = self._conn.execute(
|
||||||
|
text(
|
||||||
|
"SELECT * FROM memories WHERE user_id = :user_id "
|
||||||
|
"AND tool_id = CAST(:tool_id AS uuid) AND path LIKE :prefix"
|
||||||
|
),
|
||||||
|
{"user_id": user_id, "tool_id": tool_id, "prefix": prefix + "%"},
|
||||||
|
)
|
||||||
|
return [row_to_dict(r) for r in result.fetchall()]
|
||||||
|
|
||||||
|
def delete_by_path(self, user_id: str, tool_id: str, path: str) -> int:
|
||||||
|
result = self._conn.execute(
|
||||||
|
text(
|
||||||
|
"DELETE FROM memories WHERE user_id = :user_id "
|
||||||
|
"AND tool_id = CAST(:tool_id AS uuid) AND path = :path"
|
||||||
|
),
|
||||||
|
{"user_id": user_id, "tool_id": tool_id, "path": path},
|
||||||
|
)
|
||||||
|
return result.rowcount
|
||||||
|
|
||||||
|
def delete_by_prefix(self, user_id: str, tool_id: str, prefix: str) -> int:
|
||||||
|
result = self._conn.execute(
|
||||||
|
text(
|
||||||
|
"DELETE FROM memories WHERE user_id = :user_id "
|
||||||
|
"AND tool_id = CAST(:tool_id AS uuid) AND path LIKE :prefix"
|
||||||
|
),
|
||||||
|
{"user_id": user_id, "tool_id": tool_id, "prefix": prefix + "%"},
|
||||||
|
)
|
||||||
|
return result.rowcount
|
||||||
|
|
||||||
|
def delete_all(self, user_id: str, tool_id: str) -> int:
|
||||||
|
result = self._conn.execute(
|
||||||
|
text(
|
||||||
|
"DELETE FROM memories WHERE user_id = :user_id AND tool_id = CAST(:tool_id AS uuid)"
|
||||||
|
),
|
||||||
|
{"user_id": user_id, "tool_id": tool_id},
|
||||||
|
)
|
||||||
|
return result.rowcount
|
||||||
|
|
||||||
|
def update_path(self, user_id: str, tool_id: str, old_path: str, new_path: str) -> bool:
|
||||||
|
result = self._conn.execute(
|
||||||
|
text(
|
||||||
|
"UPDATE memories SET path = :new_path, updated_at = now() "
|
||||||
|
"WHERE user_id = :user_id AND tool_id = CAST(:tool_id AS uuid) AND path = :old_path"
|
||||||
|
),
|
||||||
|
{"user_id": user_id, "tool_id": tool_id, "old_path": old_path, "new_path": new_path},
|
||||||
|
)
|
||||||
|
return result.rowcount > 0
|
||||||
88
application/storage/db/repositories/notes.py
Normal file
88
application/storage/db/repositories/notes.py
Normal file
@@ -0,0 +1,88 @@
|
|||||||
|
"""Repository for the ``notes`` table.
|
||||||
|
|
||||||
|
Covers the operations in ``application/agents/tools/notes.py``.
|
||||||
|
Note: the Mongo schema stores a single ``note`` text field per (user_id, tool_id),
|
||||||
|
while the Postgres schema has ``title`` + ``content``. During dual-write,
|
||||||
|
title is set to a default and content holds the note text.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from sqlalchemy import Connection, text
|
||||||
|
|
||||||
|
from application.storage.db.base_repository import looks_like_uuid, row_to_dict
|
||||||
|
|
||||||
|
|
||||||
|
class NotesRepository:
|
||||||
|
def __init__(self, conn: Connection) -> None:
|
||||||
|
self._conn = conn
|
||||||
|
|
||||||
|
def upsert(self, user_id: str, tool_id: str, title: str, content: str) -> dict:
|
||||||
|
result = self._conn.execute(
|
||||||
|
text(
|
||||||
|
"""
|
||||||
|
INSERT INTO notes (user_id, tool_id, title, content)
|
||||||
|
VALUES (:user_id, CAST(:tool_id AS uuid), :title, :content)
|
||||||
|
ON CONFLICT (user_id, tool_id)
|
||||||
|
DO UPDATE SET content = EXCLUDED.content, title = EXCLUDED.title, updated_at = now()
|
||||||
|
RETURNING *
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
{"user_id": user_id, "tool_id": tool_id, "title": title, "content": content},
|
||||||
|
)
|
||||||
|
return row_to_dict(result.fetchone())
|
||||||
|
|
||||||
|
def get_for_user_tool(self, user_id: str, tool_id: str) -> Optional[dict]:
|
||||||
|
result = self._conn.execute(
|
||||||
|
text(
|
||||||
|
"SELECT * FROM notes WHERE user_id = :user_id AND tool_id = CAST(:tool_id AS uuid)"
|
||||||
|
),
|
||||||
|
{"user_id": user_id, "tool_id": tool_id},
|
||||||
|
)
|
||||||
|
row = result.fetchone()
|
||||||
|
return row_to_dict(row) if row is not None else None
|
||||||
|
|
||||||
|
def get(self, note_id: str, user_id: str) -> Optional[dict]:
|
||||||
|
result = self._conn.execute(
|
||||||
|
text("SELECT * FROM notes WHERE id = CAST(:id AS uuid) AND user_id = :user_id"),
|
||||||
|
{"id": note_id, "user_id": user_id},
|
||||||
|
)
|
||||||
|
row = result.fetchone()
|
||||||
|
return row_to_dict(row) if row is not None else None
|
||||||
|
|
||||||
|
def delete(self, user_id: str, tool_id: str) -> bool:
|
||||||
|
result = self._conn.execute(
|
||||||
|
text(
|
||||||
|
"DELETE FROM notes WHERE user_id = :user_id AND tool_id = CAST(:tool_id AS uuid)"
|
||||||
|
),
|
||||||
|
{"user_id": user_id, "tool_id": tool_id},
|
||||||
|
)
|
||||||
|
return result.rowcount > 0
|
||||||
|
|
||||||
|
def get_by_legacy_id(self, legacy_mongo_id: str) -> Optional[dict]:
|
||||||
|
legacy_mongo_id = str(legacy_mongo_id) if legacy_mongo_id is not None else None
|
||||||
|
result = self._conn.execute(
|
||||||
|
text("SELECT * FROM notes WHERE legacy_mongo_id = :legacy"),
|
||||||
|
{"legacy": legacy_mongo_id},
|
||||||
|
)
|
||||||
|
row = result.fetchone()
|
||||||
|
return row_to_dict(row) if row is not None else None
|
||||||
|
|
||||||
|
def get_any(self, identifier: str, user_id: str) -> Optional[dict]:
|
||||||
|
"""Resolve a note by PG UUID or legacy Mongo ObjectId.
|
||||||
|
|
||||||
|
Picks the lookup path from the id shape so non-UUID input never
|
||||||
|
reaches ``CAST(:id AS uuid)`` — that cast raises on the server
|
||||||
|
and poisons the enclosing transaction, making any subsequent
|
||||||
|
query on the same connection fail.
|
||||||
|
"""
|
||||||
|
if looks_like_uuid(identifier):
|
||||||
|
doc = self.get(identifier, user_id)
|
||||||
|
if doc is not None:
|
||||||
|
return doc
|
||||||
|
legacy = self.get_by_legacy_id(identifier)
|
||||||
|
if legacy and legacy.get("user_id") == user_id:
|
||||||
|
return legacy
|
||||||
|
return None
|
||||||
128
application/storage/db/repositories/pending_tool_state.py
Normal file
128
application/storage/db/repositories/pending_tool_state.py
Normal file
@@ -0,0 +1,128 @@
|
|||||||
|
"""Repository for the ``pending_tool_state`` table.
|
||||||
|
|
||||||
|
Mirrors the continuation service's three operations on
|
||||||
|
``pending_tool_state`` in Mongo:
|
||||||
|
|
||||||
|
- save_state → upsert (INSERT ... ON CONFLICT DO UPDATE)
|
||||||
|
- load_state → find_one by (conversation_id, user_id)
|
||||||
|
- delete_state → delete_one by (conversation_id, user_id)
|
||||||
|
|
||||||
|
Plus a cleanup method for the Celery beat task that replaces Mongo's
|
||||||
|
TTL index.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from sqlalchemy import Connection, text
|
||||||
|
|
||||||
|
from application.storage.db.base_repository import row_to_dict
|
||||||
|
|
||||||
|
PENDING_STATE_TTL_SECONDS = 30 * 60 # 1800 seconds
|
||||||
|
|
||||||
|
|
||||||
|
class PendingToolStateRepository:
|
||||||
|
def __init__(self, conn: Connection) -> None:
|
||||||
|
self._conn = conn
|
||||||
|
|
||||||
|
def save_state(
|
||||||
|
self,
|
||||||
|
conversation_id: str,
|
||||||
|
user_id: str,
|
||||||
|
*,
|
||||||
|
messages: list,
|
||||||
|
pending_tool_calls: list,
|
||||||
|
tools_dict: dict,
|
||||||
|
tool_schemas: list,
|
||||||
|
agent_config: dict,
|
||||||
|
client_tools: list | None = None,
|
||||||
|
ttl_seconds: int = PENDING_STATE_TTL_SECONDS,
|
||||||
|
) -> dict:
|
||||||
|
"""Upsert pending tool state.
|
||||||
|
|
||||||
|
Mirrors Mongo's ``replace_one(..., upsert=True)``.
|
||||||
|
"""
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
expires = datetime.fromtimestamp(
|
||||||
|
now.timestamp() + ttl_seconds, tz=timezone.utc,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = self._conn.execute(
|
||||||
|
text(
|
||||||
|
"""
|
||||||
|
INSERT INTO pending_tool_state
|
||||||
|
(conversation_id, user_id, messages, pending_tool_calls,
|
||||||
|
tools_dict, tool_schemas, agent_config, client_tools,
|
||||||
|
created_at, expires_at)
|
||||||
|
VALUES
|
||||||
|
(CAST(:conv_id AS uuid), :user_id,
|
||||||
|
CAST(:messages AS jsonb), CAST(:pending AS jsonb),
|
||||||
|
CAST(:tools_dict AS jsonb), CAST(:schemas AS jsonb),
|
||||||
|
CAST(:agent_config AS jsonb), CAST(:client_tools AS jsonb),
|
||||||
|
:created_at, :expires_at)
|
||||||
|
ON CONFLICT (conversation_id, user_id) DO UPDATE SET
|
||||||
|
messages = EXCLUDED.messages,
|
||||||
|
pending_tool_calls = EXCLUDED.pending_tool_calls,
|
||||||
|
tools_dict = EXCLUDED.tools_dict,
|
||||||
|
tool_schemas = EXCLUDED.tool_schemas,
|
||||||
|
agent_config = EXCLUDED.agent_config,
|
||||||
|
client_tools = EXCLUDED.client_tools,
|
||||||
|
created_at = EXCLUDED.created_at,
|
||||||
|
expires_at = EXCLUDED.expires_at
|
||||||
|
RETURNING *
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
{
|
||||||
|
"conv_id": conversation_id,
|
||||||
|
"user_id": user_id,
|
||||||
|
"messages": json.dumps(messages),
|
||||||
|
"pending": json.dumps(pending_tool_calls),
|
||||||
|
"tools_dict": json.dumps(tools_dict),
|
||||||
|
"schemas": json.dumps(tool_schemas),
|
||||||
|
"agent_config": json.dumps(agent_config),
|
||||||
|
"client_tools": json.dumps(client_tools) if client_tools is not None else None,
|
||||||
|
"created_at": now,
|
||||||
|
"expires_at": expires,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return row_to_dict(result.fetchone())
|
||||||
|
|
||||||
|
def load_state(self, conversation_id: str, user_id: str) -> Optional[dict]:
|
||||||
|
result = self._conn.execute(
|
||||||
|
text(
|
||||||
|
"SELECT * FROM pending_tool_state "
|
||||||
|
"WHERE conversation_id = CAST(:conv_id AS uuid) "
|
||||||
|
"AND user_id = :user_id"
|
||||||
|
),
|
||||||
|
{"conv_id": conversation_id, "user_id": user_id},
|
||||||
|
)
|
||||||
|
row = result.fetchone()
|
||||||
|
return row_to_dict(row) if row is not None else None
|
||||||
|
|
||||||
|
def delete_state(self, conversation_id: str, user_id: str) -> bool:
|
||||||
|
result = self._conn.execute(
|
||||||
|
text(
|
||||||
|
"DELETE FROM pending_tool_state "
|
||||||
|
"WHERE conversation_id = CAST(:conv_id AS uuid) "
|
||||||
|
"AND user_id = :user_id"
|
||||||
|
),
|
||||||
|
{"conv_id": conversation_id, "user_id": user_id},
|
||||||
|
)
|
||||||
|
return result.rowcount > 0
|
||||||
|
|
||||||
|
def cleanup_expired(self) -> int:
|
||||||
|
"""Delete rows where ``expires_at < now()``.
|
||||||
|
|
||||||
|
Replaces Mongo's ``expireAfterSeconds=0`` TTL index. Intended to
|
||||||
|
be called from a Celery beat task every 60 seconds.
|
||||||
|
"""
|
||||||
|
# clock_timestamp() — not now() — since the latter is frozen to the
|
||||||
|
# start of the transaction, which would let state that has just
|
||||||
|
# expired survive one more cleanup tick.
|
||||||
|
result = self._conn.execute(
|
||||||
|
text("DELETE FROM pending_tool_state WHERE expires_at < clock_timestamp()")
|
||||||
|
)
|
||||||
|
return result.rowcount
|
||||||
178
application/storage/db/repositories/prompts.py
Normal file
178
application/storage/db/repositories/prompts.py
Normal file
@@ -0,0 +1,178 @@
|
|||||||
|
"""Repository for the ``prompts`` table.
|
||||||
|
|
||||||
|
Covers every operation the legacy Mongo code performs on
|
||||||
|
``prompts_collection``:
|
||||||
|
|
||||||
|
1. ``insert_one`` in prompts/routes.py (create)
|
||||||
|
2. ``find`` by user in prompts/routes.py (list)
|
||||||
|
3. ``find_one`` by id+user in prompts/routes.py (get single)
|
||||||
|
4. ``find_one`` by id only in stream_processor.py (get content for rendering)
|
||||||
|
5. ``update_one`` in prompts/routes.py (update name+content)
|
||||||
|
6. ``delete_one`` in prompts/routes.py (delete)
|
||||||
|
7. ``find_one`` + ``insert_one`` in seeder.py (upsert by user+name+content)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from sqlalchemy import Connection, text
|
||||||
|
|
||||||
|
from application.storage.db.base_repository import looks_like_uuid, row_to_dict
|
||||||
|
|
||||||
|
|
||||||
|
class PromptsRepository:
|
||||||
|
"""Postgres-backed replacement for Mongo ``prompts_collection``."""
|
||||||
|
|
||||||
|
def __init__(self, conn: Connection) -> None:
|
||||||
|
self._conn = conn
|
||||||
|
|
||||||
|
def create(
|
||||||
|
self,
|
||||||
|
user_id: str,
|
||||||
|
name: str,
|
||||||
|
content: str,
|
||||||
|
*,
|
||||||
|
legacy_mongo_id: str | None = None,
|
||||||
|
) -> dict:
|
||||||
|
sql = """
|
||||||
|
INSERT INTO prompts (user_id, name, content, legacy_mongo_id)
|
||||||
|
VALUES (:user_id, :name, :content, :legacy_mongo_id)
|
||||||
|
RETURNING *
|
||||||
|
"""
|
||||||
|
result = self._conn.execute(
|
||||||
|
text(sql),
|
||||||
|
{
|
||||||
|
"user_id": user_id,
|
||||||
|
"name": name,
|
||||||
|
"content": content,
|
||||||
|
"legacy_mongo_id": legacy_mongo_id,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return row_to_dict(result.fetchone())
|
||||||
|
|
||||||
|
def get(self, prompt_id: str, user_id: str) -> Optional[dict]:
|
||||||
|
result = self._conn.execute(
|
||||||
|
text("SELECT * FROM prompts WHERE id = CAST(:id AS uuid) AND user_id = :user_id"),
|
||||||
|
{"id": prompt_id, "user_id": user_id},
|
||||||
|
)
|
||||||
|
row = result.fetchone()
|
||||||
|
return row_to_dict(row) if row is not None else None
|
||||||
|
|
||||||
|
def get_by_legacy_id(self, legacy_mongo_id: str, user_id: str | None = None) -> Optional[dict]:
|
||||||
|
"""Fetch a prompt by the original Mongo ObjectId string."""
|
||||||
|
legacy_mongo_id = str(legacy_mongo_id) if legacy_mongo_id is not None else None
|
||||||
|
sql = "SELECT * FROM prompts WHERE legacy_mongo_id = :legacy_id"
|
||||||
|
params: dict[str, str] = {"legacy_id": legacy_mongo_id}
|
||||||
|
if user_id is not None:
|
||||||
|
sql += " AND user_id = :user_id"
|
||||||
|
params["user_id"] = user_id
|
||||||
|
result = self._conn.execute(text(sql), params)
|
||||||
|
row = result.fetchone()
|
||||||
|
return row_to_dict(row) if row is not None else None
|
||||||
|
|
||||||
|
def get_any(self, identifier: str, user_id: str) -> Optional[dict]:
|
||||||
|
"""Resolve a prompt by PG UUID or legacy Mongo ObjectId.
|
||||||
|
|
||||||
|
Picks the lookup path from the id shape so non-UUID input never
|
||||||
|
reaches ``CAST(:id AS uuid)`` — that cast raises on the server
|
||||||
|
and poisons the enclosing transaction, making any subsequent
|
||||||
|
query on the same connection fail.
|
||||||
|
"""
|
||||||
|
if looks_like_uuid(identifier):
|
||||||
|
doc = self.get(identifier, user_id)
|
||||||
|
if doc is not None:
|
||||||
|
return doc
|
||||||
|
return self.get_by_legacy_id(identifier, user_id)
|
||||||
|
|
||||||
|
def get_for_rendering(self, prompt_id: str) -> Optional[dict]:
|
||||||
|
"""Fetch prompt content by ID without user scoping.
|
||||||
|
|
||||||
|
Used only by stream_processor to render a prompt whose owner is
|
||||||
|
not known at call time. Do NOT use in user-facing routes.
|
||||||
|
"""
|
||||||
|
result = self._conn.execute(
|
||||||
|
text("SELECT * FROM prompts WHERE id = CAST(:id AS uuid)"),
|
||||||
|
{"id": prompt_id},
|
||||||
|
)
|
||||||
|
row = result.fetchone()
|
||||||
|
return row_to_dict(row) if row is not None else None
|
||||||
|
|
||||||
|
def list_for_user(self, user_id: str) -> list[dict]:
|
||||||
|
result = self._conn.execute(
|
||||||
|
text("SELECT * FROM prompts WHERE user_id = :user_id ORDER BY created_at"),
|
||||||
|
{"user_id": user_id},
|
||||||
|
)
|
||||||
|
return [row_to_dict(r) for r in result.fetchall()]
|
||||||
|
|
||||||
|
def update(self, prompt_id: str, user_id: str, name: str, content: str) -> None:
|
||||||
|
self._conn.execute(
|
||||||
|
text(
|
||||||
|
"""
|
||||||
|
UPDATE prompts
|
||||||
|
SET name = :name, content = :content, updated_at = now()
|
||||||
|
WHERE id = CAST(:id AS uuid) AND user_id = :user_id
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
{"id": prompt_id, "user_id": user_id, "name": name, "content": content},
|
||||||
|
)
|
||||||
|
|
||||||
|
def update_by_legacy_id(
|
||||||
|
self,
|
||||||
|
legacy_mongo_id: str,
|
||||||
|
user_id: str,
|
||||||
|
name: str,
|
||||||
|
content: str,
|
||||||
|
) -> bool:
|
||||||
|
"""Update a prompt addressed by the Mongo ObjectId string."""
|
||||||
|
legacy_mongo_id = str(legacy_mongo_id) if legacy_mongo_id is not None else None
|
||||||
|
result = self._conn.execute(
|
||||||
|
text(
|
||||||
|
"""
|
||||||
|
UPDATE prompts
|
||||||
|
SET name = :name, content = :content, updated_at = now()
|
||||||
|
WHERE legacy_mongo_id = :legacy_id AND user_id = :user_id
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
{
|
||||||
|
"legacy_id": legacy_mongo_id,
|
||||||
|
"user_id": user_id,
|
||||||
|
"name": name,
|
||||||
|
"content": content,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return result.rowcount > 0
|
||||||
|
|
||||||
|
def delete(self, prompt_id: str, user_id: str) -> None:
|
||||||
|
self._conn.execute(
|
||||||
|
text("DELETE FROM prompts WHERE id = CAST(:id AS uuid) AND user_id = :user_id"),
|
||||||
|
{"id": prompt_id, "user_id": user_id},
|
||||||
|
)
|
||||||
|
|
||||||
|
def delete_by_legacy_id(self, legacy_mongo_id: str, user_id: str) -> bool:
|
||||||
|
"""Delete a prompt addressed by the Mongo ObjectId string."""
|
||||||
|
legacy_mongo_id = str(legacy_mongo_id) if legacy_mongo_id is not None else None
|
||||||
|
result = self._conn.execute(
|
||||||
|
text(
|
||||||
|
"DELETE FROM prompts "
|
||||||
|
"WHERE legacy_mongo_id = :legacy_id AND user_id = :user_id"
|
||||||
|
),
|
||||||
|
{"legacy_id": legacy_mongo_id, "user_id": user_id},
|
||||||
|
)
|
||||||
|
return result.rowcount > 0
|
||||||
|
|
||||||
|
def find_or_create(self, user_id: str, name: str, content: str) -> dict:
|
||||||
|
"""Return existing prompt matching (user, name, content), or create one.
|
||||||
|
|
||||||
|
Used by the seeder to avoid duplicating template prompts.
|
||||||
|
"""
|
||||||
|
result = self._conn.execute(
|
||||||
|
text(
|
||||||
|
"SELECT * FROM prompts WHERE user_id = :user_id AND name = :name AND content = :content"
|
||||||
|
),
|
||||||
|
{"user_id": user_id, "name": name, "content": content},
|
||||||
|
)
|
||||||
|
row = result.fetchone()
|
||||||
|
if row is not None:
|
||||||
|
return row_to_dict(row)
|
||||||
|
return self.create(user_id, name, content)
|
||||||
213
application/storage/db/repositories/shared_conversations.py
Normal file
213
application/storage/db/repositories/shared_conversations.py
Normal file
@@ -0,0 +1,213 @@
|
|||||||
|
"""Repository for the ``shared_conversations`` table.
|
||||||
|
|
||||||
|
Covers the sharing operations from ``shared_conversations_collections``
|
||||||
|
in Mongo:
|
||||||
|
|
||||||
|
- create a share record (with UUID, conversation_id, user, visibility flags)
|
||||||
|
- look up by uuid (public access)
|
||||||
|
- look up by conversation_id + user + flags (dedup check)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import uuid as uuid_mod
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from sqlalchemy import Connection, text
|
||||||
|
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||||
|
|
||||||
|
from application.storage.db.base_repository import looks_like_uuid, row_to_dict
|
||||||
|
from application.storage.db.models import shared_conversations_table
|
||||||
|
|
||||||
|
|
||||||
|
class SharedConversationsRepository:
|
||||||
|
def __init__(self, conn: Connection) -> None:
|
||||||
|
self._conn = conn
|
||||||
|
|
||||||
|
def create(
|
||||||
|
self,
|
||||||
|
conversation_id: str,
|
||||||
|
user_id: str,
|
||||||
|
*,
|
||||||
|
is_promptable: bool = False,
|
||||||
|
first_n_queries: int = 0,
|
||||||
|
api_key: str | None = None,
|
||||||
|
prompt_id: str | None = None,
|
||||||
|
chunks: int | None = None,
|
||||||
|
share_uuid: str | None = None,
|
||||||
|
) -> dict:
|
||||||
|
"""Create a share record.
|
||||||
|
|
||||||
|
``share_uuid`` allows the dual-write caller to supply the same
|
||||||
|
UUID that Mongo received, so public ``/shared/{uuid}`` links
|
||||||
|
keep resolving from both stores during the dual-write window.
|
||||||
|
|
||||||
|
Callers that need race-free dedup on the logical share key
|
||||||
|
should use :meth:`get_or_create` instead — it relies on the
|
||||||
|
composite partial unique index added in migration 0008 to
|
||||||
|
collapse concurrent requests to a single row.
|
||||||
|
"""
|
||||||
|
final_uuid = share_uuid or str(uuid_mod.uuid4())
|
||||||
|
values: dict = {
|
||||||
|
"uuid": final_uuid,
|
||||||
|
"conversation_id": conversation_id,
|
||||||
|
"user_id": user_id,
|
||||||
|
"is_promptable": is_promptable,
|
||||||
|
"first_n_queries": first_n_queries,
|
||||||
|
}
|
||||||
|
if api_key:
|
||||||
|
values["api_key"] = api_key
|
||||||
|
if prompt_id:
|
||||||
|
values["prompt_id"] = prompt_id
|
||||||
|
if chunks is not None:
|
||||||
|
values["chunks"] = chunks
|
||||||
|
|
||||||
|
stmt = (
|
||||||
|
pg_insert(shared_conversations_table)
|
||||||
|
.values(**values)
|
||||||
|
.returning(shared_conversations_table)
|
||||||
|
)
|
||||||
|
result = self._conn.execute(stmt)
|
||||||
|
return row_to_dict(result.fetchone())
|
||||||
|
|
||||||
|
def get_or_create(
|
||||||
|
self,
|
||||||
|
conversation_id: str,
|
||||||
|
user_id: str,
|
||||||
|
*,
|
||||||
|
is_promptable: bool = False,
|
||||||
|
first_n_queries: int = 0,
|
||||||
|
api_key: str | None = None,
|
||||||
|
prompt_id: str | None = None,
|
||||||
|
chunks: int | None = None,
|
||||||
|
share_uuid: str | None = None,
|
||||||
|
) -> dict:
|
||||||
|
"""Race-free share create/lookup keyed on the logical dedup tuple.
|
||||||
|
|
||||||
|
Leverages the partial unique index on
|
||||||
|
``(conversation_id, user_id, is_promptable, first_n_queries,
|
||||||
|
COALESCE(api_key, ''))`` added in migration 0008. Concurrent
|
||||||
|
requests for the same logical share converge on one row. The
|
||||||
|
returned dict's ``uuid`` is the canonical public identifier.
|
||||||
|
|
||||||
|
Dedup key rationale — ``prompt_id`` and ``chunks`` are
|
||||||
|
deliberately *not* part of the uniqueness key. A share row is
|
||||||
|
identified by "who shared what conversation under which
|
||||||
|
visibility rules"; ``prompt_id`` / ``chunks`` are mutable
|
||||||
|
properties of that share and are last-write-wins on re-share.
|
||||||
|
This preserves existing public ``/shared/{uuid}`` URLs when a
|
||||||
|
user updates the prompt or chunk count, matching the Mongo
|
||||||
|
``find_one`` + ``update`` semantics.
|
||||||
|
"""
|
||||||
|
final_uuid = share_uuid or str(uuid_mod.uuid4())
|
||||||
|
result = self._conn.execute(
|
||||||
|
text(
|
||||||
|
"""
|
||||||
|
INSERT INTO shared_conversations
|
||||||
|
(uuid, conversation_id, user_id, is_promptable,
|
||||||
|
first_n_queries, api_key, prompt_id, chunks)
|
||||||
|
VALUES
|
||||||
|
(CAST(:uuid AS uuid), CAST(:conversation_id AS uuid),
|
||||||
|
:user_id, :is_promptable, :first_n_queries,
|
||||||
|
:api_key, CAST(:prompt_id AS uuid), :chunks)
|
||||||
|
ON CONFLICT (conversation_id, user_id, is_promptable,
|
||||||
|
first_n_queries, COALESCE(api_key, ''))
|
||||||
|
DO UPDATE SET prompt_id = EXCLUDED.prompt_id,
|
||||||
|
chunks = EXCLUDED.chunks
|
||||||
|
RETURNING *
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
{
|
||||||
|
"uuid": final_uuid,
|
||||||
|
"conversation_id": conversation_id,
|
||||||
|
"user_id": user_id,
|
||||||
|
"is_promptable": is_promptable,
|
||||||
|
"first_n_queries": first_n_queries,
|
||||||
|
"api_key": api_key,
|
||||||
|
"prompt_id": prompt_id,
|
||||||
|
"chunks": chunks,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return row_to_dict(result.fetchone())
|
||||||
|
|
||||||
|
def find_by_uuid(self, share_uuid: str) -> Optional[dict]:
|
||||||
|
# Shape-gate: the public ``/api/shared_conversation/<identifier>``
|
||||||
|
# endpoint threads the URL path segment straight here. A non-UUID
|
||||||
|
# (e.g. a legacy Mongo ObjectId still embedded in an old link or
|
||||||
|
# an outright garbage path) must resolve to ``None`` rather than
|
||||||
|
# raise — the CAST would otherwise poison the txn and mask the
|
||||||
|
# real "not found" response behind a generic 400.
|
||||||
|
if not looks_like_uuid(share_uuid):
|
||||||
|
return None
|
||||||
|
result = self._conn.execute(
|
||||||
|
text(
|
||||||
|
"SELECT * FROM shared_conversations "
|
||||||
|
"WHERE uuid = CAST(:uuid AS uuid)"
|
||||||
|
),
|
||||||
|
{"uuid": share_uuid},
|
||||||
|
)
|
||||||
|
row = result.fetchone()
|
||||||
|
return row_to_dict(row) if row is not None else None
|
||||||
|
|
||||||
|
def find_existing(
|
||||||
|
self,
|
||||||
|
conversation_id: str,
|
||||||
|
user_id: str,
|
||||||
|
is_promptable: bool,
|
||||||
|
first_n_queries: int,
|
||||||
|
api_key: str | None = None,
|
||||||
|
) -> Optional[dict]:
|
||||||
|
"""Check for an existing share with matching parameters.
|
||||||
|
|
||||||
|
Mirrors the Mongo ``find_one`` dedup check before creating a share.
|
||||||
|
"""
|
||||||
|
if api_key:
|
||||||
|
result = self._conn.execute(
|
||||||
|
text(
|
||||||
|
"SELECT * FROM shared_conversations "
|
||||||
|
"WHERE conversation_id = CAST(:conv_id AS uuid) "
|
||||||
|
"AND user_id = :user_id "
|
||||||
|
"AND is_promptable = :is_promptable "
|
||||||
|
"AND first_n_queries = :fnq "
|
||||||
|
"AND api_key = :api_key "
|
||||||
|
"LIMIT 1"
|
||||||
|
),
|
||||||
|
{
|
||||||
|
"conv_id": conversation_id,
|
||||||
|
"user_id": user_id,
|
||||||
|
"is_promptable": is_promptable,
|
||||||
|
"fnq": first_n_queries,
|
||||||
|
"api_key": api_key,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
result = self._conn.execute(
|
||||||
|
text(
|
||||||
|
"SELECT * FROM shared_conversations "
|
||||||
|
"WHERE conversation_id = CAST(:conv_id AS uuid) "
|
||||||
|
"AND user_id = :user_id "
|
||||||
|
"AND is_promptable = :is_promptable "
|
||||||
|
"AND first_n_queries = :fnq "
|
||||||
|
"AND api_key IS NULL "
|
||||||
|
"LIMIT 1"
|
||||||
|
),
|
||||||
|
{
|
||||||
|
"conv_id": conversation_id,
|
||||||
|
"user_id": user_id,
|
||||||
|
"is_promptable": is_promptable,
|
||||||
|
"fnq": first_n_queries,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
row = result.fetchone()
|
||||||
|
return row_to_dict(row) if row is not None else None
|
||||||
|
|
||||||
|
def list_for_conversation(self, conversation_id: str) -> list[dict]:
|
||||||
|
result = self._conn.execute(
|
||||||
|
text(
|
||||||
|
"SELECT * FROM shared_conversations "
|
||||||
|
"WHERE conversation_id = CAST(:conv_id AS uuid) "
|
||||||
|
"ORDER BY created_at DESC"
|
||||||
|
),
|
||||||
|
{"conv_id": conversation_id},
|
||||||
|
)
|
||||||
|
return [row_to_dict(r) for r in result.fetchall()]
|
||||||
324
application/storage/db/repositories/sources.py
Normal file
324
application/storage/db/repositories/sources.py
Normal file
@@ -0,0 +1,324 @@
|
|||||||
|
"""Repository for the ``sources`` table."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
from sqlalchemy import Connection, func, select, text
|
||||||
|
|
||||||
|
from application.storage.db.base_repository import looks_like_uuid, row_to_dict
|
||||||
|
from application.storage.db.models import sources_table
|
||||||
|
|
||||||
|
|
||||||
|
_SCALAR_COLUMNS = {
|
||||||
|
"name", "type", "retriever", "sync_frequency", "tokens", "file_path",
|
||||||
|
"language", "model", "date",
|
||||||
|
}
|
||||||
|
_JSONB_COLUMNS = {"metadata", "remote_data", "directory_structure", "file_name_map"}
|
||||||
|
_ALLOWED_COLUMNS = _SCALAR_COLUMNS | _JSONB_COLUMNS
|
||||||
|
|
||||||
|
# Whitelist for sort columns exposed via ``list_for_user``. Anything not in
|
||||||
|
# this set falls back to ``date`` so user-supplied sort params can't be
|
||||||
|
# interpolated into SQL unchecked.
|
||||||
|
_SORTABLE_COLUMNS = {"date", "name", "tokens", "type", "created_at", "updated_at"}
|
||||||
|
|
||||||
|
|
||||||
|
def _escape_like(pattern: str) -> str:
|
||||||
|
"""Escape wildcards so a user-supplied substring is matched literally.
|
||||||
|
|
||||||
|
We use ``LIKE ESCAPE '\\'`` on the query side so backslash, percent, and
|
||||||
|
underscore in the input don't accidentally turn into regex-like wildcards.
|
||||||
|
"""
|
||||||
|
return (
|
||||||
|
pattern
|
||||||
|
.replace("\\", "\\\\")
|
||||||
|
.replace("%", "\\%")
|
||||||
|
.replace("_", "\\_")
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _coerce_jsonb(value: Any) -> Any:
|
||||||
|
"""Normalize incoming JSONB values for the Core ``Table.update()`` path.
|
||||||
|
|
||||||
|
``remote_data`` in particular arrives as either a dict or a JSON string
|
||||||
|
(the legacy Mongo docs stored both shapes). Strings are parsed so the
|
||||||
|
stored representation is always structured JSONB; dicts/lists pass
|
||||||
|
through untouched for the SQLAlchemy JSONB type processor.
|
||||||
|
"""
|
||||||
|
if value is None:
|
||||||
|
return None
|
||||||
|
if isinstance(value, (dict, list)):
|
||||||
|
return value
|
||||||
|
if isinstance(value, str):
|
||||||
|
stripped = value.strip()
|
||||||
|
if not stripped:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
return json.loads(stripped)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
return {"raw": value}
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
class SourcesRepository:
|
||||||
|
def __init__(self, conn: Connection) -> None:
|
||||||
|
self._conn = conn
|
||||||
|
|
||||||
|
def create(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
*,
|
||||||
|
source_id: Optional[str] = None,
|
||||||
|
user_id: str,
|
||||||
|
type: Optional[str] = None,
|
||||||
|
metadata: Optional[dict] = None,
|
||||||
|
retriever: Optional[str] = None,
|
||||||
|
sync_frequency: Optional[str] = None,
|
||||||
|
tokens: Optional[str] = None,
|
||||||
|
file_path: Optional[str] = None,
|
||||||
|
remote_data: Any = None,
|
||||||
|
directory_structure: Any = None,
|
||||||
|
file_name_map: Any = None,
|
||||||
|
language: Optional[str] = None,
|
||||||
|
model: Optional[str] = None,
|
||||||
|
date: Any = None,
|
||||||
|
legacy_mongo_id: Optional[str] = None,
|
||||||
|
) -> dict:
|
||||||
|
result = self._conn.execute(
|
||||||
|
text(
|
||||||
|
"""
|
||||||
|
INSERT INTO sources (
|
||||||
|
id, user_id, name, type, metadata,
|
||||||
|
retriever, sync_frequency, tokens, file_path,
|
||||||
|
remote_data, directory_structure, file_name_map,
|
||||||
|
language, model, date, legacy_mongo_id
|
||||||
|
)
|
||||||
|
VALUES (
|
||||||
|
COALESCE(CAST(:source_id AS uuid), gen_random_uuid()),
|
||||||
|
:user_id, :name, :type, CAST(:metadata AS jsonb),
|
||||||
|
:retriever, :sync_frequency, :tokens, :file_path,
|
||||||
|
CAST(:remote_data AS jsonb),
|
||||||
|
CAST(:directory_structure AS jsonb),
|
||||||
|
CAST(:file_name_map AS jsonb),
|
||||||
|
:language, :model,
|
||||||
|
COALESCE(:date, now()),
|
||||||
|
:legacy_mongo_id
|
||||||
|
)
|
||||||
|
RETURNING *
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
{
|
||||||
|
"source_id": source_id,
|
||||||
|
"user_id": user_id,
|
||||||
|
"name": name,
|
||||||
|
"type": type,
|
||||||
|
"metadata": json.dumps(metadata or {}),
|
||||||
|
"retriever": retriever,
|
||||||
|
"sync_frequency": sync_frequency,
|
||||||
|
"tokens": tokens,
|
||||||
|
"file_path": file_path,
|
||||||
|
"remote_data": (
|
||||||
|
None if remote_data is None
|
||||||
|
else json.dumps(_coerce_jsonb(remote_data))
|
||||||
|
),
|
||||||
|
"directory_structure": (
|
||||||
|
None if directory_structure is None
|
||||||
|
else json.dumps(_coerce_jsonb(directory_structure))
|
||||||
|
),
|
||||||
|
"file_name_map": (
|
||||||
|
None if file_name_map is None
|
||||||
|
else json.dumps(_coerce_jsonb(file_name_map))
|
||||||
|
),
|
||||||
|
"language": language,
|
||||||
|
"model": model,
|
||||||
|
"date": date,
|
||||||
|
"legacy_mongo_id": legacy_mongo_id,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return row_to_dict(result.fetchone())
|
||||||
|
|
||||||
|
def get(self, source_id: str, user_id: str) -> Optional[dict]:
|
||||||
|
result = self._conn.execute(
|
||||||
|
text("SELECT * FROM sources WHERE id = CAST(:id AS uuid) AND user_id = :user_id"),
|
||||||
|
{"id": source_id, "user_id": user_id},
|
||||||
|
)
|
||||||
|
row = result.fetchone()
|
||||||
|
return row_to_dict(row) if row is not None else None
|
||||||
|
|
||||||
|
def get_any(self, source_id: str, user_id: str) -> Optional[dict]:
|
||||||
|
"""Resolve a source by either PG UUID or legacy Mongo ObjectId string.
|
||||||
|
|
||||||
|
Cutover helper: URLs / bookmarks may still hold Mongo ObjectIds.
|
||||||
|
Tries the UUID path first, then falls back to ``legacy_mongo_id``.
|
||||||
|
Both paths are scoped by ``user_id``.
|
||||||
|
"""
|
||||||
|
if looks_like_uuid(source_id):
|
||||||
|
row = self.get(source_id, user_id)
|
||||||
|
if row is not None:
|
||||||
|
return row
|
||||||
|
return self.get_by_legacy_id(source_id, user_id)
|
||||||
|
|
||||||
|
def list_for_user(
|
||||||
|
self,
|
||||||
|
user_id: str,
|
||||||
|
*,
|
||||||
|
limit: Optional[int] = None,
|
||||||
|
offset: int = 0,
|
||||||
|
search_term: Optional[str] = None,
|
||||||
|
sort_field: str = "created_at",
|
||||||
|
sort_order: str = "desc",
|
||||||
|
) -> list[dict]:
|
||||||
|
"""Return sources owned by ``user_id``, paginated and optionally filtered.
|
||||||
|
|
||||||
|
All pagination, filtering, and sorting are pushed into SQL so large
|
||||||
|
accounts don't materialize their full source list in Python for every
|
||||||
|
page. See ``PaginatedSources`` in the sources routes for the matching
|
||||||
|
call site.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: Scope rows to this owner.
|
||||||
|
limit: Page size. ``None`` returns every matching row (legacy
|
||||||
|
full-list path used by ``CombinedJson``).
|
||||||
|
offset: Rows to skip before collecting ``limit`` results.
|
||||||
|
search_term: Case-insensitive substring filter on ``name``.
|
||||||
|
``%`` and ``_`` in the input are escaped so they match
|
||||||
|
literally rather than as LIKE wildcards.
|
||||||
|
sort_field: Column to sort by. Unknown values fall back to
|
||||||
|
``date``. Resolved against ``sources_table.c`` so the
|
||||||
|
column identity is bound by SQLAlchemy — user input never
|
||||||
|
reaches the emitted SQL as a string.
|
||||||
|
sort_order: ``"asc"`` or ``"desc"``; anything else is treated
|
||||||
|
as ``"desc"``.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of source rows as plain dicts (via ``row_to_dict``).
|
||||||
|
"""
|
||||||
|
column_name = sort_field if sort_field in _SORTABLE_COLUMNS else "date"
|
||||||
|
sort_column = sources_table.c[column_name]
|
||||||
|
ascending = sort_order.lower() == "asc"
|
||||||
|
|
||||||
|
stmt = select(sources_table).where(sources_table.c.user_id == user_id)
|
||||||
|
if search_term:
|
||||||
|
stmt = stmt.where(
|
||||||
|
sources_table.c.name.ilike(
|
||||||
|
f"%{_escape_like(search_term)}%",
|
||||||
|
escape="\\",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# ``id`` is appended as a stable tiebreaker so paginated windows
|
||||||
|
# are deterministic across equal sort keys.
|
||||||
|
id_column = sources_table.c.id
|
||||||
|
if ascending:
|
||||||
|
stmt = stmt.order_by(sort_column.asc(), id_column.asc())
|
||||||
|
else:
|
||||||
|
stmt = stmt.order_by(sort_column.desc(), id_column.desc())
|
||||||
|
|
||||||
|
if limit is not None:
|
||||||
|
stmt = stmt.limit(limit).offset(offset)
|
||||||
|
|
||||||
|
result = self._conn.execute(stmt)
|
||||||
|
return [row_to_dict(r) for r in result.fetchall()]
|
||||||
|
|
||||||
|
def count_for_user(
|
||||||
|
self,
|
||||||
|
user_id: str,
|
||||||
|
*,
|
||||||
|
search_term: Optional[str] = None,
|
||||||
|
) -> int:
|
||||||
|
"""Return the count of rows that ``list_for_user`` would produce.
|
||||||
|
|
||||||
|
The filter mirrors ``list_for_user`` exactly so ``total`` and the
|
||||||
|
paginated window stay consistent page-to-page.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: Scope rows to this owner.
|
||||||
|
search_term: Same substring filter semantics as
|
||||||
|
``list_for_user``; ``None``/empty disables the filter.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The total number of matching rows.
|
||||||
|
"""
|
||||||
|
stmt = (
|
||||||
|
select(func.count())
|
||||||
|
.select_from(sources_table)
|
||||||
|
.where(sources_table.c.user_id == user_id)
|
||||||
|
)
|
||||||
|
if search_term:
|
||||||
|
stmt = stmt.where(
|
||||||
|
sources_table.c.name.ilike(
|
||||||
|
f"%{_escape_like(search_term)}%",
|
||||||
|
escape="\\",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
result = self._conn.execute(stmt)
|
||||||
|
row = result.fetchone()
|
||||||
|
return int(row[0]) if row is not None else 0
|
||||||
|
|
||||||
|
def update(self, source_id: str, user_id: str, fields: dict) -> None:
|
||||||
|
filtered = {k: v for k, v in fields.items() if k in _ALLOWED_COLUMNS}
|
||||||
|
if not filtered:
|
||||||
|
return
|
||||||
|
|
||||||
|
values: dict = {}
|
||||||
|
for col, val in filtered.items():
|
||||||
|
values[col] = _coerce_jsonb(val) if col in _JSONB_COLUMNS else val
|
||||||
|
values["updated_at"] = func.now()
|
||||||
|
|
||||||
|
t = sources_table
|
||||||
|
stmt = (
|
||||||
|
t.update()
|
||||||
|
.where(t.c.id == source_id)
|
||||||
|
.where(t.c.user_id == user_id)
|
||||||
|
.values(**values)
|
||||||
|
)
|
||||||
|
self._conn.execute(stmt)
|
||||||
|
|
||||||
|
def get_by_legacy_id(
|
||||||
|
self, legacy_mongo_id: str, user_id: Optional[str] = None,
|
||||||
|
) -> Optional[dict]:
|
||||||
|
legacy_mongo_id = str(legacy_mongo_id) if legacy_mongo_id is not None else None
|
||||||
|
sql = "SELECT * FROM sources WHERE legacy_mongo_id = :legacy_id"
|
||||||
|
params: dict[str, str] = {"legacy_id": legacy_mongo_id}
|
||||||
|
if user_id is not None:
|
||||||
|
sql += " AND user_id = :user_id"
|
||||||
|
params["user_id"] = user_id
|
||||||
|
result = self._conn.execute(text(sql), params)
|
||||||
|
row = result.fetchone()
|
||||||
|
return row_to_dict(row) if row is not None else None
|
||||||
|
|
||||||
|
def update_by_legacy_id(
|
||||||
|
self, legacy_mongo_id: str, user_id: str, fields: dict,
|
||||||
|
) -> bool:
|
||||||
|
"""Update a source addressed by the Mongo ObjectId string.
|
||||||
|
|
||||||
|
Used by dual_write call sites that hold the Mongo ``_id`` but
|
||||||
|
haven't resolved the PG UUID yet. Returns ``True`` if a row was
|
||||||
|
updated (i.e. the legacy id was found).
|
||||||
|
"""
|
||||||
|
legacy_mongo_id = str(legacy_mongo_id) if legacy_mongo_id is not None else None
|
||||||
|
row = self.get_by_legacy_id(legacy_mongo_id, user_id)
|
||||||
|
if row is None:
|
||||||
|
return False
|
||||||
|
self.update(str(row["id"]), user_id, fields)
|
||||||
|
return True
|
||||||
|
|
||||||
|
def delete_by_legacy_id(self, legacy_mongo_id: str, user_id: str) -> bool:
|
||||||
|
"""Delete by Mongo ObjectId. Used by dual_write in DeleteOldIndexes."""
|
||||||
|
legacy_mongo_id = str(legacy_mongo_id) if legacy_mongo_id is not None else None
|
||||||
|
result = self._conn.execute(
|
||||||
|
text(
|
||||||
|
"DELETE FROM sources "
|
||||||
|
"WHERE legacy_mongo_id = :legacy_id AND user_id = :user_id"
|
||||||
|
),
|
||||||
|
{"legacy_id": legacy_mongo_id, "user_id": user_id},
|
||||||
|
)
|
||||||
|
return result.rowcount > 0
|
||||||
|
|
||||||
|
def delete(self, source_id: str, user_id: str) -> bool:
|
||||||
|
result = self._conn.execute(
|
||||||
|
text("DELETE FROM sources WHERE id = CAST(:id AS uuid) AND user_id = :user_id"),
|
||||||
|
{"id": source_id, "user_id": user_id},
|
||||||
|
)
|
||||||
|
return result.rowcount > 0
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user