mirror of
https://github.com/arc53/DocsGPT.git
synced 2026-05-15 17:03:47 +00:00
Compare commits
330 Commits
0.15.0
...
dependabot
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
acfa972c40 | ||
|
|
3ceabed8ad | ||
|
|
422a4b139e | ||
|
|
e85935eed0 | ||
|
|
6a69b8aca0 | ||
|
|
33c2cc9660 | ||
|
|
175d4d5a68 | ||
|
|
6c3ead1071 | ||
|
|
d23f88f825 | ||
|
|
da1df515f7 | ||
|
|
671a9d75ad | ||
|
|
1c829667ff | ||
|
|
3ab0ebb16d | ||
|
|
988c4a5a15 | ||
|
|
01db8b2c41 | ||
|
|
ef19da9516 | ||
|
|
cc1275c3f9 | ||
|
|
14c2f4890f | ||
|
|
b3aec36aa2 | ||
|
|
50f62beaeb | ||
|
|
423b4c6494 | ||
|
|
54f615c59d | ||
|
|
223b3de66e | ||
|
|
4db9622ef5 | ||
|
|
e8d1bbfb68 | ||
|
|
aff1345ae4 | ||
|
|
ee430aff1e | ||
|
|
81b6ee5daa | ||
|
|
ebb7938d1b | ||
|
|
7f6360b4ff | ||
|
|
c68f18a0ae | ||
|
|
684b29e73c | ||
|
|
a1efea81d0 | ||
|
|
9eb34262e0 | ||
|
|
951bdb8365 | ||
|
|
c18f85a050 | ||
|
|
5ecb174567 | ||
|
|
ed7212d016 | ||
|
|
f82acdab5d | ||
|
|
361aebc34c | ||
|
|
bf194c1a0f | ||
|
|
54c396750b | ||
|
|
9adebfec69 | ||
|
|
92c321f163 | ||
|
|
e3d36b9e52 | ||
|
|
8950e11208 | ||
|
|
5de0132a65 | ||
|
|
b92ca91512 | ||
|
|
8e0b2844a2 | ||
|
|
0969db5e30 | ||
|
|
af335a27e8 | ||
|
|
0ae3139284 | ||
|
|
7529ca3dd6 | ||
|
|
1b813320f1 | ||
|
|
02012e9a0b | ||
|
|
c2f027265a | ||
|
|
0ae615c10e | ||
|
|
881d0da344 | ||
|
|
1376de6bae | ||
|
|
362ebfcc0a | ||
|
|
bc77eed3d8 | ||
|
|
1f346588e7 | ||
|
|
2fed5c882b | ||
|
|
aa938d76d7 | ||
|
|
2940628aa6 | ||
|
|
7f23928134 | ||
|
|
20e17c84c7 | ||
|
|
389ddf6068 | ||
|
|
1e2443fb90 | ||
|
|
6387bd1892 | ||
|
|
7d22724d1c | ||
|
|
f6f12f6895 | ||
|
|
934127f323 | ||
|
|
1780e3cc91 | ||
|
|
5e7fab2f34 | ||
|
|
92ae76f95e | ||
|
|
18755bdd9b | ||
|
|
0f20adcbf4 | ||
|
|
18e2a829c9 | ||
|
|
cd44501a71 | ||
|
|
f8ebdf3fd4 | ||
|
|
7c6fca18ad | ||
|
|
5fab798707 | ||
|
|
cb30a24e05 | ||
|
|
530761d08c | ||
|
|
73fbc28744 | ||
|
|
b5b6538762 | ||
|
|
a9761061fc | ||
|
|
9388996a15 | ||
|
|
875868b7e5 | ||
|
|
502819ae52 | ||
|
|
cada1a44fc | ||
|
|
6192767451 | ||
|
|
5c3e6eca54 | ||
|
|
59d9d4ac50 | ||
|
|
3931ccccee | ||
|
|
55717043f6 | ||
|
|
ececcb8b17 | ||
|
|
420e9d3dd5 | ||
|
|
749eed3d0b | ||
|
|
bd03a513e3 | ||
|
|
fcdb4fb5e8 | ||
|
|
e787c896eb | ||
|
|
23aeaff5db | ||
|
|
689dd79597 | ||
|
|
0c15af90b1 | ||
|
|
cdd6ff6557 | ||
|
|
72b3d94453 | ||
|
|
7e88d09e5d | ||
|
|
74a4a237dc | ||
|
|
c3f01c6619 | ||
|
|
6b408823d4 | ||
|
|
3fc81ac5d8 | ||
|
|
2652f8a5b0 | ||
|
|
d711eefe96 | ||
|
|
79206f3919 | ||
|
|
de971d9452 | ||
|
|
1b4d5ca0dd | ||
|
|
81989e8258 | ||
|
|
dc262d1698 | ||
|
|
69f9c93869 | ||
|
|
74bf80b25c | ||
|
|
d9a92a7208 | ||
|
|
02e93d993d | ||
|
|
6b6495f48c | ||
|
|
249dd9ce37 | ||
|
|
9134ab0478 | ||
|
|
10ef68c9d0 | ||
|
|
7d65cf1c2b | ||
|
|
13c6cc59c1 | ||
|
|
6381f7dd4e | ||
|
|
e6ac4008fe | ||
|
|
1af09f114d | ||
|
|
be7da983e7 | ||
|
|
8b9e595d85 | ||
|
|
398f3acc8d | ||
|
|
e04baa7ed8 | ||
|
|
e5586b6f20 | ||
|
|
addf57cab7 | ||
|
|
648b3f1d20 | ||
|
|
a75a9e23f9 | ||
|
|
73256389cf | ||
|
|
d609efca49 | ||
|
|
772860b667 | ||
|
|
ea2fd8b04a | ||
|
|
2c73deac20 | ||
|
|
47f3907e5e | ||
|
|
727495c553 | ||
|
|
a3b08a5b44 | ||
|
|
81532ada2a | ||
|
|
43f71374e5 | ||
|
|
d5c0322e2a | ||
|
|
3b66a3176c | ||
|
|
dc6db847ca | ||
|
|
ed0063aada | ||
|
|
9a6a55b6da | ||
|
|
12a8368216 | ||
|
|
3f6d6f15ea | ||
|
|
126fa01b14 | ||
|
|
e06debad5f | ||
|
|
6492852f7d | ||
|
|
00a621f33a | ||
|
|
e92ffc6fdc | ||
|
|
fe185e5b8d | ||
|
|
9f3d9ab860 | ||
|
|
1c0adde380 | ||
|
|
3c56bd0d0b | ||
|
|
86664ebda2 | ||
|
|
db18b743d1 | ||
|
|
9e85cc9065 | ||
|
|
aaaa6f002d | ||
|
|
47dcbcb74b | ||
|
|
ddbfd94193 | ||
|
|
8dec60ab8b | ||
|
|
84b2e4bab4 | ||
|
|
193ca6fd63 | ||
|
|
2afdd7f026 | ||
|
|
f364475f64 | ||
|
|
b254de6ed6 | ||
|
|
08dedcaf95 | ||
|
|
c726eb8ebd | ||
|
|
5f0d39e5f1 | ||
|
|
8c82fc5495 | ||
|
|
6d81a15e97 | ||
|
|
5478e4234c | ||
|
|
4056278fef | ||
|
|
ee6530fe00 | ||
|
|
7c1decbcc3 | ||
|
|
8a3c724b31 | ||
|
|
15d4e9dbf5 | ||
|
|
174dee0fe6 | ||
|
|
f7bfd38b28 | ||
|
|
187e5da61e | ||
|
|
175ed58d2e | ||
|
|
820ee3a843 | ||
|
|
462f2e9494 | ||
|
|
c4968a641e | ||
|
|
c6ece177cd | ||
|
|
a3e6a5622d | ||
|
|
e8d11fdfa6 | ||
|
|
72393dc369 | ||
|
|
556b0a1da5 | ||
|
|
844167ba06 | ||
|
|
6fa3acb1ca | ||
|
|
32c268a21e | ||
|
|
ed34c2b929 | ||
|
|
06e827573c | ||
|
|
cdb71a54f0 | ||
|
|
74e76d4cda | ||
|
|
db5c69ca76 | ||
|
|
9fd063266b | ||
|
|
05aa9d7cca | ||
|
|
dcececd118 | ||
|
|
eaf39bb15b | ||
|
|
6515481624 | ||
|
|
6a7e3b6d77 | ||
|
|
02804fecce | ||
|
|
324a8cd4cf | ||
|
|
ce5cd5561a | ||
|
|
adeefce9aa | ||
|
|
5ab43fd12c | ||
|
|
5894e47189 | ||
|
|
ca61d81f4a | ||
|
|
b12d0ca7b1 | ||
|
|
21996af626 | ||
|
|
cc3b174e5a | ||
|
|
faee58fb1e | ||
|
|
d439e48b39 | ||
|
|
3f0f155d64 | ||
|
|
d82d512319 | ||
|
|
76aea1716f | ||
|
|
586649b73f | ||
|
|
0349a79cb3 | ||
|
|
78a255bdd7 | ||
|
|
5b30e71aa1 | ||
|
|
99d84aece9 | ||
|
|
525d8eb66d | ||
|
|
4c810108e0 | ||
|
|
fc03cdc76a | ||
|
|
9779a563f3 | ||
|
|
6141c3c348 | ||
|
|
c3726ddfc9 | ||
|
|
10eaa8143e | ||
|
|
0c4f4e1f0c | ||
|
|
b225c3cd80 | ||
|
|
b558645d6b | ||
|
|
03b0889b15 | ||
|
|
943fe3651c | ||
|
|
65e57be4dd | ||
|
|
13ad3b5dce | ||
|
|
918bbf0369 | ||
|
|
5006271abb | ||
|
|
a6625ec5de | ||
|
|
1a2104f474 | ||
|
|
444abb8283 | ||
|
|
ee86537f21 | ||
|
|
17a736a927 | ||
|
|
6b5779054d | ||
|
|
14296632ef | ||
|
|
2a3f0e455a | ||
|
|
8aa44c415b | ||
|
|
0566c41a32 | ||
|
|
876b04c058 | ||
|
|
b49a5934e2 | ||
|
|
5fb063914e | ||
|
|
b9941e29a9 | ||
|
|
8ef321d784 | ||
|
|
8353f9c649 | ||
|
|
cb6b3aa406 | ||
|
|
36c7bd9206 | ||
|
|
fea94379d7 | ||
|
|
e602d941ca | ||
|
|
f41f69a268 | ||
|
|
ff72251878 | ||
|
|
7751fb52dd | ||
|
|
87a44d101d | ||
|
|
80148f25b6 | ||
|
|
8e3e4a8b09 | ||
|
|
9389b4a1e8 | ||
|
|
4245e5bd2e | ||
|
|
e7d2af2405 | ||
|
|
4c32a96370 | ||
|
|
f61d112cea | ||
|
|
2c55c6cd9a | ||
|
|
f1d714b5c1 | ||
|
|
69d9dc672a | ||
|
|
9192e010e8 | ||
|
|
f24cea0877 | ||
|
|
a29bfa7489 | ||
|
|
2246866a09 | ||
|
|
7b17fde34a | ||
|
|
df57053613 | ||
|
|
5662be12b5 | ||
|
|
d3e9d66b07 | ||
|
|
e0bdbcbe38 | ||
|
|
05c835ed02 | ||
|
|
9e7f1ad1c0 | ||
|
|
f910a82683 | ||
|
|
d8b7e86f8d | ||
|
|
aef3e0b4bb | ||
|
|
b0eee7be24 | ||
|
|
197e94302b | ||
|
|
98e949d2fd | ||
|
|
83e7a928f1 | ||
|
|
ccd29b7d4e | ||
|
|
5b6cfa6ecc | ||
|
|
f91846ce2d | ||
|
|
87e24ab96e | ||
|
|
40c3e5568c | ||
|
|
7958d29e13 | ||
|
|
a6fafa6a4d | ||
|
|
3ad38f53fd | ||
|
|
d90b1c57e5 | ||
|
|
a69a0e100f | ||
|
|
b0d4576a95 | ||
|
|
2a4ab3aca1 | ||
|
|
e0fd11a86e | ||
|
|
de369f8b5e | ||
|
|
af3e16c4fc | ||
|
|
aacf281222 | ||
|
|
6d8f083c6f | ||
|
|
909bc421c0 | ||
|
|
d14f04d79c | ||
|
|
e0a9f08632 | ||
|
|
09e7c1b97f | ||
|
|
4adffe762a | ||
|
|
9a937d2686 | ||
|
|
2c2bdd37d5 | ||
|
|
6a00319c2d | ||
|
|
66870279d3 |
@@ -1,9 +1,42 @@
|
|||||||
API_KEY=<LLM api key (for example, open ai key)>
|
API_KEY=<LLM api key (for example, open ai key)>
|
||||||
LLM_NAME=docsgpt
|
LLM_NAME=docsgpt
|
||||||
VITE_API_STREAMING=true
|
VITE_API_STREAMING=true
|
||||||
|
INTERNAL_KEY=<internal key for worker-to-backend authentication>
|
||||||
|
|
||||||
|
# Provider-specific API keys (optional - use these to enable multiple providers)
|
||||||
|
# OPENAI_API_KEY=<your-openai-api-key>
|
||||||
|
# ANTHROPIC_API_KEY=<your-anthropic-api-key>
|
||||||
|
# GOOGLE_API_KEY=<your-google-api-key>
|
||||||
|
# GROQ_API_KEY=<your-groq-api-key>
|
||||||
|
# NOVITA_API_KEY=<your-novita-api-key>
|
||||||
|
# OPEN_ROUTER_API_KEY=<your-openrouter-api-key>
|
||||||
|
|
||||||
|
# Remote Embeddings (Optional - for using a remote embeddings API instead of local SentenceTransformer)
|
||||||
|
# When set, the app will use the remote API and won't load SentenceTransformer (saves RAM)
|
||||||
|
EMBEDDINGS_BASE_URL=
|
||||||
|
EMBEDDINGS_KEY=
|
||||||
|
|
||||||
#For Azure (you can delete it if you don't use Azure)
|
#For Azure (you can delete it if you don't use Azure)
|
||||||
OPENAI_API_BASE=
|
OPENAI_API_BASE=
|
||||||
OPENAI_API_VERSION=
|
OPENAI_API_VERSION=
|
||||||
AZURE_DEPLOYMENT_NAME=
|
AZURE_DEPLOYMENT_NAME=
|
||||||
AZURE_EMBEDDINGS_DEPLOYMENT_NAME=
|
AZURE_EMBEDDINGS_DEPLOYMENT_NAME=
|
||||||
|
|
||||||
|
#Azure AD Application (client) ID
|
||||||
|
MICROSOFT_CLIENT_ID=your-azure-ad-client-id
|
||||||
|
#Azure AD Application client secret
|
||||||
|
MICROSOFT_CLIENT_SECRET=your-azure-ad-client-secret
|
||||||
|
#Azure AD Tenant ID (or 'common' for multi-tenant)
|
||||||
|
MICROSOFT_TENANT_ID=your-azure-ad-tenant-id
|
||||||
|
#If you are using a Microsoft Entra ID tenant,
|
||||||
|
#configure the AUTHORITY variable as
|
||||||
|
#"https://login.microsoftonline.com/TENANT_GUID"
|
||||||
|
#or "https://login.microsoftonline.com/contoso.onmicrosoft.com".
|
||||||
|
#Alternatively, use "https://login.microsoftonline.com/common" for multi-tenant app.
|
||||||
|
MICROSOFT_AUTHORITY=https://{tenantId}.ciamlogin.com/{tenantId}
|
||||||
|
|
||||||
|
# User-data Postgres DB (Phase 0 of the MongoDB→Postgres migration).
|
||||||
|
# Standard Postgres URI — `postgres://` and `postgresql://` both work.
|
||||||
|
# Leave unset while the migration is still being rolled out; the app will
|
||||||
|
# fall back to MongoDB for user data until POSTGRES_URI is configured.
|
||||||
|
# POSTGRES_URI=postgresql://docsgpt:docsgpt@localhost:5432/docsgpt
|
||||||
|
|||||||
99
.github/INCIDENT_RESPONSE.md
vendored
Normal file
99
.github/INCIDENT_RESPONSE.md
vendored
Normal file
@@ -0,0 +1,99 @@
|
|||||||
|
# DocsGPT Incident Response Plan (IRP)
|
||||||
|
|
||||||
|
This playbook describes how maintainers respond to confirmed or suspected security incidents.
|
||||||
|
|
||||||
|
- Vulnerability reporting: [`SECURITY.md`](../SECURITY.md)
|
||||||
|
- Non-security bugs/features: [`CONTRIBUTING.md`](../CONTRIBUTING.md)
|
||||||
|
|
||||||
|
## Severity
|
||||||
|
|
||||||
|
| Severity | Definition | Typical examples |
|
||||||
|
|---|---|---|
|
||||||
|
| **Critical** | Active exploitation, supply-chain compromise, or confirmed data breach requiring immediate user action. | Compromised release artifact/image; remote execution. |
|
||||||
|
| **High** | Serious undisclosed vulnerability with no practical workaround, or CVSS >= 7.0. | key leakage; prompt injection enabling cross-tenant access. |
|
||||||
|
| **Medium** | Material impact but constrained by preconditions/scope, or a practical workaround exists. | Auth-required exploit; dependency CVE with limited reachability. |
|
||||||
|
| **Low** | Defense-in-depth or narrow availability impact with no confirmed data exposure. | Missing rate limiting; hardening gap without exploit evidence. |
|
||||||
|
|
||||||
|
|
||||||
|
## Response workflow
|
||||||
|
|
||||||
|
### 1) Triage (target: initial response within 48 hours)
|
||||||
|
|
||||||
|
1. Acknowledge report.
|
||||||
|
2. Validate on latest release and `main`.
|
||||||
|
3. Confirm in-scope security issue vs. hardening item (per `SECURITY.md`).
|
||||||
|
4. Assign severity and open a **draft GitHub Security Advisory (GHSA)** (no public issue).
|
||||||
|
5. Determine whether root cause is DocsGPT code or upstream dependency/provider.
|
||||||
|
|
||||||
|
### 2) Investigation
|
||||||
|
|
||||||
|
1. Identify affected components, versions, and deployment scope (self-hosted, cloud, or both).
|
||||||
|
2. For AI issues, explicitly evaluate prompt injection, document isolation, and output leakage.
|
||||||
|
3. Request a CVE through GHSA for **Medium+** issues.
|
||||||
|
|
||||||
|
### 3) Containment, fix, and disclosure
|
||||||
|
|
||||||
|
1. Implement and test fix in private security workflow (GHSA private fork/branch).
|
||||||
|
2. Merge fix to `main`, cut patched release, and verify published artifacts/images.
|
||||||
|
3. Patch managed cloud deployment (`app.docsgpt.cloud`) and other deployments as soon as validated.
|
||||||
|
4. Publish GHSA with CVE (if assigned), affected/fixed versions, CVSS, mitigations, and upgrade guidance.
|
||||||
|
5. **Critical/High:** coordinate disclosure timing with reporter (goal: <= 90 days) and publish a notice.
|
||||||
|
6. **Medium/Low:** include in next scheduled release unless risk requires immediate out-of-band patching.
|
||||||
|
|
||||||
|
### 4) Post-incident
|
||||||
|
|
||||||
|
1. Monitor support channels (GitHub/Discord) for regressions or exploitation reports.
|
||||||
|
2. Run a short retrospective (root cause, detection, response gaps, prevention work).
|
||||||
|
3. Track follow-up hardening actions with owners/dates.
|
||||||
|
4. Update this IRP and related runbooks as needed.
|
||||||
|
|
||||||
|
## Scenario playbooks
|
||||||
|
|
||||||
|
### Supply-chain compromise
|
||||||
|
|
||||||
|
1. Freeze releases and investigate blast radius.
|
||||||
|
2. Rotate credentials in order: Docker Hub -> GitHub tokens -> LLM provider keys -> DB credentials -> `JWT_SECRET_KEY` -> `ENCRYPTION_SECRET_KEY` -> `INTERNAL_KEY`.
|
||||||
|
3. Replace compromised artifacts/tags with clean releases and revoke/remove bad tags where possible.
|
||||||
|
4. Publish advisory with exact affected versions and required user actions.
|
||||||
|
|
||||||
|
### Data exposure
|
||||||
|
|
||||||
|
1. Determine scope (users, documents, keys, logs, time window).
|
||||||
|
2. Disable affected path or hotfix immediately for managed cloud.
|
||||||
|
3. Notify affected users with concrete remediation steps (for example, rotate keys).
|
||||||
|
4. Continue through standard fix/disclosure workflow.
|
||||||
|
|
||||||
|
### Critical regression with security impact
|
||||||
|
|
||||||
|
1. Identify introducing change (`git bisect` if needed).
|
||||||
|
2. Publish workaround within 24 hours (for example, pin to known-good version).
|
||||||
|
3. Ship patch release with regression test and close incident with public summary.
|
||||||
|
|
||||||
|
## AI-specific guidance
|
||||||
|
|
||||||
|
Treat confirmed AI-specific abuse as security incidents:
|
||||||
|
|
||||||
|
- Prompt injection causing sensitive data exfiltration (from tools that don't belong to the agent) -> **High**
|
||||||
|
- Cross-tenant retrieval/isolation failure -> **High**
|
||||||
|
- API key disclosure in output -> **High**
|
||||||
|
|
||||||
|
## Secret rotation quick reference
|
||||||
|
|
||||||
|
| Secret | Standard rotation action |
|
||||||
|
|---|---|
|
||||||
|
| Docker Hub credentials | Revoke/replace in Docker Hub; update CI/CD secrets |
|
||||||
|
| GitHub tokens/PATs | Revoke/replace in GitHub; update automation secrets |
|
||||||
|
| LLM provider API keys | Rotate in provider console; update runtime/deploy secrets |
|
||||||
|
| Database credentials | Rotate in DB platform; redeploy with new secrets |
|
||||||
|
| `JWT_SECRET_KEY` | Rotate and redeploy (invalidates all active user sessions/tokens) |
|
||||||
|
| `ENCRYPTION_SECRET_KEY` | Rotate and redeploy (re-encrypt stored data if possible; existing encrypted data may become inaccessible) |
|
||||||
|
| `INTERNAL_KEY` | Rotate and redeploy (invalidates worker-to-backend authentication) |
|
||||||
|
|
||||||
|
## Maintenance
|
||||||
|
|
||||||
|
Review this document:
|
||||||
|
|
||||||
|
- after every **Critical/High** incident, and
|
||||||
|
- at least annually.
|
||||||
|
|
||||||
|
Changes should be proposed via pull request to `main`.
|
||||||
144
.github/THREAT_MODEL.md
vendored
Normal file
144
.github/THREAT_MODEL.md
vendored
Normal file
@@ -0,0 +1,144 @@
|
|||||||
|
# DocsGPT Public Threat Model
|
||||||
|
|
||||||
|
**Classification:** Public
|
||||||
|
**Last updated:** 2026-04-15
|
||||||
|
**Applies to:** Open-source and self-hosted DocsGPT deployments
|
||||||
|
|
||||||
|
## 1) Overview
|
||||||
|
|
||||||
|
DocsGPT ingests content (files/URLs/connectors), indexes it, and answers queries via LLM-backed APIs and optional tools.
|
||||||
|
|
||||||
|
Core components:
|
||||||
|
- Backend API (`application/`)
|
||||||
|
- Workers/ingestion (`application/worker.py` and related modules)
|
||||||
|
- Datastores (MongoDB/Redis/vector stores)
|
||||||
|
- Frontend (`frontend/`)
|
||||||
|
- Optional extensions/integrations (`extensions/`)
|
||||||
|
|
||||||
|
## 2) Scope and assumptions
|
||||||
|
|
||||||
|
In scope:
|
||||||
|
- Application-level threats in this repository.
|
||||||
|
- Local and internet-exposed self-hosted deployments.
|
||||||
|
|
||||||
|
Assumptions:
|
||||||
|
- Internet-facing instances enable auth and use strong secrets.
|
||||||
|
- Datastores/internal services are not publicly exposed.
|
||||||
|
|
||||||
|
Out of scope:
|
||||||
|
- Cloud hardware/provider compromise.
|
||||||
|
- Security guarantees of external LLM vendors.
|
||||||
|
- Full security audits of third-party systems targeted by tools (external DBs/MCP servers/code-exec APIs).
|
||||||
|
|
||||||
|
## 3) Security objectives
|
||||||
|
|
||||||
|
- Protect document/conversation confidentiality.
|
||||||
|
- Preserve integrity of prompts, agents, tools, and indexed data.
|
||||||
|
- Maintain API/worker availability.
|
||||||
|
- Enforce tenant isolation in authenticated deployments.
|
||||||
|
|
||||||
|
## 4) Assets
|
||||||
|
|
||||||
|
- Documents, attachments, chunks/embeddings, summaries.
|
||||||
|
- Conversations, agents, workflows, prompt templates.
|
||||||
|
- Secrets (JWT secret, `INTERNAL_KEY`, provider/API/OAuth credentials).
|
||||||
|
- Operational capacity (worker throughput, queue depth, model quota/cost).
|
||||||
|
|
||||||
|
## 5) Trust boundaries and untrusted input
|
||||||
|
|
||||||
|
Trust boundaries:
|
||||||
|
- Internet ↔ Frontend
|
||||||
|
- Frontend ↔ Backend API
|
||||||
|
- Backend ↔ Workers/internal APIs
|
||||||
|
- Backend/workers ↔ Datastores
|
||||||
|
- Backend ↔ External LLM/connectors/remote URLs
|
||||||
|
|
||||||
|
Untrusted input includes API payloads, file uploads, remote URLs, OAuth/webhook data, retrieved content, and LLM/tool arguments.
|
||||||
|
|
||||||
|
## 6) Main attack surfaces
|
||||||
|
|
||||||
|
1. Auth/authz paths and sharing tokens.
|
||||||
|
2. File upload + parsing pipeline.
|
||||||
|
3. Remote URL fetching and connectors (SSRF risk).
|
||||||
|
4. Agent/tool execution from LLM output.
|
||||||
|
5. Template/workflow rendering.
|
||||||
|
6. Frontend rendering + token storage.
|
||||||
|
7. Internal service endpoints (`INTERNAL_KEY`).
|
||||||
|
8. High-impact integrations (SQL tool, generic API tool, remote MCP tools).
|
||||||
|
|
||||||
|
## 7) Key threats and expected mitigations
|
||||||
|
|
||||||
|
### A. Auth/authz misconfiguration
|
||||||
|
- Threat: weak/no auth or leaked tokens leads to broad data access.
|
||||||
|
- Mitigations: require auth for public deployments, short-lived tokens, rotation/revocation, least-privilege sharing.
|
||||||
|
|
||||||
|
### B. Untrusted file ingestion
|
||||||
|
- Threat: malicious files/archives trigger traversal, parser exploits, or resource exhaustion.
|
||||||
|
- Mitigations: strict path checks, archive safeguards, file limits, patched parser dependencies.
|
||||||
|
|
||||||
|
### C. SSRF/outbound abuse
|
||||||
|
- Threat: URL loaders/tools access private/internal/metadata endpoints.
|
||||||
|
- Mitigations: validate URLs + redirects, block private/link-local ranges, apply egress controls/allowlists.
|
||||||
|
|
||||||
|
### D. Prompt injection + tool abuse
|
||||||
|
- Threat: retrieved text manipulates model behavior and causes unsafe tool calls.
|
||||||
|
- Threat: never rely on the model to "choose correctly" under adversarial input.
|
||||||
|
- Mitigations: treat retrieved/model output as untrusted, enforce tool policies, only expose tools explicitly assigned by the user/admin to that agent, separate system instructions from retrieved content, audit tool calls.
|
||||||
|
|
||||||
|
### E. Dangerous tool capability chaining (SQL/API/MCP)
|
||||||
|
- Threat: write-capable SQL credentials allow destructive queries.
|
||||||
|
- Threat: API tool can trigger side effects (infra/payment/webhook/code-exec endpoints).
|
||||||
|
- Threat: remote MCP tools may expose privileged operations.
|
||||||
|
- Mitigations: read-only-by-default credentials, destination allowlists, explicit approval for write/exec actions, per-tool policy enforcement + logging.
|
||||||
|
|
||||||
|
### F. Frontend/XSS + token theft
|
||||||
|
- Threat: XSS can steal local tokens and call APIs.
|
||||||
|
- Mitigations: reduce unsafe rendering paths, strong CSP, scoped short-lived credentials.
|
||||||
|
|
||||||
|
### G. Internal endpoint exposure
|
||||||
|
- Threat: weak/unset `INTERNAL_KEY` enables internal API abuse.
|
||||||
|
- Mitigations: fail closed, require strong random keys, keep internal APIs private.
|
||||||
|
|
||||||
|
### H. DoS and cost abuse
|
||||||
|
- Threat: request floods, large ingestion jobs, expensive prompts/crawls.
|
||||||
|
- Mitigations: rate limits, quotas, timeouts, queue backpressure, usage budgets.
|
||||||
|
|
||||||
|
## 8) Example attacker stories
|
||||||
|
|
||||||
|
- Internet-exposed deployment runs with weak/no auth and receives unauthorized data access/abuse.
|
||||||
|
- Intranet deployment intentionally using weak/no auth is vulnerable to insider misuse and lateral-movement abuse.
|
||||||
|
- Crafted archive attempts path traversal during extraction.
|
||||||
|
- Malicious URL/redirect chain targets internal services.
|
||||||
|
- Poisoned document causes data exfiltration through tool calls.
|
||||||
|
- Over-privileged SQL/API/MCP tool performs destructive side effects.
|
||||||
|
|
||||||
|
## 9) Severity calibration
|
||||||
|
|
||||||
|
- **Critical:** unauthenticated public data access; prompt-injection-driven exfiltration; SSRF to sensitive internal endpoints.
|
||||||
|
- **High:** cross-tenant leakage, persistent token compromise, over-privileged destructive tools.
|
||||||
|
- **Medium:** DoS/cost amplification and non-critical information disclosure.
|
||||||
|
- **Low:** minor hardening gaps with limited impact.
|
||||||
|
|
||||||
|
## 10) Baseline controls for public deployments
|
||||||
|
|
||||||
|
1. Enforce authentication and secure defaults.
|
||||||
|
2. Set/rotate strong secrets (`JWT`, `INTERNAL_KEY`, encryption keys).
|
||||||
|
3. Restrict CORS and front API with a hardened proxy.
|
||||||
|
4. Add rate limiting/quotas for answer/upload/crawl/token endpoints.
|
||||||
|
5. Enforce URL+redirect SSRF protections and egress restrictions.
|
||||||
|
6. Apply upload/archive/parsing hardening.
|
||||||
|
7. Require least-privilege tool credentials and auditable tool execution.
|
||||||
|
8. Monitor auth failures, tool anomalies, ingestion spikes, and cost anomalies.
|
||||||
|
9. Keep dependencies/images patched and scanned.
|
||||||
|
10. Validate multi-tenant isolation with explicit tests.
|
||||||
|
|
||||||
|
## 11) Maintenance
|
||||||
|
|
||||||
|
Review this model after major auth, ingestion, connector, tool, or workflow changes.
|
||||||
|
|
||||||
|
## References
|
||||||
|
|
||||||
|
- [OWASP Top 10 for LLM Applications](https://owasp.org/www-project-top-10-for-large-language-model-applications/)
|
||||||
|
- [OWASP ASVS](https://owasp.org/www-project-application-security-verification-standard/)
|
||||||
|
- [STRIDE overview](https://learn.microsoft.com/azure/security/develop/threat-modeling-tool-threats)
|
||||||
|
- [DocsGPT SECURITY.md](../SECURITY.md)
|
||||||
@@ -1,46 +1,80 @@
|
|||||||
Ollama
|
Agentic
|
||||||
Qdrant
|
Anthropic's
|
||||||
Milvus
|
api
|
||||||
Chatwoot
|
|
||||||
Nextra
|
|
||||||
VSCode
|
|
||||||
npm
|
|
||||||
LLMs
|
|
||||||
APIs
|
APIs
|
||||||
Groq
|
Atlassian
|
||||||
SGLang
|
automations
|
||||||
LMDeploy
|
autoescaping
|
||||||
OAuth
|
Autoescaping
|
||||||
Vite
|
backfill
|
||||||
LLM
|
backfills
|
||||||
JSONPath
|
bool
|
||||||
UIs
|
boolean
|
||||||
|
brave_web_search
|
||||||
|
chatbot
|
||||||
|
Chatwoot
|
||||||
|
config
|
||||||
configs
|
configs
|
||||||
uncomment
|
CSVs
|
||||||
qdrant
|
dev
|
||||||
vectorstore
|
diarization
|
||||||
|
Docling
|
||||||
docsgpt
|
docsgpt
|
||||||
llm
|
docstrings
|
||||||
|
Entra
|
||||||
|
env
|
||||||
|
enqueues
|
||||||
|
EOL
|
||||||
|
ESLint
|
||||||
|
feedbacks
|
||||||
|
Figma
|
||||||
GPUs
|
GPUs
|
||||||
|
Groq
|
||||||
|
hardcode
|
||||||
|
hardcoding
|
||||||
|
Idempotency
|
||||||
|
JSONPath
|
||||||
kubectl
|
kubectl
|
||||||
Lightsail
|
Lightsail
|
||||||
enqueues
|
llama_cpp
|
||||||
chatbot
|
llm
|
||||||
VSCode's
|
LLM
|
||||||
Shareability
|
LLMs
|
||||||
feedbacks
|
LMDeploy
|
||||||
automations
|
Milvus
|
||||||
|
Mixtral
|
||||||
|
namespace
|
||||||
|
namespaces
|
||||||
|
needs_auth
|
||||||
|
Nextra
|
||||||
|
Novita
|
||||||
|
npm
|
||||||
|
OAuth
|
||||||
|
Ollama
|
||||||
|
opencode
|
||||||
|
parsable
|
||||||
|
passthrough
|
||||||
|
PDFs
|
||||||
|
pgvector
|
||||||
|
Postgres
|
||||||
Premade
|
Premade
|
||||||
Signup
|
Pydantic
|
||||||
|
pytest
|
||||||
|
Qdrant
|
||||||
|
qdrant
|
||||||
Repo
|
Repo
|
||||||
repo
|
repo
|
||||||
env
|
Sanitization
|
||||||
URl
|
|
||||||
agentic
|
|
||||||
llama_cpp
|
|
||||||
parsable
|
|
||||||
SDKs
|
SDKs
|
||||||
boolean
|
SGLang
|
||||||
bool
|
Shareability
|
||||||
hardcode
|
Signup
|
||||||
EOL
|
Supabase
|
||||||
|
UIs
|
||||||
|
uncomment
|
||||||
|
URl
|
||||||
|
vectorstore
|
||||||
|
Vite
|
||||||
|
VSCode
|
||||||
|
VSCode's
|
||||||
|
widget's
|
||||||
|
|||||||
2
.github/workflows/bandit.yaml
vendored
2
.github/workflows/bandit.yaml
vendored
@@ -21,7 +21,7 @@ jobs:
|
|||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: Set up Python
|
- name: Set up Python
|
||||||
uses: actions/setup-python@v5
|
uses: actions/setup-python@v6
|
||||||
with:
|
with:
|
||||||
python-version: '3.12'
|
python-version: '3.12'
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
|
|||||||
3
.github/workflows/lint.yml
vendored
3
.github/workflows/lint.yml
vendored
@@ -7,6 +7,9 @@ on:
|
|||||||
pull_request:
|
pull_request:
|
||||||
types: [ opened, synchronize ]
|
types: [ opened, synchronize ]
|
||||||
|
|
||||||
|
permissions:
|
||||||
|
contents: read
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
ruff:
|
ruff:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
|
|||||||
114
.github/workflows/npm-publish.yml
vendored
Normal file
114
.github/workflows/npm-publish.yml
vendored
Normal file
@@ -0,0 +1,114 @@
|
|||||||
|
name: Publish npm libraries
|
||||||
|
|
||||||
|
on:
|
||||||
|
workflow_dispatch:
|
||||||
|
inputs:
|
||||||
|
version:
|
||||||
|
description: >
|
||||||
|
Version bump type (patch | minor | major) or explicit semver (e.g. 1.2.3).
|
||||||
|
Applies to both docsgpt and docsgpt-react.
|
||||||
|
required: true
|
||||||
|
default: patch
|
||||||
|
|
||||||
|
permissions:
|
||||||
|
contents: write
|
||||||
|
pull-requests: write
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
publish:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
environment: npm-release
|
||||||
|
defaults:
|
||||||
|
run:
|
||||||
|
working-directory: extensions/react-widget
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- uses: actions/setup-node@v4
|
||||||
|
with:
|
||||||
|
node-version: 20
|
||||||
|
registry-url: https://registry.npmjs.org
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
run: npm ci
|
||||||
|
|
||||||
|
# ── docsgpt (HTML embedding bundle) ──────────────────────────────────
|
||||||
|
# Uses the `build` script (parcel build src/browser.tsx) and keeps
|
||||||
|
# the `targets` field so Parcel produces browser-optimised bundles.
|
||||||
|
|
||||||
|
- name: Set package name → docsgpt
|
||||||
|
run: jq --arg n "docsgpt" '.name=$n' package.json > _tmp.json && mv _tmp.json package.json
|
||||||
|
|
||||||
|
- name: Bump version (docsgpt)
|
||||||
|
id: version_docsgpt
|
||||||
|
run: |
|
||||||
|
VERSION="${{ github.event.inputs.version }}"
|
||||||
|
NEW_VER=$(npm version "${VERSION:-patch}" --no-git-tag-version)
|
||||||
|
echo "version=${NEW_VER#v}" >> "$GITHUB_OUTPUT"
|
||||||
|
|
||||||
|
- name: Build docsgpt
|
||||||
|
run: npm run build
|
||||||
|
|
||||||
|
- name: Publish docsgpt
|
||||||
|
run: npm publish --verbose
|
||||||
|
env:
|
||||||
|
NODE_AUTH_TOKEN: ${{ secrets.NPM_TOKEN }}
|
||||||
|
|
||||||
|
# ── docsgpt-react (React library bundle) ─────────────────────────────
|
||||||
|
# Uses `build:react` script (parcel build src/index.ts) and strips
|
||||||
|
# the `targets` field so Parcel treats the output as a plain library
|
||||||
|
# without browser-specific target resolution, producing a smaller bundle.
|
||||||
|
|
||||||
|
- name: Reset package.json from source control
|
||||||
|
run: git checkout -- package.json
|
||||||
|
|
||||||
|
- name: Set package name → docsgpt-react
|
||||||
|
run: jq --arg n "docsgpt-react" '.name=$n' package.json > _tmp.json && mv _tmp.json package.json
|
||||||
|
|
||||||
|
- name: Remove targets field (react library build)
|
||||||
|
run: jq 'del(.targets)' package.json > _tmp.json && mv _tmp.json package.json
|
||||||
|
|
||||||
|
- name: Bump version (docsgpt-react) to match docsgpt
|
||||||
|
run: npm version "${{ steps.version_docsgpt.outputs.version }}" --no-git-tag-version
|
||||||
|
|
||||||
|
- name: Clean dist before react build
|
||||||
|
run: rm -rf dist
|
||||||
|
|
||||||
|
- name: Build docsgpt-react
|
||||||
|
run: npm run build:react
|
||||||
|
|
||||||
|
- name: Publish docsgpt-react
|
||||||
|
run: npm publish --verbose
|
||||||
|
env:
|
||||||
|
NODE_AUTH_TOKEN: ${{ secrets.NPM_TOKEN }}
|
||||||
|
|
||||||
|
# ── Commit the bumped version back to the repository ─────────────────
|
||||||
|
|
||||||
|
- name: Reset package.json and write final version
|
||||||
|
run: |
|
||||||
|
git checkout -- package.json
|
||||||
|
jq --arg v "${{ steps.version_docsgpt.outputs.version }}" '.version=$v' \
|
||||||
|
package.json > _tmp.json && mv _tmp.json package.json
|
||||||
|
npm install --package-lock-only
|
||||||
|
|
||||||
|
- name: Commit version bump and create PR
|
||||||
|
run: |
|
||||||
|
git config user.name "github-actions[bot]"
|
||||||
|
git config user.email "github-actions[bot]@users.noreply.github.com"
|
||||||
|
BRANCH="chore/bump-npm-v${{ steps.version_docsgpt.outputs.version }}"
|
||||||
|
git checkout -b "$BRANCH"
|
||||||
|
git add package.json package-lock.json
|
||||||
|
git commit -m "chore: bump npm libraries to v${{ steps.version_docsgpt.outputs.version }}"
|
||||||
|
git push origin "$BRANCH"
|
||||||
|
env:
|
||||||
|
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||||
|
|
||||||
|
- name: Create PR
|
||||||
|
run: |
|
||||||
|
gh pr create \
|
||||||
|
--title "chore: bump npm libraries to v${{ steps.version_docsgpt.outputs.version }}" \
|
||||||
|
--body "Automated version bump after npm publish." \
|
||||||
|
--base main
|
||||||
|
env:
|
||||||
|
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||||
6
.github/workflows/pytest.yml
vendored
6
.github/workflows/pytest.yml
vendored
@@ -1,5 +1,9 @@
|
|||||||
name: Run python tests with pytest
|
name: Run python tests with pytest
|
||||||
on: [push, pull_request]
|
on: [push, pull_request]
|
||||||
|
|
||||||
|
permissions:
|
||||||
|
contents: read
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
pytest_and_coverage:
|
pytest_and_coverage:
|
||||||
name: Run tests and count coverage
|
name: Run tests and count coverage
|
||||||
@@ -10,7 +14,7 @@ jobs:
|
|||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
- name: Set up Python ${{ matrix.python-version }}
|
- name: Set up Python ${{ matrix.python-version }}
|
||||||
uses: actions/setup-python@v5
|
uses: actions/setup-python@v6
|
||||||
with:
|
with:
|
||||||
python-version: ${{ matrix.python-version }}
|
python-version: ${{ matrix.python-version }}
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
|
|||||||
34
.github/workflows/react-widget-build.yml
vendored
Normal file
34
.github/workflows/react-widget-build.yml
vendored
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
name: React Widget Build
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
paths:
|
||||||
|
- 'extensions/react-widget/**'
|
||||||
|
pull_request:
|
||||||
|
paths:
|
||||||
|
- 'extensions/react-widget/**'
|
||||||
|
|
||||||
|
permissions:
|
||||||
|
contents: read
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
build:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
defaults:
|
||||||
|
run:
|
||||||
|
working-directory: extensions/react-widget
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- uses: actions/setup-node@v4
|
||||||
|
with:
|
||||||
|
node-version: 20
|
||||||
|
cache: npm
|
||||||
|
cache-dependency-path: extensions/react-widget/package-lock.json
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
run: npm ci
|
||||||
|
|
||||||
|
- name: Build
|
||||||
|
run: npm run build
|
||||||
24
.github/workflows/vale.yml
vendored
24
.github/workflows/vale.yml
vendored
@@ -9,6 +9,9 @@ on:
|
|||||||
- '.vale.ini'
|
- '.vale.ini'
|
||||||
- '.github/styles/**'
|
- '.github/styles/**'
|
||||||
|
|
||||||
|
permissions:
|
||||||
|
contents: read
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
vale:
|
vale:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
@@ -16,11 +19,16 @@ jobs:
|
|||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: Vale linter
|
- name: Install Vale
|
||||||
uses: errata-ai/vale-action@v2
|
run: |
|
||||||
with:
|
curl -fsSL -o vale.tar.gz \
|
||||||
files: docs
|
https://github.com/errata-ai/vale/releases/download/v3.0.5/vale_3.0.5_Linux_64-bit.tar.gz
|
||||||
fail_on_error: false
|
tar -xzf vale.tar.gz
|
||||||
version: 3.0.5
|
sudo mv vale /usr/local/bin/vale
|
||||||
env:
|
vale --version
|
||||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
|
||||||
|
- name: Sync Vale packages
|
||||||
|
run: vale sync
|
||||||
|
|
||||||
|
- name: Run Vale
|
||||||
|
run: vale --minAlertLevel=error docs
|
||||||
|
|||||||
25
.github/workflows/zizmor.yml
vendored
Normal file
25
.github/workflows/zizmor.yml
vendored
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
name: GitHub Actions Security Analysis
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches: ["master"]
|
||||||
|
pull_request:
|
||||||
|
branches: ["**"]
|
||||||
|
|
||||||
|
permissions: {}
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
zizmor:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
|
permissions:
|
||||||
|
security-events: write # Required for upload-sarif (used by zizmor-action) to upload SARIF files.
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Checkout repository
|
||||||
|
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||||
|
with:
|
||||||
|
persist-credentials: false
|
||||||
|
|
||||||
|
- name: Run zizmor 🌈
|
||||||
|
uses: zizmorcore/zizmor-action@71321a20a9ded102f6e9ce5718a2fcec2c4f70d8 # v0.5.2
|
||||||
17
.gitignore
vendored
17
.gitignore
vendored
@@ -2,6 +2,7 @@
|
|||||||
__pycache__/
|
__pycache__/
|
||||||
*.py[cod]
|
*.py[cod]
|
||||||
*$py.class
|
*$py.class
|
||||||
|
results.txt
|
||||||
experiments/
|
experiments/
|
||||||
|
|
||||||
experiments
|
experiments
|
||||||
@@ -71,6 +72,7 @@ instance/
|
|||||||
|
|
||||||
# Sphinx documentation
|
# Sphinx documentation
|
||||||
docs/_build/
|
docs/_build/
|
||||||
|
docs/public/_pagefind/
|
||||||
|
|
||||||
# PyBuilder
|
# PyBuilder
|
||||||
target/
|
target/
|
||||||
@@ -106,6 +108,8 @@ celerybeat.pid
|
|||||||
# Environments
|
# Environments
|
||||||
.env
|
.env
|
||||||
.venv
|
.venv
|
||||||
|
# Machine-specific Claude Code guidance (see CLAUDE.md preamble)
|
||||||
|
CLAUDE.md
|
||||||
env/
|
env/
|
||||||
venv/
|
venv/
|
||||||
ENV/
|
ENV/
|
||||||
@@ -147,6 +151,10 @@ frontend/yarn-error.log*
|
|||||||
frontend/pnpm-debug.log*
|
frontend/pnpm-debug.log*
|
||||||
frontend/lerna-debug.log*
|
frontend/lerna-debug.log*
|
||||||
|
|
||||||
|
# Keep frontend utility helpers tracked (overrides global lib/ ignore)
|
||||||
|
!frontend/src/lib/
|
||||||
|
!frontend/src/lib/**
|
||||||
|
|
||||||
frontend/node_modules
|
frontend/node_modules
|
||||||
frontend/dist
|
frontend/dist
|
||||||
frontend/dist-ssr
|
frontend/dist-ssr
|
||||||
@@ -175,5 +183,14 @@ application/vectors/
|
|||||||
|
|
||||||
node_modules/
|
node_modules/
|
||||||
.vscode/settings.json
|
.vscode/settings.json
|
||||||
|
.vscode/sftp.json
|
||||||
/models/
|
/models/
|
||||||
model/
|
model/
|
||||||
|
|
||||||
|
# E2E test artifacts
|
||||||
|
.e2e-tmp/
|
||||||
|
/tmp/docsgpt-e2e/
|
||||||
|
tests/e2e/node_modules/
|
||||||
|
tests/e2e/playwright-report/
|
||||||
|
tests/e2e/test-results/
|
||||||
|
tests/e2e/.e2e-last-run.json
|
||||||
|
|||||||
@@ -1,2 +1,6 @@
|
|||||||
# Allow lines to be as long as 120 characters.
|
# Allow lines to be as long as 120 characters.
|
||||||
line-length = 120
|
line-length = 120
|
||||||
|
|
||||||
|
[lint.per-file-ignores]
|
||||||
|
# Integration tests use sys.path.insert() before imports for standalone execution
|
||||||
|
"tests/integration/*.py" = ["E402"]
|
||||||
@@ -1,5 +1,7 @@
|
|||||||
MinAlertLevel = warning
|
MinAlertLevel = warning
|
||||||
StylesPath = .github/styles
|
StylesPath = .github/styles
|
||||||
|
Vocab = DocsGPT
|
||||||
|
|
||||||
[*.{md,mdx}]
|
[*.{md,mdx}]
|
||||||
BasedOnStyles = DocsGPT
|
BasedOnStyles = DocsGPT
|
||||||
|
|
||||||
|
|||||||
140
AGENTS.md
Normal file
140
AGENTS.md
Normal file
@@ -0,0 +1,140 @@
|
|||||||
|
# AGENTS.md
|
||||||
|
|
||||||
|
- Read `CONTRIBUTING.md` before making non-trivial changes.
|
||||||
|
- For day-to-day development and feature work, follow the development-environment workflow rather than defaulting to `setup.sh` / `setup.ps1`.
|
||||||
|
- Avoid using the setup scripts during normal feature work unless the user explicitly asks for them. Users configure `.env` usually.
|
||||||
|
- Try to follow red/green TDD
|
||||||
|
|
||||||
|
### Check existing dev prerequisites first
|
||||||
|
|
||||||
|
For feature work, do **not** assume the environment needs to be recreated.
|
||||||
|
|
||||||
|
- Check whether the user already has a Python virtual environment such as `venv/` or `.venv/`.
|
||||||
|
- Check whether Postgres is already running and reachable via `POSTGRES_URI` (the canonical user-data store).
|
||||||
|
- Check whether Redis is already running.
|
||||||
|
- Reuse what is already working. Do not stop or recreate Postgres, Redis, or the Python environment unless the task is environment setup or troubleshooting.
|
||||||
|
|
||||||
|
> MongoDB is **not** required for the default install. It is only needed if
|
||||||
|
> the user opts into the Mongo vector-store backend (`VECTOR_STORE=mongodb`)
|
||||||
|
> or is running the one-shot `scripts/db/backfill.py` to migrate existing
|
||||||
|
> user data from the legacy Mongo-based install. In those cases, `pymongo`
|
||||||
|
> is available as an optional extra, not a core dependency.
|
||||||
|
|
||||||
|
## Normal local development commands
|
||||||
|
|
||||||
|
Use these commands once the dev prerequisites above are satisfied.
|
||||||
|
|
||||||
|
### Backend
|
||||||
|
|
||||||
|
```bash
|
||||||
|
source .venv/bin/activate # macOS/Linux
|
||||||
|
uv pip install -r application/requirements.txt # or: pip install -r application/requirements.txt
|
||||||
|
```
|
||||||
|
|
||||||
|
Run the Flask API (if needed):
|
||||||
|
|
||||||
|
```bash
|
||||||
|
flask --app application/app.py run --host=0.0.0.0 --port=7091
|
||||||
|
```
|
||||||
|
|
||||||
|
Run the Celery worker in a separate terminal (if needed):
|
||||||
|
|
||||||
|
```bash
|
||||||
|
celery -A application.app.celery worker -l INFO
|
||||||
|
```
|
||||||
|
|
||||||
|
On macOS, prefer the solo pool for Celery:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python -m celery -A application.app.celery worker -l INFO --pool=solo
|
||||||
|
```
|
||||||
|
|
||||||
|
### Frontend
|
||||||
|
|
||||||
|
Install dependencies only when needed, then run the dev server:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd frontend
|
||||||
|
npm install --include=dev
|
||||||
|
npm run dev
|
||||||
|
```
|
||||||
|
|
||||||
|
### Docs site
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd docs
|
||||||
|
npm install
|
||||||
|
```
|
||||||
|
|
||||||
|
### Python / backend changes validation
|
||||||
|
|
||||||
|
```bash
|
||||||
|
ruff check .
|
||||||
|
python -m pytest
|
||||||
|
```
|
||||||
|
|
||||||
|
### Frontend changes
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd frontend && npm run lint
|
||||||
|
cd frontend && npm run build
|
||||||
|
```
|
||||||
|
|
||||||
|
### Documentation changes
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd docs && npm run build
|
||||||
|
```
|
||||||
|
|
||||||
|
If Vale is installed locally and you edited prose, also run:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
vale .
|
||||||
|
```
|
||||||
|
|
||||||
|
## Repository map
|
||||||
|
|
||||||
|
- `application/`: Flask backend, API routes, agent logic, retrieval, parsing, security, storage, Celery worker, and WSGI entrypoints.
|
||||||
|
- `tests/`: backend unit/integration tests and test-only Python dependencies.
|
||||||
|
- `frontend/`: Vite + React + TypeScript application.
|
||||||
|
- `frontend/src/`: main UI code, including `components`, `conversation`, `hooks`, `locale`, `settings`, `upload`, and Redux store wiring in `store.ts`.
|
||||||
|
- `docs/`: separate documentation site built with Next.js/Nextra.
|
||||||
|
- `extensions/`: integrations and widgets such as Chatwoot, Chrome, Discord, React widget, Slack bot, and web widget.
|
||||||
|
- `deployment/`: Docker Compose variants and Kubernetes manifests.
|
||||||
|
|
||||||
|
## Coding rules
|
||||||
|
|
||||||
|
### Backend
|
||||||
|
|
||||||
|
- Follow PEP 8 and keep Python line length at or under 120 characters.
|
||||||
|
- Use type hints for function arguments and return values.
|
||||||
|
- Add Google-style docstrings to new or substantially changed functions and classes.
|
||||||
|
- Add or update tests under `tests/` for backend behavior changes.
|
||||||
|
- Keep changes narrow in `api`, `auth`, `security`, `parser`, `retriever`, and `storage` areas.
|
||||||
|
|
||||||
|
### Backend Abstractions
|
||||||
|
|
||||||
|
- LLM providers implement a common interface in `application/llm/` (add new providers by extending the base class).
|
||||||
|
- Vector stores are abstracted in `application/vectorstore/`.
|
||||||
|
- Parsers live in `application/parser/` and handle different document formats in the ingestion stage.
|
||||||
|
- Agents and tools are in `application/agents/` and `application/agents/tools/`.
|
||||||
|
- Celery setup/config lives in `application/celery_init.py` and `application/celeryconfig.py`.
|
||||||
|
- Settings and env vars are managed via Pydantic in `application/core/settings.py`.
|
||||||
|
|
||||||
|
### Frontend
|
||||||
|
|
||||||
|
- Follow the existing ESLint + Prettier setup.
|
||||||
|
- Prefer small, reusable functional components and hooks.
|
||||||
|
- If shared state must be added, use Redux rather than introducing a new global state library.
|
||||||
|
- Avoid broad UI refactors unless the task explicitly asks for them.
|
||||||
|
- Do not re-create components if we already have some in the app.
|
||||||
|
|
||||||
|
## PR readiness
|
||||||
|
|
||||||
|
Before opening a PR:
|
||||||
|
|
||||||
|
- run the relevant validation commands above
|
||||||
|
- confirm backend changes still work end-to-end after ingesting sample data when applicable
|
||||||
|
- clearly summarize user-visible behavior changes
|
||||||
|
- mention any config, dependency, or deployment implications
|
||||||
|
- Ask your user to attach a screenshot or a video to it
|
||||||
@@ -22,6 +22,11 @@ Thank you for choosing to contribute to DocsGPT! We are all very grateful!
|
|||||||
|
|
||||||
- We have a frontend built on React (Vite) and a backend in Python.
|
- We have a frontend built on React (Vite) and a backend in Python.
|
||||||
|
|
||||||
|
> **Required for every PR:** Please attach screenshots or a short screen
|
||||||
|
> recording that shows the working version of your changes. This makes the
|
||||||
|
> requirement visible to reviewers and helps them quickly verify what you are
|
||||||
|
> submitting.
|
||||||
|
|
||||||
|
|
||||||
Before creating issues, please check out how the latest version of our app looks and works by launching it via [Quickstart](https://github.com/arc53/DocsGPT#quickstart) the version on our live demo is slightly modified with login. Your issues should relate to the version you can launch via [Quickstart](https://github.com/arc53/DocsGPT#quickstart).
|
Before creating issues, please check out how the latest version of our app looks and works by launching it via [Quickstart](https://github.com/arc53/DocsGPT#quickstart) the version on our live demo is slightly modified with login. Your issues should relate to the version you can launch via [Quickstart](https://github.com/arc53/DocsGPT#quickstart).
|
||||||
|
|
||||||
@@ -125,7 +130,7 @@ Here's a step-by-step guide on how to contribute to DocsGPT:
|
|||||||
```
|
```
|
||||||
|
|
||||||
9. **Submit a Pull Request (PR):**
|
9. **Submit a Pull Request (PR):**
|
||||||
- Create a Pull Request from your branch to the main repository. Make sure to include a detailed description of your changes and reference any related issues.
|
- Create a Pull Request from your branch to the main repository. Make sure to include a detailed description of your changes, reference any related issues, and attach screenshots or a screen recording showing the working version.
|
||||||
|
|
||||||
10. **Collaborate:**
|
10. **Collaborate:**
|
||||||
- Be responsive to comments and feedback on your PR.
|
- Be responsive to comments and feedback on your PR.
|
||||||
|
|||||||
39
README.md
39
README.md
@@ -7,7 +7,7 @@
|
|||||||
</p>
|
</p>
|
||||||
|
|
||||||
<p align="left">
|
<p align="left">
|
||||||
<strong><a href="https://www.docsgpt.cloud/">DocsGPT</a></strong> is an open-source AI platform for building intelligent agents and assistants. Features Agent Builder, deep research tools, document analysis (PDF, Office, web content), Multi-model support (choose your provider or run locally), and rich API connectivity for agents with actionable tools and integrations. Deploy anywhere with complete privacy control.
|
<strong><a href="https://www.docsgpt.cloud/">DocsGPT</a></strong> is an open-source AI platform for building intelligent agents and assistants. Features Agent Builder, deep research tools, document analysis (PDF, Office, web content, and audio), Multi-model support (choose your provider or run locally), and rich API connectivity for agents with actionable tools and integrations. Deploy anywhere with complete privacy control.
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
<div align="center">
|
<div align="center">
|
||||||
@@ -29,13 +29,14 @@
|
|||||||
|
|
||||||
<div align="center">
|
<div align="center">
|
||||||
<br>
|
<br>
|
||||||
<img src="https://d3dg1063dc54p9.cloudfront.net/videos/demov7.gif" alt="video-example-of-docs-gpt" width="800" height="450">
|
<img src="https://d3dg1063dc54p9.cloudfront.net/videos/demo-26.gif" alt="video-example-of-docs-gpt" width="800" height="480">
|
||||||
</div>
|
</div>
|
||||||
<h3 align="left">
|
<h3 align="left">
|
||||||
<strong>Key Features:</strong>
|
<strong>Key Features:</strong>
|
||||||
</h3>
|
</h3>
|
||||||
<ul align="left">
|
<ul align="left">
|
||||||
<li><strong>🗂️ Wide Format Support:</strong> Reads PDF, DOCX, CSV, XLSX, EPUB, MD, RST, HTML, MDX, JSON, PPTX, and images.</li>
|
<li><strong>🗂️ Wide Format Support:</strong> Reads PDF, DOCX, CSV, XLSX, EPUB, MD, RST, HTML, MDX, JSON, PPTX, images, and audio files such as MP3, WAV, M4A, OGG, and WebM.</li>
|
||||||
|
<li><strong>🎙️ Speech Workflows:</strong> Record voice input into chat, transcribe audio on the backend, and ingest meeting recordings or voice notes as searchable knowledge.</li>
|
||||||
<li><strong>🌐 Web & Data Integration:</strong> Ingests from URLs, sitemaps, Reddit, GitHub and web crawlers.</li>
|
<li><strong>🌐 Web & Data Integration:</strong> Ingests from URLs, sitemaps, Reddit, GitHub and web crawlers.</li>
|
||||||
<li><strong>✅ Reliable Answers:</strong> Get accurate, hallucination-free responses with source citations viewable in a clean UI.</li>
|
<li><strong>✅ Reliable Answers:</strong> Get accurate, hallucination-free responses with source citations viewable in a clean UI.</li>
|
||||||
<li><strong>🔑 Streamlined API Keys:</strong> Generate keys linked to your settings, documents, and models, simplifying chatbot and integration setup.</li>
|
<li><strong>🔑 Streamlined API Keys:</strong> Generate keys linked to your settings, documents, and models, simplifying chatbot and integration setup.</li>
|
||||||
@@ -46,24 +47,11 @@
|
|||||||
</ul>
|
</ul>
|
||||||
|
|
||||||
## Roadmap
|
## Roadmap
|
||||||
|
- [x] Add OAuth 2.0 authentication for MCP ( September 2025 )
|
||||||
- [x] Full GoogleAI compatibility (Jan 2025)
|
- [x] Deep Agents ( October 2025 )
|
||||||
- [x] Add tools (Jan 2025)
|
- [x] Prompt Templating ( October 2025 )
|
||||||
- [x] Manually updating chunks in the app UI (Feb 2025)
|
- [x] Full api tooling ( Dec 2025 )
|
||||||
- [x] Devcontainer for easy development (Feb 2025)
|
- [ ] Agent scheduling ( Jan 2026 )
|
||||||
- [x] ReACT agent (March 2025)
|
|
||||||
- [x] Chatbots menu re-design to handle tools, agent types, and more (April 2025)
|
|
||||||
- [x] New input box in the conversation menu (April 2025)
|
|
||||||
- [x] Add triggerable actions / tools (webhook) (April 2025)
|
|
||||||
- [x] Agent optimisations (May 2025)
|
|
||||||
- [x] Filesystem sources update (July 2025)
|
|
||||||
- [x] Json Responses (August 2025)
|
|
||||||
- [x] MCP support (August 2025)
|
|
||||||
- [x] Google Drive integration (September 2025)
|
|
||||||
- [x] Add OAuth 2.0 authentication for MCP (September 2025)
|
|
||||||
- [ ] SharePoint integration (October 2025)
|
|
||||||
- [ ] Deep Agents (October 2025)
|
|
||||||
- [ ] Agent scheduling
|
|
||||||
|
|
||||||
You can find our full roadmap [here](https://github.com/orgs/arc53/projects/2). Please don't hesitate to contribute or create issues, it helps us improve DocsGPT!
|
You can find our full roadmap [here](https://github.com/orgs/arc53/projects/2). Please don't hesitate to contribute or create issues, it helps us improve DocsGPT!
|
||||||
|
|
||||||
@@ -158,9 +146,16 @@ We as members, contributors, and leaders, pledge to make participation in our co
|
|||||||
|
|
||||||
The source code license is [MIT](https://opensource.org/license/mit/), as described in the [LICENSE](LICENSE) file.
|
The source code license is [MIT](https://opensource.org/license/mit/), as described in the [LICENSE](LICENSE) file.
|
||||||
|
|
||||||
<p>This project is supported by:</p>
|
## This project is supported by:
|
||||||
|
|
||||||
<p>
|
<p>
|
||||||
<a href="https://www.digitalocean.com/?utm_medium=opensource&utm_source=DocsGPT">
|
<a href="https://www.digitalocean.com/?utm_medium=opensource&utm_source=DocsGPT">
|
||||||
<img src="https://opensource.nyc3.cdn.digitaloceanspaces.com/attribution/assets/SVG/DO_Logo_horizontal_blue.svg" width="201px">
|
<img src="https://opensource.nyc3.cdn.digitaloceanspaces.com/attribution/assets/SVG/DO_Logo_horizontal_blue.svg" width="201px">
|
||||||
</a>
|
</a>
|
||||||
</p>
|
</p>
|
||||||
|
<p>
|
||||||
|
<a href="https://get.neon.com/docsgpt">
|
||||||
|
<img width="201" alt="color" src="https://github.com/user-attachments/assets/7d9813b7-0e6d-403f-b5af-68af066b326f" />
|
||||||
|
</a>
|
||||||
|
|
||||||
|
</p>
|
||||||
|
|||||||
18
SECURITY.md
18
SECURITY.md
@@ -2,13 +2,21 @@
|
|||||||
|
|
||||||
## Supported Versions
|
## Supported Versions
|
||||||
|
|
||||||
Supported Versions:
|
Security patches target the latest release and the `main` branch. We recommend always running the most recent version.
|
||||||
|
|
||||||
Currently, we support security patches by committing changes and bumping the version published on Github.
|
|
||||||
|
|
||||||
## Reporting a Vulnerability
|
## Reporting a Vulnerability
|
||||||
|
|
||||||
Found a vulnerability? Please email us:
|
Preferred method: use GitHub's private vulnerability reporting flow:
|
||||||
|
https://github.com/arc53/DocsGPT/security
|
||||||
|
|
||||||
security@arc53.com
|
Then click **Report a vulnerability**.
|
||||||
|
|
||||||
|
|
||||||
|
Alternatively, email us at: security@arc53.com
|
||||||
|
|
||||||
|
We aim to acknowledge reports within 48 hours.
|
||||||
|
|
||||||
|
## Incident Handling
|
||||||
|
|
||||||
|
For the public incident response process, see [`INCIDENT_RESPONSE.md`](./.github/INCIDENT_RESPONSE.md). If you believe an active exploit is occurring, include **URGENT** in your report subject line.
|
||||||
|
|
||||||
|
|||||||
@@ -1,12 +0,0 @@
|
|||||||
API_KEY=your_api_key
|
|
||||||
EMBEDDINGS_KEY=your_api_key
|
|
||||||
API_URL=http://localhost:7091
|
|
||||||
INTERNAL_KEY=your_internal_key
|
|
||||||
FLASK_APP=application/app.py
|
|
||||||
FLASK_DEBUG=true
|
|
||||||
|
|
||||||
#For OPENAI on Azure
|
|
||||||
OPENAI_API_BASE=
|
|
||||||
OPENAI_API_VERSION=
|
|
||||||
AZURE_DEPLOYMENT_NAME=
|
|
||||||
AZURE_EMBEDDINGS_DEPLOYMENT_NAME=
|
|
||||||
@@ -7,7 +7,7 @@ RUN apt-get update && \
|
|||||||
apt-get install -y software-properties-common && \
|
apt-get install -y software-properties-common && \
|
||||||
add-apt-repository ppa:deadsnakes/ppa && \
|
add-apt-repository ppa:deadsnakes/ppa && \
|
||||||
apt-get update && \
|
apt-get update && \
|
||||||
apt-get install -y --no-install-recommends gcc wget unzip libc6-dev python3.12 python3.12-venv && \
|
apt-get install -y --no-install-recommends gcc g++ wget unzip libc6-dev python3.12 python3.12-venv python3.12-dev && \
|
||||||
rm -rf /var/lib/apt/lists/*
|
rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
# Verify Python installation and setup symlink
|
# Verify Python installation and setup symlink
|
||||||
@@ -48,7 +48,12 @@ FROM ubuntu:24.04 as final
|
|||||||
RUN apt-get update && \
|
RUN apt-get update && \
|
||||||
apt-get install -y software-properties-common && \
|
apt-get install -y software-properties-common && \
|
||||||
add-apt-repository ppa:deadsnakes/ppa && \
|
add-apt-repository ppa:deadsnakes/ppa && \
|
||||||
apt-get update && apt-get install -y --no-install-recommends python3.12 && \
|
apt-get update && apt-get install -y --no-install-recommends \
|
||||||
|
python3.12 \
|
||||||
|
libgl1 \
|
||||||
|
libglib2.0-0 \
|
||||||
|
poppler-utils \
|
||||||
|
&& \
|
||||||
ln -s /usr/bin/python3.12 /usr/bin/python && \
|
ln -s /usr/bin/python3.12 /usr/bin/python && \
|
||||||
rm -rf /var/lib/apt/lists/*
|
rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
|||||||
@@ -1,14 +1,20 @@
|
|||||||
from application.agents.classic_agent import ClassicAgent
|
|
||||||
from application.agents.react_agent import ReActAgent
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
from application.agents.agentic_agent import AgenticAgent
|
||||||
|
from application.agents.classic_agent import ClassicAgent
|
||||||
|
from application.agents.research_agent import ResearchAgent
|
||||||
|
from application.agents.workflow_agent import WorkflowAgent
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class AgentCreator:
|
class AgentCreator:
|
||||||
agents = {
|
agents = {
|
||||||
"classic": ClassicAgent,
|
"classic": ClassicAgent,
|
||||||
"react": ReActAgent,
|
"react": ClassicAgent, # backwards compat: react falls back to classic
|
||||||
|
"agentic": AgenticAgent,
|
||||||
|
"research": ResearchAgent,
|
||||||
|
"workflow": WorkflowAgent,
|
||||||
}
|
}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -16,5 +22,4 @@ class AgentCreator:
|
|||||||
agent_class = cls.agents.get(type.lower())
|
agent_class = cls.agents.get(type.lower())
|
||||||
if not agent_class:
|
if not agent_class:
|
||||||
raise ValueError(f"No agent class found for type {type}")
|
raise ValueError(f"No agent class found for type {type}")
|
||||||
|
|
||||||
return agent_class(*args, **kwargs)
|
return agent_class(*args, **kwargs)
|
||||||
|
|||||||
63
application/agents/agentic_agent.py
Normal file
63
application/agents/agentic_agent.py
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
import logging
|
||||||
|
from typing import Dict, Generator, Optional
|
||||||
|
|
||||||
|
from application.agents.base import BaseAgent
|
||||||
|
from application.agents.tools.internal_search import (
|
||||||
|
INTERNAL_TOOL_ID,
|
||||||
|
add_internal_search_tool,
|
||||||
|
)
|
||||||
|
from application.logging import LogContext
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class AgenticAgent(BaseAgent):
|
||||||
|
"""Agent where the LLM controls retrieval via tools.
|
||||||
|
|
||||||
|
Unlike ClassicAgent which pre-fetches docs into the prompt,
|
||||||
|
AgenticAgent gives the LLM an internal_search tool so it can
|
||||||
|
decide when, what, and whether to search.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
retriever_config: Optional[Dict] = None,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.retriever_config = retriever_config or {}
|
||||||
|
|
||||||
|
def _gen_inner(
|
||||||
|
self, query: str, log_context: LogContext
|
||||||
|
) -> Generator[Dict, None, None]:
|
||||||
|
tools_dict = self.tool_executor.get_tools()
|
||||||
|
add_internal_search_tool(tools_dict, self.retriever_config)
|
||||||
|
self._prepare_tools(tools_dict)
|
||||||
|
|
||||||
|
# 4. Build messages (prompt has NO pre-fetched docs)
|
||||||
|
messages = self._build_messages(self.prompt, query)
|
||||||
|
|
||||||
|
# 5. Call LLM — the handler manages the tool loop
|
||||||
|
llm_response = self._llm_gen(messages, log_context)
|
||||||
|
|
||||||
|
yield from self._handle_response(
|
||||||
|
llm_response, tools_dict, messages, log_context
|
||||||
|
)
|
||||||
|
|
||||||
|
# 6. Collect sources from internal search tool results
|
||||||
|
self._collect_internal_sources()
|
||||||
|
|
||||||
|
yield {"sources": self.retrieved_docs}
|
||||||
|
yield {"tool_calls": self._get_truncated_tool_calls()}
|
||||||
|
|
||||||
|
log_context.stacks.append(
|
||||||
|
{"component": "agent", "data": {"tool_calls": self.tool_calls.copy()}}
|
||||||
|
)
|
||||||
|
|
||||||
|
def _collect_internal_sources(self):
|
||||||
|
"""Collect retrieved docs from the cached InternalSearchTool instance."""
|
||||||
|
cache_key = f"internal_search:{INTERNAL_TOOL_ID}:{self.user or ''}"
|
||||||
|
tool = self.tool_executor._loaded_tools.get(cache_key)
|
||||||
|
if tool and hasattr(tool, "retrieved_docs") and tool.retrieved_docs:
|
||||||
|
self.retrieved_docs = tool.retrieved_docs
|
||||||
@@ -1,14 +1,16 @@
|
|||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Dict, Generator, List, Optional
|
from typing import Any, Dict, Generator, List, Optional
|
||||||
|
|
||||||
from bson.objectid import ObjectId
|
from application.agents.tool_executor import ToolExecutor
|
||||||
|
from application.core.json_schema_utils import (
|
||||||
from application.agents.tools.tool_action_parser import ToolActionParser
|
JsonSchemaValidationError,
|
||||||
from application.agents.tools.tool_manager import ToolManager
|
normalize_json_schema_payload,
|
||||||
from application.core.mongo_db import MongoDB
|
)
|
||||||
from application.core.settings import settings
|
from application.core.settings import settings
|
||||||
|
from application.llm.handlers.base import ToolCall
|
||||||
from application.llm.handlers.handler_creator import LLMHandlerCreator
|
from application.llm.handlers.handler_creator import LLMHandlerCreator
|
||||||
from application.llm.llm_creator import LLMCreator
|
from application.llm.llm_creator import LLMCreator
|
||||||
from application.logging import build_stack_data, log_activity, LogContext
|
from application.logging import build_stack_data, log_activity, LogContext
|
||||||
@@ -23,6 +25,7 @@ class BaseAgent(ABC):
|
|||||||
llm_name: str,
|
llm_name: str,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
api_key: str,
|
api_key: str,
|
||||||
|
agent_id: Optional[str] = None,
|
||||||
user_api_key: Optional[str] = None,
|
user_api_key: Optional[str] = None,
|
||||||
prompt: str = "",
|
prompt: str = "",
|
||||||
chat_history: Optional[List[Dict]] = None,
|
chat_history: Optional[List[Dict]] = None,
|
||||||
@@ -35,32 +38,63 @@ class BaseAgent(ABC):
|
|||||||
limited_request_mode: Optional[bool] = False,
|
limited_request_mode: Optional[bool] = False,
|
||||||
request_limit: Optional[int] = settings.DEFAULT_AGENT_LIMITS["request_limit"],
|
request_limit: Optional[int] = settings.DEFAULT_AGENT_LIMITS["request_limit"],
|
||||||
compressed_summary: Optional[str] = None,
|
compressed_summary: Optional[str] = None,
|
||||||
|
llm=None,
|
||||||
|
llm_handler=None,
|
||||||
|
tool_executor: Optional[ToolExecutor] = None,
|
||||||
|
backup_models: Optional[List[str]] = None,
|
||||||
):
|
):
|
||||||
self.endpoint = endpoint
|
self.endpoint = endpoint
|
||||||
self.llm_name = llm_name
|
self.llm_name = llm_name
|
||||||
self.model_id = model_id
|
self.model_id = model_id
|
||||||
self.api_key = api_key
|
self.api_key = api_key
|
||||||
|
self.agent_id = agent_id
|
||||||
self.user_api_key = user_api_key
|
self.user_api_key = user_api_key
|
||||||
self.prompt = prompt
|
self.prompt = prompt
|
||||||
self.decoded_token = decoded_token or {}
|
self.decoded_token = decoded_token or {}
|
||||||
self.user: str = self.decoded_token.get("sub")
|
self.user: str = self.decoded_token.get("sub")
|
||||||
self.tool_config: Dict = {}
|
|
||||||
self.tools: List[Dict] = []
|
self.tools: List[Dict] = []
|
||||||
self.tool_calls: List[Dict] = []
|
|
||||||
self.chat_history: List[Dict] = chat_history if chat_history is not None else []
|
self.chat_history: List[Dict] = chat_history if chat_history is not None else []
|
||||||
self.llm = LLMCreator.create_llm(
|
|
||||||
llm_name,
|
# Dependency injection for LLM — fall back to creating if not provided
|
||||||
api_key=api_key,
|
if llm is not None:
|
||||||
user_api_key=user_api_key,
|
self.llm = llm
|
||||||
decoded_token=decoded_token,
|
else:
|
||||||
model_id=model_id,
|
self.llm = LLMCreator.create_llm(
|
||||||
)
|
llm_name,
|
||||||
|
api_key=api_key,
|
||||||
|
user_api_key=user_api_key,
|
||||||
|
decoded_token=decoded_token,
|
||||||
|
model_id=model_id,
|
||||||
|
agent_id=agent_id,
|
||||||
|
backup_models=backup_models,
|
||||||
|
)
|
||||||
|
|
||||||
self.retrieved_docs = retrieved_docs or []
|
self.retrieved_docs = retrieved_docs or []
|
||||||
self.llm_handler = LLMHandlerCreator.create_handler(
|
|
||||||
llm_name if llm_name else "default"
|
if llm_handler is not None:
|
||||||
)
|
self.llm_handler = llm_handler
|
||||||
|
else:
|
||||||
|
self.llm_handler = LLMHandlerCreator.create_handler(
|
||||||
|
llm_name if llm_name else "default"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Tool executor — injected or created
|
||||||
|
if tool_executor is not None:
|
||||||
|
self.tool_executor = tool_executor
|
||||||
|
else:
|
||||||
|
self.tool_executor = ToolExecutor(
|
||||||
|
user_api_key=user_api_key,
|
||||||
|
user=self.user,
|
||||||
|
decoded_token=decoded_token,
|
||||||
|
)
|
||||||
|
|
||||||
self.attachments = attachments or []
|
self.attachments = attachments or []
|
||||||
self.json_schema = json_schema
|
self.json_schema = None
|
||||||
|
if json_schema is not None:
|
||||||
|
try:
|
||||||
|
self.json_schema = normalize_json_schema_payload(json_schema)
|
||||||
|
except JsonSchemaValidationError as exc:
|
||||||
|
logger.warning("Ignoring invalid JSON schema payload: %s", exc)
|
||||||
self.limited_token_mode = limited_token_mode
|
self.limited_token_mode = limited_token_mode
|
||||||
self.token_limit = token_limit
|
self.token_limit = token_limit
|
||||||
self.limited_request_mode = limited_request_mode
|
self.limited_request_mode = limited_request_mode
|
||||||
@@ -81,266 +115,268 @@ class BaseAgent(ABC):
|
|||||||
) -> Generator[Dict, None, None]:
|
) -> Generator[Dict, None, None]:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def _get_tools(self, api_key: str = None) -> Dict[str, Dict]:
|
def gen_continuation(
|
||||||
mongo = MongoDB.get_client()
|
self,
|
||||||
db = mongo[settings.MONGO_DB_NAME]
|
messages: List[Dict],
|
||||||
agents_collection = db["agents"]
|
tools_dict: Dict,
|
||||||
tools_collection = db["user_tools"]
|
pending_tool_calls: List[Dict],
|
||||||
|
tool_actions: List[Dict],
|
||||||
|
) -> Generator[Dict, None, None]:
|
||||||
|
"""Resume generation after tool actions are resolved.
|
||||||
|
|
||||||
agent_data = agents_collection.find_one({"key": api_key or self.user_api_key})
|
Processes the client-provided *tool_actions* (approvals, denials,
|
||||||
tool_ids = agent_data.get("tools", []) if agent_data else []
|
or client-side results), appends the resulting messages, then
|
||||||
|
hands back to the LLM to continue the conversation.
|
||||||
tools = (
|
|
||||||
tools_collection.find(
|
|
||||||
{"_id": {"$in": [ObjectId(tool_id) for tool_id in tool_ids]}}
|
|
||||||
)
|
|
||||||
if tool_ids
|
|
||||||
else []
|
|
||||||
)
|
|
||||||
tools = list(tools)
|
|
||||||
tools_by_id = {str(tool["_id"]): tool for tool in tools} if tools else {}
|
|
||||||
|
|
||||||
return tools_by_id
|
|
||||||
|
|
||||||
def _get_user_tools(self, user="local"):
|
|
||||||
mongo = MongoDB.get_client()
|
|
||||||
db = mongo[settings.MONGO_DB_NAME]
|
|
||||||
user_tools_collection = db["user_tools"]
|
|
||||||
user_tools = user_tools_collection.find({"user": user, "status": True})
|
|
||||||
user_tools = list(user_tools)
|
|
||||||
|
|
||||||
return {str(i): tool for i, tool in enumerate(user_tools)}
|
|
||||||
|
|
||||||
def _build_tool_parameters(self, action):
|
|
||||||
params = {"type": "object", "properties": {}, "required": []}
|
|
||||||
for param_type in ["query_params", "headers", "body", "parameters"]:
|
|
||||||
if param_type in action and action[param_type].get("properties"):
|
|
||||||
for k, v in action[param_type]["properties"].items():
|
|
||||||
if v.get("filled_by_llm", True):
|
|
||||||
params["properties"][k] = {
|
|
||||||
key: value
|
|
||||||
for key, value in v.items()
|
|
||||||
if key != "filled_by_llm" and key != "value"
|
|
||||||
}
|
|
||||||
|
|
||||||
params["required"].append(k)
|
|
||||||
return params
|
|
||||||
|
|
||||||
def _prepare_tools(self, tools_dict):
|
|
||||||
self.tools = [
|
|
||||||
{
|
|
||||||
"type": "function",
|
|
||||||
"function": {
|
|
||||||
"name": f"{action['name']}_{tool_id}",
|
|
||||||
"description": action["description"],
|
|
||||||
"parameters": self._build_tool_parameters(action),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
for tool_id, tool in tools_dict.items()
|
|
||||||
if (
|
|
||||||
(tool["name"] == "api_tool" and "actions" in tool.get("config", {}))
|
|
||||||
or (tool["name"] != "api_tool" and "actions" in tool)
|
|
||||||
)
|
|
||||||
for action in (
|
|
||||||
tool["config"]["actions"].values()
|
|
||||||
if tool["name"] == "api_tool"
|
|
||||||
else tool["actions"]
|
|
||||||
)
|
|
||||||
if action.get("active", True)
|
|
||||||
]
|
|
||||||
|
|
||||||
def _execute_tool_action(self, tools_dict, call):
|
|
||||||
parser = ToolActionParser(self.llm.__class__.__name__)
|
|
||||||
tool_id, action_name, call_args = parser.parse_args(call)
|
|
||||||
|
|
||||||
call_id = getattr(call, "id", None) or str(uuid.uuid4())
|
|
||||||
|
|
||||||
# Check if parsing failed
|
|
||||||
|
|
||||||
if tool_id is None or action_name is None:
|
|
||||||
error_message = f"Error: Failed to parse LLM tool call. Tool name: {getattr(call, 'name', 'unknown')}"
|
|
||||||
logger.error(error_message)
|
|
||||||
|
|
||||||
tool_call_data = {
|
|
||||||
"tool_name": "unknown",
|
|
||||||
"call_id": call_id,
|
|
||||||
"action_name": getattr(call, "name", "unknown"),
|
|
||||||
"arguments": call_args or {},
|
|
||||||
"result": f"Failed to parse tool call. Invalid tool name format: {getattr(call, 'name', 'unknown')}",
|
|
||||||
}
|
|
||||||
yield {"type": "tool_call", "data": {**tool_call_data, "status": "error"}}
|
|
||||||
self.tool_calls.append(tool_call_data)
|
|
||||||
return "Failed to parse tool call.", call_id
|
|
||||||
# Check if tool_id exists in available tools
|
|
||||||
|
|
||||||
if tool_id not in tools_dict:
|
|
||||||
error_message = f"Error: Tool ID '{tool_id}' extracted from LLM call not found in available tools_dict. Available IDs: {list(tools_dict.keys())}"
|
|
||||||
logger.error(error_message)
|
|
||||||
|
|
||||||
# Return error result
|
|
||||||
|
|
||||||
tool_call_data = {
|
|
||||||
"tool_name": "unknown",
|
|
||||||
"call_id": call_id,
|
|
||||||
"action_name": f"{action_name}_{tool_id}",
|
|
||||||
"arguments": call_args,
|
|
||||||
"result": f"Tool with ID {tool_id} not found. Available tools: {list(tools_dict.keys())}",
|
|
||||||
}
|
|
||||||
yield {"type": "tool_call", "data": {**tool_call_data, "status": "error"}}
|
|
||||||
self.tool_calls.append(tool_call_data)
|
|
||||||
return f"Tool with ID {tool_id} not found.", call_id
|
|
||||||
tool_call_data = {
|
|
||||||
"tool_name": tools_dict[tool_id]["name"],
|
|
||||||
"call_id": call_id,
|
|
||||||
"action_name": f"{action_name}_{tool_id}",
|
|
||||||
"arguments": call_args,
|
|
||||||
}
|
|
||||||
yield {"type": "tool_call", "data": {**tool_call_data, "status": "pending"}}
|
|
||||||
|
|
||||||
tool_data = tools_dict[tool_id]
|
|
||||||
action_data = (
|
|
||||||
tool_data["config"]["actions"][action_name]
|
|
||||||
if tool_data["name"] == "api_tool"
|
|
||||||
else next(
|
|
||||||
action
|
|
||||||
for action in tool_data["actions"]
|
|
||||||
if action["name"] == action_name
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
query_params, headers, body, parameters = {}, {}, {}, {}
|
|
||||||
param_types = {
|
|
||||||
"query_params": query_params,
|
|
||||||
"headers": headers,
|
|
||||||
"body": body,
|
|
||||||
"parameters": parameters,
|
|
||||||
}
|
|
||||||
|
|
||||||
for param_type, target_dict in param_types.items():
|
|
||||||
if param_type in action_data and action_data[param_type].get("properties"):
|
|
||||||
for param, details in action_data[param_type]["properties"].items():
|
|
||||||
if param not in call_args and "value" in details:
|
|
||||||
target_dict[param] = details["value"]
|
|
||||||
for param, value in call_args.items():
|
|
||||||
for param_type, target_dict in param_types.items():
|
|
||||||
if param_type in action_data and param in action_data[param_type].get(
|
|
||||||
"properties", {}
|
|
||||||
):
|
|
||||||
target_dict[param] = value
|
|
||||||
tm = ToolManager(config={})
|
|
||||||
|
|
||||||
# Prepare tool_config and add tool_id for memory tools
|
|
||||||
|
|
||||||
if tool_data["name"] == "api_tool":
|
|
||||||
tool_config = {
|
|
||||||
"url": tool_data["config"]["actions"][action_name]["url"],
|
|
||||||
"method": tool_data["config"]["actions"][action_name]["method"],
|
|
||||||
"headers": headers,
|
|
||||||
"query_params": query_params,
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
tool_config = tool_data["config"].copy() if tool_data["config"] else {}
|
|
||||||
# Add tool_id from MongoDB _id for tools that need instance isolation (like memory tool)
|
|
||||||
# Use MongoDB _id if available, otherwise fall back to enumerated tool_id
|
|
||||||
|
|
||||||
tool_config["tool_id"] = str(tool_data.get("_id", tool_id))
|
|
||||||
tool = tm.load_tool(
|
|
||||||
tool_data["name"],
|
|
||||||
tool_config=tool_config,
|
|
||||||
user_id=self.user, # Pass user ID for MCP tools credential decryption
|
|
||||||
)
|
|
||||||
if tool_data["name"] == "api_tool":
|
|
||||||
print(
|
|
||||||
f"Executing api: {action_name} with query_params: {query_params}, headers: {headers}, body: {body}"
|
|
||||||
)
|
|
||||||
result = tool.execute_action(action_name, **body)
|
|
||||||
else:
|
|
||||||
print(f"Executing tool: {action_name} with args: {call_args}")
|
|
||||||
result = tool.execute_action(action_name, **parameters)
|
|
||||||
tool_call_data["result"] = (
|
|
||||||
f"{str(result)[:50]}..." if len(str(result)) > 50 else result
|
|
||||||
)
|
|
||||||
|
|
||||||
yield {"type": "tool_call", "data": {**tool_call_data, "status": "completed"}}
|
|
||||||
self.tool_calls.append(tool_call_data)
|
|
||||||
|
|
||||||
return result, call_id
|
|
||||||
|
|
||||||
def _get_truncated_tool_calls(self):
|
|
||||||
return [
|
|
||||||
{
|
|
||||||
**tool_call,
|
|
||||||
"result": (
|
|
||||||
f"{str(tool_call['result'])[:50]}..."
|
|
||||||
if len(str(tool_call["result"])) > 50
|
|
||||||
else tool_call["result"]
|
|
||||||
),
|
|
||||||
"status": "completed",
|
|
||||||
}
|
|
||||||
for tool_call in self.tool_calls
|
|
||||||
]
|
|
||||||
|
|
||||||
def _calculate_current_context_tokens(self, messages: List[Dict]) -> int:
|
|
||||||
"""
|
|
||||||
Calculate total tokens in current context (messages).
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
messages: List of message dicts
|
messages: The saved messages array from the pause point.
|
||||||
|
tools_dict: The saved tools dictionary.
|
||||||
Returns:
|
pending_tool_calls: The pending tool call descriptors from the pause.
|
||||||
Total token count
|
tool_actions: Client-provided actions resolving the pending calls.
|
||||||
"""
|
"""
|
||||||
|
self._prepare_tools(tools_dict)
|
||||||
|
|
||||||
|
actions_by_id = {a["call_id"]: a for a in tool_actions}
|
||||||
|
|
||||||
|
# Build a single assistant message containing all tool calls so
|
||||||
|
# the message history matches the format LLM providers expect
|
||||||
|
# (one assistant message with N tool_calls, followed by N tool results).
|
||||||
|
tc_objects: List[Dict[str, Any]] = []
|
||||||
|
for pending in pending_tool_calls:
|
||||||
|
call_id = pending["call_id"]
|
||||||
|
args = pending["arguments"]
|
||||||
|
args_str = (
|
||||||
|
json.dumps(args) if isinstance(args, dict) else (args or "{}")
|
||||||
|
)
|
||||||
|
tc_obj: Dict[str, Any] = {
|
||||||
|
"id": call_id,
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": pending["name"],
|
||||||
|
"arguments": args_str,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if pending.get("thought_signature"):
|
||||||
|
tc_obj["thought_signature"] = pending["thought_signature"]
|
||||||
|
tc_objects.append(tc_obj)
|
||||||
|
|
||||||
|
messages.append({
|
||||||
|
"role": "assistant",
|
||||||
|
"content": None,
|
||||||
|
"tool_calls": tc_objects,
|
||||||
|
})
|
||||||
|
|
||||||
|
# Now process each pending call and append tool result messages
|
||||||
|
for pending in pending_tool_calls:
|
||||||
|
call_id = pending["call_id"]
|
||||||
|
args = pending["arguments"]
|
||||||
|
action = actions_by_id.get(call_id)
|
||||||
|
if not action:
|
||||||
|
action = {
|
||||||
|
"call_id": call_id,
|
||||||
|
"decision": "denied",
|
||||||
|
"comment": "No response provided",
|
||||||
|
}
|
||||||
|
|
||||||
|
if action.get("decision") == "approved":
|
||||||
|
# Execute the tool server-side
|
||||||
|
tc = ToolCall(
|
||||||
|
id=call_id,
|
||||||
|
name=pending["name"],
|
||||||
|
arguments=(
|
||||||
|
json.dumps(args) if isinstance(args, dict) else args
|
||||||
|
),
|
||||||
|
)
|
||||||
|
tool_gen = self._execute_tool_action(tools_dict, tc)
|
||||||
|
tool_response = None
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
event = next(tool_gen)
|
||||||
|
yield event
|
||||||
|
except StopIteration as e:
|
||||||
|
tool_response, _ = e.value
|
||||||
|
break
|
||||||
|
messages.append(
|
||||||
|
self.llm_handler.create_tool_message(tc, tool_response)
|
||||||
|
)
|
||||||
|
|
||||||
|
elif action.get("decision") == "denied":
|
||||||
|
comment = action.get("comment", "")
|
||||||
|
denial = (
|
||||||
|
f"Tool execution denied by user. Reason: {comment}"
|
||||||
|
if comment
|
||||||
|
else "Tool execution denied by user."
|
||||||
|
)
|
||||||
|
tc = ToolCall(
|
||||||
|
id=call_id, name=pending["name"], arguments=args
|
||||||
|
)
|
||||||
|
messages.append(
|
||||||
|
self.llm_handler.create_tool_message(tc, denial)
|
||||||
|
)
|
||||||
|
yield {
|
||||||
|
"type": "tool_call",
|
||||||
|
"data": {
|
||||||
|
"tool_name": pending.get("tool_name", "unknown"),
|
||||||
|
"call_id": call_id,
|
||||||
|
"action_name": pending.get("llm_name", pending["name"]),
|
||||||
|
"arguments": args,
|
||||||
|
"status": "denied",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
elif "result" in action:
|
||||||
|
result = action["result"]
|
||||||
|
result_str = (
|
||||||
|
json.dumps(result)
|
||||||
|
if not isinstance(result, str)
|
||||||
|
else result
|
||||||
|
)
|
||||||
|
tc = ToolCall(
|
||||||
|
id=call_id, name=pending["name"], arguments=args
|
||||||
|
)
|
||||||
|
messages.append(
|
||||||
|
self.llm_handler.create_tool_message(tc, result_str)
|
||||||
|
)
|
||||||
|
yield {
|
||||||
|
"type": "tool_call",
|
||||||
|
"data": {
|
||||||
|
"tool_name": pending.get("tool_name", "unknown"),
|
||||||
|
"call_id": call_id,
|
||||||
|
"action_name": pending.get("llm_name", pending["name"]),
|
||||||
|
"arguments": args,
|
||||||
|
"result": (
|
||||||
|
result_str[:50] + "..."
|
||||||
|
if len(result_str) > 50
|
||||||
|
else result_str
|
||||||
|
),
|
||||||
|
"status": "completed",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
# Resume the LLM loop with the updated messages
|
||||||
|
llm_response = self._llm_gen(messages)
|
||||||
|
yield from self._handle_response(
|
||||||
|
llm_response, tools_dict, messages, None
|
||||||
|
)
|
||||||
|
|
||||||
|
yield {"sources": self.retrieved_docs}
|
||||||
|
yield {"tool_calls": self._get_truncated_tool_calls()}
|
||||||
|
|
||||||
|
# ---- Tool delegation (thin wrappers around ToolExecutor) ----
|
||||||
|
|
||||||
|
@property
|
||||||
|
def tool_calls(self) -> List[Dict]:
|
||||||
|
return self.tool_executor.tool_calls
|
||||||
|
|
||||||
|
@tool_calls.setter
|
||||||
|
def tool_calls(self, value: List[Dict]):
|
||||||
|
self.tool_executor.tool_calls = value
|
||||||
|
|
||||||
|
def _get_tools(self, api_key: str = None) -> Dict[str, Dict]:
|
||||||
|
return self.tool_executor._get_tools_by_api_key(api_key or self.user_api_key)
|
||||||
|
|
||||||
|
def _get_user_tools(self, user="local"):
|
||||||
|
return self.tool_executor._get_user_tools(user)
|
||||||
|
|
||||||
|
def _build_tool_parameters(self, action):
|
||||||
|
return self.tool_executor._build_tool_parameters(action)
|
||||||
|
|
||||||
|
def _prepare_tools(self, tools_dict):
|
||||||
|
self.tools = self.tool_executor.prepare_tools_for_llm(tools_dict)
|
||||||
|
|
||||||
|
def _execute_tool_action(self, tools_dict, call):
|
||||||
|
return self.tool_executor.execute(
|
||||||
|
tools_dict, call, self.llm.__class__.__name__
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_truncated_tool_calls(self):
|
||||||
|
return self.tool_executor.get_truncated_tool_calls()
|
||||||
|
|
||||||
|
# ---- Context / token management ----
|
||||||
|
|
||||||
|
def _calculate_current_context_tokens(self, messages: List[Dict]) -> int:
|
||||||
from application.api.answer.services.compression.token_counter import (
|
from application.api.answer.services.compression.token_counter import (
|
||||||
TokenCounter,
|
TokenCounter,
|
||||||
)
|
)
|
||||||
|
|
||||||
return TokenCounter.count_message_tokens(messages)
|
return TokenCounter.count_message_tokens(messages)
|
||||||
|
|
||||||
def _check_context_limit(self, messages: List[Dict]) -> bool:
|
def _check_context_limit(self, messages: List[Dict]) -> bool:
|
||||||
"""
|
|
||||||
Check if we're approaching context limit (80%).
|
|
||||||
|
|
||||||
Args:
|
|
||||||
messages: Current message list
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if at or above 80% of context limit
|
|
||||||
"""
|
|
||||||
from application.core.model_utils import get_token_limit
|
from application.core.model_utils import get_token_limit
|
||||||
from application.core.settings import settings
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Calculate current tokens
|
|
||||||
current_tokens = self._calculate_current_context_tokens(messages)
|
current_tokens = self._calculate_current_context_tokens(messages)
|
||||||
self.current_token_count = current_tokens
|
self.current_token_count = current_tokens
|
||||||
|
|
||||||
# Get context limit for model
|
|
||||||
context_limit = get_token_limit(self.model_id)
|
context_limit = get_token_limit(self.model_id)
|
||||||
|
|
||||||
# Calculate threshold (80%)
|
|
||||||
threshold = int(context_limit * settings.COMPRESSION_THRESHOLD_PERCENTAGE)
|
threshold = int(context_limit * settings.COMPRESSION_THRESHOLD_PERCENTAGE)
|
||||||
|
|
||||||
# Check if we've reached the limit
|
|
||||||
if current_tokens >= threshold:
|
if current_tokens >= threshold:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Context limit approaching: {current_tokens}/{context_limit} tokens "
|
f"Context limit approaching: {current_tokens}/{context_limit} tokens "
|
||||||
f"({(current_tokens/context_limit)*100:.1f}%)"
|
f"({(current_tokens/context_limit)*100:.1f}%)"
|
||||||
)
|
)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error checking context limit: {str(e)}", exc_info=True)
|
logger.error(f"Error checking context limit: {str(e)}", exc_info=True)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def _validate_context_size(self, messages: List[Dict]) -> None:
|
||||||
|
from application.core.model_utils import get_token_limit
|
||||||
|
|
||||||
|
current_tokens = self._calculate_current_context_tokens(messages)
|
||||||
|
self.current_token_count = current_tokens
|
||||||
|
context_limit = get_token_limit(self.model_id)
|
||||||
|
percentage = (current_tokens / context_limit) * 100
|
||||||
|
|
||||||
|
if current_tokens >= context_limit:
|
||||||
|
logger.warning(
|
||||||
|
f"Context at limit: {current_tokens:,}/{context_limit:,} tokens "
|
||||||
|
f"({percentage:.1f}%). Model: {self.model_id}"
|
||||||
|
)
|
||||||
|
elif current_tokens >= int(
|
||||||
|
context_limit * settings.COMPRESSION_THRESHOLD_PERCENTAGE
|
||||||
|
):
|
||||||
|
logger.info(
|
||||||
|
f"Context approaching limit: {current_tokens:,}/{context_limit:,} tokens "
|
||||||
|
f"({percentage:.1f}%)"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _truncate_text_middle(self, text: str, max_tokens: int) -> str:
|
||||||
|
from application.utils import num_tokens_from_string
|
||||||
|
|
||||||
|
current_tokens = num_tokens_from_string(text)
|
||||||
|
if current_tokens <= max_tokens:
|
||||||
|
return text
|
||||||
|
|
||||||
|
chars_per_token = len(text) / current_tokens if current_tokens > 0 else 4
|
||||||
|
target_chars = int(max_tokens * chars_per_token * 0.95)
|
||||||
|
|
||||||
|
if target_chars <= 0:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
start_chars = int(target_chars * 0.4)
|
||||||
|
end_chars = int(target_chars * 0.4)
|
||||||
|
|
||||||
|
truncation_marker = "\n\n[... content truncated to fit context limit ...]\n\n"
|
||||||
|
truncated = text[:start_chars] + truncation_marker + text[-end_chars:]
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Truncated text from {current_tokens:,} to ~{max_tokens:,} tokens "
|
||||||
|
f"(removed middle section)"
|
||||||
|
)
|
||||||
|
return truncated
|
||||||
|
|
||||||
|
# ---- Message building ----
|
||||||
|
|
||||||
def _build_messages(
|
def _build_messages(
|
||||||
self,
|
self,
|
||||||
system_prompt: str,
|
system_prompt: str,
|
||||||
query: str,
|
query: str,
|
||||||
) -> List[Dict]:
|
) -> List[Dict]:
|
||||||
"""Build messages using pre-rendered system prompt"""
|
"""Build messages using pre-rendered system prompt"""
|
||||||
# Append compression summary to system prompt if present
|
from application.core.model_utils import get_token_limit
|
||||||
|
from application.utils import num_tokens_from_string
|
||||||
|
|
||||||
if self.compressed_summary:
|
if self.compressed_summary:
|
||||||
compression_context = (
|
compression_context = (
|
||||||
"\n\n---\n\n"
|
"\n\n---\n\n"
|
||||||
@@ -351,42 +387,119 @@ class BaseAgent(ABC):
|
|||||||
)
|
)
|
||||||
system_prompt = system_prompt + compression_context
|
system_prompt = system_prompt + compression_context
|
||||||
|
|
||||||
|
context_limit = get_token_limit(self.model_id)
|
||||||
|
system_tokens = num_tokens_from_string(system_prompt)
|
||||||
|
|
||||||
|
safety_buffer = int(context_limit * 0.1)
|
||||||
|
available_after_system = context_limit - system_tokens - safety_buffer
|
||||||
|
|
||||||
|
max_query_tokens = int(available_after_system * 0.8)
|
||||||
|
query_tokens = num_tokens_from_string(query)
|
||||||
|
|
||||||
|
if query_tokens > max_query_tokens:
|
||||||
|
query = self._truncate_text_middle(query, max_query_tokens)
|
||||||
|
query_tokens = num_tokens_from_string(query)
|
||||||
|
|
||||||
|
available_for_history = max(available_after_system - query_tokens, 0)
|
||||||
|
|
||||||
|
working_history = self._truncate_history_to_fit(
|
||||||
|
self.chat_history,
|
||||||
|
available_for_history,
|
||||||
|
)
|
||||||
|
|
||||||
messages = [{"role": "system", "content": system_prompt}]
|
messages = [{"role": "system", "content": system_prompt}]
|
||||||
|
|
||||||
for i in self.chat_history:
|
for i in working_history:
|
||||||
if "prompt" in i and "response" in i:
|
if "prompt" in i and "response" in i:
|
||||||
messages.append({"role": "user", "content": i["prompt"]})
|
messages.append({"role": "user", "content": i["prompt"]})
|
||||||
messages.append({"role": "assistant", "content": i["response"]})
|
messages.append({"role": "assistant", "content": i["response"]})
|
||||||
if "tool_calls" in i:
|
if "tool_calls" in i:
|
||||||
for tool_call in i["tool_calls"]:
|
for tool_call in i["tool_calls"]:
|
||||||
call_id = tool_call.get("call_id") or str(uuid.uuid4())
|
call_id = tool_call.get("call_id") or str(uuid.uuid4())
|
||||||
|
args = tool_call.get("arguments")
|
||||||
function_call_dict = {
|
args_str = (
|
||||||
"function_call": {
|
json.dumps(args)
|
||||||
"name": tool_call.get("action_name"),
|
if isinstance(args, dict)
|
||||||
"args": tool_call.get("arguments"),
|
else (args or "{}")
|
||||||
"call_id": call_id,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
function_response_dict = {
|
|
||||||
"function_response": {
|
|
||||||
"name": tool_call.get("action_name"),
|
|
||||||
"response": {"result": tool_call.get("result")},
|
|
||||||
"call_id": call_id,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
messages.append(
|
|
||||||
{"role": "assistant", "content": [function_call_dict]}
|
|
||||||
)
|
)
|
||||||
messages.append(
|
messages.append({
|
||||||
{"role": "tool", "content": [function_response_dict]}
|
"role": "assistant",
|
||||||
|
"content": None,
|
||||||
|
"tool_calls": [{
|
||||||
|
"id": call_id,
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": tool_call.get("action_name", ""),
|
||||||
|
"arguments": args_str,
|
||||||
|
},
|
||||||
|
}],
|
||||||
|
})
|
||||||
|
result = tool_call.get("result")
|
||||||
|
result_str = (
|
||||||
|
json.dumps(result)
|
||||||
|
if not isinstance(result, str)
|
||||||
|
else (result or "")
|
||||||
)
|
)
|
||||||
|
messages.append({
|
||||||
|
"role": "tool",
|
||||||
|
"tool_call_id": call_id,
|
||||||
|
"content": result_str,
|
||||||
|
})
|
||||||
messages.append({"role": "user", "content": query})
|
messages.append({"role": "user", "content": query})
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
|
def _truncate_history_to_fit(
|
||||||
|
self,
|
||||||
|
history: List[Dict],
|
||||||
|
max_tokens: int,
|
||||||
|
) -> List[Dict]:
|
||||||
|
from application.utils import num_tokens_from_string
|
||||||
|
|
||||||
|
if not history or max_tokens <= 0:
|
||||||
|
return []
|
||||||
|
|
||||||
|
truncated = []
|
||||||
|
current_tokens = 0
|
||||||
|
|
||||||
|
for message in reversed(history):
|
||||||
|
message_tokens = 0
|
||||||
|
|
||||||
|
if "prompt" in message and "response" in message:
|
||||||
|
message_tokens += num_tokens_from_string(message["prompt"])
|
||||||
|
message_tokens += num_tokens_from_string(message["response"])
|
||||||
|
|
||||||
|
if "tool_calls" in message:
|
||||||
|
for tool_call in message["tool_calls"]:
|
||||||
|
tool_str = (
|
||||||
|
f"Tool: {tool_call.get('tool_name')} | "
|
||||||
|
f"Action: {tool_call.get('action_name')} | "
|
||||||
|
f"Args: {tool_call.get('arguments')} | "
|
||||||
|
f"Response: {tool_call.get('result')}"
|
||||||
|
)
|
||||||
|
message_tokens += num_tokens_from_string(tool_str)
|
||||||
|
|
||||||
|
if current_tokens + message_tokens <= max_tokens:
|
||||||
|
current_tokens += message_tokens
|
||||||
|
truncated.insert(0, message)
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
|
if len(truncated) < len(history):
|
||||||
|
logger.info(
|
||||||
|
f"Truncated chat history from {len(history)} to {len(truncated)} messages "
|
||||||
|
f"to fit within {max_tokens:,} token budget"
|
||||||
|
)
|
||||||
|
|
||||||
|
return truncated
|
||||||
|
|
||||||
|
# ---- LLM generation ----
|
||||||
|
|
||||||
def _llm_gen(self, messages: List[Dict], log_context: Optional[LogContext] = None):
|
def _llm_gen(self, messages: List[Dict], log_context: Optional[LogContext] = None):
|
||||||
|
self._validate_context_size(messages)
|
||||||
|
|
||||||
gen_kwargs = {"model": self.model_id, "messages": messages}
|
gen_kwargs = {"model": self.model_id, "messages": messages}
|
||||||
|
if self.attachments:
|
||||||
|
gen_kwargs["_usage_attachments"] = self.attachments
|
||||||
|
|
||||||
if (
|
if (
|
||||||
hasattr(self.llm, "_supports_tools")
|
hasattr(self.llm, "_supports_tools")
|
||||||
|
|||||||
@@ -15,11 +15,7 @@ class ClassicAgent(BaseAgent):
|
|||||||
) -> Generator[Dict, None, None]:
|
) -> Generator[Dict, None, None]:
|
||||||
"""Core generator function for ClassicAgent execution flow"""
|
"""Core generator function for ClassicAgent execution flow"""
|
||||||
|
|
||||||
tools_dict = (
|
tools_dict = self.tool_executor.get_tools()
|
||||||
self._get_user_tools(self.user)
|
|
||||||
if not self.user_api_key
|
|
||||||
else self._get_tools(self.user_api_key)
|
|
||||||
)
|
|
||||||
self._prepare_tools(tools_dict)
|
self._prepare_tools(tools_dict)
|
||||||
|
|
||||||
messages = self._build_messages(self.prompt, query)
|
messages = self._build_messages(self.prompt, query)
|
||||||
|
|||||||
@@ -1,238 +0,0 @@
|
|||||||
import logging
|
|
||||||
import os
|
|
||||||
from typing import Any, Dict, Generator, List
|
|
||||||
|
|
||||||
from application.agents.base import BaseAgent
|
|
||||||
from application.logging import build_stack_data, LogContext
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
MAX_ITERATIONS_REASONING = 10
|
|
||||||
|
|
||||||
current_dir = os.path.dirname(
|
|
||||||
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
|
||||||
)
|
|
||||||
with open(
|
|
||||||
os.path.join(current_dir, "application/prompts", "react_planning_prompt.txt"), "r"
|
|
||||||
) as f:
|
|
||||||
PLANNING_PROMPT_TEMPLATE = f.read()
|
|
||||||
with open(
|
|
||||||
os.path.join(current_dir, "application/prompts", "react_final_prompt.txt"), "r"
|
|
||||||
) as f:
|
|
||||||
FINAL_PROMPT_TEMPLATE = f.read()
|
|
||||||
|
|
||||||
|
|
||||||
class ReActAgent(BaseAgent):
|
|
||||||
"""
|
|
||||||
Research and Action (ReAct) Agent - Advanced reasoning agent with iterative planning.
|
|
||||||
|
|
||||||
Implements a think-act-observe loop for complex problem-solving:
|
|
||||||
1. Creates a strategic plan based on the query
|
|
||||||
2. Executes tools and gathers observations
|
|
||||||
3. Iteratively refines approach until satisfied
|
|
||||||
4. Synthesizes final answer from all observations
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
self.plan: str = ""
|
|
||||||
self.observations: List[str] = []
|
|
||||||
|
|
||||||
def _gen_inner(
|
|
||||||
self, query: str, log_context: LogContext
|
|
||||||
) -> Generator[Dict, None, None]:
|
|
||||||
"""Execute ReAct reasoning loop with planning, action, and observation cycles"""
|
|
||||||
|
|
||||||
self._reset_state()
|
|
||||||
|
|
||||||
tools_dict = (
|
|
||||||
self._get_tools(self.user_api_key)
|
|
||||||
if self.user_api_key
|
|
||||||
else self._get_user_tools(self.user)
|
|
||||||
)
|
|
||||||
self._prepare_tools(tools_dict)
|
|
||||||
|
|
||||||
for iteration in range(1, MAX_ITERATIONS_REASONING + 1):
|
|
||||||
yield {"thought": f"Reasoning... (iteration {iteration})\n\n"}
|
|
||||||
|
|
||||||
yield from self._planning_phase(query, log_context)
|
|
||||||
|
|
||||||
if not self.plan:
|
|
||||||
logger.warning(
|
|
||||||
f"ReActAgent: No plan generated in iteration {iteration}"
|
|
||||||
)
|
|
||||||
break
|
|
||||||
self.observations.append(f"Plan (iteration {iteration}): {self.plan}")
|
|
||||||
|
|
||||||
satisfied = yield from self._execution_phase(query, tools_dict, log_context)
|
|
||||||
|
|
||||||
if satisfied:
|
|
||||||
logger.info("ReActAgent: Goal satisfied, stopping reasoning loop")
|
|
||||||
break
|
|
||||||
yield from self._synthesis_phase(query, log_context)
|
|
||||||
|
|
||||||
def _reset_state(self):
|
|
||||||
"""Reset agent state for new query"""
|
|
||||||
self.plan = ""
|
|
||||||
self.observations = []
|
|
||||||
|
|
||||||
def _planning_phase(
|
|
||||||
self, query: str, log_context: LogContext
|
|
||||||
) -> Generator[Dict, None, None]:
|
|
||||||
"""Generate strategic plan for query"""
|
|
||||||
logger.info("ReActAgent: Creating plan...")
|
|
||||||
|
|
||||||
plan_prompt = self._build_planning_prompt(query)
|
|
||||||
messages = [{"role": "user", "content": plan_prompt}]
|
|
||||||
|
|
||||||
plan_stream = self.llm.gen_stream(
|
|
||||||
model=self.model_id,
|
|
||||||
messages=messages,
|
|
||||||
tools=self.tools if self.tools else None,
|
|
||||||
)
|
|
||||||
|
|
||||||
if log_context:
|
|
||||||
log_context.stacks.append(
|
|
||||||
{"component": "planning_llm", "data": build_stack_data(self.llm)}
|
|
||||||
)
|
|
||||||
plan_parts = []
|
|
||||||
for chunk in plan_stream:
|
|
||||||
content = self._extract_content(chunk)
|
|
||||||
if content:
|
|
||||||
plan_parts.append(content)
|
|
||||||
yield {"thought": content}
|
|
||||||
self.plan = "".join(plan_parts)
|
|
||||||
|
|
||||||
def _execution_phase(
|
|
||||||
self, query: str, tools_dict: Dict, log_context: LogContext
|
|
||||||
) -> Generator[bool, None, None]:
|
|
||||||
"""Execute plan with tool calls and observations"""
|
|
||||||
execution_prompt = self._build_execution_prompt(query)
|
|
||||||
messages = self._build_messages(execution_prompt, query)
|
|
||||||
|
|
||||||
llm_response = self._llm_gen(messages, log_context)
|
|
||||||
initial_content = self._extract_content(llm_response)
|
|
||||||
|
|
||||||
if initial_content:
|
|
||||||
self.observations.append(f"Initial response: {initial_content}")
|
|
||||||
processed_response = self._llm_handler(
|
|
||||||
llm_response, tools_dict, messages, log_context
|
|
||||||
)
|
|
||||||
|
|
||||||
for tool_call in self.tool_calls:
|
|
||||||
observation = (
|
|
||||||
f"Executed: {tool_call.get('tool_name', 'Unknown')} "
|
|
||||||
f"with args {tool_call.get('arguments', {})}. "
|
|
||||||
f"Result: {str(tool_call.get('result', ''))[:200]}"
|
|
||||||
)
|
|
||||||
self.observations.append(observation)
|
|
||||||
final_content = self._extract_content(processed_response)
|
|
||||||
if final_content:
|
|
||||||
self.observations.append(f"Response after tools: {final_content}")
|
|
||||||
if log_context:
|
|
||||||
log_context.stacks.append(
|
|
||||||
{
|
|
||||||
"component": "agent_tool_calls",
|
|
||||||
"data": {"tool_calls": self.tool_calls.copy()},
|
|
||||||
}
|
|
||||||
)
|
|
||||||
yield {"sources": self.retrieved_docs}
|
|
||||||
yield {"tool_calls": self._get_truncated_tool_calls()}
|
|
||||||
|
|
||||||
return "SATISFIED" in (final_content or "")
|
|
||||||
|
|
||||||
def _synthesis_phase(
|
|
||||||
self, query: str, log_context: LogContext
|
|
||||||
) -> Generator[Dict, None, None]:
|
|
||||||
"""Synthesize final answer from all observations"""
|
|
||||||
logger.info("ReActAgent: Generating final answer...")
|
|
||||||
|
|
||||||
final_prompt = self._build_final_answer_prompt(query)
|
|
||||||
messages = [{"role": "user", "content": final_prompt}]
|
|
||||||
|
|
||||||
final_stream = self.llm.gen_stream(
|
|
||||||
model=self.model_id, messages=messages, tools=None
|
|
||||||
)
|
|
||||||
|
|
||||||
if log_context:
|
|
||||||
log_context.stacks.append(
|
|
||||||
{"component": "final_answer_llm", "data": build_stack_data(self.llm)}
|
|
||||||
)
|
|
||||||
for chunk in final_stream:
|
|
||||||
content = self._extract_content(chunk)
|
|
||||||
if content:
|
|
||||||
yield {"answer": content}
|
|
||||||
|
|
||||||
def _build_planning_prompt(self, query: str) -> str:
|
|
||||||
"""Build planning phase prompt"""
|
|
||||||
prompt = PLANNING_PROMPT_TEMPLATE.replace("{query}", query)
|
|
||||||
prompt = prompt.replace("{prompt}", self.prompt or "")
|
|
||||||
prompt = prompt.replace("{summaries}", "")
|
|
||||||
prompt = prompt.replace("{observations}", "\n".join(self.observations))
|
|
||||||
return prompt
|
|
||||||
|
|
||||||
def _build_execution_prompt(self, query: str) -> str:
|
|
||||||
"""Build execution phase prompt with plan and observations"""
|
|
||||||
observations_str = "\n".join(self.observations)
|
|
||||||
|
|
||||||
if len(observations_str) > 20000:
|
|
||||||
observations_str = observations_str[:20000] + "\n...[truncated]"
|
|
||||||
return (
|
|
||||||
f"{self.prompt or ''}\n\n"
|
|
||||||
f"Follow this plan:\n{self.plan}\n\n"
|
|
||||||
f"Observations:\n{observations_str}\n\n"
|
|
||||||
f"If sufficient data exists to answer '{query}', respond with 'SATISFIED'. "
|
|
||||||
f"Otherwise, continue executing the plan."
|
|
||||||
)
|
|
||||||
|
|
||||||
def _build_final_answer_prompt(self, query: str) -> str:
|
|
||||||
"""Build final synthesis prompt"""
|
|
||||||
observations_str = "\n".join(self.observations)
|
|
||||||
|
|
||||||
if len(observations_str) > 10000:
|
|
||||||
observations_str = observations_str[:10000] + "\n...[truncated]"
|
|
||||||
logger.warning("ReActAgent: Observations truncated for final answer")
|
|
||||||
return FINAL_PROMPT_TEMPLATE.format(query=query, observations=observations_str)
|
|
||||||
|
|
||||||
def _extract_content(self, response: Any) -> str:
|
|
||||||
"""Extract text content from various LLM response formats"""
|
|
||||||
if not response:
|
|
||||||
return ""
|
|
||||||
collected = []
|
|
||||||
|
|
||||||
if isinstance(response, str):
|
|
||||||
return response
|
|
||||||
if hasattr(response, "message") and hasattr(response.message, "content"):
|
|
||||||
if response.message.content:
|
|
||||||
return response.message.content
|
|
||||||
if hasattr(response, "choices") and response.choices:
|
|
||||||
if hasattr(response.choices[0], "message"):
|
|
||||||
content = response.choices[0].message.content
|
|
||||||
if content:
|
|
||||||
return content
|
|
||||||
if hasattr(response, "content") and isinstance(response.content, list):
|
|
||||||
if response.content and hasattr(response.content[0], "text"):
|
|
||||||
return response.content[0].text
|
|
||||||
try:
|
|
||||||
for chunk in response:
|
|
||||||
content_piece = ""
|
|
||||||
|
|
||||||
if hasattr(chunk, "choices") and chunk.choices:
|
|
||||||
if hasattr(chunk.choices[0], "delta"):
|
|
||||||
delta_content = chunk.choices[0].delta.content
|
|
||||||
if delta_content:
|
|
||||||
content_piece = delta_content
|
|
||||||
elif hasattr(chunk, "type") and chunk.type == "content_block_delta":
|
|
||||||
if hasattr(chunk, "delta") and hasattr(chunk.delta, "text"):
|
|
||||||
content_piece = chunk.delta.text
|
|
||||||
elif isinstance(chunk, str):
|
|
||||||
content_piece = chunk
|
|
||||||
if content_piece:
|
|
||||||
collected.append(content_piece)
|
|
||||||
except (TypeError, AttributeError):
|
|
||||||
logger.debug(
|
|
||||||
f"Response not iterable or unexpected format: {type(response)}"
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error extracting content: {e}")
|
|
||||||
return "".join(collected)
|
|
||||||
698
application/agents/research_agent.py
Normal file
698
application/agents/research_agent.py
Normal file
@@ -0,0 +1,698 @@
|
|||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
from typing import Dict, Generator, List, Optional
|
||||||
|
|
||||||
|
from application.agents.base import BaseAgent
|
||||||
|
from application.agents.tool_executor import ToolExecutor
|
||||||
|
from application.agents.tools.internal_search import (
|
||||||
|
INTERNAL_TOOL_ID,
|
||||||
|
add_internal_search_tool,
|
||||||
|
)
|
||||||
|
from application.agents.tools.think import THINK_TOOL_ENTRY, THINK_TOOL_ID
|
||||||
|
from application.logging import LogContext
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Defaults (can be overridden via constructor)
|
||||||
|
DEFAULT_MAX_STEPS = 6
|
||||||
|
DEFAULT_MAX_SUB_ITERATIONS = 5
|
||||||
|
DEFAULT_TIMEOUT_SECONDS = 300 # 5 minutes
|
||||||
|
DEFAULT_TOKEN_BUDGET = 100_000
|
||||||
|
DEFAULT_PARALLEL_WORKERS = 3
|
||||||
|
|
||||||
|
# Adaptive depth caps per complexity level
|
||||||
|
COMPLEXITY_CAPS = {
|
||||||
|
"simple": 2,
|
||||||
|
"moderate": 4,
|
||||||
|
"complex": 6,
|
||||||
|
}
|
||||||
|
|
||||||
|
_PROMPTS_DIR = os.path.join(
|
||||||
|
os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
|
||||||
|
"prompts",
|
||||||
|
"research",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _load_prompt(name: str) -> str:
|
||||||
|
with open(os.path.join(_PROMPTS_DIR, name), "r") as f:
|
||||||
|
return f.read()
|
||||||
|
|
||||||
|
|
||||||
|
CLARIFICATION_PROMPT = _load_prompt("clarification.txt")
|
||||||
|
PLANNING_PROMPT = _load_prompt("planning.txt")
|
||||||
|
STEP_PROMPT = _load_prompt("step.txt")
|
||||||
|
SYNTHESIS_PROMPT = _load_prompt("synthesis.txt")
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# CitationManager
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class CitationManager:
|
||||||
|
"""Tracks and deduplicates citations across research steps."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.citations: Dict[int, Dict] = {}
|
||||||
|
self._counter = 0
|
||||||
|
|
||||||
|
def add(self, doc: Dict) -> int:
|
||||||
|
"""Register a source, return its citation number. Deduplicates by source."""
|
||||||
|
source = doc.get("source", "")
|
||||||
|
title = doc.get("title", "")
|
||||||
|
for num, existing in self.citations.items():
|
||||||
|
if existing.get("source") == source and existing.get("title") == title:
|
||||||
|
return num
|
||||||
|
self._counter += 1
|
||||||
|
self.citations[self._counter] = doc
|
||||||
|
return self._counter
|
||||||
|
|
||||||
|
def add_docs(self, docs: List[Dict]) -> str:
|
||||||
|
"""Register multiple docs, return formatted citation mapping text."""
|
||||||
|
mapping_lines = []
|
||||||
|
for doc in docs:
|
||||||
|
num = self.add(doc)
|
||||||
|
title = doc.get("title", "Untitled")
|
||||||
|
mapping_lines.append(f"[{num}] {title}")
|
||||||
|
return "\n".join(mapping_lines)
|
||||||
|
|
||||||
|
def format_references(self) -> str:
|
||||||
|
"""Generate [N] -> source mapping for report footer."""
|
||||||
|
if not self.citations:
|
||||||
|
return "No sources found."
|
||||||
|
lines = []
|
||||||
|
for num, doc in sorted(self.citations.items()):
|
||||||
|
title = doc.get("title", "Untitled")
|
||||||
|
source = doc.get("source", "Unknown")
|
||||||
|
filename = doc.get("filename", "")
|
||||||
|
display = filename or title
|
||||||
|
lines.append(f"[{num}] {display} — {source}")
|
||||||
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
def get_all_docs(self) -> List[Dict]:
|
||||||
|
return list(self.citations.values())
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# ResearchAgent
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class ResearchAgent(BaseAgent):
|
||||||
|
"""Multi-step research agent with parallel execution and budget controls.
|
||||||
|
|
||||||
|
Orchestrates: Plan -> Research (per step, optionally parallel) -> Synthesize.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
retriever_config: Optional[Dict] = None,
|
||||||
|
max_steps: int = DEFAULT_MAX_STEPS,
|
||||||
|
max_sub_iterations: int = DEFAULT_MAX_SUB_ITERATIONS,
|
||||||
|
timeout_seconds: int = DEFAULT_TIMEOUT_SECONDS,
|
||||||
|
token_budget: int = DEFAULT_TOKEN_BUDGET,
|
||||||
|
parallel_workers: int = DEFAULT_PARALLEL_WORKERS,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.retriever_config = retriever_config or {}
|
||||||
|
self.max_steps = max_steps
|
||||||
|
self.max_sub_iterations = max_sub_iterations
|
||||||
|
self.timeout_seconds = timeout_seconds
|
||||||
|
self.token_budget = token_budget
|
||||||
|
self.parallel_workers = parallel_workers
|
||||||
|
self.citations = CitationManager()
|
||||||
|
self._start_time: float = 0
|
||||||
|
self._tokens_used: int = 0
|
||||||
|
self._last_token_snapshot: int = 0
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Budget & timeout helpers
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _is_timed_out(self) -> bool:
|
||||||
|
return (time.monotonic() - self._start_time) >= self.timeout_seconds
|
||||||
|
|
||||||
|
def _elapsed(self) -> float:
|
||||||
|
return round(time.monotonic() - self._start_time, 1)
|
||||||
|
|
||||||
|
def _track_tokens(self, count: int):
|
||||||
|
self._tokens_used += count
|
||||||
|
|
||||||
|
def _budget_remaining(self) -> int:
|
||||||
|
return max(self.token_budget - self._tokens_used, 0)
|
||||||
|
|
||||||
|
def _is_over_budget(self) -> bool:
|
||||||
|
return self._tokens_used >= self.token_budget
|
||||||
|
|
||||||
|
def _snapshot_llm_tokens(self) -> int:
|
||||||
|
"""Read current token usage from LLM and return delta since last snapshot."""
|
||||||
|
current = self.llm.token_usage.get("prompt_tokens", 0) + self.llm.token_usage.get("generated_tokens", 0)
|
||||||
|
delta = current - self._last_token_snapshot
|
||||||
|
self._last_token_snapshot = current
|
||||||
|
return delta
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Main orchestration
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _gen_inner(
|
||||||
|
self, query: str, log_context: LogContext
|
||||||
|
) -> Generator[Dict, None, None]:
|
||||||
|
self._start_time = time.monotonic()
|
||||||
|
tools_dict = self._setup_tools()
|
||||||
|
|
||||||
|
# Phase 0: Clarification (skip if user is responding to a prior clarification)
|
||||||
|
if not self._is_follow_up():
|
||||||
|
clarification = self._clarification_phase(query)
|
||||||
|
if clarification:
|
||||||
|
yield {"metadata": {"is_clarification": True}}
|
||||||
|
yield {"answer": clarification}
|
||||||
|
yield {"sources": []}
|
||||||
|
yield {"tool_calls": []}
|
||||||
|
log_context.stacks.append(
|
||||||
|
{"component": "agent", "data": {"clarification": True}}
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Phase 1: Planning (with adaptive depth)
|
||||||
|
yield {"type": "research_progress", "data": {"status": "planning"}}
|
||||||
|
plan, complexity = self._planning_phase(query)
|
||||||
|
|
||||||
|
if not plan:
|
||||||
|
logger.warning("ResearchAgent: Planning produced no steps, falling back")
|
||||||
|
plan = [{"query": query, "rationale": "Direct investigation"}]
|
||||||
|
complexity = "simple"
|
||||||
|
|
||||||
|
yield {
|
||||||
|
"type": "research_plan",
|
||||||
|
"data": {"steps": plan, "complexity": complexity},
|
||||||
|
}
|
||||||
|
|
||||||
|
# Phase 2: Research each step (yields progress events in real-time)
|
||||||
|
intermediate_reports = []
|
||||||
|
for i, step in enumerate(plan):
|
||||||
|
step_num = i + 1
|
||||||
|
step_query = step.get("query", query)
|
||||||
|
|
||||||
|
if self._is_timed_out():
|
||||||
|
logger.warning(
|
||||||
|
f"ResearchAgent: Timeout at step {step_num}/{len(plan)} "
|
||||||
|
f"({self._elapsed()}s)"
|
||||||
|
)
|
||||||
|
break
|
||||||
|
if self._is_over_budget():
|
||||||
|
logger.warning(
|
||||||
|
f"ResearchAgent: Token budget exhausted at step {step_num}/{len(plan)}"
|
||||||
|
)
|
||||||
|
break
|
||||||
|
|
||||||
|
yield {
|
||||||
|
"type": "research_progress",
|
||||||
|
"data": {
|
||||||
|
"step": step_num,
|
||||||
|
"total": len(plan),
|
||||||
|
"query": step_query,
|
||||||
|
"status": "researching",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
report = self._research_step(step_query, tools_dict)
|
||||||
|
intermediate_reports.append({"step": step, "content": report})
|
||||||
|
|
||||||
|
yield {
|
||||||
|
"type": "research_progress",
|
||||||
|
"data": {
|
||||||
|
"step": step_num,
|
||||||
|
"total": len(plan),
|
||||||
|
"query": step_query,
|
||||||
|
"status": "complete",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
# Phase 3: Synthesis (streaming)
|
||||||
|
if self._is_timed_out():
|
||||||
|
logger.warning(
|
||||||
|
f"ResearchAgent: Timeout ({self._elapsed()}s) before synthesis, "
|
||||||
|
f"synthesizing with {len(intermediate_reports)} reports"
|
||||||
|
)
|
||||||
|
yield {
|
||||||
|
"type": "research_progress",
|
||||||
|
"data": {
|
||||||
|
"status": "synthesizing",
|
||||||
|
"elapsed_seconds": self._elapsed(),
|
||||||
|
"tokens_used": self._tokens_used,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
yield from self._synthesis_phase(
|
||||||
|
query, plan, intermediate_reports, tools_dict, log_context
|
||||||
|
)
|
||||||
|
|
||||||
|
# Sources and tool calls
|
||||||
|
self.retrieved_docs = self.citations.get_all_docs()
|
||||||
|
yield {"sources": self.retrieved_docs}
|
||||||
|
yield {"tool_calls": self._get_truncated_tool_calls()}
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"ResearchAgent completed: {len(intermediate_reports)}/{len(plan)} steps, "
|
||||||
|
f"{self._elapsed()}s, ~{self._tokens_used} tokens"
|
||||||
|
)
|
||||||
|
log_context.stacks.append(
|
||||||
|
{"component": "agent", "data": {"tool_calls": self.tool_calls.copy()}}
|
||||||
|
)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Tool setup
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _setup_tools(self) -> Dict:
|
||||||
|
"""Build tools_dict with user tools + internal search + think."""
|
||||||
|
tools_dict = self.tool_executor.get_tools()
|
||||||
|
|
||||||
|
add_internal_search_tool(tools_dict, self.retriever_config)
|
||||||
|
|
||||||
|
think_entry = dict(THINK_TOOL_ENTRY)
|
||||||
|
think_entry["config"] = {}
|
||||||
|
tools_dict[THINK_TOOL_ID] = think_entry
|
||||||
|
|
||||||
|
self._prepare_tools(tools_dict)
|
||||||
|
return tools_dict
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Phase 0: Clarification
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _is_follow_up(self) -> bool:
|
||||||
|
"""Check if the user is responding to a prior clarification.
|
||||||
|
|
||||||
|
Uses the metadata flag stored in the conversation DB — no string matching.
|
||||||
|
Only skip clarification when the last query was explicitly flagged
|
||||||
|
as a clarification by this agent.
|
||||||
|
"""
|
||||||
|
if not self.chat_history:
|
||||||
|
return False
|
||||||
|
last = self.chat_history[-1]
|
||||||
|
meta = last.get("metadata", {})
|
||||||
|
return bool(meta.get("is_clarification"))
|
||||||
|
|
||||||
|
def _clarification_phase(self, question: str) -> Optional[str]:
|
||||||
|
"""Ask the LLM whether the question needs clarification.
|
||||||
|
|
||||||
|
Returns formatted clarification text if needed, or None to proceed.
|
||||||
|
Uses response_format to force valid JSON output.
|
||||||
|
"""
|
||||||
|
messages = [
|
||||||
|
{"role": "system", "content": CLARIFICATION_PROMPT},
|
||||||
|
{"role": "user", "content": question},
|
||||||
|
]
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = self.llm.gen(
|
||||||
|
model=self.model_id,
|
||||||
|
messages=messages,
|
||||||
|
tools=None,
|
||||||
|
response_format={"type": "json_object"},
|
||||||
|
)
|
||||||
|
text = self._extract_text(response)
|
||||||
|
self._track_tokens(self._snapshot_llm_tokens())
|
||||||
|
logger.info(f"ResearchAgent clarification response: {text[:300]}")
|
||||||
|
|
||||||
|
data = self._parse_clarification_json(text)
|
||||||
|
if not data or not data.get("needs_clarification"):
|
||||||
|
return None
|
||||||
|
|
||||||
|
questions = data.get("questions", [])
|
||||||
|
if not questions:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Format as a friendly response
|
||||||
|
lines = [
|
||||||
|
"Before I begin researching, I'd like to clarify a few things:\n"
|
||||||
|
]
|
||||||
|
for i, q in enumerate(questions[:3], 1):
|
||||||
|
lines.append(f"{i}. {q}")
|
||||||
|
lines.append(
|
||||||
|
"\nPlease provide these details and I'll start the research."
|
||||||
|
)
|
||||||
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Clarification phase failed: {e}", exc_info=True)
|
||||||
|
return None # proceed with research on failure
|
||||||
|
|
||||||
|
def _parse_clarification_json(self, text: str) -> Optional[Dict]:
|
||||||
|
"""Parse clarification JSON from LLM response."""
|
||||||
|
try:
|
||||||
|
return json.loads(text)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Try extracting from code fences
|
||||||
|
for marker in ["```json", "```"]:
|
||||||
|
if marker in text:
|
||||||
|
start = text.index(marker) + len(marker)
|
||||||
|
end = text.index("```", start) if "```" in text[start:] else len(text)
|
||||||
|
try:
|
||||||
|
return json.loads(text[start:end].strip())
|
||||||
|
except (json.JSONDecodeError, ValueError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Try finding JSON object
|
||||||
|
for i, ch in enumerate(text):
|
||||||
|
if ch == "{":
|
||||||
|
for j in range(len(text) - 1, i, -1):
|
||||||
|
if text[j] == "}":
|
||||||
|
try:
|
||||||
|
return json.loads(text[i : j + 1])
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
continue
|
||||||
|
break
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Phase 1: Planning (with adaptive depth)
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _planning_phase(self, question: str) -> tuple[List[Dict], str]:
|
||||||
|
"""Decompose the question into research steps via LLM.
|
||||||
|
|
||||||
|
Returns (steps, complexity) where complexity is simple/moderate/complex.
|
||||||
|
"""
|
||||||
|
messages = [
|
||||||
|
{"role": "system", "content": PLANNING_PROMPT},
|
||||||
|
{"role": "user", "content": question},
|
||||||
|
]
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = self.llm.gen(
|
||||||
|
model=self.model_id,
|
||||||
|
messages=messages,
|
||||||
|
tools=None,
|
||||||
|
response_format={"type": "json_object"},
|
||||||
|
)
|
||||||
|
text = self._extract_text(response)
|
||||||
|
self._track_tokens(self._snapshot_llm_tokens())
|
||||||
|
logger.info(f"ResearchAgent planning LLM response: {text[:500]}")
|
||||||
|
|
||||||
|
plan_data = self._parse_plan_json(text)
|
||||||
|
if isinstance(plan_data, dict):
|
||||||
|
complexity = plan_data.get("complexity", "moderate")
|
||||||
|
steps = plan_data.get("steps", [])
|
||||||
|
else:
|
||||||
|
complexity = "moderate"
|
||||||
|
steps = plan_data
|
||||||
|
|
||||||
|
# Adaptive depth: cap steps based on assessed complexity
|
||||||
|
cap = COMPLEXITY_CAPS.get(complexity, self.max_steps)
|
||||||
|
cap = min(cap, self.max_steps)
|
||||||
|
steps = steps[:cap]
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"ResearchAgent plan: complexity={complexity}, "
|
||||||
|
f"steps={len(steps)} (cap={cap})"
|
||||||
|
)
|
||||||
|
return steps, complexity
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Planning phase failed: {e}", exc_info=True)
|
||||||
|
return (
|
||||||
|
[{"query": question, "rationale": "Direct investigation (planning failed)"}],
|
||||||
|
"simple",
|
||||||
|
)
|
||||||
|
|
||||||
|
def _parse_plan_json(self, text: str):
|
||||||
|
"""Extract JSON plan from LLM response. Returns dict or list."""
|
||||||
|
# Try direct parse
|
||||||
|
try:
|
||||||
|
data = json.loads(text)
|
||||||
|
if isinstance(data, dict) and "steps" in data:
|
||||||
|
return data
|
||||||
|
if isinstance(data, list):
|
||||||
|
return data
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Try extracting from markdown code fences
|
||||||
|
for marker in ["```json", "```"]:
|
||||||
|
if marker in text:
|
||||||
|
start = text.index(marker) + len(marker)
|
||||||
|
end = text.index("```", start) if "```" in text[start:] else len(text)
|
||||||
|
try:
|
||||||
|
data = json.loads(text[start:end].strip())
|
||||||
|
if isinstance(data, dict) and "steps" in data:
|
||||||
|
return data
|
||||||
|
if isinstance(data, list):
|
||||||
|
return data
|
||||||
|
except (json.JSONDecodeError, ValueError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Try finding JSON object in text
|
||||||
|
for i, ch in enumerate(text):
|
||||||
|
if ch == "{":
|
||||||
|
for j in range(len(text) - 1, i, -1):
|
||||||
|
if text[j] == "}":
|
||||||
|
try:
|
||||||
|
data = json.loads(text[i : j + 1])
|
||||||
|
if isinstance(data, dict) and "steps" in data:
|
||||||
|
return data
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
continue
|
||||||
|
break
|
||||||
|
|
||||||
|
logger.warning(f"Could not parse plan JSON from: {text[:200]}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Phase 2: Research step (core loop)
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _research_step(self, step_query: str, tools_dict: Dict) -> str:
|
||||||
|
"""Run a focused research loop for one sub-question (sequential path)."""
|
||||||
|
report = self._research_step_with_executor(
|
||||||
|
step_query, tools_dict, self.tool_executor
|
||||||
|
)
|
||||||
|
self._collect_step_sources()
|
||||||
|
return report
|
||||||
|
|
||||||
|
def _research_step_with_executor(
|
||||||
|
self, step_query: str, tools_dict: Dict, executor: ToolExecutor
|
||||||
|
) -> str:
|
||||||
|
"""Core research loop. Works with any ToolExecutor instance."""
|
||||||
|
system_prompt = STEP_PROMPT.replace("{step_query}", step_query)
|
||||||
|
messages = [
|
||||||
|
{"role": "system", "content": system_prompt},
|
||||||
|
{"role": "user", "content": step_query},
|
||||||
|
]
|
||||||
|
|
||||||
|
last_search_empty = False
|
||||||
|
|
||||||
|
for iteration in range(self.max_sub_iterations):
|
||||||
|
# Check timeout and budget
|
||||||
|
if self._is_timed_out():
|
||||||
|
logger.info(
|
||||||
|
f"Research step '{step_query[:50]}' timed out at iteration {iteration}"
|
||||||
|
)
|
||||||
|
break
|
||||||
|
if self._is_over_budget():
|
||||||
|
logger.info(
|
||||||
|
f"Research step '{step_query[:50]}' hit token budget at iteration {iteration}"
|
||||||
|
)
|
||||||
|
break
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = self.llm.gen(
|
||||||
|
model=self.model_id,
|
||||||
|
messages=messages,
|
||||||
|
tools=self.tools if self.tools else None,
|
||||||
|
)
|
||||||
|
self._track_tokens(self._snapshot_llm_tokens())
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Research step LLM call failed (iteration {iteration}): {e}",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
break
|
||||||
|
|
||||||
|
parsed = self.llm_handler.parse_response(response)
|
||||||
|
|
||||||
|
if not parsed.requires_tool_call:
|
||||||
|
return parsed.content or "No findings for this step."
|
||||||
|
|
||||||
|
# Execute tool calls
|
||||||
|
messages, last_search_empty = self._execute_step_tools_with_refinement(
|
||||||
|
parsed.tool_calls, tools_dict, messages, executor, last_search_empty
|
||||||
|
)
|
||||||
|
|
||||||
|
# Max iterations / timeout / budget — ask for summary
|
||||||
|
messages.append(
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Please summarize your findings so far based on the information gathered.",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
response = self.llm.gen(
|
||||||
|
model=self.model_id, messages=messages, tools=None
|
||||||
|
)
|
||||||
|
self._track_tokens(self._snapshot_llm_tokens())
|
||||||
|
text = self._extract_text(response)
|
||||||
|
return text or "Research step completed."
|
||||||
|
except Exception:
|
||||||
|
return "Research step completed."
|
||||||
|
|
||||||
|
def _execute_step_tools_with_refinement(
|
||||||
|
self,
|
||||||
|
tool_calls,
|
||||||
|
tools_dict: Dict,
|
||||||
|
messages: List[Dict],
|
||||||
|
executor: ToolExecutor,
|
||||||
|
last_search_empty: bool,
|
||||||
|
) -> tuple[List[Dict], bool]:
|
||||||
|
"""Execute tool calls with query refinement on empty results.
|
||||||
|
|
||||||
|
Returns (updated_messages, was_last_search_empty).
|
||||||
|
"""
|
||||||
|
search_returned_empty = False
|
||||||
|
|
||||||
|
for call in tool_calls:
|
||||||
|
gen = executor.execute(
|
||||||
|
tools_dict, call, self.llm.__class__.__name__
|
||||||
|
)
|
||||||
|
result = None
|
||||||
|
call_id = None
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
event = next(gen)
|
||||||
|
# Log tool_call status events instead of discarding them
|
||||||
|
if isinstance(event, dict) and event.get("type") == "tool_call":
|
||||||
|
logger.debug(
|
||||||
|
"Tool %s status: %s",
|
||||||
|
event.get("data", {}).get("action_name", ""),
|
||||||
|
event.get("data", {}).get("status", ""),
|
||||||
|
)
|
||||||
|
except StopIteration as e:
|
||||||
|
result, call_id = e.value
|
||||||
|
break
|
||||||
|
|
||||||
|
# Detect empty search results for refinement
|
||||||
|
is_search = "search" in (call.name or "").lower()
|
||||||
|
result_str = str(result) if result else ""
|
||||||
|
if is_search and "No documents found" in result_str:
|
||||||
|
search_returned_empty = True
|
||||||
|
if last_search_empty:
|
||||||
|
# Two consecutive empty searches — inject refinement hint
|
||||||
|
result_str += (
|
||||||
|
"\n\nHint: Previous search also returned no results. "
|
||||||
|
"Try a very different query with different keywords, "
|
||||||
|
"or broaden your search terms."
|
||||||
|
)
|
||||||
|
result = result_str
|
||||||
|
|
||||||
|
import json as _json
|
||||||
|
|
||||||
|
args_str = (
|
||||||
|
_json.dumps(call.arguments)
|
||||||
|
if isinstance(call.arguments, dict)
|
||||||
|
else call.arguments
|
||||||
|
)
|
||||||
|
messages.append({
|
||||||
|
"role": "assistant",
|
||||||
|
"content": None,
|
||||||
|
"tool_calls": [{
|
||||||
|
"id": call_id,
|
||||||
|
"type": "function",
|
||||||
|
"function": {"name": call.name, "arguments": args_str},
|
||||||
|
}],
|
||||||
|
})
|
||||||
|
tool_message = self.llm_handler.create_tool_message(call, result)
|
||||||
|
messages.append(tool_message)
|
||||||
|
|
||||||
|
return messages, search_returned_empty
|
||||||
|
|
||||||
|
def _collect_step_sources(self):
|
||||||
|
"""Collect sources from InternalSearchTool and register with CitationManager."""
|
||||||
|
cache_key = f"internal_search:{INTERNAL_TOOL_ID}:{self.user or ''}"
|
||||||
|
tool = self.tool_executor._loaded_tools.get(cache_key)
|
||||||
|
if tool and hasattr(tool, "retrieved_docs"):
|
||||||
|
for doc in tool.retrieved_docs:
|
||||||
|
self.citations.add(doc)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Phase 3: Synthesis
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _synthesis_phase(
|
||||||
|
self,
|
||||||
|
question: str,
|
||||||
|
plan: List[Dict],
|
||||||
|
intermediate_reports: List[Dict],
|
||||||
|
tools_dict: Dict,
|
||||||
|
log_context: LogContext,
|
||||||
|
) -> Generator[Dict, None, None]:
|
||||||
|
"""Compile all findings into a final cited report (streaming)."""
|
||||||
|
plan_lines = []
|
||||||
|
for i, step in enumerate(plan, 1):
|
||||||
|
plan_lines.append(
|
||||||
|
f"{i}. {step.get('query', 'Unknown')} — {step.get('rationale', '')}"
|
||||||
|
)
|
||||||
|
plan_summary = "\n".join(plan_lines)
|
||||||
|
|
||||||
|
findings_parts = []
|
||||||
|
for i, report in enumerate(intermediate_reports, 1):
|
||||||
|
step_query = report["step"].get("query", "Unknown")
|
||||||
|
content = report["content"]
|
||||||
|
findings_parts.append(
|
||||||
|
f"--- Step {i}: {step_query} ---\n{content}"
|
||||||
|
)
|
||||||
|
findings = "\n\n".join(findings_parts)
|
||||||
|
|
||||||
|
references = self.citations.format_references()
|
||||||
|
|
||||||
|
synthesis_prompt = SYNTHESIS_PROMPT.replace("{question}", question)
|
||||||
|
synthesis_prompt = synthesis_prompt.replace("{plan_summary}", plan_summary)
|
||||||
|
synthesis_prompt = synthesis_prompt.replace("{findings}", findings)
|
||||||
|
synthesis_prompt = synthesis_prompt.replace("{references}", references)
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{"role": "system", "content": synthesis_prompt},
|
||||||
|
{"role": "user", "content": f"Please write the research report for: {question}"},
|
||||||
|
]
|
||||||
|
|
||||||
|
llm_response = self.llm.gen_stream(
|
||||||
|
model=self.model_id, messages=messages, tools=None
|
||||||
|
)
|
||||||
|
|
||||||
|
if log_context:
|
||||||
|
from application.logging import build_stack_data
|
||||||
|
|
||||||
|
log_context.stacks.append(
|
||||||
|
{"component": "synthesis_llm", "data": build_stack_data(self.llm)}
|
||||||
|
)
|
||||||
|
|
||||||
|
yield from self._handle_response(
|
||||||
|
llm_response, tools_dict, messages, log_context
|
||||||
|
)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Helpers
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _extract_text(self, response) -> str:
|
||||||
|
"""Extract text content from a non-streaming LLM response."""
|
||||||
|
if isinstance(response, str):
|
||||||
|
return response
|
||||||
|
if hasattr(response, "message") and hasattr(response.message, "content"):
|
||||||
|
return response.message.content or ""
|
||||||
|
if hasattr(response, "choices") and response.choices:
|
||||||
|
choice = response.choices[0]
|
||||||
|
if hasattr(choice, "message") and hasattr(choice.message, "content"):
|
||||||
|
return choice.message.content or ""
|
||||||
|
if hasattr(response, "content") and isinstance(response.content, list):
|
||||||
|
if response.content and hasattr(response.content[0], "text"):
|
||||||
|
return response.content[0].text or ""
|
||||||
|
return str(response) if response else ""
|
||||||
494
application/agents/tool_executor.py
Normal file
494
application/agents/tool_executor.py
Normal file
@@ -0,0 +1,494 @@
|
|||||||
|
import logging
|
||||||
|
import uuid
|
||||||
|
from collections import Counter
|
||||||
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
from application.agents.tools.tool_action_parser import ToolActionParser
|
||||||
|
from application.agents.tools.tool_manager import ToolManager
|
||||||
|
from application.security.encryption import decrypt_credentials
|
||||||
|
from application.storage.db.repositories.agents import AgentsRepository
|
||||||
|
from application.storage.db.repositories.user_tools import UserToolsRepository
|
||||||
|
from application.storage.db.session import db_readonly
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ToolExecutor:
|
||||||
|
"""Handles tool discovery, preparation, and execution.
|
||||||
|
|
||||||
|
Extracted from BaseAgent to separate concerns and enable tool caching.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
user_api_key: Optional[str] = None,
|
||||||
|
user: Optional[str] = None,
|
||||||
|
decoded_token: Optional[Dict] = None,
|
||||||
|
):
|
||||||
|
self.user_api_key = user_api_key
|
||||||
|
self.user = user
|
||||||
|
self.decoded_token = decoded_token
|
||||||
|
self.tool_calls: List[Dict] = []
|
||||||
|
self._loaded_tools: Dict[str, object] = {}
|
||||||
|
self.conversation_id: Optional[str] = None
|
||||||
|
self.client_tools: Optional[List[Dict]] = None
|
||||||
|
self._name_to_tool: Dict[str, Tuple[str, str]] = {}
|
||||||
|
self._tool_to_name: Dict[Tuple[str, str], str] = {}
|
||||||
|
|
||||||
|
def get_tools(self) -> Dict[str, Dict]:
|
||||||
|
"""Load tool configs from DB based on user context.
|
||||||
|
|
||||||
|
If *client_tools* have been set on this executor, they are
|
||||||
|
automatically merged into the returned dict.
|
||||||
|
"""
|
||||||
|
if self.user_api_key:
|
||||||
|
tools = self._get_tools_by_api_key(self.user_api_key)
|
||||||
|
else:
|
||||||
|
tools = self._get_user_tools(self.user or "local")
|
||||||
|
if self.client_tools:
|
||||||
|
self.merge_client_tools(tools, self.client_tools)
|
||||||
|
return tools
|
||||||
|
|
||||||
|
def _get_tools_by_api_key(self, api_key: str) -> Dict[str, Dict]:
|
||||||
|
# Per-operation session: the answer pipeline spans a long-lived
|
||||||
|
# generator; wrapping it in a single connection would pin a PG
|
||||||
|
# conn for the whole stream. Open, fetch, close.
|
||||||
|
with db_readonly() as conn:
|
||||||
|
agent_data = AgentsRepository(conn).find_by_key(api_key)
|
||||||
|
tool_ids = agent_data.get("tools", []) if agent_data else []
|
||||||
|
if not tool_ids:
|
||||||
|
return {}
|
||||||
|
tools_repo = UserToolsRepository(conn)
|
||||||
|
tools: List[Dict] = []
|
||||||
|
owner = (agent_data.get("user_id") or agent_data.get("user")) if agent_data else None
|
||||||
|
for tid in tool_ids:
|
||||||
|
row = None
|
||||||
|
if owner:
|
||||||
|
row = tools_repo.get_any(str(tid), owner)
|
||||||
|
if row is not None:
|
||||||
|
tools.append(row)
|
||||||
|
return {str(tool["id"]): tool for tool in tools} if tools else {}
|
||||||
|
|
||||||
|
def _get_user_tools(self, user: str = "local") -> Dict[str, Dict]:
|
||||||
|
with db_readonly() as conn:
|
||||||
|
user_tools = UserToolsRepository(conn).list_active_for_user(user)
|
||||||
|
return {str(i): tool for i, tool in enumerate(user_tools)}
|
||||||
|
|
||||||
|
def merge_client_tools(
|
||||||
|
self, tools_dict: Dict, client_tools: List[Dict]
|
||||||
|
) -> Dict:
|
||||||
|
"""Merge client-provided tool definitions into tools_dict.
|
||||||
|
|
||||||
|
Client tools use the standard function-calling format::
|
||||||
|
|
||||||
|
[{"type": "function", "function": {"name": "get_weather",
|
||||||
|
"description": "...", "parameters": {...}}}]
|
||||||
|
|
||||||
|
They are stored in *tools_dict* with ``client_side: True`` so that
|
||||||
|
:meth:`check_pause` returns a pause signal instead of trying to
|
||||||
|
execute them server-side.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tools_dict: The mutable server tools dict (will be modified in place).
|
||||||
|
client_tools: List of tool definitions in function-calling format.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The updated *tools_dict* (same reference, for convenience).
|
||||||
|
"""
|
||||||
|
for i, ct in enumerate(client_tools):
|
||||||
|
func = ct.get("function", ct) # tolerate bare {"name":..} too
|
||||||
|
name = func.get("name", f"clienttool{i}")
|
||||||
|
tool_id = f"ct{i}"
|
||||||
|
|
||||||
|
tools_dict[tool_id] = {
|
||||||
|
"name": name,
|
||||||
|
"client_side": True,
|
||||||
|
"actions": [
|
||||||
|
{
|
||||||
|
"name": name,
|
||||||
|
"description": func.get("description", ""),
|
||||||
|
"active": True,
|
||||||
|
"parameters": func.get("parameters", {}),
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
return tools_dict
|
||||||
|
|
||||||
|
def prepare_tools_for_llm(self, tools_dict: Dict) -> List[Dict]:
|
||||||
|
"""Convert tool configs to LLM function schemas.
|
||||||
|
|
||||||
|
Action names are kept clean for the LLM:
|
||||||
|
- Unique action names appear as-is (e.g. ``get_weather``).
|
||||||
|
- Duplicate action names get numbered suffixes (e.g. ``search_1``,
|
||||||
|
``search_2``).
|
||||||
|
|
||||||
|
A reverse mapping is stored in ``_name_to_tool`` so that tool calls
|
||||||
|
can be routed back to the correct ``(tool_id, action_name)`` without
|
||||||
|
brittle string splitting.
|
||||||
|
"""
|
||||||
|
# Pass 1: collect entries and count action name occurrences
|
||||||
|
entries: List[Tuple[str, str, Dict, bool]] = [] # (tool_id, action_name, action, is_client)
|
||||||
|
name_counts: Counter = Counter()
|
||||||
|
|
||||||
|
for tool_id, tool in tools_dict.items():
|
||||||
|
is_api = tool["name"] == "api_tool"
|
||||||
|
is_client = tool.get("client_side", False)
|
||||||
|
|
||||||
|
if is_api and "actions" not in tool.get("config", {}):
|
||||||
|
continue
|
||||||
|
if not is_api and "actions" not in tool:
|
||||||
|
continue
|
||||||
|
|
||||||
|
actions = (
|
||||||
|
tool["config"]["actions"].values()
|
||||||
|
if is_api
|
||||||
|
else tool["actions"]
|
||||||
|
)
|
||||||
|
|
||||||
|
for action in actions:
|
||||||
|
if not action.get("active", True):
|
||||||
|
continue
|
||||||
|
entries.append((tool_id, action["name"], action, is_client))
|
||||||
|
name_counts[action["name"]] += 1
|
||||||
|
|
||||||
|
# Pass 2: assign LLM-visible names and build mappings
|
||||||
|
self._name_to_tool = {}
|
||||||
|
self._tool_to_name = {}
|
||||||
|
collision_counters: Dict[str, int] = {}
|
||||||
|
all_llm_names: set = set()
|
||||||
|
|
||||||
|
result = []
|
||||||
|
for tool_id, action_name, action, is_client in entries:
|
||||||
|
if name_counts[action_name] == 1:
|
||||||
|
llm_name = action_name
|
||||||
|
else:
|
||||||
|
counter = collision_counters.get(action_name, 1)
|
||||||
|
candidate = f"{action_name}_{counter}"
|
||||||
|
# Skip if candidate collides with a unique action name
|
||||||
|
while candidate in all_llm_names or (
|
||||||
|
candidate in name_counts and name_counts[candidate] == 1
|
||||||
|
):
|
||||||
|
counter += 1
|
||||||
|
candidate = f"{action_name}_{counter}"
|
||||||
|
collision_counters[action_name] = counter + 1
|
||||||
|
llm_name = candidate
|
||||||
|
|
||||||
|
all_llm_names.add(llm_name)
|
||||||
|
self._name_to_tool[llm_name] = (tool_id, action_name)
|
||||||
|
self._tool_to_name[(tool_id, action_name)] = llm_name
|
||||||
|
|
||||||
|
if is_client:
|
||||||
|
params = action.get("parameters", {})
|
||||||
|
else:
|
||||||
|
params = self._build_tool_parameters(action)
|
||||||
|
|
||||||
|
result.append({
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": llm_name,
|
||||||
|
"description": action.get("description", ""),
|
||||||
|
"parameters": params,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
return result
|
||||||
|
|
||||||
|
def _build_tool_parameters(self, action: Dict) -> Dict:
|
||||||
|
params = {"type": "object", "properties": {}, "required": []}
|
||||||
|
for param_type in ["query_params", "headers", "body", "parameters"]:
|
||||||
|
if param_type in action and action[param_type].get("properties"):
|
||||||
|
for k, v in action[param_type]["properties"].items():
|
||||||
|
if v.get("filled_by_llm", True):
|
||||||
|
params["properties"][k] = {
|
||||||
|
key: value
|
||||||
|
for key, value in v.items()
|
||||||
|
if key not in ("filled_by_llm", "value", "required")
|
||||||
|
}
|
||||||
|
if v.get("required", False):
|
||||||
|
params["required"].append(k)
|
||||||
|
return params
|
||||||
|
|
||||||
|
def check_pause(
|
||||||
|
self, tools_dict: Dict, call, llm_class_name: str
|
||||||
|
) -> Optional[Dict]:
|
||||||
|
"""Check if a tool call requires pausing for approval or client execution.
|
||||||
|
|
||||||
|
Returns a dict describing the pending action if pause is needed, None otherwise.
|
||||||
|
"""
|
||||||
|
parser = ToolActionParser(llm_class_name, name_mapping=self._name_to_tool)
|
||||||
|
tool_id, action_name, call_args = parser.parse_args(call)
|
||||||
|
call_id = getattr(call, "id", None) or str(uuid.uuid4())
|
||||||
|
llm_name = getattr(call, "name", "")
|
||||||
|
|
||||||
|
if tool_id is None or action_name is None or tool_id not in tools_dict:
|
||||||
|
return None # Will be handled as error by execute()
|
||||||
|
|
||||||
|
tool_data = tools_dict[tool_id]
|
||||||
|
|
||||||
|
# Client-side tools
|
||||||
|
if tool_data.get("client_side"):
|
||||||
|
return {
|
||||||
|
"call_id": call_id,
|
||||||
|
"name": llm_name,
|
||||||
|
"tool_name": tool_data.get("name", "unknown"),
|
||||||
|
"tool_id": tool_id,
|
||||||
|
"action_name": action_name,
|
||||||
|
"llm_name": llm_name,
|
||||||
|
"arguments": call_args if isinstance(call_args, dict) else {},
|
||||||
|
"pause_type": "requires_client_execution",
|
||||||
|
"thought_signature": getattr(call, "thought_signature", None),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Approval required
|
||||||
|
if tool_data["name"] == "api_tool":
|
||||||
|
action_data = tool_data.get("config", {}).get("actions", {}).get(
|
||||||
|
action_name, {}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
action_data = next(
|
||||||
|
(a for a in tool_data.get("actions", []) if a["name"] == action_name),
|
||||||
|
{},
|
||||||
|
)
|
||||||
|
|
||||||
|
if action_data.get("require_approval"):
|
||||||
|
return {
|
||||||
|
"call_id": call_id,
|
||||||
|
"name": llm_name,
|
||||||
|
"tool_name": tool_data.get("name", "unknown"),
|
||||||
|
"tool_id": tool_id,
|
||||||
|
"action_name": action_name,
|
||||||
|
"llm_name": llm_name,
|
||||||
|
"arguments": call_args if isinstance(call_args, dict) else {},
|
||||||
|
"pause_type": "awaiting_approval",
|
||||||
|
"thought_signature": getattr(call, "thought_signature", None),
|
||||||
|
}
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def execute(self, tools_dict: Dict, call, llm_class_name: str):
|
||||||
|
"""Execute a tool call. Yields status events, returns (result, call_id)."""
|
||||||
|
parser = ToolActionParser(llm_class_name, name_mapping=self._name_to_tool)
|
||||||
|
tool_id, action_name, call_args = parser.parse_args(call)
|
||||||
|
llm_name = getattr(call, "name", "unknown")
|
||||||
|
|
||||||
|
call_id = getattr(call, "id", None) or str(uuid.uuid4())
|
||||||
|
|
||||||
|
if tool_id is None or action_name is None:
|
||||||
|
error_message = f"Error: Failed to parse LLM tool call. Tool name: {llm_name}"
|
||||||
|
logger.error(error_message)
|
||||||
|
|
||||||
|
tool_call_data = {
|
||||||
|
"tool_name": "unknown",
|
||||||
|
"call_id": call_id,
|
||||||
|
"action_name": llm_name,
|
||||||
|
"arguments": call_args or {},
|
||||||
|
"result": f"Failed to parse tool call. Invalid tool name format: {llm_name}",
|
||||||
|
}
|
||||||
|
yield {"type": "tool_call", "data": {**tool_call_data, "status": "error"}}
|
||||||
|
self.tool_calls.append(tool_call_data)
|
||||||
|
return "Failed to parse tool call.", call_id
|
||||||
|
|
||||||
|
if tool_id not in tools_dict:
|
||||||
|
error_message = f"Error: Tool ID '{tool_id}' extracted from LLM call not found in available tools_dict. Available IDs: {list(tools_dict.keys())}"
|
||||||
|
logger.error(error_message)
|
||||||
|
|
||||||
|
tool_call_data = {
|
||||||
|
"tool_name": "unknown",
|
||||||
|
"call_id": call_id,
|
||||||
|
"action_name": llm_name,
|
||||||
|
"arguments": call_args,
|
||||||
|
"result": f"Tool with ID {tool_id} not found. Available tools: {list(tools_dict.keys())}",
|
||||||
|
}
|
||||||
|
yield {"type": "tool_call", "data": {**tool_call_data, "status": "error"}}
|
||||||
|
self.tool_calls.append(tool_call_data)
|
||||||
|
return f"Tool with ID {tool_id} not found.", call_id
|
||||||
|
|
||||||
|
tool_call_data = {
|
||||||
|
"tool_name": tools_dict[tool_id]["name"],
|
||||||
|
"call_id": call_id,
|
||||||
|
"action_name": llm_name,
|
||||||
|
"arguments": call_args,
|
||||||
|
}
|
||||||
|
yield {"type": "tool_call", "data": {**tool_call_data, "status": "pending"}}
|
||||||
|
|
||||||
|
tool_data = tools_dict[tool_id]
|
||||||
|
action_data = (
|
||||||
|
tool_data["config"]["actions"][action_name]
|
||||||
|
if tool_data["name"] == "api_tool"
|
||||||
|
else next(
|
||||||
|
action
|
||||||
|
for action in tool_data["actions"]
|
||||||
|
if action["name"] == action_name
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
query_params, headers, body, parameters = {}, {}, {}, {}
|
||||||
|
param_types = {
|
||||||
|
"query_params": query_params,
|
||||||
|
"headers": headers,
|
||||||
|
"body": body,
|
||||||
|
"parameters": parameters,
|
||||||
|
}
|
||||||
|
|
||||||
|
for param_type, target_dict in param_types.items():
|
||||||
|
if param_type in action_data and action_data[param_type].get("properties"):
|
||||||
|
for param, details in action_data[param_type]["properties"].items():
|
||||||
|
if (
|
||||||
|
param not in call_args
|
||||||
|
and "value" in details
|
||||||
|
and details["value"]
|
||||||
|
):
|
||||||
|
target_dict[param] = details["value"]
|
||||||
|
for param, value in call_args.items():
|
||||||
|
for param_type, target_dict in param_types.items():
|
||||||
|
if param_type in action_data and param in action_data[param_type].get(
|
||||||
|
"properties", {}
|
||||||
|
):
|
||||||
|
target_dict[param] = value
|
||||||
|
|
||||||
|
# Load tool (with caching)
|
||||||
|
tool = self._get_or_load_tool(
|
||||||
|
tool_data, tool_id, action_name,
|
||||||
|
headers=headers, query_params=query_params,
|
||||||
|
)
|
||||||
|
|
||||||
|
if tool is None:
|
||||||
|
error_message = (
|
||||||
|
f"Failed to load tool '{tool_data.get('name')}' (tool_id key={tool_id}): "
|
||||||
|
"missing 'id' on tool row."
|
||||||
|
)
|
||||||
|
logger.error(error_message)
|
||||||
|
tool_call_data["result"] = error_message
|
||||||
|
yield {"type": "tool_call", "data": {**tool_call_data, "status": "error"}}
|
||||||
|
self.tool_calls.append(tool_call_data)
|
||||||
|
return error_message, call_id
|
||||||
|
|
||||||
|
resolved_arguments = (
|
||||||
|
{"query_params": query_params, "headers": headers, "body": body}
|
||||||
|
if tool_data["name"] == "api_tool"
|
||||||
|
else parameters
|
||||||
|
)
|
||||||
|
if tool_data["name"] == "api_tool":
|
||||||
|
logger.debug(
|
||||||
|
f"Executing api: {action_name} with query_params: {query_params}, headers: {headers}, body: {body}"
|
||||||
|
)
|
||||||
|
result = tool.execute_action(action_name, **body)
|
||||||
|
else:
|
||||||
|
logger.debug(f"Executing tool: {action_name} with args: {call_args}")
|
||||||
|
result = tool.execute_action(action_name, **parameters)
|
||||||
|
|
||||||
|
get_artifact_id = (
|
||||||
|
getattr(tool, "get_artifact_id", None)
|
||||||
|
if tool_data["name"] != "api_tool"
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
artifact_id = None
|
||||||
|
if callable(get_artifact_id):
|
||||||
|
try:
|
||||||
|
artifact_id = get_artifact_id(action_name, **parameters)
|
||||||
|
except Exception:
|
||||||
|
logger.exception(
|
||||||
|
"Failed to extract artifact_id from tool %s for action %s",
|
||||||
|
tool_data["name"],
|
||||||
|
action_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
artifact_id = str(artifact_id).strip() if artifact_id is not None else ""
|
||||||
|
if artifact_id:
|
||||||
|
tool_call_data["artifact_id"] = artifact_id
|
||||||
|
result_full = str(result)
|
||||||
|
tool_call_data["resolved_arguments"] = resolved_arguments
|
||||||
|
tool_call_data["result_full"] = result_full
|
||||||
|
tool_call_data["result"] = (
|
||||||
|
f"{result_full[:50]}..." if len(result_full) > 50 else result_full
|
||||||
|
)
|
||||||
|
|
||||||
|
stream_tool_call_data = {
|
||||||
|
key: value
|
||||||
|
for key, value in tool_call_data.items()
|
||||||
|
if key not in {"result_full", "resolved_arguments"}
|
||||||
|
}
|
||||||
|
yield {"type": "tool_call", "data": {**stream_tool_call_data, "status": "completed"}}
|
||||||
|
self.tool_calls.append(tool_call_data)
|
||||||
|
|
||||||
|
return result, call_id
|
||||||
|
|
||||||
|
def _get_or_load_tool(
|
||||||
|
self, tool_data: Dict, tool_id: str, action_name: str,
|
||||||
|
headers: Optional[Dict] = None, query_params: Optional[Dict] = None,
|
||||||
|
):
|
||||||
|
"""Load a tool, using cache when possible."""
|
||||||
|
cache_key = f"{tool_data['name']}:{tool_id}:{self.user or ''}"
|
||||||
|
if cache_key in self._loaded_tools:
|
||||||
|
return self._loaded_tools[cache_key]
|
||||||
|
|
||||||
|
tm = ToolManager(config={})
|
||||||
|
|
||||||
|
if tool_data["name"] == "api_tool":
|
||||||
|
action_config = tool_data["config"]["actions"][action_name]
|
||||||
|
tool_config = {
|
||||||
|
"url": action_config["url"],
|
||||||
|
"method": action_config["method"],
|
||||||
|
"headers": headers or {},
|
||||||
|
"query_params": query_params or {},
|
||||||
|
}
|
||||||
|
if "body_content_type" in action_config:
|
||||||
|
tool_config["body_content_type"] = action_config.get(
|
||||||
|
"body_content_type", "application/json"
|
||||||
|
)
|
||||||
|
tool_config["body_encoding_rules"] = action_config.get(
|
||||||
|
"body_encoding_rules", {}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
tool_config = tool_data["config"].copy() if tool_data["config"] else {}
|
||||||
|
if tool_config.get("encrypted_credentials") and self.user:
|
||||||
|
decrypted = decrypt_credentials(
|
||||||
|
tool_config["encrypted_credentials"], self.user
|
||||||
|
)
|
||||||
|
tool_config.update(decrypted)
|
||||||
|
tool_config["auth_credentials"] = decrypted
|
||||||
|
tool_config.pop("encrypted_credentials", None)
|
||||||
|
row_id = tool_data.get("id")
|
||||||
|
if not row_id:
|
||||||
|
logger.error(
|
||||||
|
"Tool data missing 'id' for tool name=%s (enumerate-key tool_id=%s); "
|
||||||
|
"skipping load to avoid binding a non-UUID downstream.",
|
||||||
|
tool_data.get("name"),
|
||||||
|
tool_id,
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
tool_config["tool_id"] = str(row_id)
|
||||||
|
if self.conversation_id:
|
||||||
|
tool_config["conversation_id"] = self.conversation_id
|
||||||
|
if tool_data["name"] == "mcp_tool":
|
||||||
|
tool_config["query_mode"] = True
|
||||||
|
|
||||||
|
tool = tm.load_tool(
|
||||||
|
tool_data["name"],
|
||||||
|
tool_config=tool_config,
|
||||||
|
user_id=self.user,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Don't cache api_tool since config varies by action
|
||||||
|
if tool_data["name"] != "api_tool":
|
||||||
|
self._loaded_tools[cache_key] = tool
|
||||||
|
|
||||||
|
return tool
|
||||||
|
|
||||||
|
def get_truncated_tool_calls(self) -> List[Dict]:
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"tool_name": tool_call.get("tool_name"),
|
||||||
|
"call_id": tool_call.get("call_id"),
|
||||||
|
"action_name": tool_call.get("action_name"),
|
||||||
|
"arguments": tool_call.get("arguments"),
|
||||||
|
"artifact_id": tool_call.get("artifact_id"),
|
||||||
|
"result": (
|
||||||
|
f"{str(tool_call['result'])[:50]}..."
|
||||||
|
if len(str(tool_call["result"])) > 50
|
||||||
|
else tool_call["result"]
|
||||||
|
),
|
||||||
|
"status": "completed",
|
||||||
|
}
|
||||||
|
for tool_call in self.tool_calls
|
||||||
|
]
|
||||||
323
application/agents/tools/api_body_serializer.py
Normal file
323
application/agents/tools/api_body_serializer.py
Normal file
@@ -0,0 +1,323 @@
|
|||||||
|
import base64
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any, Dict, Optional, Union
|
||||||
|
from urllib.parse import quote, urlencode
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ContentType(str, Enum):
|
||||||
|
"""Supported content types for request bodies."""
|
||||||
|
|
||||||
|
JSON = "application/json"
|
||||||
|
FORM_URLENCODED = "application/x-www-form-urlencoded"
|
||||||
|
MULTIPART_FORM_DATA = "multipart/form-data"
|
||||||
|
TEXT_PLAIN = "text/plain"
|
||||||
|
XML = "application/xml"
|
||||||
|
OCTET_STREAM = "application/octet-stream"
|
||||||
|
|
||||||
|
|
||||||
|
class RequestBodySerializer:
|
||||||
|
"""Serializes request bodies according to content-type and OpenAPI 3.1 spec."""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def serialize(
|
||||||
|
body_data: Dict[str, Any],
|
||||||
|
content_type: str = ContentType.JSON,
|
||||||
|
encoding_rules: Optional[Dict[str, Dict[str, Any]]] = None,
|
||||||
|
) -> tuple[Union[str, bytes], Dict[str, str]]:
|
||||||
|
"""
|
||||||
|
Serialize body data to appropriate format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
body_data: Dictionary of body parameters
|
||||||
|
content_type: Content-Type header value
|
||||||
|
encoding_rules: OpenAPI Encoding Object rules per field
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (serialized_body, updated_headers_dict)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If serialization fails
|
||||||
|
"""
|
||||||
|
if not body_data:
|
||||||
|
return None, {}
|
||||||
|
|
||||||
|
try:
|
||||||
|
content_type_lower = content_type.lower().split(";")[0].strip()
|
||||||
|
|
||||||
|
if content_type_lower == ContentType.JSON:
|
||||||
|
return RequestBodySerializer._serialize_json(body_data)
|
||||||
|
|
||||||
|
elif content_type_lower == ContentType.FORM_URLENCODED:
|
||||||
|
return RequestBodySerializer._serialize_form_urlencoded(
|
||||||
|
body_data, encoding_rules
|
||||||
|
)
|
||||||
|
|
||||||
|
elif content_type_lower == ContentType.MULTIPART_FORM_DATA:
|
||||||
|
return RequestBodySerializer._serialize_multipart_form_data(
|
||||||
|
body_data, encoding_rules
|
||||||
|
)
|
||||||
|
|
||||||
|
elif content_type_lower == ContentType.TEXT_PLAIN:
|
||||||
|
return RequestBodySerializer._serialize_text_plain(body_data)
|
||||||
|
|
||||||
|
elif content_type_lower == ContentType.XML:
|
||||||
|
return RequestBodySerializer._serialize_xml(body_data)
|
||||||
|
|
||||||
|
elif content_type_lower == ContentType.OCTET_STREAM:
|
||||||
|
return RequestBodySerializer._serialize_octet_stream(body_data)
|
||||||
|
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
f"Unknown content type: {content_type}, treating as JSON"
|
||||||
|
)
|
||||||
|
return RequestBodySerializer._serialize_json(body_data)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error serializing body: {str(e)}", exc_info=True)
|
||||||
|
raise ValueError(f"Failed to serialize request body: {str(e)}")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _serialize_json(body_data: Dict[str, Any]) -> tuple[str, Dict[str, str]]:
|
||||||
|
"""Serialize body as JSON per OpenAPI spec."""
|
||||||
|
try:
|
||||||
|
serialized = json.dumps(
|
||||||
|
body_data, separators=(",", ":"), ensure_ascii=False
|
||||||
|
)
|
||||||
|
headers = {"Content-Type": ContentType.JSON.value}
|
||||||
|
return serialized, headers
|
||||||
|
except (TypeError, ValueError) as e:
|
||||||
|
raise ValueError(f"Failed to serialize JSON body: {str(e)}")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _serialize_form_urlencoded(
|
||||||
|
body_data: Dict[str, Any],
|
||||||
|
encoding_rules: Optional[Dict[str, Dict[str, Any]]] = None,
|
||||||
|
) -> tuple[str, Dict[str, str]]:
|
||||||
|
"""Serialize body as application/x-www-form-urlencoded per RFC1866/RFC3986."""
|
||||||
|
encoding_rules = encoding_rules or {}
|
||||||
|
params = []
|
||||||
|
|
||||||
|
for key, value in body_data.items():
|
||||||
|
if value is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
rule = encoding_rules.get(key, {})
|
||||||
|
style = rule.get("style", "form")
|
||||||
|
explode = rule.get("explode", style == "form")
|
||||||
|
content_type = rule.get("contentType", "text/plain")
|
||||||
|
|
||||||
|
serialized_value = RequestBodySerializer._serialize_form_value(
|
||||||
|
value, style, explode, content_type, key
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(serialized_value, list):
|
||||||
|
for sv in serialized_value:
|
||||||
|
params.append((key, sv))
|
||||||
|
else:
|
||||||
|
params.append((key, serialized_value))
|
||||||
|
|
||||||
|
# Use standard urlencode (replaces space with +)
|
||||||
|
serialized = urlencode(params, safe="")
|
||||||
|
headers = {"Content-Type": ContentType.FORM_URLENCODED.value}
|
||||||
|
return serialized, headers
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _serialize_form_value(
|
||||||
|
value: Any, style: str, explode: bool, content_type: str, key: str
|
||||||
|
) -> Union[str, list]:
|
||||||
|
"""Serialize individual form value with encoding rules."""
|
||||||
|
if isinstance(value, dict):
|
||||||
|
if content_type == "application/json":
|
||||||
|
return json.dumps(value, separators=(",", ":"))
|
||||||
|
elif content_type == "application/xml":
|
||||||
|
return RequestBodySerializer._dict_to_xml(value)
|
||||||
|
else:
|
||||||
|
if style == "deepObject" and explode:
|
||||||
|
return [
|
||||||
|
f"{RequestBodySerializer._percent_encode(str(v))}"
|
||||||
|
for v in value.values()
|
||||||
|
]
|
||||||
|
elif explode:
|
||||||
|
return [
|
||||||
|
f"{RequestBodySerializer._percent_encode(str(v))}"
|
||||||
|
for v in value.values()
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
pairs = [f"{k},{v}" for k, v in value.items()]
|
||||||
|
return RequestBodySerializer._percent_encode(",".join(pairs))
|
||||||
|
|
||||||
|
elif isinstance(value, (list, tuple)):
|
||||||
|
if explode:
|
||||||
|
return [
|
||||||
|
RequestBodySerializer._percent_encode(str(item)) for item in value
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
return RequestBodySerializer._percent_encode(
|
||||||
|
",".join(str(v) for v in value)
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
return RequestBodySerializer._percent_encode(str(value))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _serialize_multipart_form_data(
|
||||||
|
body_data: Dict[str, Any],
|
||||||
|
encoding_rules: Optional[Dict[str, Dict[str, Any]]] = None,
|
||||||
|
) -> tuple[bytes, Dict[str, str]]:
|
||||||
|
"""
|
||||||
|
Serialize body as multipart/form-data per RFC7578.
|
||||||
|
|
||||||
|
Supports file uploads and encoding rules.
|
||||||
|
"""
|
||||||
|
import secrets
|
||||||
|
|
||||||
|
encoding_rules = encoding_rules or {}
|
||||||
|
boundary = f"----DocsGPT{secrets.token_hex(16)}"
|
||||||
|
parts = []
|
||||||
|
|
||||||
|
for key, value in body_data.items():
|
||||||
|
if value is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
rule = encoding_rules.get(key, {})
|
||||||
|
content_type = rule.get("contentType", "text/plain")
|
||||||
|
headers_rule = rule.get("headers", {})
|
||||||
|
|
||||||
|
part = RequestBodySerializer._create_multipart_part(
|
||||||
|
key, value, content_type, headers_rule
|
||||||
|
)
|
||||||
|
parts.append(part)
|
||||||
|
|
||||||
|
body_bytes = f"--{boundary}\r\n".encode("utf-8")
|
||||||
|
body_bytes += f"--{boundary}\r\n".join(parts).encode("utf-8")
|
||||||
|
body_bytes += f"\r\n--{boundary}--\r\n".encode("utf-8")
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
"Content-Type": f"multipart/form-data; boundary={boundary}",
|
||||||
|
}
|
||||||
|
return body_bytes, headers
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _create_multipart_part(
|
||||||
|
name: str, value: Any, content_type: str, headers_rule: Dict[str, Any]
|
||||||
|
) -> str:
|
||||||
|
"""Create a single multipart/form-data part."""
|
||||||
|
headers = [
|
||||||
|
f'Content-Disposition: form-data; name="{RequestBodySerializer._percent_encode(name)}"'
|
||||||
|
]
|
||||||
|
|
||||||
|
if isinstance(value, bytes):
|
||||||
|
if content_type == "application/octet-stream":
|
||||||
|
value_encoded = base64.b64encode(value).decode("utf-8")
|
||||||
|
else:
|
||||||
|
value_encoded = value.decode("utf-8", errors="replace")
|
||||||
|
headers.append(f"Content-Type: {content_type}")
|
||||||
|
headers.append("Content-Transfer-Encoding: base64")
|
||||||
|
elif isinstance(value, dict):
|
||||||
|
if content_type == "application/json":
|
||||||
|
value_encoded = json.dumps(value, separators=(",", ":"))
|
||||||
|
elif content_type == "application/xml":
|
||||||
|
value_encoded = RequestBodySerializer._dict_to_xml(value)
|
||||||
|
else:
|
||||||
|
value_encoded = str(value)
|
||||||
|
headers.append(f"Content-Type: {content_type}")
|
||||||
|
elif isinstance(value, str) and content_type != "text/plain":
|
||||||
|
try:
|
||||||
|
if content_type == "application/json":
|
||||||
|
json.loads(value)
|
||||||
|
value_encoded = value
|
||||||
|
elif content_type == "application/xml":
|
||||||
|
value_encoded = value
|
||||||
|
else:
|
||||||
|
value_encoded = str(value)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
value_encoded = str(value)
|
||||||
|
headers.append(f"Content-Type: {content_type}")
|
||||||
|
else:
|
||||||
|
value_encoded = str(value)
|
||||||
|
if content_type != "text/plain":
|
||||||
|
headers.append(f"Content-Type: {content_type}")
|
||||||
|
|
||||||
|
part = "\r\n".join(headers) + "\r\n\r\n" + value_encoded + "\r\n"
|
||||||
|
return part
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _serialize_text_plain(body_data: Dict[str, Any]) -> tuple[str, Dict[str, str]]:
|
||||||
|
"""Serialize body as plain text."""
|
||||||
|
if len(body_data) == 1:
|
||||||
|
value = list(body_data.values())[0]
|
||||||
|
return str(value), {"Content-Type": ContentType.TEXT_PLAIN.value}
|
||||||
|
else:
|
||||||
|
text = "\n".join(f"{k}: {v}" for k, v in body_data.items())
|
||||||
|
return text, {"Content-Type": ContentType.TEXT_PLAIN.value}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _serialize_xml(body_data: Dict[str, Any]) -> tuple[str, Dict[str, str]]:
|
||||||
|
"""Serialize body as XML."""
|
||||||
|
xml_str = RequestBodySerializer._dict_to_xml(body_data)
|
||||||
|
return xml_str, {"Content-Type": ContentType.XML.value}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _serialize_octet_stream(
|
||||||
|
body_data: Dict[str, Any],
|
||||||
|
) -> tuple[bytes, Dict[str, str]]:
|
||||||
|
"""Serialize body as binary octet stream."""
|
||||||
|
if isinstance(body_data, bytes):
|
||||||
|
return body_data, {"Content-Type": ContentType.OCTET_STREAM.value}
|
||||||
|
elif isinstance(body_data, str):
|
||||||
|
return body_data.encode("utf-8"), {
|
||||||
|
"Content-Type": ContentType.OCTET_STREAM.value
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
serialized = json.dumps(body_data)
|
||||||
|
return serialized.encode("utf-8"), {
|
||||||
|
"Content-Type": ContentType.OCTET_STREAM.value
|
||||||
|
}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _percent_encode(value: str, safe_chars: str = "") -> str:
|
||||||
|
"""
|
||||||
|
Percent-encode per RFC3986.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
value: String to encode
|
||||||
|
safe_chars: Additional characters to not encode
|
||||||
|
"""
|
||||||
|
return quote(value, safe=safe_chars)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _dict_to_xml(data: Dict[str, Any], root_name: str = "root") -> str:
|
||||||
|
"""
|
||||||
|
Convert dict to simple XML format.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def build_xml(obj: Any, name: str) -> str:
|
||||||
|
if isinstance(obj, dict):
|
||||||
|
inner = "".join(build_xml(v, k) for k, v in obj.items())
|
||||||
|
return f"<{name}>{inner}</{name}>"
|
||||||
|
elif isinstance(obj, (list, tuple)):
|
||||||
|
items = "".join(
|
||||||
|
build_xml(item, f"{name[:-1] if name.endswith('s') else name}")
|
||||||
|
for item in obj
|
||||||
|
)
|
||||||
|
return items
|
||||||
|
else:
|
||||||
|
return f"<{name}>{RequestBodySerializer._escape_xml(str(obj))}</{name}>"
|
||||||
|
|
||||||
|
root = build_xml(data, root_name)
|
||||||
|
return f'<?xml version="1.0" encoding="UTF-8"?>{root}'
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _escape_xml(value: str) -> str:
|
||||||
|
"""Escape XML special characters."""
|
||||||
|
return (
|
||||||
|
value.replace("&", "&")
|
||||||
|
.replace("<", "<")
|
||||||
|
.replace(">", ">")
|
||||||
|
.replace('"', """)
|
||||||
|
.replace("'", "'")
|
||||||
|
)
|
||||||
@@ -1,72 +1,280 @@
|
|||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
from urllib.parse import urlencode
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
|
from application.agents.tools.api_body_serializer import (
|
||||||
|
ContentType,
|
||||||
|
RequestBodySerializer,
|
||||||
|
)
|
||||||
from application.agents.tools.base import Tool
|
from application.agents.tools.base import Tool
|
||||||
|
from application.core.url_validation import validate_url, SSRFError
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
DEFAULT_TIMEOUT = 90 # seconds
|
||||||
|
|
||||||
|
|
||||||
class APITool(Tool):
|
class APITool(Tool):
|
||||||
"""
|
"""
|
||||||
API Tool
|
API Tool
|
||||||
A flexible tool for performing various API actions (e.g., sending messages, retrieving data) via custom user-specified APIs
|
A flexible tool for performing various API actions (e.g., sending messages, retrieving data) via custom user-specified APIs.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
self.config = config
|
self.config = config
|
||||||
self.url = config.get("url", "")
|
self.url = config.get("url", "")
|
||||||
self.method = config.get("method", "GET")
|
self.method = config.get("method", "GET")
|
||||||
self.headers = config.get("headers", {"Content-Type": "application/json"})
|
self.headers = config.get("headers", {})
|
||||||
self.query_params = config.get("query_params", {})
|
self.query_params = config.get("query_params", {})
|
||||||
|
self.body_content_type = config.get("body_content_type", ContentType.JSON)
|
||||||
|
self.body_encoding_rules = config.get("body_encoding_rules", {})
|
||||||
|
|
||||||
def execute_action(self, action_name, **kwargs):
|
def execute_action(self, action_name, **kwargs):
|
||||||
|
"""Execute an API action with the given arguments."""
|
||||||
return self._make_api_call(
|
return self._make_api_call(
|
||||||
self.url, self.method, self.headers, self.query_params, kwargs
|
self.url,
|
||||||
|
self.method,
|
||||||
|
self.headers,
|
||||||
|
self.query_params,
|
||||||
|
kwargs,
|
||||||
|
self.body_content_type,
|
||||||
|
self.body_encoding_rules,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _make_api_call(self, url, method, headers, query_params, body):
|
def _make_api_call(
|
||||||
if query_params:
|
self,
|
||||||
url = f"{url}?{requests.compat.urlencode(query_params)}"
|
url: str,
|
||||||
# if isinstance(body, dict):
|
method: str,
|
||||||
# body = json.dumps(body)
|
headers: Dict[str, str],
|
||||||
|
query_params: Dict[str, Any],
|
||||||
|
body: Dict[str, Any],
|
||||||
|
content_type: str = ContentType.JSON,
|
||||||
|
encoding_rules: Optional[Dict[str, Dict[str, Any]]] = None,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Make an API call with proper body serialization and error handling.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
url: API endpoint URL
|
||||||
|
method: HTTP method (GET, POST, PUT, DELETE, PATCH, HEAD, OPTIONS)
|
||||||
|
headers: Request headers dict
|
||||||
|
query_params: URL query parameters
|
||||||
|
body: Request body as dict
|
||||||
|
content_type: Content-Type for serialization
|
||||||
|
encoding_rules: OpenAPI encoding rules
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict with status_code, data, and message
|
||||||
|
"""
|
||||||
|
request_url = url
|
||||||
|
request_headers = headers.copy() if headers else {}
|
||||||
|
response = None
|
||||||
|
|
||||||
|
# Validate URL to prevent SSRF attacks
|
||||||
try:
|
try:
|
||||||
print(f"Making API call: {method} {url} with body: {body}")
|
validate_url(request_url)
|
||||||
if body == "{}":
|
except SSRFError as e:
|
||||||
body = None
|
logger.error(f"URL validation failed: {e}")
|
||||||
response = requests.request(method, url, headers=headers, data=body)
|
return {
|
||||||
response.raise_for_status()
|
"status_code": None,
|
||||||
content_type = response.headers.get(
|
"message": f"URL validation error: {e}",
|
||||||
"Content-Type", "application/json"
|
"data": None,
|
||||||
).lower()
|
}
|
||||||
if "application/json" in content_type:
|
|
||||||
|
try:
|
||||||
|
path_params_used = set()
|
||||||
|
if query_params:
|
||||||
|
for match in re.finditer(r"\{([^}]+)\}", request_url):
|
||||||
|
param_name = match.group(1)
|
||||||
|
if param_name in query_params:
|
||||||
|
request_url = request_url.replace(
|
||||||
|
f"{{{param_name}}}", str(query_params[param_name])
|
||||||
|
)
|
||||||
|
path_params_used.add(param_name)
|
||||||
|
remaining_params = {
|
||||||
|
k: v for k, v in query_params.items() if k not in path_params_used
|
||||||
|
}
|
||||||
|
if remaining_params:
|
||||||
|
query_string = urlencode(remaining_params)
|
||||||
|
separator = "&" if "?" in request_url else "?"
|
||||||
|
request_url = f"{request_url}{separator}{query_string}"
|
||||||
|
|
||||||
|
# Re-validate URL after parameter substitution to prevent SSRF via path params
|
||||||
|
try:
|
||||||
|
validate_url(request_url)
|
||||||
|
except SSRFError as e:
|
||||||
|
logger.error(f"URL validation failed after parameter substitution: {e}")
|
||||||
|
return {
|
||||||
|
"status_code": None,
|
||||||
|
"message": f"URL validation error: {e}",
|
||||||
|
"data": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Serialize body based on content type
|
||||||
|
|
||||||
|
if body and body != {}:
|
||||||
try:
|
try:
|
||||||
data = response.json()
|
serialized_body, body_headers = RequestBodySerializer.serialize(
|
||||||
except json.JSONDecodeError as e:
|
body, content_type, encoding_rules
|
||||||
print(f"Error decoding JSON: {e}. Raw response: {response.text}")
|
)
|
||||||
|
request_headers.update(body_headers)
|
||||||
|
except ValueError as e:
|
||||||
|
logger.error(f"Body serialization failed: {str(e)}")
|
||||||
return {
|
return {
|
||||||
"status_code": response.status_code,
|
"status_code": None,
|
||||||
"message": f"API call returned invalid JSON. Error: {e}",
|
"message": f"Body serialization error: {str(e)}",
|
||||||
"data": response.text,
|
"data": None,
|
||||||
}
|
}
|
||||||
elif "text/" in content_type or "application/xml" in content_type:
|
|
||||||
data = response.text
|
|
||||||
elif not response.content:
|
|
||||||
data = None
|
|
||||||
else:
|
else:
|
||||||
print(f"Unsupported content type: {content_type}")
|
serialized_body = None
|
||||||
data = response.content
|
if "Content-Type" not in request_headers and method not in [
|
||||||
|
"GET",
|
||||||
|
"HEAD",
|
||||||
|
"DELETE",
|
||||||
|
]:
|
||||||
|
request_headers["Content-Type"] = ContentType.JSON
|
||||||
|
logger.debug(
|
||||||
|
f"API Call: {method} {request_url} | Content-Type: {request_headers.get('Content-Type', 'N/A')}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if method.upper() == "GET":
|
||||||
|
response = requests.get(
|
||||||
|
request_url, headers=request_headers, timeout=DEFAULT_TIMEOUT
|
||||||
|
)
|
||||||
|
elif method.upper() == "POST":
|
||||||
|
response = requests.post(
|
||||||
|
request_url,
|
||||||
|
data=serialized_body,
|
||||||
|
headers=request_headers,
|
||||||
|
timeout=DEFAULT_TIMEOUT,
|
||||||
|
)
|
||||||
|
elif method.upper() == "PUT":
|
||||||
|
response = requests.put(
|
||||||
|
request_url,
|
||||||
|
data=serialized_body,
|
||||||
|
headers=request_headers,
|
||||||
|
timeout=DEFAULT_TIMEOUT,
|
||||||
|
)
|
||||||
|
elif method.upper() == "DELETE":
|
||||||
|
response = requests.delete(
|
||||||
|
request_url, headers=request_headers, timeout=DEFAULT_TIMEOUT
|
||||||
|
)
|
||||||
|
elif method.upper() == "PATCH":
|
||||||
|
response = requests.patch(
|
||||||
|
request_url,
|
||||||
|
data=serialized_body,
|
||||||
|
headers=request_headers,
|
||||||
|
timeout=DEFAULT_TIMEOUT,
|
||||||
|
)
|
||||||
|
elif method.upper() == "HEAD":
|
||||||
|
response = requests.head(
|
||||||
|
request_url, headers=request_headers, timeout=DEFAULT_TIMEOUT
|
||||||
|
)
|
||||||
|
elif method.upper() == "OPTIONS":
|
||||||
|
response = requests.options(
|
||||||
|
request_url, headers=request_headers, timeout=DEFAULT_TIMEOUT
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return {
|
||||||
|
"status_code": None,
|
||||||
|
"message": f"Unsupported HTTP method: {method}",
|
||||||
|
"data": None,
|
||||||
|
}
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
data = self._parse_response(response)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"status_code": response.status_code,
|
"status_code": response.status_code,
|
||||||
"data": data,
|
"data": data,
|
||||||
"message": "API call successful.",
|
"message": "API call successful.",
|
||||||
}
|
}
|
||||||
|
except requests.exceptions.Timeout:
|
||||||
|
logger.error(f"Request timeout for {request_url}")
|
||||||
|
return {
|
||||||
|
"status_code": None,
|
||||||
|
"message": f"Request timeout ({DEFAULT_TIMEOUT}s exceeded)",
|
||||||
|
"data": None,
|
||||||
|
}
|
||||||
|
except requests.exceptions.ConnectionError as e:
|
||||||
|
logger.error(f"Connection error: {str(e)}")
|
||||||
|
return {
|
||||||
|
"status_code": None,
|
||||||
|
"message": f"Connection error: {str(e)}",
|
||||||
|
"data": None,
|
||||||
|
}
|
||||||
|
except requests.exceptions.HTTPError as e:
|
||||||
|
logger.error(f"HTTP error {response.status_code}: {str(e)}")
|
||||||
|
try:
|
||||||
|
error_data = response.json()
|
||||||
|
except (json.JSONDecodeError, ValueError):
|
||||||
|
error_data = response.text
|
||||||
|
return {
|
||||||
|
"status_code": response.status_code,
|
||||||
|
"message": f"HTTP Error {response.status_code}",
|
||||||
|
"data": error_data,
|
||||||
|
}
|
||||||
except requests.exceptions.RequestException as e:
|
except requests.exceptions.RequestException as e:
|
||||||
|
logger.error(f"Request failed: {str(e)}")
|
||||||
return {
|
return {
|
||||||
"status_code": response.status_code if response else None,
|
"status_code": response.status_code if response else None,
|
||||||
"message": f"API call failed: {str(e)}",
|
"message": f"API call failed: {str(e)}",
|
||||||
|
"data": None,
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Unexpected error in API call: {str(e)}", exc_info=True)
|
||||||
|
return {
|
||||||
|
"status_code": None,
|
||||||
|
"message": f"Unexpected error: {str(e)}",
|
||||||
|
"data": None,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def _parse_response(self, response: requests.Response) -> Any:
|
||||||
|
"""
|
||||||
|
Parse response based on Content-Type header.
|
||||||
|
|
||||||
|
Supports: JSON, XML, plain text, binary data.
|
||||||
|
"""
|
||||||
|
content_type = response.headers.get("Content-Type", "").lower()
|
||||||
|
|
||||||
|
if not response.content:
|
||||||
|
return None
|
||||||
|
# JSON response
|
||||||
|
|
||||||
|
if "application/json" in content_type:
|
||||||
|
try:
|
||||||
|
return response.json()
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
logger.warning(f"Failed to parse JSON response: {str(e)}")
|
||||||
|
return response.text
|
||||||
|
# XML response
|
||||||
|
|
||||||
|
elif "application/xml" in content_type or "text/xml" in content_type:
|
||||||
|
return response.text
|
||||||
|
# Plain text response
|
||||||
|
|
||||||
|
elif "text/plain" in content_type or "text/html" in content_type:
|
||||||
|
return response.text
|
||||||
|
# Binary/unknown response
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Try to decode as text first, fall back to base64
|
||||||
|
|
||||||
|
try:
|
||||||
|
return response.text
|
||||||
|
except (UnicodeDecodeError, AttributeError):
|
||||||
|
import base64
|
||||||
|
|
||||||
|
return base64.b64encode(response.content).decode("utf-8")
|
||||||
|
|
||||||
def get_actions_metadata(self):
|
def get_actions_metadata(self):
|
||||||
|
"""Return metadata for available actions (none for API Tool - actions are user-defined)."""
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def get_config_requirements(self):
|
def get_config_requirements(self):
|
||||||
|
"""Return configuration requirements for the tool."""
|
||||||
return {}
|
return {}
|
||||||
|
|||||||
@@ -2,6 +2,8 @@ from abc import ABC, abstractmethod
|
|||||||
|
|
||||||
|
|
||||||
class Tool(ABC):
|
class Tool(ABC):
|
||||||
|
internal: bool = False
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def execute_action(self, action_name: str, **kwargs):
|
def execute_action(self, action_name: str, **kwargs):
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -1,6 +1,11 @@
|
|||||||
|
import logging
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
from application.agents.tools.base import Tool
|
from application.agents.tools.base import Tool
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class BraveSearchTool(Tool):
|
class BraveSearchTool(Tool):
|
||||||
"""
|
"""
|
||||||
@@ -41,7 +46,7 @@ class BraveSearchTool(Tool):
|
|||||||
"""
|
"""
|
||||||
Performs a web search using the Brave Search API.
|
Performs a web search using the Brave Search API.
|
||||||
"""
|
"""
|
||||||
print(f"Performing Brave web search for: {query}")
|
logger.debug("Performing Brave web search for: %s", query)
|
||||||
|
|
||||||
url = f"{self.base_url}/web/search"
|
url = f"{self.base_url}/web/search"
|
||||||
|
|
||||||
@@ -68,7 +73,7 @@ class BraveSearchTool(Tool):
|
|||||||
"X-Subscription-Token": self.token,
|
"X-Subscription-Token": self.token,
|
||||||
}
|
}
|
||||||
|
|
||||||
response = requests.get(url, params=params, headers=headers)
|
response = requests.get(url, params=params, headers=headers, timeout=100)
|
||||||
|
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
return {
|
return {
|
||||||
@@ -94,7 +99,7 @@ class BraveSearchTool(Tool):
|
|||||||
"""
|
"""
|
||||||
Performs an image search using the Brave Search API.
|
Performs an image search using the Brave Search API.
|
||||||
"""
|
"""
|
||||||
print(f"Performing Brave image search for: {query}")
|
logger.debug("Performing Brave image search for: %s", query)
|
||||||
|
|
||||||
url = f"{self.base_url}/images/search"
|
url = f"{self.base_url}/images/search"
|
||||||
|
|
||||||
@@ -113,7 +118,7 @@ class BraveSearchTool(Tool):
|
|||||||
"X-Subscription-Token": self.token,
|
"X-Subscription-Token": self.token,
|
||||||
}
|
}
|
||||||
|
|
||||||
response = requests.get(url, params=params, headers=headers)
|
response = requests.get(url, params=params, headers=headers, timeout=100)
|
||||||
|
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
return {
|
return {
|
||||||
@@ -177,6 +182,10 @@ class BraveSearchTool(Tool):
|
|||||||
return {
|
return {
|
||||||
"token": {
|
"token": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
|
"label": "API Key",
|
||||||
"description": "Brave Search API key for authentication",
|
"description": "Brave Search API key for authentication",
|
||||||
|
"required": True,
|
||||||
|
"secret": True,
|
||||||
|
"order": 1,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ class CryptoPriceTool(Tool):
|
|||||||
returns price in USD.
|
returns price in USD.
|
||||||
"""
|
"""
|
||||||
url = f"https://min-api.cryptocompare.com/data/price?fsym={symbol.upper()}&tsyms={currency.upper()}"
|
url = f"https://min-api.cryptocompare.com/data/price?fsym={symbol.upper()}&tsyms={currency.upper()}"
|
||||||
response = requests.get(url)
|
response = requests.get(url, timeout=100)
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
data = response.json()
|
data = response.json()
|
||||||
if currency.upper() in data:
|
if currency.upper() in data:
|
||||||
|
|||||||
@@ -1,5 +1,14 @@
|
|||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
from application.agents.tools.base import Tool
|
from application.agents.tools.base import Tool
|
||||||
from duckduckgo_search import DDGS
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
MAX_RETRIES = 3
|
||||||
|
RETRY_DELAY = 2.0
|
||||||
|
DEFAULT_TIMEOUT = 15
|
||||||
|
|
||||||
|
|
||||||
class DuckDuckGoSearchTool(Tool):
|
class DuckDuckGoSearchTool(Tool):
|
||||||
@@ -10,71 +19,123 @@ class DuckDuckGoSearchTool(Tool):
|
|||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
self.config = config
|
self.config = config
|
||||||
|
self.timeout = config.get("timeout", DEFAULT_TIMEOUT)
|
||||||
|
|
||||||
|
def _get_ddgs_client(self):
|
||||||
|
from ddgs import DDGS
|
||||||
|
|
||||||
|
return DDGS(timeout=self.timeout)
|
||||||
|
|
||||||
|
def _execute_with_retry(self, operation, operation_name: str) -> Dict[str, Any]:
|
||||||
|
last_error = None
|
||||||
|
for attempt in range(1, MAX_RETRIES + 1):
|
||||||
|
try:
|
||||||
|
results = operation()
|
||||||
|
return {
|
||||||
|
"status_code": 200,
|
||||||
|
"results": list(results) if results else [],
|
||||||
|
"message": f"{operation_name} completed successfully.",
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
last_error = e
|
||||||
|
error_str = str(e).lower()
|
||||||
|
if "ratelimit" in error_str or "429" in error_str:
|
||||||
|
if attempt < MAX_RETRIES:
|
||||||
|
delay = RETRY_DELAY * attempt
|
||||||
|
logger.warning(
|
||||||
|
f"{operation_name} rate limited, retrying in {delay}s (attempt {attempt}/{MAX_RETRIES})"
|
||||||
|
)
|
||||||
|
time.sleep(delay)
|
||||||
|
continue
|
||||||
|
logger.error(f"{operation_name} failed: {e}")
|
||||||
|
break
|
||||||
|
return {
|
||||||
|
"status_code": 500,
|
||||||
|
"results": [],
|
||||||
|
"message": f"{operation_name} failed: {str(last_error)}",
|
||||||
|
}
|
||||||
|
|
||||||
def execute_action(self, action_name, **kwargs):
|
def execute_action(self, action_name, **kwargs):
|
||||||
actions = {
|
actions = {
|
||||||
"ddg_web_search": self._web_search,
|
"ddg_web_search": self._web_search,
|
||||||
"ddg_image_search": self._image_search,
|
"ddg_image_search": self._image_search,
|
||||||
|
"ddg_news_search": self._news_search,
|
||||||
}
|
}
|
||||||
|
if action_name not in actions:
|
||||||
if action_name in actions:
|
|
||||||
return actions[action_name](**kwargs)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unknown action: {action_name}")
|
raise ValueError(f"Unknown action: {action_name}")
|
||||||
|
return actions[action_name](**kwargs)
|
||||||
|
|
||||||
def _web_search(
|
def _web_search(
|
||||||
self,
|
self,
|
||||||
query,
|
query: str,
|
||||||
max_results=5,
|
max_results: int = 5,
|
||||||
):
|
region: str = "wt-wt",
|
||||||
print(f"Performing DuckDuckGo web search for: {query}")
|
safesearch: str = "moderate",
|
||||||
|
timelimit: Optional[str] = None,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
logger.info(f"DuckDuckGo web search: {query}")
|
||||||
|
|
||||||
try:
|
def operation():
|
||||||
results = DDGS().text(
|
client = self._get_ddgs_client()
|
||||||
|
return client.text(
|
||||||
query,
|
query,
|
||||||
max_results=max_results,
|
region=region,
|
||||||
|
safesearch=safesearch,
|
||||||
|
timelimit=timelimit,
|
||||||
|
max_results=min(max_results, 20),
|
||||||
)
|
)
|
||||||
|
|
||||||
return {
|
return self._execute_with_retry(operation, "Web search")
|
||||||
"status_code": 200,
|
|
||||||
"results": results,
|
|
||||||
"message": "Web search completed successfully.",
|
|
||||||
}
|
|
||||||
except Exception as e:
|
|
||||||
return {
|
|
||||||
"status_code": 500,
|
|
||||||
"message": f"Web search failed: {str(e)}",
|
|
||||||
}
|
|
||||||
|
|
||||||
def _image_search(
|
def _image_search(
|
||||||
self,
|
self,
|
||||||
query,
|
query: str,
|
||||||
max_results=5,
|
max_results: int = 5,
|
||||||
):
|
region: str = "wt-wt",
|
||||||
print(f"Performing DuckDuckGo image search for: {query}")
|
safesearch: str = "moderate",
|
||||||
|
timelimit: Optional[str] = None,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
logger.info(f"DuckDuckGo image search: {query}")
|
||||||
|
|
||||||
try:
|
def operation():
|
||||||
results = DDGS().images(
|
client = self._get_ddgs_client()
|
||||||
keywords=query,
|
return client.images(
|
||||||
max_results=max_results,
|
query,
|
||||||
|
region=region,
|
||||||
|
safesearch=safesearch,
|
||||||
|
timelimit=timelimit,
|
||||||
|
max_results=min(max_results, 50),
|
||||||
)
|
)
|
||||||
|
|
||||||
return {
|
return self._execute_with_retry(operation, "Image search")
|
||||||
"status_code": 200,
|
|
||||||
"results": results,
|
def _news_search(
|
||||||
"message": "Image search completed successfully.",
|
self,
|
||||||
}
|
query: str,
|
||||||
except Exception as e:
|
max_results: int = 5,
|
||||||
return {
|
region: str = "wt-wt",
|
||||||
"status_code": 500,
|
safesearch: str = "moderate",
|
||||||
"message": f"Image search failed: {str(e)}",
|
timelimit: Optional[str] = None,
|
||||||
}
|
) -> Dict[str, Any]:
|
||||||
|
logger.info(f"DuckDuckGo news search: {query}")
|
||||||
|
|
||||||
|
def operation():
|
||||||
|
client = self._get_ddgs_client()
|
||||||
|
return client.news(
|
||||||
|
query,
|
||||||
|
region=region,
|
||||||
|
safesearch=safesearch,
|
||||||
|
timelimit=timelimit,
|
||||||
|
max_results=min(max_results, 20),
|
||||||
|
)
|
||||||
|
|
||||||
|
return self._execute_with_retry(operation, "News search")
|
||||||
|
|
||||||
def get_actions_metadata(self):
|
def get_actions_metadata(self):
|
||||||
return [
|
return [
|
||||||
{
|
{
|
||||||
"name": "ddg_web_search",
|
"name": "ddg_web_search",
|
||||||
"description": "Perform a web search using DuckDuckGo.",
|
"description": "Search the web using DuckDuckGo. Returns titles, URLs, and snippets.",
|
||||||
"parameters": {
|
"parameters": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
@@ -84,7 +145,15 @@ class DuckDuckGoSearchTool(Tool):
|
|||||||
},
|
},
|
||||||
"max_results": {
|
"max_results": {
|
||||||
"type": "integer",
|
"type": "integer",
|
||||||
"description": "Number of results to return (default: 5)",
|
"description": "Number of results (default: 5, max: 20)",
|
||||||
|
},
|
||||||
|
"region": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Region code (default: wt-wt for worldwide, us-en for US)",
|
||||||
|
},
|
||||||
|
"timelimit": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Time filter: d (day), w (week), m (month), y (year)",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"required": ["query"],
|
"required": ["query"],
|
||||||
@@ -92,17 +161,43 @@ class DuckDuckGoSearchTool(Tool):
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "ddg_image_search",
|
"name": "ddg_image_search",
|
||||||
"description": "Perform an image search using DuckDuckGo.",
|
"description": "Search for images using DuckDuckGo. Returns image URLs and metadata.",
|
||||||
"parameters": {
|
"parameters": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
"query": {
|
"query": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "Search query",
|
"description": "Image search query",
|
||||||
},
|
},
|
||||||
"max_results": {
|
"max_results": {
|
||||||
"type": "integer",
|
"type": "integer",
|
||||||
"description": "Number of results to return (default: 5, max: 50)",
|
"description": "Number of results (default: 5, max: 50)",
|
||||||
|
},
|
||||||
|
"region": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Region code (default: wt-wt for worldwide)",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["query"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "ddg_news_search",
|
||||||
|
"description": "Search for news articles using DuckDuckGo. Returns recent news.",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"query": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "News search query",
|
||||||
|
},
|
||||||
|
"max_results": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "Number of results (default: 5, max: 20)",
|
||||||
|
},
|
||||||
|
"timelimit": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Time filter: d (day), w (week), m (month)",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"required": ["query"],
|
"required": ["query"],
|
||||||
|
|||||||
458
application/agents/tools/internal_search.py
Normal file
458
application/agents/tools/internal_search.py
Normal file
@@ -0,0 +1,458 @@
|
|||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
|
from application.agents.tools.base import Tool
|
||||||
|
from application.core.settings import settings
|
||||||
|
from application.retriever.retriever_creator import RetrieverCreator
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class InternalSearchTool(Tool):
|
||||||
|
"""Wraps the ClassicRAG retriever as an LLM-callable tool.
|
||||||
|
|
||||||
|
Instead of pre-fetching docs into the prompt, the LLM decides
|
||||||
|
when and what to search. Supports multiple searches per session.
|
||||||
|
|
||||||
|
Optional capabilities (enabled when sources have directory_structure):
|
||||||
|
- path_filter on search: restrict results to a specific file/folder
|
||||||
|
- list_files action: browse the file/folder structure
|
||||||
|
"""
|
||||||
|
|
||||||
|
internal = True
|
||||||
|
|
||||||
|
def __init__(self, config: Dict):
|
||||||
|
self.config = config
|
||||||
|
self.retrieved_docs: List[Dict] = []
|
||||||
|
self._retriever = None
|
||||||
|
self._directory_structure: Optional[Dict] = None
|
||||||
|
self._dir_structure_loaded = False
|
||||||
|
|
||||||
|
def _get_retriever(self):
|
||||||
|
if self._retriever is None:
|
||||||
|
self._retriever = RetrieverCreator.create_retriever(
|
||||||
|
self.config.get("retriever_name", "classic"),
|
||||||
|
source=self.config.get("source", {}),
|
||||||
|
chat_history=[],
|
||||||
|
prompt="",
|
||||||
|
chunks=int(self.config.get("chunks", 2)),
|
||||||
|
doc_token_limit=int(self.config.get("doc_token_limit", 50000)),
|
||||||
|
model_id=self.config.get("model_id", "docsgpt-local"),
|
||||||
|
user_api_key=self.config.get("user_api_key"),
|
||||||
|
agent_id=self.config.get("agent_id"),
|
||||||
|
llm_name=self.config.get("llm_name", settings.LLM_PROVIDER),
|
||||||
|
api_key=self.config.get("api_key", settings.API_KEY),
|
||||||
|
decoded_token=self.config.get("decoded_token"),
|
||||||
|
)
|
||||||
|
return self._retriever
|
||||||
|
|
||||||
|
def _get_directory_structure(self) -> Optional[Dict]:
|
||||||
|
"""Load directory structure from Postgres for the configured sources."""
|
||||||
|
if self._dir_structure_loaded:
|
||||||
|
return self._directory_structure
|
||||||
|
|
||||||
|
self._dir_structure_loaded = True
|
||||||
|
source = self.config.get("source", {})
|
||||||
|
active_docs = source.get("active_docs", [])
|
||||||
|
if not active_docs:
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Per-operation session: this tool runs inside the answer
|
||||||
|
# generator hot path, so we open a short-lived read
|
||||||
|
# connection for the batch lookup and release immediately.
|
||||||
|
from application.storage.db.repositories.sources import (
|
||||||
|
SourcesRepository,
|
||||||
|
)
|
||||||
|
from application.storage.db.session import db_readonly
|
||||||
|
|
||||||
|
if isinstance(active_docs, str):
|
||||||
|
active_docs = [active_docs]
|
||||||
|
|
||||||
|
decoded_token = self.config.get("decoded_token") or {}
|
||||||
|
user_id = decoded_token.get("sub") if decoded_token else None
|
||||||
|
|
||||||
|
merged_structure = {}
|
||||||
|
with db_readonly() as conn:
|
||||||
|
repo = SourcesRepository(conn)
|
||||||
|
for doc_id in active_docs:
|
||||||
|
try:
|
||||||
|
source_doc = repo.get_any(str(doc_id), user_id) if user_id else None
|
||||||
|
if not source_doc:
|
||||||
|
continue
|
||||||
|
dir_str = source_doc.get("directory_structure")
|
||||||
|
if dir_str:
|
||||||
|
if isinstance(dir_str, str):
|
||||||
|
dir_str = json.loads(dir_str)
|
||||||
|
source_name = source_doc.get("name", doc_id)
|
||||||
|
if len(active_docs) > 1:
|
||||||
|
merged_structure[source_name] = dir_str
|
||||||
|
else:
|
||||||
|
merged_structure = dir_str
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Could not load dir structure for {doc_id}: {e}")
|
||||||
|
|
||||||
|
self._directory_structure = merged_structure if merged_structure else None
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Failed to load directory structures: {e}")
|
||||||
|
|
||||||
|
return self._directory_structure
|
||||||
|
|
||||||
|
def execute_action(self, action_name: str, **kwargs):
|
||||||
|
if action_name == "search":
|
||||||
|
return self._execute_search(**kwargs)
|
||||||
|
elif action_name == "list_files":
|
||||||
|
return self._execute_list_files(**kwargs)
|
||||||
|
return f"Unknown action: {action_name}"
|
||||||
|
|
||||||
|
def _execute_search(self, **kwargs) -> str:
|
||||||
|
query = kwargs.get("query", "")
|
||||||
|
path_filter = kwargs.get("path_filter", "")
|
||||||
|
|
||||||
|
if not query:
|
||||||
|
return "Error: 'query' parameter is required."
|
||||||
|
|
||||||
|
try:
|
||||||
|
retriever = self._get_retriever()
|
||||||
|
docs = retriever.search(query)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Internal search failed: {e}", exc_info=True)
|
||||||
|
return "Search failed: an internal error occurred."
|
||||||
|
|
||||||
|
if not docs:
|
||||||
|
return "No documents found matching your query."
|
||||||
|
|
||||||
|
# Apply path filter if specified
|
||||||
|
if path_filter:
|
||||||
|
path_lower = path_filter.lower()
|
||||||
|
docs = [
|
||||||
|
d
|
||||||
|
for d in docs
|
||||||
|
if path_lower in d.get("source", "").lower()
|
||||||
|
or path_lower in d.get("filename", "").lower()
|
||||||
|
or path_lower in d.get("title", "").lower()
|
||||||
|
]
|
||||||
|
if not docs:
|
||||||
|
return f"No documents found matching query '{query}' in path '{path_filter}'."
|
||||||
|
|
||||||
|
# Accumulate for source tracking
|
||||||
|
for doc in docs:
|
||||||
|
if doc not in self.retrieved_docs:
|
||||||
|
self.retrieved_docs.append(doc)
|
||||||
|
|
||||||
|
# Format results for the LLM
|
||||||
|
formatted = []
|
||||||
|
for i, doc in enumerate(docs, 1):
|
||||||
|
title = doc.get("title", "Untitled")
|
||||||
|
text = doc.get("text", "")
|
||||||
|
source = doc.get("source", "Unknown")
|
||||||
|
filename = doc.get("filename", "")
|
||||||
|
header = filename or title
|
||||||
|
formatted.append(f"[{i}] {header} (source: {source})\n{text}")
|
||||||
|
|
||||||
|
return "\n\n---\n\n".join(formatted)
|
||||||
|
|
||||||
|
def _execute_list_files(self, **kwargs) -> str:
|
||||||
|
path = kwargs.get("path", "")
|
||||||
|
dir_structure = self._get_directory_structure()
|
||||||
|
|
||||||
|
if not dir_structure:
|
||||||
|
return "No file structure available for the current sources."
|
||||||
|
|
||||||
|
# Navigate to the requested path
|
||||||
|
current = dir_structure
|
||||||
|
if path:
|
||||||
|
for part in path.strip("/").split("/"):
|
||||||
|
if not part:
|
||||||
|
continue
|
||||||
|
if isinstance(current, dict) and part in current:
|
||||||
|
current = current[part]
|
||||||
|
else:
|
||||||
|
return f"Path '{path}' not found in the file structure."
|
||||||
|
|
||||||
|
# Format the structure for the LLM
|
||||||
|
return self._format_structure(current, path or "/")
|
||||||
|
|
||||||
|
def _format_structure(self, node: Dict, current_path: str) -> str:
|
||||||
|
if not isinstance(node, dict):
|
||||||
|
return f"'{current_path}' is a file, not a directory."
|
||||||
|
|
||||||
|
lines = [f"File structure at '{current_path}':\n"]
|
||||||
|
folders = []
|
||||||
|
files = []
|
||||||
|
|
||||||
|
for name, value in sorted(node.items()):
|
||||||
|
if isinstance(value, dict):
|
||||||
|
# Check if it's a file metadata dict or a folder
|
||||||
|
if "type" in value or "size_bytes" in value or "token_count" in value:
|
||||||
|
# It's a file with metadata
|
||||||
|
size = value.get("token_count", "")
|
||||||
|
ftype = value.get("type", "")
|
||||||
|
info_parts = []
|
||||||
|
if ftype:
|
||||||
|
info_parts.append(ftype)
|
||||||
|
if size:
|
||||||
|
info_parts.append(f"{size} tokens")
|
||||||
|
info = f" ({', '.join(info_parts)})" if info_parts else ""
|
||||||
|
files.append(f" {name}{info}")
|
||||||
|
else:
|
||||||
|
# It's a folder
|
||||||
|
count = self._count_files(value)
|
||||||
|
folders.append(f" {name}/ ({count} items)")
|
||||||
|
else:
|
||||||
|
files.append(f" {name}")
|
||||||
|
|
||||||
|
if folders:
|
||||||
|
lines.append("Folders:")
|
||||||
|
lines.extend(folders)
|
||||||
|
if files:
|
||||||
|
lines.append("Files:")
|
||||||
|
lines.extend(files)
|
||||||
|
if not folders and not files:
|
||||||
|
lines.append(" (empty)")
|
||||||
|
|
||||||
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
def _count_files(self, node: Dict) -> int:
|
||||||
|
count = 0
|
||||||
|
for value in node.values():
|
||||||
|
if isinstance(value, dict):
|
||||||
|
if "type" in value or "size_bytes" in value or "token_count" in value:
|
||||||
|
count += 1
|
||||||
|
else:
|
||||||
|
count += self._count_files(value)
|
||||||
|
else:
|
||||||
|
count += 1
|
||||||
|
return count
|
||||||
|
|
||||||
|
def get_actions_metadata(self):
|
||||||
|
actions = [
|
||||||
|
{
|
||||||
|
"name": "search",
|
||||||
|
"description": (
|
||||||
|
"Search the user's uploaded documents and knowledge base. "
|
||||||
|
"Use this to find relevant information before answering questions. "
|
||||||
|
"You can call this multiple times with different queries."
|
||||||
|
),
|
||||||
|
"parameters": {
|
||||||
|
"properties": {
|
||||||
|
"query": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The search query. Be specific and focused.",
|
||||||
|
"filled_by_llm": True,
|
||||||
|
"required": True,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
# Add path_filter and list_files only if directory structure exists
|
||||||
|
has_structure = self.config.get("has_directory_structure", False)
|
||||||
|
if has_structure:
|
||||||
|
actions[0]["parameters"]["properties"]["path_filter"] = {
|
||||||
|
"type": "string",
|
||||||
|
"description": (
|
||||||
|
"Optional: filter results to a specific file or folder path. "
|
||||||
|
"Use list_files first to see available paths."
|
||||||
|
),
|
||||||
|
"filled_by_llm": True,
|
||||||
|
"required": False,
|
||||||
|
}
|
||||||
|
actions.append(
|
||||||
|
{
|
||||||
|
"name": "list_files",
|
||||||
|
"description": (
|
||||||
|
"Browse the file and folder structure of the knowledge base. "
|
||||||
|
"Use this to see what files are available before searching. "
|
||||||
|
"Optionally provide a path to browse a specific folder."
|
||||||
|
),
|
||||||
|
"parameters": {
|
||||||
|
"properties": {
|
||||||
|
"path": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Optional: folder path to browse. Leave empty for root.",
|
||||||
|
"filled_by_llm": True,
|
||||||
|
"required": False,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return actions
|
||||||
|
|
||||||
|
def get_config_requirements(self):
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
# Constants for building synthetic tools_dict entries
|
||||||
|
INTERNAL_TOOL_ID = "internal"
|
||||||
|
|
||||||
|
|
||||||
|
def build_internal_tool_entry(has_directory_structure: bool = False) -> Dict:
|
||||||
|
"""Build the tools_dict entry for InternalSearchTool.
|
||||||
|
|
||||||
|
Dynamically includes list_files and path_filter based on
|
||||||
|
whether the sources have directory structure.
|
||||||
|
"""
|
||||||
|
search_params = {
|
||||||
|
"properties": {
|
||||||
|
"query": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The search query. Be specific and focused.",
|
||||||
|
"filled_by_llm": True,
|
||||||
|
"required": True,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
actions = [
|
||||||
|
{
|
||||||
|
"name": "search",
|
||||||
|
"description": (
|
||||||
|
"Search the user's uploaded documents and knowledge base. "
|
||||||
|
"Use this to find relevant information before answering questions. "
|
||||||
|
"You can call this multiple times with different queries."
|
||||||
|
),
|
||||||
|
"active": True,
|
||||||
|
"parameters": search_params,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
if has_directory_structure:
|
||||||
|
search_params["properties"]["path_filter"] = {
|
||||||
|
"type": "string",
|
||||||
|
"description": (
|
||||||
|
"Optional: filter results to a specific file or folder path. "
|
||||||
|
"Use list_files first to see available paths."
|
||||||
|
),
|
||||||
|
"filled_by_llm": True,
|
||||||
|
"required": False,
|
||||||
|
}
|
||||||
|
actions.append(
|
||||||
|
{
|
||||||
|
"name": "list_files",
|
||||||
|
"description": (
|
||||||
|
"Browse the file and folder structure of the knowledge base. "
|
||||||
|
"Use this to see what files are available before searching. "
|
||||||
|
"Optionally provide a path to browse a specific folder."
|
||||||
|
),
|
||||||
|
"active": True,
|
||||||
|
"parameters": {
|
||||||
|
"properties": {
|
||||||
|
"path": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Optional: folder path to browse. Leave empty for root.",
|
||||||
|
"filled_by_llm": True,
|
||||||
|
"required": False,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return {"name": "internal_search", "actions": actions}
|
||||||
|
|
||||||
|
|
||||||
|
# Keep backward compat
|
||||||
|
INTERNAL_TOOL_ENTRY = build_internal_tool_entry(has_directory_structure=False)
|
||||||
|
|
||||||
|
|
||||||
|
def sources_have_directory_structure(source: Dict) -> bool:
|
||||||
|
"""Check if any of the active sources have a ``directory_structure`` row."""
|
||||||
|
active_docs = source.get("active_docs", [])
|
||||||
|
if not active_docs:
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
# TODO(pg-cutover): SourcesRepository.get_any requires ``user_id``
|
||||||
|
# scoping, but callers in the agent build path don't always
|
||||||
|
# thread the decoded token through here. Use a direct
|
||||||
|
# short-lived SQL lookup instead of the repo until the call
|
||||||
|
# sites are updated to propagate user context.
|
||||||
|
from sqlalchemy import text as _text
|
||||||
|
|
||||||
|
from application.storage.db.session import db_readonly
|
||||||
|
|
||||||
|
if isinstance(active_docs, str):
|
||||||
|
active_docs = [active_docs]
|
||||||
|
|
||||||
|
with db_readonly() as conn:
|
||||||
|
for doc_id in active_docs:
|
||||||
|
try:
|
||||||
|
value = str(doc_id)
|
||||||
|
if len(value) == 36 and "-" in value:
|
||||||
|
row = conn.execute(
|
||||||
|
_text(
|
||||||
|
"SELECT directory_structure FROM sources "
|
||||||
|
"WHERE id = CAST(:id AS uuid)"
|
||||||
|
),
|
||||||
|
{"id": value},
|
||||||
|
).fetchone()
|
||||||
|
else:
|
||||||
|
row = conn.execute(
|
||||||
|
_text(
|
||||||
|
"SELECT directory_structure FROM sources "
|
||||||
|
"WHERE legacy_mongo_id = :lid"
|
||||||
|
),
|
||||||
|
{"lid": value},
|
||||||
|
).fetchone()
|
||||||
|
if row is not None and row[0]:
|
||||||
|
return True
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Could not check directory structure: {e}")
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def add_internal_search_tool(tools_dict: Dict, retriever_config: Dict) -> None:
|
||||||
|
"""Add the internal search tool to tools_dict if sources are configured.
|
||||||
|
|
||||||
|
Shared by AgenticAgent and ResearchAgent to avoid duplicate setup logic.
|
||||||
|
Mutates tools_dict in place.
|
||||||
|
"""
|
||||||
|
source = retriever_config.get("source", {})
|
||||||
|
has_sources = bool(source.get("active_docs"))
|
||||||
|
if not retriever_config or not has_sources:
|
||||||
|
return
|
||||||
|
|
||||||
|
has_dir = sources_have_directory_structure(source)
|
||||||
|
internal_entry = build_internal_tool_entry(has_directory_structure=has_dir)
|
||||||
|
internal_entry["config"] = build_internal_tool_config(
|
||||||
|
**retriever_config,
|
||||||
|
has_directory_structure=has_dir,
|
||||||
|
)
|
||||||
|
tools_dict[INTERNAL_TOOL_ID] = internal_entry
|
||||||
|
|
||||||
|
|
||||||
|
def build_internal_tool_config(
|
||||||
|
source: Dict,
|
||||||
|
retriever_name: str = "classic",
|
||||||
|
chunks: int = 2,
|
||||||
|
doc_token_limit: int = 50000,
|
||||||
|
model_id: str = "docsgpt-local",
|
||||||
|
user_api_key: Optional[str] = None,
|
||||||
|
agent_id: Optional[str] = None,
|
||||||
|
llm_name: str = None,
|
||||||
|
api_key: str = None,
|
||||||
|
decoded_token: Optional[Dict] = None,
|
||||||
|
has_directory_structure: bool = False,
|
||||||
|
) -> Dict:
|
||||||
|
"""Build the config dict for InternalSearchTool."""
|
||||||
|
return {
|
||||||
|
"source": source,
|
||||||
|
"retriever_name": retriever_name,
|
||||||
|
"chunks": chunks,
|
||||||
|
"doc_token_limit": doc_token_limit,
|
||||||
|
"model_id": model_id,
|
||||||
|
"user_api_key": user_api_key,
|
||||||
|
"agent_id": agent_id,
|
||||||
|
"llm_name": llm_name or settings.LLM_PROVIDER,
|
||||||
|
"api_key": api_key or settings.API_KEY,
|
||||||
|
"decoded_token": decoded_token,
|
||||||
|
"has_directory_structure": has_directory_structure,
|
||||||
|
}
|
||||||
@@ -1,20 +1,12 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import base64
|
import base64
|
||||||
|
import concurrent.futures
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
from urllib.parse import parse_qs, urlparse
|
from urllib.parse import parse_qs, urlparse
|
||||||
|
|
||||||
from application.agents.tools.base import Tool
|
|
||||||
from application.api.user.tasks import mcp_oauth_status_task, mcp_oauth_task
|
|
||||||
from application.cache import get_redis_instance
|
|
||||||
|
|
||||||
from application.core.mongo_db import MongoDB
|
|
||||||
|
|
||||||
from application.core.settings import settings
|
|
||||||
|
|
||||||
from application.security.encryption import decrypt_credentials
|
|
||||||
from fastmcp import Client
|
from fastmcp import Client
|
||||||
from fastmcp.client.auth import BearerAuth
|
from fastmcp.client.auth import BearerAuth
|
||||||
from fastmcp.client.transports import (
|
from fastmcp.client.transports import (
|
||||||
@@ -24,12 +16,17 @@ from fastmcp.client.transports import (
|
|||||||
)
|
)
|
||||||
from mcp.client.auth import OAuthClientProvider, TokenStorage
|
from mcp.client.auth import OAuthClientProvider, TokenStorage
|
||||||
from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken
|
from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken
|
||||||
|
|
||||||
from pydantic import AnyHttpUrl, ValidationError
|
from pydantic import AnyHttpUrl, ValidationError
|
||||||
from redis import Redis
|
from redis import Redis
|
||||||
|
|
||||||
mongo = MongoDB.get_client()
|
from application.agents.tools.base import Tool
|
||||||
db = mongo[settings.MONGO_DB_NAME]
|
from application.api.user.tasks import mcp_oauth_status_task, mcp_oauth_task
|
||||||
|
from application.cache import get_redis_instance
|
||||||
|
from application.core.settings import settings
|
||||||
|
from application.core.url_validation import SSRFError, validate_url
|
||||||
|
from application.security.encryption import decrypt_credentials
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
_mcp_clients_cache = {}
|
_mcp_clients_cache = {}
|
||||||
|
|
||||||
@@ -56,11 +53,13 @@ class MCPTool(Tool):
|
|||||||
- args: Arguments for STDIO transport
|
- args: Arguments for STDIO transport
|
||||||
- oauth_scopes: OAuth scopes for oauth auth type
|
- oauth_scopes: OAuth scopes for oauth auth type
|
||||||
- oauth_client_name: OAuth client name for oauth auth type
|
- oauth_client_name: OAuth client name for oauth auth type
|
||||||
|
- query_mode: If True, use non-interactive OAuth (fail-fast on 401)
|
||||||
user_id: User ID for decrypting credentials (required if encrypted_credentials exist)
|
user_id: User ID for decrypting credentials (required if encrypted_credentials exist)
|
||||||
"""
|
"""
|
||||||
self.config = config
|
self.config = config
|
||||||
self.user_id = user_id
|
self.user_id = user_id
|
||||||
self.server_url = config.get("server_url", "")
|
raw_url = config.get("server_url", "")
|
||||||
|
self.server_url = self._validate_server_url(raw_url) if raw_url else ""
|
||||||
self.transport_type = config.get("transport_type", "auto")
|
self.transport_type = config.get("transport_type", "auto")
|
||||||
self.auth_type = config.get("auth_type", "none")
|
self.auth_type = config.get("auth_type", "none")
|
||||||
self.timeout = config.get("timeout", 30)
|
self.timeout = config.get("timeout", 30)
|
||||||
@@ -76,23 +75,53 @@ class MCPTool(Tool):
|
|||||||
self.oauth_scopes = config.get("oauth_scopes", [])
|
self.oauth_scopes = config.get("oauth_scopes", [])
|
||||||
self.oauth_task_id = config.get("oauth_task_id", None)
|
self.oauth_task_id = config.get("oauth_task_id", None)
|
||||||
self.oauth_client_name = config.get("oauth_client_name", "DocsGPT-MCP")
|
self.oauth_client_name = config.get("oauth_client_name", "DocsGPT-MCP")
|
||||||
self.redirect_uri = f"{settings.API_URL}/api/mcp_server/callback"
|
self.redirect_uri = self._resolve_redirect_uri(config.get("redirect_uri"))
|
||||||
|
|
||||||
self.available_tools = []
|
self.available_tools = []
|
||||||
self._cache_key = self._generate_cache_key()
|
self._cache_key = self._generate_cache_key()
|
||||||
self._client = None
|
self._client = None
|
||||||
|
self.query_mode = config.get("query_mode", False)
|
||||||
# Only validate and setup if server_url is provided and not OAuth
|
|
||||||
|
|
||||||
if self.server_url and self.auth_type != "oauth":
|
if self.server_url and self.auth_type != "oauth":
|
||||||
self._setup_client()
|
self._setup_client()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _validate_server_url(server_url: str) -> str:
|
||||||
|
"""Validate server_url to prevent SSRF to internal networks.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the URL points to a private/internal address.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return validate_url(server_url)
|
||||||
|
except SSRFError as exc:
|
||||||
|
raise ValueError(f"Invalid MCP server URL: {exc}") from exc
|
||||||
|
|
||||||
|
def _resolve_redirect_uri(self, configured_redirect_uri: Optional[str]) -> str:
|
||||||
|
if configured_redirect_uri:
|
||||||
|
return configured_redirect_uri.rstrip("/")
|
||||||
|
|
||||||
|
explicit = getattr(settings, "MCP_OAUTH_REDIRECT_URI", None)
|
||||||
|
if explicit:
|
||||||
|
return explicit.rstrip("/")
|
||||||
|
|
||||||
|
connector_base = getattr(settings, "CONNECTOR_REDIRECT_BASE_URI", None)
|
||||||
|
if connector_base:
|
||||||
|
parsed = urlparse(connector_base)
|
||||||
|
if parsed.scheme and parsed.netloc:
|
||||||
|
return f"{parsed.scheme}://{parsed.netloc}/api/mcp_server/callback"
|
||||||
|
|
||||||
|
return f"{settings.API_URL.rstrip('/')}/api/mcp_server/callback"
|
||||||
|
|
||||||
def _generate_cache_key(self) -> str:
|
def _generate_cache_key(self) -> str:
|
||||||
"""Generate a unique cache key for this MCP server configuration."""
|
"""Generate a unique cache key for this MCP server configuration."""
|
||||||
auth_key = ""
|
auth_key = ""
|
||||||
if self.auth_type == "oauth":
|
if self.auth_type == "oauth":
|
||||||
scopes_str = ",".join(self.oauth_scopes) if self.oauth_scopes else "none"
|
scopes_str = ",".join(self.oauth_scopes) if self.oauth_scopes else "none"
|
||||||
auth_key = f"oauth:{self.oauth_client_name}:{scopes_str}"
|
oauth_identity = self.user_id or self.oauth_task_id or "anonymous"
|
||||||
|
auth_key = (
|
||||||
|
f"oauth:{oauth_identity}:{self.oauth_client_name}:{scopes_str}:{self.redirect_uri}"
|
||||||
|
)
|
||||||
elif self.auth_type in ["bearer"]:
|
elif self.auth_type in ["bearer"]:
|
||||||
token = self.auth_credentials.get(
|
token = self.auth_credentials.get(
|
||||||
"bearer_token", ""
|
"bearer_token", ""
|
||||||
@@ -109,11 +138,10 @@ class MCPTool(Tool):
|
|||||||
return f"{self.server_url}#{self.transport_type}#{auth_key}"
|
return f"{self.server_url}#{self.transport_type}#{auth_key}"
|
||||||
|
|
||||||
def _setup_client(self):
|
def _setup_client(self):
|
||||||
"""Setup FastMCP client with proper transport and authentication."""
|
|
||||||
global _mcp_clients_cache
|
global _mcp_clients_cache
|
||||||
if self._cache_key in _mcp_clients_cache:
|
if self._cache_key in _mcp_clients_cache:
|
||||||
cached_data = _mcp_clients_cache[self._cache_key]
|
cached_data = _mcp_clients_cache[self._cache_key]
|
||||||
if time.time() - cached_data["created_at"] < 1800:
|
if time.time() - cached_data["created_at"] < 300:
|
||||||
self._client = cached_data["client"]
|
self._client = cached_data["client"]
|
||||||
return
|
return
|
||||||
else:
|
else:
|
||||||
@@ -123,15 +151,23 @@ class MCPTool(Tool):
|
|||||||
|
|
||||||
if self.auth_type == "oauth":
|
if self.auth_type == "oauth":
|
||||||
redis_client = get_redis_instance()
|
redis_client = get_redis_instance()
|
||||||
auth = DocsGPTOAuth(
|
if self.query_mode:
|
||||||
mcp_url=self.server_url,
|
auth = NonInteractiveOAuth(
|
||||||
scopes=self.oauth_scopes,
|
mcp_url=self.server_url,
|
||||||
redis_client=redis_client,
|
scopes=self.oauth_scopes,
|
||||||
redirect_uri=self.redirect_uri,
|
redis_client=redis_client,
|
||||||
task_id=self.oauth_task_id,
|
redirect_uri=self.redirect_uri,
|
||||||
db=db,
|
user_id=self.user_id,
|
||||||
user_id=self.user_id,
|
)
|
||||||
)
|
else:
|
||||||
|
auth = DocsGPTOAuth(
|
||||||
|
mcp_url=self.server_url,
|
||||||
|
scopes=self.oauth_scopes,
|
||||||
|
redis_client=redis_client,
|
||||||
|
redirect_uri=self.redirect_uri,
|
||||||
|
task_id=self.oauth_task_id,
|
||||||
|
user_id=self.user_id,
|
||||||
|
)
|
||||||
elif self.auth_type == "bearer":
|
elif self.auth_type == "bearer":
|
||||||
token = self.auth_credentials.get(
|
token = self.auth_credentials.get(
|
||||||
"bearer_token", ""
|
"bearer_token", ""
|
||||||
@@ -169,6 +205,8 @@ class MCPTool(Tool):
|
|||||||
transport_type = "http"
|
transport_type = "http"
|
||||||
else:
|
else:
|
||||||
transport_type = self.transport_type
|
transport_type = self.transport_type
|
||||||
|
if transport_type == "stdio":
|
||||||
|
raise ValueError("STDIO transport is disabled")
|
||||||
if transport_type == "sse":
|
if transport_type == "sse":
|
||||||
headers.update({"Accept": "text/event-stream", "Cache-Control": "no-cache"})
|
headers.update({"Accept": "text/event-stream", "Cache-Control": "no-cache"})
|
||||||
return SSETransport(url=self.server_url, headers=headers)
|
return SSETransport(url=self.server_url, headers=headers)
|
||||||
@@ -231,38 +269,53 @@ class MCPTool(Tool):
|
|||||||
else:
|
else:
|
||||||
raise Exception(f"Unknown operation: {operation}")
|
raise Exception(f"Unknown operation: {operation}")
|
||||||
|
|
||||||
|
_ERROR_MAP = [
|
||||||
|
(concurrent.futures.TimeoutError, lambda op, t, _: f"Timed out after {t}s"),
|
||||||
|
(ConnectionRefusedError, lambda *_: "Connection refused"),
|
||||||
|
]
|
||||||
|
|
||||||
|
_ERROR_PATTERNS = {
|
||||||
|
("403", "Forbidden"): "Access denied (403 Forbidden)",
|
||||||
|
("401", "Unauthorized"): "Authentication failed (401 Unauthorized)",
|
||||||
|
("ECONNREFUSED",): "Connection refused",
|
||||||
|
("SSL", "certificate"): "SSL/TLS error",
|
||||||
|
}
|
||||||
|
|
||||||
def _run_async_operation(self, operation: str, *args, **kwargs):
|
def _run_async_operation(self, operation: str, *args, **kwargs):
|
||||||
"""Run async operation in sync context."""
|
|
||||||
try:
|
try:
|
||||||
try:
|
try:
|
||||||
loop = asyncio.get_running_loop()
|
asyncio.get_running_loop()
|
||||||
import concurrent.futures
|
|
||||||
|
|
||||||
def run_in_thread():
|
|
||||||
new_loop = asyncio.new_event_loop()
|
|
||||||
asyncio.set_event_loop(new_loop)
|
|
||||||
try:
|
|
||||||
return new_loop.run_until_complete(
|
|
||||||
self._execute_with_client(operation, *args, **kwargs)
|
|
||||||
)
|
|
||||||
finally:
|
|
||||||
new_loop.close()
|
|
||||||
|
|
||||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||||
future = executor.submit(run_in_thread)
|
future = executor.submit(
|
||||||
|
self._run_in_new_loop, operation, *args, **kwargs
|
||||||
|
)
|
||||||
return future.result(timeout=self.timeout)
|
return future.result(timeout=self.timeout)
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
loop = asyncio.new_event_loop()
|
return self._run_in_new_loop(operation, *args, **kwargs)
|
||||||
asyncio.set_event_loop(loop)
|
|
||||||
try:
|
|
||||||
return loop.run_until_complete(
|
|
||||||
self._execute_with_client(operation, *args, **kwargs)
|
|
||||||
)
|
|
||||||
finally:
|
|
||||||
loop.close()
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error occurred while running async operation: {e}")
|
raise self._map_error(operation, e) from e
|
||||||
raise
|
raise self._map_error(operation, e) from e
|
||||||
|
|
||||||
|
def _run_in_new_loop(self, operation, *args, **kwargs):
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
|
try:
|
||||||
|
return loop.run_until_complete(
|
||||||
|
self._execute_with_client(operation, *args, **kwargs)
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
loop.close()
|
||||||
|
|
||||||
|
def _map_error(self, operation: str, exc: Exception) -> Exception:
|
||||||
|
for exc_type, msg_fn in self._ERROR_MAP:
|
||||||
|
if isinstance(exc, exc_type):
|
||||||
|
return Exception(msg_fn(operation, self.timeout, exc))
|
||||||
|
error_msg = str(exc)
|
||||||
|
for patterns, friendly in self._ERROR_PATTERNS.items():
|
||||||
|
if any(p.lower() in error_msg.lower() for p in patterns):
|
||||||
|
return Exception(friendly)
|
||||||
|
logger.error("MCP %s failed: %s", operation, exc)
|
||||||
|
return exc
|
||||||
|
|
||||||
def discover_tools(self) -> List[Dict]:
|
def discover_tools(self) -> List[Dict]:
|
||||||
"""
|
"""
|
||||||
@@ -283,16 +336,6 @@ class MCPTool(Tool):
|
|||||||
raise Exception(f"Failed to discover tools from MCP server: {str(e)}")
|
raise Exception(f"Failed to discover tools from MCP server: {str(e)}")
|
||||||
|
|
||||||
def execute_action(self, action_name: str, **kwargs) -> Any:
|
def execute_action(self, action_name: str, **kwargs) -> Any:
|
||||||
"""
|
|
||||||
Execute an action on the remote MCP server using FastMCP.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
action_name: Name of the action to execute
|
|
||||||
**kwargs: Parameters for the action
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Result from the MCP server
|
|
||||||
"""
|
|
||||||
if not self.server_url:
|
if not self.server_url:
|
||||||
raise Exception("No MCP server configured")
|
raise Exception("No MCP server configured")
|
||||||
if not self._client:
|
if not self._client:
|
||||||
@@ -308,7 +351,37 @@ class MCPTool(Tool):
|
|||||||
)
|
)
|
||||||
return self._format_result(result)
|
return self._format_result(result)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise Exception(f"Failed to execute action '{action_name}': {str(e)}")
|
error_msg = str(e)
|
||||||
|
lower_msg = error_msg.lower()
|
||||||
|
is_auth_error = (
|
||||||
|
"401" in error_msg
|
||||||
|
or "unauthorized" in lower_msg
|
||||||
|
or "session expired" in lower_msg
|
||||||
|
or "re-authorize" in lower_msg
|
||||||
|
)
|
||||||
|
if is_auth_error:
|
||||||
|
if self.auth_type == "oauth":
|
||||||
|
raise Exception(
|
||||||
|
f"Action '{action_name}' failed: OAuth session expired. "
|
||||||
|
"Please re-authorize this MCP server in tool settings."
|
||||||
|
) from e
|
||||||
|
global _mcp_clients_cache
|
||||||
|
_mcp_clients_cache.pop(self._cache_key, None)
|
||||||
|
self._client = None
|
||||||
|
self._setup_client()
|
||||||
|
try:
|
||||||
|
result = self._run_async_operation(
|
||||||
|
"call_tool", action_name, **cleaned_kwargs
|
||||||
|
)
|
||||||
|
return self._format_result(result)
|
||||||
|
except Exception as retry_e:
|
||||||
|
raise Exception(
|
||||||
|
f"Action '{action_name}' failed after re-auth attempt: {retry_e}. "
|
||||||
|
"Your credentials may have expired — please re-authorize in tool settings."
|
||||||
|
) from retry_e
|
||||||
|
raise Exception(
|
||||||
|
f"Failed to execute action '{action_name}': {error_msg}"
|
||||||
|
) from e
|
||||||
|
|
||||||
def _format_result(self, result) -> Dict:
|
def _format_result(self, result) -> Dict:
|
||||||
"""Format FastMCP result to match expected format."""
|
"""Format FastMCP result to match expected format."""
|
||||||
@@ -331,23 +404,35 @@ class MCPTool(Tool):
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
def test_connection(self) -> Dict:
|
def test_connection(self) -> Dict:
|
||||||
"""
|
|
||||||
Test the connection to the MCP server and validate functionality.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dictionary with connection test results including tool count
|
|
||||||
"""
|
|
||||||
if not self.server_url:
|
if not self.server_url:
|
||||||
return {
|
return {
|
||||||
"success": False,
|
"success": False,
|
||||||
"message": "No MCP server URL configured",
|
"message": "No server URL configured",
|
||||||
|
"tools_count": 0,
|
||||||
|
}
|
||||||
|
try:
|
||||||
|
parsed = urlparse(self.server_url)
|
||||||
|
if parsed.scheme not in ("http", "https"):
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"message": f"Invalid URL scheme '{parsed.scheme}' — use http:// or https://",
|
||||||
|
"tools_count": 0,
|
||||||
|
}
|
||||||
|
except Exception:
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"message": "Invalid URL format",
|
||||||
"tools_count": 0,
|
"tools_count": 0,
|
||||||
"transport_type": self.transport_type,
|
|
||||||
"auth_type": self.auth_type,
|
|
||||||
"error_type": "ConfigurationError",
|
|
||||||
}
|
}
|
||||||
if not self._client:
|
if not self._client:
|
||||||
self._setup_client()
|
try:
|
||||||
|
self._setup_client()
|
||||||
|
except Exception as e:
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"message": f"Client init failed: {str(e)}",
|
||||||
|
"tools_count": 0,
|
||||||
|
}
|
||||||
try:
|
try:
|
||||||
if self.auth_type == "oauth":
|
if self.auth_type == "oauth":
|
||||||
return self._test_oauth_connection()
|
return self._test_oauth_connection()
|
||||||
@@ -358,56 +443,94 @@ class MCPTool(Tool):
|
|||||||
"success": False,
|
"success": False,
|
||||||
"message": f"Connection failed: {str(e)}",
|
"message": f"Connection failed: {str(e)}",
|
||||||
"tools_count": 0,
|
"tools_count": 0,
|
||||||
"transport_type": self.transport_type,
|
|
||||||
"auth_type": self.auth_type,
|
|
||||||
"error_type": type(e).__name__,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def _test_regular_connection(self) -> Dict:
|
def _test_regular_connection(self) -> Dict:
|
||||||
"""Test connection for non-OAuth auth types."""
|
ping_ok = False
|
||||||
|
ping_error = None
|
||||||
try:
|
try:
|
||||||
self._run_async_operation("ping")
|
self._run_async_operation("ping")
|
||||||
ping_success = True
|
ping_ok = True
|
||||||
except Exception:
|
except Exception as e:
|
||||||
ping_success = False
|
ping_error = str(e)
|
||||||
tools = self.discover_tools()
|
|
||||||
|
|
||||||
message = f"Successfully connected to MCP server. Found {len(tools)} tools."
|
|
||||||
if not ping_success:
|
|
||||||
message += " (Ping not supported, but tool discovery worked)"
|
|
||||||
return {
|
|
||||||
"success": True,
|
|
||||||
"message": message,
|
|
||||||
"tools_count": len(tools),
|
|
||||||
"transport_type": self.transport_type,
|
|
||||||
"auth_type": self.auth_type,
|
|
||||||
"ping_supported": ping_success,
|
|
||||||
"tools": [tool.get("name", "unknown") for tool in tools],
|
|
||||||
}
|
|
||||||
|
|
||||||
def _test_oauth_connection(self) -> Dict:
|
|
||||||
"""Test connection for OAuth auth type with proper async handling."""
|
|
||||||
try:
|
try:
|
||||||
task = mcp_oauth_task.delay(config=self.config, user=self.user_id)
|
tools = self.discover_tools()
|
||||||
if not task:
|
|
||||||
raise Exception("Failed to start OAuth authentication")
|
|
||||||
return {
|
|
||||||
"success": True,
|
|
||||||
"requires_oauth": True,
|
|
||||||
"task_id": task.id,
|
|
||||||
"status": "pending",
|
|
||||||
"message": "OAuth flow started",
|
|
||||||
}
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return {
|
return {
|
||||||
"success": False,
|
"success": False,
|
||||||
"message": f"OAuth connection failed: {str(e)}",
|
"message": f"Connection failed: {ping_error or str(e)}",
|
||||||
"tools_count": 0,
|
"tools_count": 0,
|
||||||
"transport_type": self.transport_type,
|
|
||||||
"auth_type": self.auth_type,
|
|
||||||
"error_type": type(e).__name__,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if not tools and not ping_ok:
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"message": f"Connection failed: {ping_error or 'No tools found'}",
|
||||||
|
"tools_count": 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"message": f"Connected — found {len(tools)} tool{'s' if len(tools) != 1 else ''}.",
|
||||||
|
"tools_count": len(tools),
|
||||||
|
"tools": [
|
||||||
|
{
|
||||||
|
"name": tool.get("name", "unknown"),
|
||||||
|
"description": tool.get("description", ""),
|
||||||
|
}
|
||||||
|
for tool in tools
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
def _test_oauth_connection(self) -> Dict:
|
||||||
|
storage = DBTokenStorage(
|
||||||
|
server_url=self.server_url, user_id=self.user_id,
|
||||||
|
)
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
try:
|
||||||
|
tokens = loop.run_until_complete(storage.get_tokens())
|
||||||
|
finally:
|
||||||
|
loop.close()
|
||||||
|
|
||||||
|
if tokens and tokens.access_token:
|
||||||
|
self.query_mode = True
|
||||||
|
_mcp_clients_cache.pop(self._cache_key, None)
|
||||||
|
self._client = None
|
||||||
|
self._setup_client()
|
||||||
|
try:
|
||||||
|
tools = self.discover_tools()
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"message": f"Connected — found {len(tools)} tool{'s' if len(tools) != 1 else ''}.",
|
||||||
|
"tools_count": len(tools),
|
||||||
|
"tools": [
|
||||||
|
{
|
||||||
|
"name": t.get("name", "unknown"),
|
||||||
|
"description": t.get("description", ""),
|
||||||
|
}
|
||||||
|
for t in tools
|
||||||
|
],
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("OAuth token validation failed: %s", e)
|
||||||
|
_mcp_clients_cache.pop(self._cache_key, None)
|
||||||
|
self._client = None
|
||||||
|
|
||||||
|
return self._start_oauth_task()
|
||||||
|
|
||||||
|
def _start_oauth_task(self) -> Dict:
|
||||||
|
task_config = self.config.copy()
|
||||||
|
task_config.pop("query_mode", None)
|
||||||
|
result = mcp_oauth_task.delay(task_config, self.user_id)
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"requires_oauth": True,
|
||||||
|
"task_id": result.id,
|
||||||
|
"message": "OAuth authorization required.",
|
||||||
|
"tools_count": 0,
|
||||||
|
}
|
||||||
|
|
||||||
def get_actions_metadata(self) -> List[Dict]:
|
def get_actions_metadata(self) -> List[Dict]:
|
||||||
"""
|
"""
|
||||||
Get metadata for all available actions.
|
Get metadata for all available actions.
|
||||||
@@ -453,107 +576,88 @@ class MCPTool(Tool):
|
|||||||
return actions
|
return actions
|
||||||
|
|
||||||
def get_config_requirements(self) -> Dict:
|
def get_config_requirements(self) -> Dict:
|
||||||
"""Get configuration requirements for the MCP tool."""
|
|
||||||
return {
|
return {
|
||||||
"server_url": {
|
"server_url": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "URL of the remote MCP server (e.g., https://api.example.com/mcp or https://docs.mcp.cloudflare.com/sse)",
|
"label": "Server URL",
|
||||||
|
"description": "URL of the remote MCP server",
|
||||||
"required": True,
|
"required": True,
|
||||||
},
|
"secret": False,
|
||||||
"transport_type": {
|
"order": 1,
|
||||||
"type": "string",
|
|
||||||
"description": "Transport type for connection",
|
|
||||||
"enum": ["auto", "sse", "http", "stdio"],
|
|
||||||
"default": "auto",
|
|
||||||
"required": False,
|
|
||||||
"help": {
|
|
||||||
"auto": "Automatically detect best transport",
|
|
||||||
"sse": "Server-Sent Events (for real-time streaming)",
|
|
||||||
"http": "HTTP streaming (recommended for production)",
|
|
||||||
"stdio": "Standard I/O (for local servers)",
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
"auth_type": {
|
"auth_type": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "Authentication type",
|
"label": "Authentication Type",
|
||||||
|
"description": "Authentication method for the MCP server",
|
||||||
"enum": ["none", "bearer", "oauth", "api_key", "basic"],
|
"enum": ["none", "bearer", "oauth", "api_key", "basic"],
|
||||||
"default": "none",
|
"default": "none",
|
||||||
"required": True,
|
"required": True,
|
||||||
"help": {
|
"secret": False,
|
||||||
"none": "No authentication",
|
"order": 2,
|
||||||
"bearer": "Bearer token authentication",
|
|
||||||
"oauth": "OAuth 2.1 authentication (with frontend integration)",
|
|
||||||
"api_key": "API key authentication",
|
|
||||||
"basic": "Basic authentication",
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
"auth_credentials": {
|
"api_key": {
|
||||||
"type": "object",
|
"type": "string",
|
||||||
"description": "Authentication credentials (varies by auth_type)",
|
"label": "API Key",
|
||||||
|
"description": "API key for authentication",
|
||||||
"required": False,
|
"required": False,
|
||||||
"properties": {
|
"secret": True,
|
||||||
"bearer_token": {
|
"order": 3,
|
||||||
"type": "string",
|
"depends_on": {"auth_type": "api_key"},
|
||||||
"description": "Bearer token for bearer auth",
|
},
|
||||||
},
|
"api_key_header": {
|
||||||
"access_token": {
|
"type": "string",
|
||||||
"type": "string",
|
"label": "API Key Header",
|
||||||
"description": "Access token for OAuth (if pre-obtained)",
|
"description": "Header name for API key (default: X-API-Key)",
|
||||||
},
|
"default": "X-API-Key",
|
||||||
"api_key": {
|
"required": False,
|
||||||
"type": "string",
|
"secret": False,
|
||||||
"description": "API key for api_key auth",
|
"order": 4,
|
||||||
},
|
"depends_on": {"auth_type": "api_key"},
|
||||||
"api_key_header": {
|
},
|
||||||
"type": "string",
|
"bearer_token": {
|
||||||
"description": "Header name for API key (default: X-API-Key)",
|
"type": "string",
|
||||||
},
|
"label": "Bearer Token",
|
||||||
"username": {
|
"description": "Bearer token for authentication",
|
||||||
"type": "string",
|
"required": False,
|
||||||
"description": "Username for basic auth",
|
"secret": True,
|
||||||
},
|
"order": 3,
|
||||||
"password": {
|
"depends_on": {"auth_type": "bearer"},
|
||||||
"type": "string",
|
},
|
||||||
"description": "Password for basic auth",
|
"username": {
|
||||||
},
|
"type": "string",
|
||||||
},
|
"label": "Username",
|
||||||
|
"description": "Username for basic authentication",
|
||||||
|
"required": False,
|
||||||
|
"secret": False,
|
||||||
|
"order": 3,
|
||||||
|
"depends_on": {"auth_type": "basic"},
|
||||||
|
},
|
||||||
|
"password": {
|
||||||
|
"type": "string",
|
||||||
|
"label": "Password",
|
||||||
|
"description": "Password for basic authentication",
|
||||||
|
"required": False,
|
||||||
|
"secret": True,
|
||||||
|
"order": 4,
|
||||||
|
"depends_on": {"auth_type": "basic"},
|
||||||
},
|
},
|
||||||
"oauth_scopes": {
|
"oauth_scopes": {
|
||||||
"type": "array",
|
|
||||||
"description": "OAuth scopes to request (for oauth auth_type)",
|
|
||||||
"items": {"type": "string"},
|
|
||||||
"required": False,
|
|
||||||
"default": [],
|
|
||||||
},
|
|
||||||
"oauth_client_name": {
|
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "Client name for OAuth registration (for oauth auth_type)",
|
"label": "OAuth Scopes",
|
||||||
"default": "DocsGPT-MCP",
|
"description": "Comma-separated OAuth scopes to request",
|
||||||
"required": False,
|
|
||||||
},
|
|
||||||
"headers": {
|
|
||||||
"type": "object",
|
|
||||||
"description": "Custom headers to send with requests",
|
|
||||||
"required": False,
|
"required": False,
|
||||||
|
"secret": False,
|
||||||
|
"order": 3,
|
||||||
|
"depends_on": {"auth_type": "oauth"},
|
||||||
},
|
},
|
||||||
"timeout": {
|
"timeout": {
|
||||||
"type": "integer",
|
"type": "number",
|
||||||
"description": "Request timeout in seconds",
|
"label": "Timeout (seconds)",
|
||||||
|
"description": "Request timeout in seconds (1-300)",
|
||||||
"default": 30,
|
"default": 30,
|
||||||
"minimum": 1,
|
|
||||||
"maximum": 300,
|
|
||||||
"required": False,
|
|
||||||
},
|
|
||||||
"command": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "Command to run for STDIO transport (e.g., 'python')",
|
|
||||||
"required": False,
|
|
||||||
},
|
|
||||||
"args": {
|
|
||||||
"type": "array",
|
|
||||||
"description": "Arguments for STDIO command",
|
|
||||||
"items": {"type": "string"},
|
|
||||||
"required": False,
|
"required": False,
|
||||||
|
"secret": False,
|
||||||
|
"order": 10,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -573,31 +677,14 @@ class DocsGPTOAuth(OAuthClientProvider):
|
|||||||
scopes: str | list[str] | None = None,
|
scopes: str | list[str] | None = None,
|
||||||
client_name: str = "DocsGPT-MCP",
|
client_name: str = "DocsGPT-MCP",
|
||||||
user_id=None,
|
user_id=None,
|
||||||
db=None,
|
|
||||||
additional_client_metadata: dict[str, Any] | None = None,
|
additional_client_metadata: dict[str, Any] | None = None,
|
||||||
|
skip_redirect_validation: bool = False,
|
||||||
):
|
):
|
||||||
"""
|
|
||||||
Initialize custom OAuth client provider for DocsGPT.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
mcp_url: Full URL to the MCP endpoint
|
|
||||||
redirect_uri: Custom redirect URI for DocsGPT frontend
|
|
||||||
redis_client: Redis client for storing auth state
|
|
||||||
redis_prefix: Prefix for Redis keys
|
|
||||||
task_id: Task ID for tracking auth status
|
|
||||||
scopes: OAuth scopes to request
|
|
||||||
client_name: Name for this client during registration
|
|
||||||
user_id: User ID for token storage
|
|
||||||
db: Database instance for token storage
|
|
||||||
additional_client_metadata: Extra fields for OAuthClientMetadata
|
|
||||||
"""
|
|
||||||
|
|
||||||
self.redirect_uri = redirect_uri
|
self.redirect_uri = redirect_uri
|
||||||
self.redis_client = redis_client
|
self.redis_client = redis_client
|
||||||
self.redis_prefix = redis_prefix
|
self.redis_prefix = redis_prefix
|
||||||
self.task_id = task_id
|
self.task_id = task_id
|
||||||
self.user_id = user_id
|
self.user_id = user_id
|
||||||
self.db = db
|
|
||||||
|
|
||||||
parsed_url = urlparse(mcp_url)
|
parsed_url = urlparse(mcp_url)
|
||||||
self.server_base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
|
self.server_base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
|
||||||
@@ -614,7 +701,9 @@ class DocsGPTOAuth(OAuthClientProvider):
|
|||||||
)
|
)
|
||||||
|
|
||||||
storage = DBTokenStorage(
|
storage = DBTokenStorage(
|
||||||
server_url=self.server_base_url, user_id=self.user_id, db_client=self.db
|
server_url=self.server_base_url,
|
||||||
|
user_id=self.user_id,
|
||||||
|
expected_redirect_uri=None if skip_redirect_validation else redirect_uri,
|
||||||
)
|
)
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
@@ -646,22 +735,20 @@ class DocsGPTOAuth(OAuthClientProvider):
|
|||||||
async def redirect_handler(self, authorization_url: str) -> None:
|
async def redirect_handler(self, authorization_url: str) -> None:
|
||||||
"""Store auth URL and state in Redis for frontend to use."""
|
"""Store auth URL and state in Redis for frontend to use."""
|
||||||
auth_url, state = self._process_auth_url(authorization_url)
|
auth_url, state = self._process_auth_url(authorization_url)
|
||||||
logging.info(
|
logger.info("Processed auth_url: %s, state: %s", auth_url, state)
|
||||||
"[DocsGPTOAuth] Processed auth_url: %s, state: %s", auth_url, state
|
|
||||||
)
|
|
||||||
self.auth_url = auth_url
|
self.auth_url = auth_url
|
||||||
self.extracted_state = state
|
self.extracted_state = state
|
||||||
|
|
||||||
if self.redis_client and self.extracted_state:
|
if self.redis_client and self.extracted_state:
|
||||||
key = f"{self.redis_prefix}auth_url:{self.extracted_state}"
|
key = f"{self.redis_prefix}auth_url:{self.extracted_state}"
|
||||||
self.redis_client.setex(key, 600, auth_url)
|
self.redis_client.setex(key, 600, auth_url)
|
||||||
logging.info("[DocsGPTOAuth] Stored auth_url in Redis: %s", key)
|
logger.info("Stored auth_url in Redis: %s", key)
|
||||||
|
|
||||||
if self.task_id:
|
if self.task_id:
|
||||||
status_key = f"mcp_oauth_status:{self.task_id}"
|
status_key = f"mcp_oauth_status:{self.task_id}"
|
||||||
status_data = {
|
status_data = {
|
||||||
"status": "requires_redirect",
|
"status": "requires_redirect",
|
||||||
"message": "OAuth authorization required",
|
"message": "Authorization required",
|
||||||
"authorization_url": self.auth_url,
|
"authorization_url": self.auth_url,
|
||||||
"state": self.extracted_state,
|
"state": self.extracted_state,
|
||||||
"requires_oauth": True,
|
"requires_oauth": True,
|
||||||
@@ -681,7 +768,7 @@ class DocsGPTOAuth(OAuthClientProvider):
|
|||||||
status_key = f"mcp_oauth_status:{self.task_id}"
|
status_key = f"mcp_oauth_status:{self.task_id}"
|
||||||
status_data = {
|
status_data = {
|
||||||
"status": "awaiting_callback",
|
"status": "awaiting_callback",
|
||||||
"message": "Waiting for OAuth callback...",
|
"message": "Waiting for authorization...",
|
||||||
"authorization_url": self.auth_url,
|
"authorization_url": self.auth_url,
|
||||||
"state": self.extracted_state,
|
"state": self.extracted_state,
|
||||||
"requires_oauth": True,
|
"requires_oauth": True,
|
||||||
@@ -706,7 +793,7 @@ class DocsGPTOAuth(OAuthClientProvider):
|
|||||||
if self.task_id:
|
if self.task_id:
|
||||||
status_data = {
|
status_data = {
|
||||||
"status": "callback_received",
|
"status": "callback_received",
|
||||||
"message": "OAuth callback received, completing authentication...",
|
"message": "Completing authentication...",
|
||||||
"task_id": self.task_id,
|
"task_id": self.task_id,
|
||||||
}
|
}
|
||||||
self.redis_client.setex(status_key, 600, json.dumps(status_data))
|
self.redis_client.setex(status_key, 600, json.dumps(status_data))
|
||||||
@@ -726,67 +813,149 @@ class DocsGPTOAuth(OAuthClientProvider):
|
|||||||
await asyncio.sleep(poll_interval)
|
await asyncio.sleep(poll_interval)
|
||||||
self.redis_client.delete(f"{self.redis_prefix}auth_url:{self.extracted_state}")
|
self.redis_client.delete(f"{self.redis_prefix}auth_url:{self.extracted_state}")
|
||||||
self.redis_client.delete(f"{self.redis_prefix}state:{self.extracted_state}")
|
self.redis_client.delete(f"{self.redis_prefix}state:{self.extracted_state}")
|
||||||
raise Exception("OAuth callback timeout: no code received within 5 minutes")
|
raise Exception("OAuth timeout: no code received within 5 minutes")
|
||||||
|
|
||||||
|
|
||||||
|
class NonInteractiveOAuth(DocsGPTOAuth):
|
||||||
|
"""OAuth provider that fails fast on 401 instead of starting interactive auth.
|
||||||
|
|
||||||
|
Used during query execution to prevent the streaming response from blocking
|
||||||
|
while waiting for user authorization that will never come.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
kwargs.setdefault("task_id", None)
|
||||||
|
kwargs["skip_redirect_validation"] = True
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
async def redirect_handler(self, authorization_url: str) -> None:
|
||||||
|
raise Exception(
|
||||||
|
"OAuth session expired — please re-authorize this MCP server in tool settings."
|
||||||
|
)
|
||||||
|
|
||||||
|
async def callback_handler(self) -> tuple[str, str | None]:
|
||||||
|
raise Exception(
|
||||||
|
"OAuth session expired — please re-authorize this MCP server in tool settings."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class DBTokenStorage(TokenStorage):
|
class DBTokenStorage(TokenStorage):
|
||||||
def __init__(self, server_url: str, user_id: str, db_client):
|
def __init__(
|
||||||
|
self,
|
||||||
|
server_url: str,
|
||||||
|
user_id: str,
|
||||||
|
expected_redirect_uri: Optional[str] = None,
|
||||||
|
):
|
||||||
self.server_url = server_url
|
self.server_url = server_url
|
||||||
self.user_id = user_id
|
self.user_id = user_id
|
||||||
self.db_client = db_client
|
self.expected_redirect_uri = expected_redirect_uri
|
||||||
self.collection = db_client["connector_sessions"]
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_base_url(url: str) -> str:
|
def get_base_url(url: str) -> str:
|
||||||
parsed = urlparse(url)
|
parsed = urlparse(url)
|
||||||
return f"{parsed.scheme}://{parsed.netloc}"
|
return f"{parsed.scheme}://{parsed.netloc}"
|
||||||
|
|
||||||
def get_db_key(self) -> dict:
|
def _pg_provider(self) -> str:
|
||||||
return {
|
return f"mcp:{self.get_base_url(self.server_url)}"
|
||||||
"server_url": self.get_base_url(self.server_url),
|
|
||||||
"user_id": self.user_id,
|
def _fetch_session_data(self) -> dict:
|
||||||
}
|
"""Read the JSONB ``session_data`` blob for this MCP server row."""
|
||||||
|
from application.storage.db.repositories.connector_sessions import (
|
||||||
|
ConnectorSessionsRepository,
|
||||||
|
)
|
||||||
|
from application.storage.db.session import db_readonly
|
||||||
|
|
||||||
|
base_url = self.get_base_url(self.server_url)
|
||||||
|
with db_readonly() as conn:
|
||||||
|
row = ConnectorSessionsRepository(conn).get_by_user_and_server_url(
|
||||||
|
self.user_id, base_url,
|
||||||
|
)
|
||||||
|
if not row:
|
||||||
|
return {}
|
||||||
|
data = row.get("session_data") or {}
|
||||||
|
if isinstance(data, str):
|
||||||
|
try:
|
||||||
|
data = json.loads(data)
|
||||||
|
except ValueError:
|
||||||
|
return {}
|
||||||
|
return data if isinstance(data, dict) else {}
|
||||||
|
|
||||||
async def get_tokens(self) -> OAuthToken | None:
|
async def get_tokens(self) -> OAuthToken | None:
|
||||||
doc = await asyncio.to_thread(self.collection.find_one, self.get_db_key())
|
data = await asyncio.to_thread(self._fetch_session_data)
|
||||||
if not doc or "tokens" not in doc:
|
if not data or "tokens" not in data:
|
||||||
return None
|
return None
|
||||||
try:
|
try:
|
||||||
tokens = OAuthToken.model_validate(doc["tokens"])
|
return OAuthToken.model_validate(data["tokens"])
|
||||||
return tokens
|
|
||||||
except ValidationError as e:
|
except ValidationError as e:
|
||||||
logging.error(f"Could not load tokens: {e}")
|
logger.error("Could not load tokens: %s", e)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def _merge(self, patch: dict) -> None:
|
||||||
|
"""Shallow-merge ``patch`` into this row's ``session_data``.
|
||||||
|
|
||||||
|
Threads ``server_url`` through to the repository so it lands in
|
||||||
|
the scalar column — ``get_by_user_and_server_url`` needs that to
|
||||||
|
resolve the row (``NULL = 'https://...'`` is UNKNOWN in SQL).
|
||||||
|
"""
|
||||||
|
from application.storage.db.repositories.connector_sessions import (
|
||||||
|
ConnectorSessionsRepository,
|
||||||
|
)
|
||||||
|
from application.storage.db.session import db_session
|
||||||
|
|
||||||
|
base_url = self.get_base_url(self.server_url)
|
||||||
|
with db_session() as conn:
|
||||||
|
ConnectorSessionsRepository(conn).merge_session_data(
|
||||||
|
self.user_id, self._pg_provider(), base_url, patch,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _delete(self) -> None:
|
||||||
|
from application.storage.db.repositories.connector_sessions import (
|
||||||
|
ConnectorSessionsRepository,
|
||||||
|
)
|
||||||
|
from application.storage.db.session import db_session
|
||||||
|
|
||||||
|
with db_session() as conn:
|
||||||
|
ConnectorSessionsRepository(conn).delete(
|
||||||
|
self.user_id, self._pg_provider(),
|
||||||
|
)
|
||||||
|
|
||||||
async def set_tokens(self, tokens: OAuthToken) -> None:
|
async def set_tokens(self, tokens: OAuthToken) -> None:
|
||||||
await asyncio.to_thread(
|
base_url = self.get_base_url(self.server_url)
|
||||||
self.collection.update_one,
|
token_dump = tokens.model_dump()
|
||||||
self.get_db_key(),
|
await asyncio.to_thread(self._merge, {"tokens": token_dump})
|
||||||
{"$set": {"tokens": tokens.model_dump()}},
|
logger.info("Saved tokens for %s", base_url)
|
||||||
True,
|
|
||||||
)
|
|
||||||
logging.info(f"Saved tokens for {self.get_base_url(self.server_url)}")
|
|
||||||
|
|
||||||
async def get_client_info(self) -> OAuthClientInformationFull | None:
|
async def get_client_info(self) -> OAuthClientInformationFull | None:
|
||||||
doc = await asyncio.to_thread(self.collection.find_one, self.get_db_key())
|
data = await asyncio.to_thread(self._fetch_session_data)
|
||||||
if not doc or "client_info" not in doc:
|
base_url = self.get_base_url(self.server_url)
|
||||||
|
if not data or "client_info" not in data:
|
||||||
|
logger.debug("No client_info in DB for %s", base_url)
|
||||||
return None
|
return None
|
||||||
try:
|
try:
|
||||||
client_info = OAuthClientInformationFull.model_validate(doc["client_info"])
|
client_info = OAuthClientInformationFull.model_validate(data["client_info"])
|
||||||
tokens = await self.get_tokens()
|
if self.expected_redirect_uri:
|
||||||
if tokens is None:
|
stored_uris = [
|
||||||
logging.debug(
|
str(uri).rstrip("/") for uri in client_info.redirect_uris
|
||||||
"No tokens found, clearing client info to force fresh registration."
|
]
|
||||||
)
|
expected_uri = self.expected_redirect_uri.rstrip("/")
|
||||||
await asyncio.to_thread(
|
if expected_uri not in stored_uris:
|
||||||
self.collection.update_one,
|
logger.warning(
|
||||||
self.get_db_key(),
|
"Redirect URI mismatch for %s: expected=%s stored=%s — clearing.",
|
||||||
{"$unset": {"client_info": ""}},
|
base_url,
|
||||||
)
|
expected_uri,
|
||||||
return None
|
stored_uris,
|
||||||
|
)
|
||||||
|
# Drop ``tokens`` and ``client_info`` from the JSONB
|
||||||
|
# blob via merge_session_data's ``None``-drops-key
|
||||||
|
# semantics — preserves the row + any other keys.
|
||||||
|
await asyncio.to_thread(
|
||||||
|
self._merge,
|
||||||
|
{"tokens": None, "client_info": None},
|
||||||
|
)
|
||||||
|
return None
|
||||||
return client_info
|
return client_info
|
||||||
except ValidationError as e:
|
except ValidationError as e:
|
||||||
logging.error(f"Could not load client info: {e}")
|
logger.error("Could not load client info: %s", e)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _serialize_client_info(self, info: dict) -> dict:
|
def _serialize_client_info(self, info: dict) -> dict:
|
||||||
@@ -796,23 +965,38 @@ class DBTokenStorage(TokenStorage):
|
|||||||
|
|
||||||
async def set_client_info(self, client_info: OAuthClientInformationFull) -> None:
|
async def set_client_info(self, client_info: OAuthClientInformationFull) -> None:
|
||||||
serialized_info = self._serialize_client_info(client_info.model_dump())
|
serialized_info = self._serialize_client_info(client_info.model_dump())
|
||||||
|
base_url = self.get_base_url(self.server_url)
|
||||||
await asyncio.to_thread(
|
await asyncio.to_thread(
|
||||||
self.collection.update_one,
|
self._merge, {"client_info": serialized_info},
|
||||||
self.get_db_key(),
|
|
||||||
{"$set": {"client_info": serialized_info}},
|
|
||||||
True,
|
|
||||||
)
|
)
|
||||||
logging.info(f"Saved client info for {self.get_base_url(self.server_url)}")
|
logger.info("Saved client info for %s", base_url)
|
||||||
|
|
||||||
async def clear(self) -> None:
|
async def clear(self) -> None:
|
||||||
await asyncio.to_thread(self.collection.delete_one, self.get_db_key())
|
await asyncio.to_thread(self._delete)
|
||||||
logging.info(f"Cleared OAuth cache for {self.get_base_url(self.server_url)}")
|
logger.info("Cleared OAuth cache for %s", self.get_base_url(self.server_url))
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def clear_all(cls, db_client) -> None:
|
async def clear_all(cls, db_client=None) -> None:
|
||||||
collection = db_client["connector_sessions"]
|
"""Delete every MCP-tagged connector session row.
|
||||||
await asyncio.to_thread(collection.delete_many, {})
|
|
||||||
logging.info("Cleared all OAuth client cache data.")
|
``db_client`` retained for call-site compatibility but unused —
|
||||||
|
storage is Postgres-only now.
|
||||||
|
"""
|
||||||
|
from sqlalchemy import text
|
||||||
|
|
||||||
|
from application.storage.db.session import db_session
|
||||||
|
|
||||||
|
def _delete_all() -> None:
|
||||||
|
with db_session() as conn:
|
||||||
|
conn.execute(
|
||||||
|
text(
|
||||||
|
"DELETE FROM connector_sessions "
|
||||||
|
"WHERE provider LIKE 'mcp:%'"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
await asyncio.to_thread(_delete_all)
|
||||||
|
logger.info("Cleared all OAuth client cache data.")
|
||||||
|
|
||||||
|
|
||||||
class MCPOAuthManager:
|
class MCPOAuthManager:
|
||||||
@@ -851,7 +1035,7 @@ class MCPOAuthManager:
|
|||||||
|
|
||||||
return True
|
return True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Error handling OAuth callback: {e}")
|
logger.error("Error handling OAuth callback: %s", e)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def get_oauth_status(self, task_id: str) -> Dict[str, Any]:
|
def get_oauth_status(self, task_id: str) -> Dict[str, Any]:
|
||||||
|
|||||||
@@ -1,12 +1,14 @@
|
|||||||
from datetime import datetime
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
import re
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
from .base import Tool
|
from .base import Tool
|
||||||
from application.core.mongo_db import MongoDB
|
from application.storage.db.repositories.memories import MemoriesRepository
|
||||||
from application.core.settings import settings
|
from application.storage.db.session import db_readonly, db_session
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class MemoryTool(Tool):
|
class MemoryTool(Tool):
|
||||||
@@ -27,7 +29,7 @@ class MemoryTool(Tool):
|
|||||||
self.user_id: Optional[str] = user_id
|
self.user_id: Optional[str] = user_id
|
||||||
|
|
||||||
# Get tool_id from configuration (passed from user_tools._id in production)
|
# Get tool_id from configuration (passed from user_tools._id in production)
|
||||||
# In production, tool_id is the MongoDB ObjectId string from user_tools collection
|
# In production, tool_id is the UUID string from user_tools.id.
|
||||||
if tool_config and "tool_id" in tool_config:
|
if tool_config and "tool_id" in tool_config:
|
||||||
self.tool_id = tool_config["tool_id"]
|
self.tool_id = tool_config["tool_id"]
|
||||||
elif user_id:
|
elif user_id:
|
||||||
@@ -37,8 +39,35 @@ class MemoryTool(Tool):
|
|||||||
# Last resort fallback (shouldn't happen in normal use)
|
# Last resort fallback (shouldn't happen in normal use)
|
||||||
self.tool_id = str(uuid.uuid4())
|
self.tool_id = str(uuid.uuid4())
|
||||||
|
|
||||||
db = MongoDB.get_client()[settings.MONGO_DB_NAME]
|
def _pg_enabled(self) -> bool:
|
||||||
self.collection = db["memories"]
|
"""Return True if this MemoryTool's tool_id is a real ``user_tools.id``.
|
||||||
|
|
||||||
|
The ``memories`` PG table has a UUID foreign key to ``user_tools``.
|
||||||
|
The sentinel ``default_{uid}`` fallback tool_id is not a UUID and
|
||||||
|
has no row in ``user_tools``, so any storage operation would fail
|
||||||
|
the foreign-key check. After the Postgres cutover Postgres is the
|
||||||
|
only store, so for the sentinel case there is nowhere to read or
|
||||||
|
write — operations become no-ops and the tool returns an
|
||||||
|
explanatory error to the caller.
|
||||||
|
"""
|
||||||
|
tool_id = getattr(self, "tool_id", None)
|
||||||
|
if not tool_id or not isinstance(tool_id, str):
|
||||||
|
return False
|
||||||
|
if tool_id.startswith("default_"):
|
||||||
|
logger.debug(
|
||||||
|
"Skipping Postgres operation for MemoryTool with sentinel tool_id=%s",
|
||||||
|
tool_id,
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
from application.storage.db.base_repository import looks_like_uuid
|
||||||
|
|
||||||
|
if not looks_like_uuid(tool_id):
|
||||||
|
logger.debug(
|
||||||
|
"Skipping Postgres operation for MemoryTool with non-UUID tool_id=%s",
|
||||||
|
tool_id,
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
# -----------------------------
|
# -----------------------------
|
||||||
# Action implementations
|
# Action implementations
|
||||||
@@ -56,6 +85,12 @@ class MemoryTool(Tool):
|
|||||||
if not self.user_id:
|
if not self.user_id:
|
||||||
return "Error: MemoryTool requires a valid user_id."
|
return "Error: MemoryTool requires a valid user_id."
|
||||||
|
|
||||||
|
if not self._pg_enabled():
|
||||||
|
return (
|
||||||
|
"Error: MemoryTool is not configured with a persistent tool_id; "
|
||||||
|
"memory storage is unavailable for this session."
|
||||||
|
)
|
||||||
|
|
||||||
if action_name == "view":
|
if action_name == "view":
|
||||||
return self._view(
|
return self._view(
|
||||||
kwargs.get("path", "/"),
|
kwargs.get("path", "/"),
|
||||||
@@ -282,14 +317,10 @@ class MemoryTool(Tool):
|
|||||||
# Ensure path ends with / for proper prefix matching
|
# Ensure path ends with / for proper prefix matching
|
||||||
search_path = path if path.endswith("/") else path + "/"
|
search_path = path if path.endswith("/") else path + "/"
|
||||||
|
|
||||||
# Find all files that start with this directory path
|
with db_readonly() as conn:
|
||||||
query = {
|
docs = MemoriesRepository(conn).list_by_prefix(
|
||||||
"user_id": self.user_id,
|
self.user_id, self.tool_id, search_path
|
||||||
"tool_id": self.tool_id,
|
)
|
||||||
"path": {"$regex": f"^{re.escape(search_path)}"}
|
|
||||||
}
|
|
||||||
|
|
||||||
docs = list(self.collection.find(query, {"path": 1}))
|
|
||||||
|
|
||||||
if not docs:
|
if not docs:
|
||||||
return f"Directory: {path}\n(empty)"
|
return f"Directory: {path}\n(empty)"
|
||||||
@@ -310,7 +341,10 @@ class MemoryTool(Tool):
|
|||||||
|
|
||||||
def _view_file(self, path: str, view_range: Optional[List[int]] = None) -> str:
|
def _view_file(self, path: str, view_range: Optional[List[int]] = None) -> str:
|
||||||
"""View file contents with optional line range."""
|
"""View file contents with optional line range."""
|
||||||
doc = self.collection.find_one({"user_id": self.user_id, "tool_id": self.tool_id, "path": path})
|
with db_readonly() as conn:
|
||||||
|
doc = MemoriesRepository(conn).get_by_path(
|
||||||
|
self.user_id, self.tool_id, path
|
||||||
|
)
|
||||||
|
|
||||||
if not doc or not doc.get("content"):
|
if not doc or not doc.get("content"):
|
||||||
return f"Error: File not found: {path}"
|
return f"Error: File not found: {path}"
|
||||||
@@ -344,16 +378,10 @@ class MemoryTool(Tool):
|
|||||||
if validated_path == "/" or validated_path.endswith("/"):
|
if validated_path == "/" or validated_path.endswith("/"):
|
||||||
return "Error: Cannot create a file at directory path."
|
return "Error: Cannot create a file at directory path."
|
||||||
|
|
||||||
self.collection.update_one(
|
with db_session() as conn:
|
||||||
{"user_id": self.user_id, "tool_id": self.tool_id, "path": validated_path},
|
MemoriesRepository(conn).upsert(
|
||||||
{
|
self.user_id, self.tool_id, validated_path, file_text
|
||||||
"$set": {
|
)
|
||||||
"content": file_text,
|
|
||||||
"updated_at": datetime.now()
|
|
||||||
}
|
|
||||||
},
|
|
||||||
upsert=True
|
|
||||||
)
|
|
||||||
|
|
||||||
return f"File created: {validated_path}"
|
return f"File created: {validated_path}"
|
||||||
|
|
||||||
@@ -366,30 +394,29 @@ class MemoryTool(Tool):
|
|||||||
if not old_str:
|
if not old_str:
|
||||||
return "Error: old_str is required."
|
return "Error: old_str is required."
|
||||||
|
|
||||||
doc = self.collection.find_one({"user_id": self.user_id, "tool_id": self.tool_id, "path": validated_path})
|
with db_session() as conn:
|
||||||
|
repo = MemoriesRepository(conn)
|
||||||
|
doc = repo.get_by_path(self.user_id, self.tool_id, validated_path)
|
||||||
|
|
||||||
if not doc or not doc.get("content"):
|
if not doc or not doc.get("content"):
|
||||||
return f"Error: File not found: {validated_path}"
|
return f"Error: File not found: {validated_path}"
|
||||||
|
|
||||||
current_content = str(doc["content"])
|
current_content = str(doc["content"])
|
||||||
|
|
||||||
# Check if old_str exists (case-insensitive)
|
# Check if old_str exists (case-insensitive)
|
||||||
if old_str.lower() not in current_content.lower():
|
if old_str.lower() not in current_content.lower():
|
||||||
return f"Error: String '{old_str}' not found in file."
|
return f"Error: String '{old_str}' not found in file."
|
||||||
|
|
||||||
# Replace the string (case-insensitive)
|
# Case-insensitive replace
|
||||||
import re as regex_module
|
import re as regex_module
|
||||||
updated_content = regex_module.sub(regex_module.escape(old_str), new_str, current_content, flags=regex_module.IGNORECASE)
|
updated_content = regex_module.sub(
|
||||||
|
regex_module.escape(old_str),
|
||||||
|
new_str,
|
||||||
|
current_content,
|
||||||
|
flags=regex_module.IGNORECASE,
|
||||||
|
)
|
||||||
|
|
||||||
self.collection.update_one(
|
repo.upsert(self.user_id, self.tool_id, validated_path, updated_content)
|
||||||
{"user_id": self.user_id, "tool_id": self.tool_id, "path": validated_path},
|
|
||||||
{
|
|
||||||
"$set": {
|
|
||||||
"content": updated_content,
|
|
||||||
"updated_at": datetime.now()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
return f"File updated: {validated_path}"
|
return f"File updated: {validated_path}"
|
||||||
|
|
||||||
@@ -402,31 +429,25 @@ class MemoryTool(Tool):
|
|||||||
if not insert_text:
|
if not insert_text:
|
||||||
return "Error: insert_text is required."
|
return "Error: insert_text is required."
|
||||||
|
|
||||||
doc = self.collection.find_one({"user_id": self.user_id, "tool_id": self.tool_id, "path": validated_path})
|
with db_session() as conn:
|
||||||
|
repo = MemoriesRepository(conn)
|
||||||
|
doc = repo.get_by_path(self.user_id, self.tool_id, validated_path)
|
||||||
|
|
||||||
if not doc or not doc.get("content"):
|
if not doc or not doc.get("content"):
|
||||||
return f"Error: File not found: {validated_path}"
|
return f"Error: File not found: {validated_path}"
|
||||||
|
|
||||||
current_content = str(doc["content"])
|
current_content = str(doc["content"])
|
||||||
lines = current_content.split("\n")
|
lines = current_content.split("\n")
|
||||||
|
|
||||||
# Convert to 0-indexed
|
# Convert to 0-indexed
|
||||||
index = insert_line - 1
|
index = insert_line - 1
|
||||||
if index < 0 or index > len(lines):
|
if index < 0 or index > len(lines):
|
||||||
return f"Error: Invalid line number. File has {len(lines)} lines."
|
return f"Error: Invalid line number. File has {len(lines)} lines."
|
||||||
|
|
||||||
lines.insert(index, insert_text)
|
lines.insert(index, insert_text)
|
||||||
updated_content = "\n".join(lines)
|
updated_content = "\n".join(lines)
|
||||||
|
|
||||||
self.collection.update_one(
|
repo.upsert(self.user_id, self.tool_id, validated_path, updated_content)
|
||||||
{"user_id": self.user_id, "tool_id": self.tool_id, "path": validated_path},
|
|
||||||
{
|
|
||||||
"$set": {
|
|
||||||
"content": updated_content,
|
|
||||||
"updated_at": datetime.now()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
return f"Text inserted at line {insert_line} in {validated_path}"
|
return f"Text inserted at line {insert_line} in {validated_path}"
|
||||||
|
|
||||||
@@ -438,39 +459,36 @@ class MemoryTool(Tool):
|
|||||||
|
|
||||||
if validated_path == "/":
|
if validated_path == "/":
|
||||||
# Delete all files for this user and tool
|
# Delete all files for this user and tool
|
||||||
result = self.collection.delete_many({"user_id": self.user_id, "tool_id": self.tool_id})
|
with db_session() as conn:
|
||||||
return f"Deleted {result.deleted_count} file(s) from memory."
|
deleted = MemoriesRepository(conn).delete_all(
|
||||||
|
self.user_id, self.tool_id
|
||||||
|
)
|
||||||
|
return f"Deleted {deleted} file(s) from memory."
|
||||||
|
|
||||||
# Check if it's a directory (ends with /)
|
# Check if it's a directory (ends with /)
|
||||||
if validated_path.endswith("/"):
|
if validated_path.endswith("/"):
|
||||||
# Delete all files in directory
|
with db_session() as conn:
|
||||||
result = self.collection.delete_many({
|
deleted = MemoriesRepository(conn).delete_by_prefix(
|
||||||
"user_id": self.user_id,
|
self.user_id, self.tool_id, validated_path
|
||||||
"tool_id": self.tool_id,
|
)
|
||||||
"path": {"$regex": f"^{re.escape(validated_path)}"}
|
return f"Deleted directory and {deleted} file(s)."
|
||||||
})
|
|
||||||
return f"Deleted directory and {result.deleted_count} file(s)."
|
|
||||||
|
|
||||||
# Try to delete as directory first (without trailing slash)
|
# Try as directory first (without trailing slash)
|
||||||
# Check if any files start with this path + /
|
|
||||||
search_path = validated_path + "/"
|
search_path = validated_path + "/"
|
||||||
directory_result = self.collection.delete_many({
|
with db_session() as conn:
|
||||||
"user_id": self.user_id,
|
repo = MemoriesRepository(conn)
|
||||||
"tool_id": self.tool_id,
|
directory_deleted = repo.delete_by_prefix(
|
||||||
"path": {"$regex": f"^{re.escape(search_path)}"}
|
self.user_id, self.tool_id, search_path
|
||||||
})
|
)
|
||||||
|
if directory_deleted > 0:
|
||||||
|
return f"Deleted directory and {directory_deleted} file(s)."
|
||||||
|
|
||||||
if directory_result.deleted_count > 0:
|
# Otherwise delete a single file
|
||||||
return f"Deleted directory and {directory_result.deleted_count} file(s)."
|
file_deleted = repo.delete_by_path(
|
||||||
|
self.user_id, self.tool_id, validated_path
|
||||||
|
)
|
||||||
|
|
||||||
# Delete single file
|
if file_deleted:
|
||||||
result = self.collection.delete_one({
|
|
||||||
"user_id": self.user_id,
|
|
||||||
"tool_id": self.tool_id,
|
|
||||||
"path": validated_path
|
|
||||||
})
|
|
||||||
|
|
||||||
if result.deleted_count:
|
|
||||||
return f"Deleted: {validated_path}"
|
return f"Deleted: {validated_path}"
|
||||||
return f"Error: File not found: {validated_path}"
|
return f"Error: File not found: {validated_path}"
|
||||||
|
|
||||||
@@ -485,62 +503,46 @@ class MemoryTool(Tool):
|
|||||||
if validated_old == "/" or validated_new == "/":
|
if validated_old == "/" or validated_new == "/":
|
||||||
return "Error: Cannot rename root directory."
|
return "Error: Cannot rename root directory."
|
||||||
|
|
||||||
# Check if renaming a directory
|
# Directory rename: do all path updates inside one transaction so
|
||||||
|
# the rename is atomic from the caller's perspective.
|
||||||
if validated_old.endswith("/"):
|
if validated_old.endswith("/"):
|
||||||
# Ensure validated_new also ends with / for proper path replacement
|
# Ensure validated_new also ends with / for proper path replacement
|
||||||
if not validated_new.endswith("/"):
|
if not validated_new.endswith("/"):
|
||||||
validated_new = validated_new + "/"
|
validated_new = validated_new + "/"
|
||||||
|
|
||||||
# Find all files in the old directory
|
with db_session() as conn:
|
||||||
docs = list(self.collection.find({
|
repo = MemoriesRepository(conn)
|
||||||
"user_id": self.user_id,
|
docs = repo.list_by_prefix(
|
||||||
"tool_id": self.tool_id,
|
self.user_id, self.tool_id, validated_old
|
||||||
"path": {"$regex": f"^{re.escape(validated_old)}"}
|
|
||||||
}))
|
|
||||||
|
|
||||||
if not docs:
|
|
||||||
return f"Error: Directory not found: {validated_old}"
|
|
||||||
|
|
||||||
# Update paths for all files
|
|
||||||
for doc in docs:
|
|
||||||
old_file_path = doc["path"]
|
|
||||||
new_file_path = old_file_path.replace(validated_old, validated_new, 1)
|
|
||||||
|
|
||||||
self.collection.update_one(
|
|
||||||
{"_id": doc["_id"]},
|
|
||||||
{"$set": {"path": new_file_path, "updated_at": datetime.now()}}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if not docs:
|
||||||
|
return f"Error: Directory not found: {validated_old}"
|
||||||
|
|
||||||
|
for doc in docs:
|
||||||
|
old_file_path = doc["path"]
|
||||||
|
new_file_path = old_file_path.replace(
|
||||||
|
validated_old, validated_new, 1
|
||||||
|
)
|
||||||
|
repo.update_path(
|
||||||
|
self.user_id, self.tool_id, old_file_path, new_file_path
|
||||||
|
)
|
||||||
|
|
||||||
return f"Renamed directory: {validated_old} -> {validated_new} ({len(docs)} files)"
|
return f"Renamed directory: {validated_old} -> {validated_new} ({len(docs)} files)"
|
||||||
|
|
||||||
# Rename single file
|
# Single-file rename: lookup, collision check, and update in one txn.
|
||||||
doc = self.collection.find_one({
|
with db_session() as conn:
|
||||||
"user_id": self.user_id,
|
repo = MemoriesRepository(conn)
|
||||||
"tool_id": self.tool_id,
|
doc = repo.get_by_path(self.user_id, self.tool_id, validated_old)
|
||||||
"path": validated_old
|
if not doc:
|
||||||
})
|
return f"Error: File not found: {validated_old}"
|
||||||
|
|
||||||
if not doc:
|
existing = repo.get_by_path(self.user_id, self.tool_id, validated_new)
|
||||||
return f"Error: File not found: {validated_old}"
|
if existing:
|
||||||
|
return f"Error: File already exists at {validated_new}"
|
||||||
|
|
||||||
# Check if new path already exists
|
repo.update_path(
|
||||||
existing = self.collection.find_one({
|
self.user_id, self.tool_id, validated_old, validated_new
|
||||||
"user_id": self.user_id,
|
)
|
||||||
"tool_id": self.tool_id,
|
|
||||||
"path": validated_new
|
|
||||||
})
|
|
||||||
|
|
||||||
if existing:
|
|
||||||
return f"Error: File already exists at {validated_new}"
|
|
||||||
|
|
||||||
# Delete the old document and create a new one with the new path
|
|
||||||
self.collection.delete_one({"user_id": self.user_id, "tool_id": self.tool_id, "path": validated_old})
|
|
||||||
self.collection.insert_one({
|
|
||||||
"user_id": self.user_id,
|
|
||||||
"tool_id": self.tool_id,
|
|
||||||
"path": validated_new,
|
|
||||||
"content": doc.get("content", ""),
|
|
||||||
"updated_at": datetime.now()
|
|
||||||
})
|
|
||||||
|
|
||||||
return f"Renamed: {validated_old} -> {validated_new}"
|
return f"Renamed: {validated_old} -> {validated_new}"
|
||||||
|
|||||||
@@ -1,10 +1,16 @@
|
|||||||
from datetime import datetime
|
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
from .base import Tool
|
from .base import Tool
|
||||||
from application.core.mongo_db import MongoDB
|
from application.storage.db.repositories.notes import NotesRepository
|
||||||
from application.core.settings import settings
|
from application.storage.db.session import db_readonly, db_session
|
||||||
|
|
||||||
|
|
||||||
|
# Stable synthetic title used in the Postgres ``notes.title`` column.
|
||||||
|
# The notes tool stores one note per (user_id, tool_id); there is no
|
||||||
|
# user-facing title. PG requires ``title`` NOT NULL, so we write a stable
|
||||||
|
# constant alongside the actual note body in ``content``.
|
||||||
|
_NOTE_TITLE = "note"
|
||||||
|
|
||||||
|
|
||||||
class NotesTool(Tool):
|
class NotesTool(Tool):
|
||||||
@@ -25,7 +31,6 @@ class NotesTool(Tool):
|
|||||||
self.user_id: Optional[str] = user_id
|
self.user_id: Optional[str] = user_id
|
||||||
|
|
||||||
# Get tool_id from configuration (passed from user_tools._id in production)
|
# Get tool_id from configuration (passed from user_tools._id in production)
|
||||||
# In production, tool_id is the MongoDB ObjectId string from user_tools collection
|
|
||||||
if tool_config and "tool_id" in tool_config:
|
if tool_config and "tool_id" in tool_config:
|
||||||
self.tool_id = tool_config["tool_id"]
|
self.tool_id = tool_config["tool_id"]
|
||||||
elif user_id:
|
elif user_id:
|
||||||
@@ -35,8 +40,24 @@ class NotesTool(Tool):
|
|||||||
# Last resort fallback (shouldn't happen in normal use)
|
# Last resort fallback (shouldn't happen in normal use)
|
||||||
self.tool_id = str(uuid.uuid4())
|
self.tool_id = str(uuid.uuid4())
|
||||||
|
|
||||||
db = MongoDB.get_client()[settings.MONGO_DB_NAME]
|
self._last_artifact_id: Optional[str] = None
|
||||||
self.collection = db["notes"]
|
|
||||||
|
def _pg_enabled(self) -> bool:
|
||||||
|
"""Return True only when ``tool_id`` is a real ``user_tools.id`` UUID.
|
||||||
|
|
||||||
|
``notes.tool_id`` is a UUID FK to ``user_tools``; repo queries
|
||||||
|
``CAST(:tool_id AS uuid)``. The sentinel ``default_{uid}``
|
||||||
|
fallback is neither a UUID nor a ``user_tools`` row, so any DB
|
||||||
|
operation would crash. Mirror MemoryTool's guard and no-op.
|
||||||
|
"""
|
||||||
|
tool_id = getattr(self, "tool_id", None)
|
||||||
|
if not tool_id or not isinstance(tool_id, str):
|
||||||
|
return False
|
||||||
|
if tool_id.startswith("default_"):
|
||||||
|
return False
|
||||||
|
from application.storage.db.base_repository import looks_like_uuid
|
||||||
|
|
||||||
|
return looks_like_uuid(tool_id)
|
||||||
|
|
||||||
# -----------------------------
|
# -----------------------------
|
||||||
# Action implementations
|
# Action implementations
|
||||||
@@ -52,7 +73,15 @@ class NotesTool(Tool):
|
|||||||
A human-readable string result.
|
A human-readable string result.
|
||||||
"""
|
"""
|
||||||
if not self.user_id:
|
if not self.user_id:
|
||||||
return "Error: NotesTool requires a valid user_id."
|
return "Error: NotesTool requires a valid user_id."
|
||||||
|
|
||||||
|
if not self._pg_enabled():
|
||||||
|
return (
|
||||||
|
"Error: NotesTool is not configured with a persistent "
|
||||||
|
"tool_id; note storage is unavailable for this session."
|
||||||
|
)
|
||||||
|
|
||||||
|
self._last_artifact_id = None
|
||||||
|
|
||||||
if action_name == "view":
|
if action_name == "view":
|
||||||
return self._get_note()
|
return self._get_note()
|
||||||
@@ -125,35 +154,51 @@ class NotesTool(Tool):
|
|||||||
"""Return configuration requirements (none for now)."""
|
"""Return configuration requirements (none for now)."""
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
def get_artifact_id(self, action_name: str, **kwargs: Any) -> Optional[str]:
|
||||||
|
return self._last_artifact_id
|
||||||
|
|
||||||
# -----------------------------
|
# -----------------------------
|
||||||
# Internal helpers (single-note)
|
# Internal helpers (single-note)
|
||||||
# -----------------------------
|
# -----------------------------
|
||||||
|
def _fetch_note(self) -> Optional[dict]:
|
||||||
|
"""Read the note row for this (user, tool) from Postgres."""
|
||||||
|
with db_readonly() as conn:
|
||||||
|
return NotesRepository(conn).get_for_user_tool(self.user_id, self.tool_id)
|
||||||
|
|
||||||
def _get_note(self) -> str:
|
def _get_note(self) -> str:
|
||||||
doc = self.collection.find_one({"user_id": self.user_id, "tool_id": self.tool_id})
|
doc = self._fetch_note()
|
||||||
if not doc or not doc.get("note"):
|
# ``content`` is the PG column; expose as ``note`` to callers via the
|
||||||
|
# textual return value. Frontends that read the artifact via the
|
||||||
|
# repo dict get ``content`` (PG-native) plus the artifact id below.
|
||||||
|
body = (doc or {}).get("content")
|
||||||
|
if not doc or not body:
|
||||||
return "No note found."
|
return "No note found."
|
||||||
return str(doc["note"])
|
if doc.get("id") is not None:
|
||||||
|
self._last_artifact_id = str(doc.get("id"))
|
||||||
|
return str(body)
|
||||||
|
|
||||||
def _overwrite_note(self, content: str) -> str:
|
def _overwrite_note(self, content: str) -> str:
|
||||||
content = (content or "").strip()
|
content = (content or "").strip()
|
||||||
if not content:
|
if not content:
|
||||||
return "Note content required."
|
return "Note content required."
|
||||||
self.collection.update_one(
|
with db_session() as conn:
|
||||||
{"user_id": self.user_id, "tool_id": self.tool_id},
|
row = NotesRepository(conn).upsert(
|
||||||
{"$set": {"note": content, "updated_at": datetime.utcnow()}},
|
self.user_id, self.tool_id, _NOTE_TITLE, content
|
||||||
upsert=True, # ✅ create if missing
|
)
|
||||||
)
|
if row and row.get("id") is not None:
|
||||||
|
self._last_artifact_id = str(row.get("id"))
|
||||||
return "Note saved."
|
return "Note saved."
|
||||||
|
|
||||||
def _str_replace(self, old_str: str, new_str: str) -> str:
|
def _str_replace(self, old_str: str, new_str: str) -> str:
|
||||||
if not old_str:
|
if not old_str:
|
||||||
return "old_str is required."
|
return "old_str is required."
|
||||||
|
|
||||||
doc = self.collection.find_one({"user_id": self.user_id, "tool_id": self.tool_id})
|
doc = self._fetch_note()
|
||||||
if not doc or not doc.get("note"):
|
existing = (doc or {}).get("content")
|
||||||
|
if not doc or not existing:
|
||||||
return "No note found."
|
return "No note found."
|
||||||
|
|
||||||
current_note = str(doc["note"])
|
current_note = str(existing)
|
||||||
|
|
||||||
# Case-insensitive search
|
# Case-insensitive search
|
||||||
if old_str.lower() not in current_note.lower():
|
if old_str.lower() not in current_note.lower():
|
||||||
@@ -163,21 +208,24 @@ class NotesTool(Tool):
|
|||||||
import re
|
import re
|
||||||
updated_note = re.sub(re.escape(old_str), new_str, current_note, flags=re.IGNORECASE)
|
updated_note = re.sub(re.escape(old_str), new_str, current_note, flags=re.IGNORECASE)
|
||||||
|
|
||||||
self.collection.update_one(
|
with db_session() as conn:
|
||||||
{"user_id": self.user_id, "tool_id": self.tool_id},
|
row = NotesRepository(conn).upsert(
|
||||||
{"$set": {"note": updated_note, "updated_at": datetime.utcnow()}},
|
self.user_id, self.tool_id, _NOTE_TITLE, updated_note
|
||||||
)
|
)
|
||||||
|
if row and row.get("id") is not None:
|
||||||
|
self._last_artifact_id = str(row.get("id"))
|
||||||
return "Note updated."
|
return "Note updated."
|
||||||
|
|
||||||
def _insert(self, line_number: int, text: str) -> str:
|
def _insert(self, line_number: int, text: str) -> str:
|
||||||
if not text:
|
if not text:
|
||||||
return "Text is required."
|
return "Text is required."
|
||||||
|
|
||||||
doc = self.collection.find_one({"user_id": self.user_id, "tool_id": self.tool_id})
|
doc = self._fetch_note()
|
||||||
if not doc or not doc.get("note"):
|
existing = (doc or {}).get("content")
|
||||||
|
if not doc or not existing:
|
||||||
return "No note found."
|
return "No note found."
|
||||||
|
|
||||||
current_note = str(doc["note"])
|
current_note = str(existing)
|
||||||
lines = current_note.split("\n")
|
lines = current_note.split("\n")
|
||||||
|
|
||||||
# Convert to 0-indexed and validate
|
# Convert to 0-indexed and validate
|
||||||
@@ -188,12 +236,23 @@ class NotesTool(Tool):
|
|||||||
lines.insert(index, text)
|
lines.insert(index, text)
|
||||||
updated_note = "\n".join(lines)
|
updated_note = "\n".join(lines)
|
||||||
|
|
||||||
self.collection.update_one(
|
with db_session() as conn:
|
||||||
{"user_id": self.user_id, "tool_id": self.tool_id},
|
row = NotesRepository(conn).upsert(
|
||||||
{"$set": {"note": updated_note, "updated_at": datetime.utcnow()}},
|
self.user_id, self.tool_id, _NOTE_TITLE, updated_note
|
||||||
)
|
)
|
||||||
|
if row and row.get("id") is not None:
|
||||||
|
self._last_artifact_id = str(row.get("id"))
|
||||||
return "Text inserted."
|
return "Text inserted."
|
||||||
|
|
||||||
def _delete_note(self) -> str:
|
def _delete_note(self) -> str:
|
||||||
res = self.collection.delete_one({"user_id": self.user_id, "tool_id": self.tool_id})
|
# Capture the id (for artifact tracking) before deleting.
|
||||||
return "Note deleted." if res.deleted_count else "No note found to delete."
|
existing = self._fetch_note()
|
||||||
|
if not existing:
|
||||||
|
return "No note found to delete."
|
||||||
|
with db_session() as conn:
|
||||||
|
deleted = NotesRepository(conn).delete(self.user_id, self.tool_id)
|
||||||
|
if not deleted:
|
||||||
|
return "No note found to delete."
|
||||||
|
if existing.get("id") is not None:
|
||||||
|
self._last_artifact_id = str(existing.get("id"))
|
||||||
|
return "Note deleted."
|
||||||
|
|||||||
@@ -71,7 +71,7 @@ class NtfyTool(Tool):
|
|||||||
if self.token:
|
if self.token:
|
||||||
headers["Authorization"] = f"Basic {self.token}"
|
headers["Authorization"] = f"Basic {self.token}"
|
||||||
data = message.encode("utf-8")
|
data = message.encode("utf-8")
|
||||||
response = requests.post(url, headers=headers, data=data)
|
response = requests.post(url, headers=headers, data=data, timeout=100)
|
||||||
return {"status_code": response.status_code, "message": "Message sent"}
|
return {"status_code": response.status_code, "message": "Message sent"}
|
||||||
|
|
||||||
def get_actions_metadata(self):
|
def get_actions_metadata(self):
|
||||||
@@ -116,12 +116,13 @@ class NtfyTool(Tool):
|
|||||||
]
|
]
|
||||||
|
|
||||||
def get_config_requirements(self):
|
def get_config_requirements(self):
|
||||||
"""
|
|
||||||
Specify the configuration requirements.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict: Dictionary describing required config parameters.
|
|
||||||
"""
|
|
||||||
return {
|
return {
|
||||||
"token": {"type": "string", "description": "Access token for authentication"},
|
"token": {
|
||||||
|
"type": "string",
|
||||||
|
"label": "Access Token",
|
||||||
|
"description": "Ntfy access token for authentication",
|
||||||
|
"required": True,
|
||||||
|
"secret": True,
|
||||||
|
"order": 1,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
@@ -1,6 +1,12 @@
|
|||||||
import psycopg2
|
import logging
|
||||||
|
|
||||||
|
import psycopg
|
||||||
|
|
||||||
from application.agents.tools.base import Tool
|
from application.agents.tools.base import Tool
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class PostgresTool(Tool):
|
class PostgresTool(Tool):
|
||||||
"""
|
"""
|
||||||
PostgreSQL Database Tool
|
PostgreSQL Database Tool
|
||||||
@@ -17,25 +23,25 @@ class PostgresTool(Tool):
|
|||||||
"postgres_execute_sql": self._execute_sql,
|
"postgres_execute_sql": self._execute_sql,
|
||||||
"postgres_get_schema": self._get_schema,
|
"postgres_get_schema": self._get_schema,
|
||||||
}
|
}
|
||||||
|
if action_name not in actions:
|
||||||
if action_name in actions:
|
|
||||||
return actions[action_name](**kwargs)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unknown action: {action_name}")
|
raise ValueError(f"Unknown action: {action_name}")
|
||||||
|
return actions[action_name](**kwargs)
|
||||||
|
|
||||||
def _execute_sql(self, sql_query):
|
def _execute_sql(self, sql_query):
|
||||||
"""
|
"""
|
||||||
Executes an SQL query against the PostgreSQL database using a connection string.
|
Executes an SQL query against the PostgreSQL database using a connection string.
|
||||||
"""
|
"""
|
||||||
conn = None # Initialize conn to None for error handling
|
conn = None
|
||||||
try:
|
try:
|
||||||
conn = psycopg2.connect(self.connection_string)
|
conn = psycopg.connect(self.connection_string)
|
||||||
cur = conn.cursor()
|
cur = conn.cursor()
|
||||||
cur.execute(sql_query)
|
cur.execute(sql_query)
|
||||||
conn.commit()
|
conn.commit()
|
||||||
|
|
||||||
if sql_query.strip().lower().startswith("select"):
|
if sql_query.strip().lower().startswith("select"):
|
||||||
column_names = [desc[0] for desc in cur.description] if cur.description else []
|
column_names = (
|
||||||
|
[desc[0] for desc in cur.description] if cur.description else []
|
||||||
|
)
|
||||||
results = []
|
results = []
|
||||||
rows = cur.fetchall()
|
rows = cur.fetchall()
|
||||||
for row in rows:
|
for row in rows:
|
||||||
@@ -43,7 +49,9 @@ class PostgresTool(Tool):
|
|||||||
response_data = {"data": results, "column_names": column_names}
|
response_data = {"data": results, "column_names": column_names}
|
||||||
else:
|
else:
|
||||||
row_count = cur.rowcount
|
row_count = cur.rowcount
|
||||||
response_data = {"message": f"Query executed successfully, {row_count} rows affected."}
|
response_data = {
|
||||||
|
"message": f"Query executed successfully, {row_count} rows affected."
|
||||||
|
}
|
||||||
|
|
||||||
cur.close()
|
cur.close()
|
||||||
return {
|
return {
|
||||||
@@ -52,28 +60,29 @@ class PostgresTool(Tool):
|
|||||||
"response_data": response_data,
|
"response_data": response_data,
|
||||||
}
|
}
|
||||||
|
|
||||||
except psycopg2.Error as e:
|
except psycopg.Error as e:
|
||||||
error_message = f"Database error: {e}"
|
error_message = f"Database error: {e}"
|
||||||
print(f"Database error: {e}")
|
logger.error("PostgreSQL execute_sql error: %s", e)
|
||||||
return {
|
return {
|
||||||
"status_code": 500,
|
"status_code": 500,
|
||||||
"message": "Failed to execute SQL query.",
|
"message": "Failed to execute SQL query.",
|
||||||
"error": error_message,
|
"error": error_message,
|
||||||
}
|
}
|
||||||
finally:
|
finally:
|
||||||
if conn: # Ensure connection is closed even if errors occur
|
if conn:
|
||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
def _get_schema(self, db_name):
|
def _get_schema(self, db_name):
|
||||||
"""
|
"""
|
||||||
Retrieves the schema of the PostgreSQL database using a connection string.
|
Retrieves the schema of the PostgreSQL database using a connection string.
|
||||||
"""
|
"""
|
||||||
conn = None # Initialize conn to None for error handling
|
conn = None
|
||||||
try:
|
try:
|
||||||
conn = psycopg2.connect(self.connection_string)
|
conn = psycopg.connect(self.connection_string)
|
||||||
cur = conn.cursor()
|
cur = conn.cursor()
|
||||||
|
|
||||||
cur.execute("""
|
cur.execute(
|
||||||
|
"""
|
||||||
SELECT
|
SELECT
|
||||||
table_name,
|
table_name,
|
||||||
column_name,
|
column_name,
|
||||||
@@ -87,19 +96,22 @@ class PostgresTool(Tool):
|
|||||||
ORDER BY
|
ORDER BY
|
||||||
table_name,
|
table_name,
|
||||||
ordinal_position;
|
ordinal_position;
|
||||||
""")
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
schema_data = {}
|
schema_data = {}
|
||||||
for row in cur.fetchall():
|
for row in cur.fetchall():
|
||||||
table_name, column_name, data_type, column_default, is_nullable = row
|
table_name, column_name, data_type, column_default, is_nullable = row
|
||||||
if table_name not in schema_data:
|
if table_name not in schema_data:
|
||||||
schema_data[table_name] = []
|
schema_data[table_name] = []
|
||||||
schema_data[table_name].append({
|
schema_data[table_name].append(
|
||||||
"column_name": column_name,
|
{
|
||||||
"data_type": data_type,
|
"column_name": column_name,
|
||||||
"column_default": column_default,
|
"data_type": data_type,
|
||||||
"is_nullable": is_nullable
|
"column_default": column_default,
|
||||||
})
|
"is_nullable": is_nullable,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
cur.close()
|
cur.close()
|
||||||
return {
|
return {
|
||||||
@@ -108,16 +120,16 @@ class PostgresTool(Tool):
|
|||||||
"schema": schema_data,
|
"schema": schema_data,
|
||||||
}
|
}
|
||||||
|
|
||||||
except psycopg2.Error as e:
|
except psycopg.Error as e:
|
||||||
error_message = f"Database error: {e}"
|
error_message = f"Database error: {e}"
|
||||||
print(f"Database error: {e}")
|
logger.error("PostgreSQL get_schema error: %s", e)
|
||||||
return {
|
return {
|
||||||
"status_code": 500,
|
"status_code": 500,
|
||||||
"message": "Failed to retrieve database schema.",
|
"message": "Failed to retrieve database schema.",
|
||||||
"error": error_message,
|
"error": error_message,
|
||||||
}
|
}
|
||||||
finally:
|
finally:
|
||||||
if conn: # Ensure connection is closed even if errors occur
|
if conn:
|
||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
def get_actions_metadata(self):
|
def get_actions_metadata(self):
|
||||||
@@ -158,6 +170,10 @@ class PostgresTool(Tool):
|
|||||||
return {
|
return {
|
||||||
"token": {
|
"token": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "PostgreSQL database connection string (e.g., 'postgresql://user:password@host:port/dbname')",
|
"label": "Connection String",
|
||||||
|
"description": "PostgreSQL database connection string",
|
||||||
|
"required": True,
|
||||||
|
"secret": True,
|
||||||
|
"order": 1,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import requests
|
import requests
|
||||||
from markdownify import markdownify
|
from markdownify import markdownify
|
||||||
from application.agents.tools.base import Tool
|
from application.agents.tools.base import Tool
|
||||||
from urllib.parse import urlparse
|
from application.core.url_validation import validate_url, SSRFError
|
||||||
|
|
||||||
class ReadWebpageTool(Tool):
|
class ReadWebpageTool(Tool):
|
||||||
"""
|
"""
|
||||||
@@ -31,11 +31,12 @@ class ReadWebpageTool(Tool):
|
|||||||
if not url:
|
if not url:
|
||||||
return "Error: URL parameter is missing."
|
return "Error: URL parameter is missing."
|
||||||
|
|
||||||
# Ensure the URL has a scheme (if not, default to http)
|
# Validate URL to prevent SSRF attacks
|
||||||
parsed_url = urlparse(url)
|
try:
|
||||||
if not parsed_url.scheme:
|
url = validate_url(url)
|
||||||
url = "http://" + url
|
except SSRFError as e:
|
||||||
|
return f"Error: URL validation failed - {e}"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = requests.get(url, timeout=10, headers={'User-Agent': 'DocsGPT-Agent/1.0'})
|
response = requests.get(url, timeout=10, headers={'User-Agent': 'DocsGPT-Agent/1.0'})
|
||||||
response.raise_for_status() # Raise an exception for HTTP errors (4xx or 5xx)
|
response.raise_for_status() # Raise an exception for HTTP errors (4xx or 5xx)
|
||||||
|
|||||||
342
application/agents/tools/spec_parser.py
Normal file
342
application/agents/tools/spec_parser.py
Normal file
@@ -0,0 +1,342 @@
|
|||||||
|
"""
|
||||||
|
API Specification Parser
|
||||||
|
|
||||||
|
Parses OpenAPI 3.x and Swagger 2.0 specifications and converts them
|
||||||
|
to API Tool action definitions for use in DocsGPT.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
SUPPORTED_METHODS = frozenset(
|
||||||
|
{"get", "post", "put", "delete", "patch", "head", "options"}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_spec(spec_content: str) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]:
|
||||||
|
"""
|
||||||
|
Parse an API specification and convert operations to action definitions.
|
||||||
|
|
||||||
|
Supports OpenAPI 3.x and Swagger 2.0 formats in JSON or YAML.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
spec_content: Raw specification content as string
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (metadata dict, list of action dicts)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the spec is invalid or uses an unsupported format
|
||||||
|
"""
|
||||||
|
spec = _load_spec(spec_content)
|
||||||
|
_validate_spec(spec)
|
||||||
|
|
||||||
|
is_swagger = "swagger" in spec
|
||||||
|
metadata = _extract_metadata(spec, is_swagger)
|
||||||
|
actions = _extract_actions(spec, is_swagger)
|
||||||
|
|
||||||
|
return metadata, actions
|
||||||
|
|
||||||
|
|
||||||
|
def _load_spec(content: str) -> Dict[str, Any]:
|
||||||
|
"""Parse spec content from JSON or YAML string."""
|
||||||
|
content = content.strip()
|
||||||
|
if not content:
|
||||||
|
raise ValueError("Empty specification content")
|
||||||
|
try:
|
||||||
|
if content.startswith("{"):
|
||||||
|
return json.loads(content)
|
||||||
|
return yaml.safe_load(content)
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
raise ValueError(f"Invalid JSON format: {e.msg}")
|
||||||
|
except yaml.YAMLError as e:
|
||||||
|
raise ValueError(f"Invalid YAML format: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_spec(spec: Dict[str, Any]) -> None:
|
||||||
|
"""Validate spec version and required fields."""
|
||||||
|
if not isinstance(spec, dict):
|
||||||
|
raise ValueError("Specification must be a valid object")
|
||||||
|
openapi_version = spec.get("openapi", "")
|
||||||
|
swagger_version = spec.get("swagger", "")
|
||||||
|
|
||||||
|
if not (openapi_version.startswith("3.") or swagger_version == "2.0"):
|
||||||
|
raise ValueError(
|
||||||
|
"Unsupported specification version. Expected OpenAPI 3.x or Swagger 2.0"
|
||||||
|
)
|
||||||
|
if "paths" not in spec or not spec["paths"]:
|
||||||
|
raise ValueError("No API paths defined in the specification")
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_metadata(spec: Dict[str, Any], is_swagger: bool) -> Dict[str, Any]:
|
||||||
|
"""Extract API metadata from specification."""
|
||||||
|
info = spec.get("info", {})
|
||||||
|
base_url = _get_base_url(spec, is_swagger)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"title": info.get("title", "Untitled API"),
|
||||||
|
"description": (info.get("description", "") or "")[:500],
|
||||||
|
"version": info.get("version", ""),
|
||||||
|
"base_url": base_url,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _get_base_url(spec: Dict[str, Any], is_swagger: bool) -> str:
|
||||||
|
"""Extract base URL from spec (handles both OpenAPI 3.x and Swagger 2.0)."""
|
||||||
|
if is_swagger:
|
||||||
|
schemes = spec.get("schemes", ["https"])
|
||||||
|
host = spec.get("host", "")
|
||||||
|
base_path = spec.get("basePath", "")
|
||||||
|
if host:
|
||||||
|
scheme = schemes[0] if schemes else "https"
|
||||||
|
return f"{scheme}://{host}{base_path}".rstrip("/")
|
||||||
|
return ""
|
||||||
|
servers = spec.get("servers", [])
|
||||||
|
if servers and isinstance(servers, list) and servers[0].get("url"):
|
||||||
|
return servers[0]["url"].rstrip("/")
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_actions(spec: Dict[str, Any], is_swagger: bool) -> List[Dict[str, Any]]:
|
||||||
|
"""Extract all API operations as action definitions."""
|
||||||
|
actions = []
|
||||||
|
paths = spec.get("paths", {})
|
||||||
|
base_url = _get_base_url(spec, is_swagger)
|
||||||
|
|
||||||
|
components = spec.get("components", {})
|
||||||
|
definitions = spec.get("definitions", {})
|
||||||
|
|
||||||
|
for path, path_item in paths.items():
|
||||||
|
if not isinstance(path_item, dict):
|
||||||
|
continue
|
||||||
|
path_params = path_item.get("parameters", [])
|
||||||
|
|
||||||
|
for method in SUPPORTED_METHODS:
|
||||||
|
operation = path_item.get(method)
|
||||||
|
if not isinstance(operation, dict):
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
action = _build_action(
|
||||||
|
path=path,
|
||||||
|
method=method,
|
||||||
|
operation=operation,
|
||||||
|
path_params=path_params,
|
||||||
|
base_url=base_url,
|
||||||
|
components=components,
|
||||||
|
definitions=definitions,
|
||||||
|
is_swagger=is_swagger,
|
||||||
|
)
|
||||||
|
actions.append(action)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
f"Failed to parse operation {method.upper()} {path}: {e}"
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
return actions
|
||||||
|
|
||||||
|
|
||||||
|
def _build_action(
|
||||||
|
path: str,
|
||||||
|
method: str,
|
||||||
|
operation: Dict[str, Any],
|
||||||
|
path_params: List[Dict],
|
||||||
|
base_url: str,
|
||||||
|
components: Dict[str, Any],
|
||||||
|
definitions: Dict[str, Any],
|
||||||
|
is_swagger: bool,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Build a single action from an API operation."""
|
||||||
|
action_name = _generate_action_name(operation, method, path)
|
||||||
|
full_url = f"{base_url}{path}" if base_url else path
|
||||||
|
|
||||||
|
all_params = path_params + operation.get("parameters", [])
|
||||||
|
query_params, headers = _categorize_parameters(all_params, components, definitions)
|
||||||
|
|
||||||
|
body, body_content_type = _extract_request_body(
|
||||||
|
operation, components, definitions, is_swagger
|
||||||
|
)
|
||||||
|
|
||||||
|
description = operation.get("summary", "") or operation.get("description", "")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"name": action_name,
|
||||||
|
"url": full_url,
|
||||||
|
"method": method.upper(),
|
||||||
|
"description": (description or "")[:500],
|
||||||
|
"query_params": {"type": "object", "properties": query_params},
|
||||||
|
"headers": {"type": "object", "properties": headers},
|
||||||
|
"body": {"type": "object", "properties": body},
|
||||||
|
"body_content_type": body_content_type,
|
||||||
|
"active": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _generate_action_name(operation: Dict[str, Any], method: str, path: str) -> str:
|
||||||
|
"""Generate a valid action name from operationId or method+path."""
|
||||||
|
if operation.get("operationId"):
|
||||||
|
name = operation["operationId"]
|
||||||
|
else:
|
||||||
|
path_slug = re.sub(r"[{}]", "", path)
|
||||||
|
path_slug = re.sub(r"[^a-zA-Z0-9]", "_", path_slug)
|
||||||
|
path_slug = re.sub(r"_+", "_", path_slug).strip("_")
|
||||||
|
name = f"{method}_{path_slug}"
|
||||||
|
name = re.sub(r"[^a-zA-Z0-9_-]", "_", name)
|
||||||
|
return name[:64]
|
||||||
|
|
||||||
|
|
||||||
|
def _categorize_parameters(
|
||||||
|
parameters: List[Dict],
|
||||||
|
components: Dict[str, Any],
|
||||||
|
definitions: Dict[str, Any],
|
||||||
|
) -> Tuple[Dict, Dict]:
|
||||||
|
"""Categorize parameters into query params and headers."""
|
||||||
|
query_params = {}
|
||||||
|
headers = {}
|
||||||
|
|
||||||
|
for param in parameters:
|
||||||
|
resolved = _resolve_ref(param, components, definitions)
|
||||||
|
if not resolved or "name" not in resolved:
|
||||||
|
continue
|
||||||
|
location = resolved.get("in", "query")
|
||||||
|
prop = _param_to_property(resolved)
|
||||||
|
|
||||||
|
if location in ("query", "path"):
|
||||||
|
query_params[resolved["name"]] = prop
|
||||||
|
elif location == "header":
|
||||||
|
headers[resolved["name"]] = prop
|
||||||
|
return query_params, headers
|
||||||
|
|
||||||
|
|
||||||
|
def _param_to_property(param: Dict) -> Dict[str, Any]:
|
||||||
|
"""Convert an API parameter to an action property definition."""
|
||||||
|
schema = param.get("schema", {})
|
||||||
|
param_type = schema.get("type", param.get("type", "string"))
|
||||||
|
|
||||||
|
mapped_type = "integer" if param_type in ("integer", "number") else "string"
|
||||||
|
|
||||||
|
return {
|
||||||
|
"type": mapped_type,
|
||||||
|
"description": (param.get("description", "") or "")[:200],
|
||||||
|
"value": "",
|
||||||
|
"filled_by_llm": param.get("required", False),
|
||||||
|
"required": param.get("required", False),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_request_body(
|
||||||
|
operation: Dict[str, Any],
|
||||||
|
components: Dict[str, Any],
|
||||||
|
definitions: Dict[str, Any],
|
||||||
|
is_swagger: bool,
|
||||||
|
) -> Tuple[Dict, str]:
|
||||||
|
"""Extract request body schema and content type."""
|
||||||
|
content_types = [
|
||||||
|
"application/json",
|
||||||
|
"application/x-www-form-urlencoded",
|
||||||
|
"multipart/form-data",
|
||||||
|
"text/plain",
|
||||||
|
"application/xml",
|
||||||
|
]
|
||||||
|
|
||||||
|
if is_swagger:
|
||||||
|
consumes = operation.get("consumes", [])
|
||||||
|
body_param = next(
|
||||||
|
(p for p in operation.get("parameters", []) if p.get("in") == "body"), None
|
||||||
|
)
|
||||||
|
if not body_param:
|
||||||
|
return {}, "application/json"
|
||||||
|
selected_type = consumes[0] if consumes else "application/json"
|
||||||
|
schema = body_param.get("schema", {})
|
||||||
|
else:
|
||||||
|
request_body = operation.get("requestBody", {})
|
||||||
|
if not request_body:
|
||||||
|
return {}, "application/json"
|
||||||
|
request_body = _resolve_ref(request_body, components, definitions)
|
||||||
|
content = request_body.get("content", {})
|
||||||
|
|
||||||
|
selected_type = "application/json"
|
||||||
|
schema = {}
|
||||||
|
|
||||||
|
for ct in content_types:
|
||||||
|
if ct in content:
|
||||||
|
selected_type = ct
|
||||||
|
schema = content[ct].get("schema", {})
|
||||||
|
break
|
||||||
|
if not schema and content:
|
||||||
|
first_type = next(iter(content))
|
||||||
|
selected_type = first_type
|
||||||
|
schema = content[first_type].get("schema", {})
|
||||||
|
properties = _schema_to_properties(schema, components, definitions)
|
||||||
|
return properties, selected_type
|
||||||
|
|
||||||
|
|
||||||
|
def _schema_to_properties(
|
||||||
|
schema: Dict,
|
||||||
|
components: Dict[str, Any],
|
||||||
|
definitions: Dict[str, Any],
|
||||||
|
depth: int = 0,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Convert schema to action body properties (limited depth to prevent recursion)."""
|
||||||
|
if depth > 3:
|
||||||
|
return {}
|
||||||
|
schema = _resolve_ref(schema, components, definitions)
|
||||||
|
if not schema or not isinstance(schema, dict):
|
||||||
|
return {}
|
||||||
|
properties = {}
|
||||||
|
schema_type = schema.get("type", "object")
|
||||||
|
|
||||||
|
if schema_type == "object":
|
||||||
|
required_fields = set(schema.get("required", []))
|
||||||
|
for prop_name, prop_schema in schema.get("properties", {}).items():
|
||||||
|
resolved = _resolve_ref(prop_schema, components, definitions)
|
||||||
|
if not isinstance(resolved, dict):
|
||||||
|
continue
|
||||||
|
prop_type = resolved.get("type", "string")
|
||||||
|
mapped_type = "integer" if prop_type in ("integer", "number") else "string"
|
||||||
|
|
||||||
|
properties[prop_name] = {
|
||||||
|
"type": mapped_type,
|
||||||
|
"description": (resolved.get("description", "") or "")[:200],
|
||||||
|
"value": "",
|
||||||
|
"filled_by_llm": prop_name in required_fields,
|
||||||
|
"required": prop_name in required_fields,
|
||||||
|
}
|
||||||
|
return properties
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_ref(
|
||||||
|
obj: Any,
|
||||||
|
components: Dict[str, Any],
|
||||||
|
definitions: Dict[str, Any],
|
||||||
|
) -> Optional[Dict]:
|
||||||
|
"""Resolve $ref references in the specification."""
|
||||||
|
if not isinstance(obj, dict):
|
||||||
|
return obj if isinstance(obj, dict) else None
|
||||||
|
if "$ref" not in obj:
|
||||||
|
return obj
|
||||||
|
ref_path = obj["$ref"]
|
||||||
|
|
||||||
|
if ref_path.startswith("#/components/"):
|
||||||
|
parts = ref_path.replace("#/components/", "").split("/")
|
||||||
|
return _traverse_path(components, parts)
|
||||||
|
elif ref_path.startswith("#/definitions/"):
|
||||||
|
parts = ref_path.replace("#/definitions/", "").split("/")
|
||||||
|
return _traverse_path(definitions, parts)
|
||||||
|
logger.debug(f"Unsupported ref path: {ref_path}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _traverse_path(obj: Dict, parts: List[str]) -> Optional[Dict]:
|
||||||
|
"""Traverse a nested dictionary using path parts."""
|
||||||
|
try:
|
||||||
|
for part in parts:
|
||||||
|
obj = obj[part]
|
||||||
|
return obj if isinstance(obj, dict) else None
|
||||||
|
except (KeyError, TypeError):
|
||||||
|
return None
|
||||||
@@ -1,6 +1,11 @@
|
|||||||
|
import logging
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
from application.agents.tools.base import Tool
|
from application.agents.tools.base import Tool
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class TelegramTool(Tool):
|
class TelegramTool(Tool):
|
||||||
"""
|
"""
|
||||||
@@ -18,24 +23,22 @@ class TelegramTool(Tool):
|
|||||||
"telegram_send_message": self._send_message,
|
"telegram_send_message": self._send_message,
|
||||||
"telegram_send_image": self._send_image,
|
"telegram_send_image": self._send_image,
|
||||||
}
|
}
|
||||||
|
if action_name not in actions:
|
||||||
if action_name in actions:
|
|
||||||
return actions[action_name](**kwargs)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unknown action: {action_name}")
|
raise ValueError(f"Unknown action: {action_name}")
|
||||||
|
return actions[action_name](**kwargs)
|
||||||
|
|
||||||
def _send_message(self, text, chat_id):
|
def _send_message(self, text, chat_id):
|
||||||
print(f"Sending message: {text}")
|
logger.debug("Sending Telegram message to chat_id=%s", chat_id)
|
||||||
url = f"https://api.telegram.org/bot{self.token}/sendMessage"
|
url = f"https://api.telegram.org/bot{self.token}/sendMessage"
|
||||||
payload = {"chat_id": chat_id, "text": text}
|
payload = {"chat_id": chat_id, "text": text}
|
||||||
response = requests.post(url, data=payload)
|
response = requests.post(url, data=payload, timeout=100)
|
||||||
return {"status_code": response.status_code, "message": "Message sent"}
|
return {"status_code": response.status_code, "message": "Message sent"}
|
||||||
|
|
||||||
def _send_image(self, image_url, chat_id):
|
def _send_image(self, image_url, chat_id):
|
||||||
print(f"Sending image: {image_url}")
|
logger.debug("Sending Telegram image to chat_id=%s", chat_id)
|
||||||
url = f"https://api.telegram.org/bot{self.token}/sendPhoto"
|
url = f"https://api.telegram.org/bot{self.token}/sendPhoto"
|
||||||
payload = {"chat_id": chat_id, "photo": image_url}
|
payload = {"chat_id": chat_id, "photo": image_url}
|
||||||
response = requests.post(url, data=payload)
|
response = requests.post(url, data=payload, timeout=100)
|
||||||
return {"status_code": response.status_code, "message": "Image sent"}
|
return {"status_code": response.status_code, "message": "Image sent"}
|
||||||
|
|
||||||
def get_actions_metadata(self):
|
def get_actions_metadata(self):
|
||||||
@@ -82,5 +85,12 @@ class TelegramTool(Tool):
|
|||||||
|
|
||||||
def get_config_requirements(self):
|
def get_config_requirements(self):
|
||||||
return {
|
return {
|
||||||
"token": {"type": "string", "description": "Bot token for authentication"},
|
"token": {
|
||||||
|
"type": "string",
|
||||||
|
"label": "Bot Token",
|
||||||
|
"description": "Telegram bot token for authentication",
|
||||||
|
"required": True,
|
||||||
|
"secret": True,
|
||||||
|
"order": 1,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|||||||
70
application/agents/tools/think.py
Normal file
70
application/agents/tools/think.py
Normal file
@@ -0,0 +1,70 @@
|
|||||||
|
from application.agents.tools.base import Tool
|
||||||
|
|
||||||
|
|
||||||
|
THINK_TOOL_ID = "think"
|
||||||
|
|
||||||
|
THINK_TOOL_ENTRY = {
|
||||||
|
"name": "think",
|
||||||
|
"actions": [
|
||||||
|
{
|
||||||
|
"name": "reason",
|
||||||
|
"description": (
|
||||||
|
"Use this tool to think through your reasoning step by step "
|
||||||
|
"before deciding on your next action. Always reason before "
|
||||||
|
"searching or answering."
|
||||||
|
),
|
||||||
|
"active": True,
|
||||||
|
"parameters": {
|
||||||
|
"properties": {
|
||||||
|
"reasoning": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Your step-by-step reasoning and analysis",
|
||||||
|
"filled_by_llm": True,
|
||||||
|
"required": True,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class ThinkTool(Tool):
|
||||||
|
"""Pseudo-tool that captures chain-of-thought reasoning.
|
||||||
|
|
||||||
|
Returns a short acknowledgment so the LLM can continue.
|
||||||
|
The reasoning content is captured in tool_call data for transparency.
|
||||||
|
"""
|
||||||
|
|
||||||
|
internal = True
|
||||||
|
|
||||||
|
def __init__(self, config=None):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def execute_action(self, action_name: str, **kwargs):
|
||||||
|
return "Continue."
|
||||||
|
|
||||||
|
def get_actions_metadata(self):
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"name": "reason",
|
||||||
|
"description": (
|
||||||
|
"Use this tool to think through your reasoning step by step "
|
||||||
|
"before deciding on your next action. Always reason before "
|
||||||
|
"searching or answering."
|
||||||
|
),
|
||||||
|
"parameters": {
|
||||||
|
"properties": {
|
||||||
|
"reasoning": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Your step-by-step reasoning and analysis",
|
||||||
|
"filled_by_llm": True,
|
||||||
|
"required": True,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
def get_config_requirements(self):
|
||||||
|
return {}
|
||||||
@@ -1,10 +1,19 @@
|
|||||||
from datetime import datetime
|
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
from .base import Tool
|
from .base import Tool
|
||||||
from application.core.mongo_db import MongoDB
|
from application.storage.db.repositories.todos import TodosRepository
|
||||||
from application.core.settings import settings
|
from application.storage.db.session import db_readonly, db_session
|
||||||
|
|
||||||
|
|
||||||
|
def _status_from_completed(completed: Any) -> str:
|
||||||
|
"""Translate the PG ``completed`` boolean to the legacy status string.
|
||||||
|
|
||||||
|
The frontend (and prior LLM-facing tool output) expects
|
||||||
|
``"open"`` / ``"completed"``. Keeping that contract at the tool
|
||||||
|
boundary insulates callers from the schema change.
|
||||||
|
"""
|
||||||
|
return "completed" if bool(completed) else "open"
|
||||||
|
|
||||||
|
|
||||||
class TodoListTool(Tool):
|
class TodoListTool(Tool):
|
||||||
@@ -25,7 +34,6 @@ class TodoListTool(Tool):
|
|||||||
self.user_id: Optional[str] = user_id
|
self.user_id: Optional[str] = user_id
|
||||||
|
|
||||||
# Get tool_id from configuration (passed from user_tools._id in production)
|
# Get tool_id from configuration (passed from user_tools._id in production)
|
||||||
# In production, tool_id is the MongoDB ObjectId string from user_tools collection
|
|
||||||
if tool_config and "tool_id" in tool_config:
|
if tool_config and "tool_id" in tool_config:
|
||||||
self.tool_id = tool_config["tool_id"]
|
self.tool_id = tool_config["tool_id"]
|
||||||
elif user_id:
|
elif user_id:
|
||||||
@@ -35,8 +43,26 @@ class TodoListTool(Tool):
|
|||||||
# Last resort fallback (shouldn't happen in normal use)
|
# Last resort fallback (shouldn't happen in normal use)
|
||||||
self.tool_id = str(uuid.uuid4())
|
self.tool_id = str(uuid.uuid4())
|
||||||
|
|
||||||
db = MongoDB.get_client()[settings.MONGO_DB_NAME]
|
self._last_artifact_id: Optional[str] = None
|
||||||
self.collection = db["todos"]
|
|
||||||
|
def _pg_enabled(self) -> bool:
|
||||||
|
"""Return True only when ``tool_id`` is a real ``user_tools.id`` UUID.
|
||||||
|
|
||||||
|
The ``todos`` PG table has a UUID foreign key to ``user_tools`` and
|
||||||
|
the repo queries ``CAST(:tool_id AS uuid)``. The sentinel
|
||||||
|
``default_{uid}`` fallback is neither a UUID nor a row in
|
||||||
|
``user_tools`` — binding it would crash ``invalid input syntax for
|
||||||
|
type uuid`` and even if it didn't the FK would reject it. Mirror
|
||||||
|
the MemoryTool guard and no-op in that case.
|
||||||
|
"""
|
||||||
|
tool_id = getattr(self, "tool_id", None)
|
||||||
|
if not tool_id or not isinstance(tool_id, str):
|
||||||
|
return False
|
||||||
|
if tool_id.startswith("default_"):
|
||||||
|
return False
|
||||||
|
from application.storage.db.base_repository import looks_like_uuid
|
||||||
|
|
||||||
|
return looks_like_uuid(tool_id)
|
||||||
|
|
||||||
# -----------------------------
|
# -----------------------------
|
||||||
# Action implementations
|
# Action implementations
|
||||||
@@ -54,6 +80,14 @@ class TodoListTool(Tool):
|
|||||||
if not self.user_id:
|
if not self.user_id:
|
||||||
return "Error: TodoListTool requires a valid user_id."
|
return "Error: TodoListTool requires a valid user_id."
|
||||||
|
|
||||||
|
if not self._pg_enabled():
|
||||||
|
return (
|
||||||
|
"Error: TodoListTool is not configured with a persistent "
|
||||||
|
"tool_id; todo storage is unavailable for this session."
|
||||||
|
)
|
||||||
|
|
||||||
|
self._last_artifact_id = None
|
||||||
|
|
||||||
if action_name == "list":
|
if action_name == "list":
|
||||||
return self._list()
|
return self._list()
|
||||||
|
|
||||||
@@ -165,6 +199,9 @@ class TodoListTool(Tool):
|
|||||||
"""Return configuration requirements."""
|
"""Return configuration requirements."""
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
def get_artifact_id(self, action_name: str, **kwargs: Any) -> Optional[str]:
|
||||||
|
return self._last_artifact_id
|
||||||
|
|
||||||
# -----------------------------
|
# -----------------------------
|
||||||
# Internal helpers
|
# Internal helpers
|
||||||
# -----------------------------
|
# -----------------------------
|
||||||
@@ -184,31 +221,10 @@ class TodoListTool(Tool):
|
|||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _get_next_todo_id(self) -> int:
|
|
||||||
"""Get the next sequential todo_id for this user and tool.
|
|
||||||
|
|
||||||
Returns a simple integer (1, 2, 3, ...) scoped to this user/tool.
|
|
||||||
With 5-10 todos max, scanning is negligible.
|
|
||||||
"""
|
|
||||||
# Find all todos for this user/tool and get their IDs
|
|
||||||
todos = list(self.collection.find(
|
|
||||||
{"user_id": self.user_id, "tool_id": self.tool_id},
|
|
||||||
{"todo_id": 1}
|
|
||||||
))
|
|
||||||
|
|
||||||
# Find the maximum todo_id
|
|
||||||
max_id = 0
|
|
||||||
for todo in todos:
|
|
||||||
todo_id = self._coerce_todo_id(todo.get("todo_id"))
|
|
||||||
if todo_id is not None:
|
|
||||||
max_id = max(max_id, todo_id)
|
|
||||||
|
|
||||||
return max_id + 1
|
|
||||||
|
|
||||||
def _list(self) -> str:
|
def _list(self) -> str:
|
||||||
"""List all todos for the user."""
|
"""List all todos for the user."""
|
||||||
cursor = self.collection.find({"user_id": self.user_id, "tool_id": self.tool_id})
|
with db_readonly() as conn:
|
||||||
todos = list(cursor)
|
todos = TodosRepository(conn).list_for_tool(self.user_id, self.tool_id)
|
||||||
|
|
||||||
if not todos:
|
if not todos:
|
||||||
return "No todos found."
|
return "No todos found."
|
||||||
@@ -217,7 +233,7 @@ class TodoListTool(Tool):
|
|||||||
for doc in todos:
|
for doc in todos:
|
||||||
todo_id = doc.get("todo_id")
|
todo_id = doc.get("todo_id")
|
||||||
title = doc.get("title", "Untitled")
|
title = doc.get("title", "Untitled")
|
||||||
status = doc.get("status", "open")
|
status = _status_from_completed(doc.get("completed"))
|
||||||
|
|
||||||
line = f"[{todo_id}] {title} ({status})"
|
line = f"[{todo_id}] {title} ({status})"
|
||||||
result_lines.append(line)
|
result_lines.append(line)
|
||||||
@@ -225,24 +241,23 @@ class TodoListTool(Tool):
|
|||||||
return "\n".join(result_lines)
|
return "\n".join(result_lines)
|
||||||
|
|
||||||
def _create(self, title: str) -> str:
|
def _create(self, title: str) -> str:
|
||||||
"""Create a new todo item."""
|
"""Create a new todo item.
|
||||||
|
|
||||||
|
``TodosRepository.create`` allocates the per-tool monotonic
|
||||||
|
``todo_id`` inside the same transaction (``COALESCE(MAX(todo_id),0)+1``
|
||||||
|
scoped to ``tool_id``), so we no longer need a separate read-then-
|
||||||
|
write step here.
|
||||||
|
"""
|
||||||
title = (title or "").strip()
|
title = (title or "").strip()
|
||||||
if not title:
|
if not title:
|
||||||
return "Error: Title is required."
|
return "Error: Title is required."
|
||||||
|
|
||||||
now = datetime.now()
|
with db_session() as conn:
|
||||||
todo_id = self._get_next_todo_id()
|
row = TodosRepository(conn).create(self.user_id, self.tool_id, title)
|
||||||
|
|
||||||
doc = {
|
todo_id = row.get("todo_id")
|
||||||
"todo_id": todo_id,
|
if row.get("id") is not None:
|
||||||
"user_id": self.user_id,
|
self._last_artifact_id = str(row.get("id"))
|
||||||
"tool_id": self.tool_id,
|
|
||||||
"title": title,
|
|
||||||
"status": "open",
|
|
||||||
"created_at": now,
|
|
||||||
"updated_at": now,
|
|
||||||
}
|
|
||||||
self.collection.insert_one(doc)
|
|
||||||
return f"Todo created with ID {todo_id}: {title}"
|
return f"Todo created with ID {todo_id}: {title}"
|
||||||
|
|
||||||
def _get(self, todo_id: Optional[Any]) -> str:
|
def _get(self, todo_id: Optional[Any]) -> str:
|
||||||
@@ -251,21 +266,21 @@ class TodoListTool(Tool):
|
|||||||
if parsed_todo_id is None:
|
if parsed_todo_id is None:
|
||||||
return "Error: todo_id must be a positive integer."
|
return "Error: todo_id must be a positive integer."
|
||||||
|
|
||||||
doc = self.collection.find_one({
|
with db_readonly() as conn:
|
||||||
"user_id": self.user_id,
|
doc = TodosRepository(conn).get_by_tool_and_todo_id(
|
||||||
"tool_id": self.tool_id,
|
self.user_id, self.tool_id, parsed_todo_id
|
||||||
"todo_id": parsed_todo_id
|
)
|
||||||
})
|
|
||||||
|
|
||||||
if not doc:
|
if not doc:
|
||||||
return f"Error: Todo with ID {parsed_todo_id} not found."
|
return f"Error: Todo with ID {parsed_todo_id} not found."
|
||||||
|
|
||||||
|
if doc.get("id") is not None:
|
||||||
|
self._last_artifact_id = str(doc.get("id"))
|
||||||
|
|
||||||
title = doc.get("title", "Untitled")
|
title = doc.get("title", "Untitled")
|
||||||
status = doc.get("status", "open")
|
status = _status_from_completed(doc.get("completed"))
|
||||||
|
|
||||||
result = f"Todo [{parsed_todo_id}]:\nTitle: {title}\nStatus: {status}"
|
return f"Todo [{parsed_todo_id}]:\nTitle: {title}\nStatus: {status}"
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
def _update(self, todo_id: Optional[Any], title: str) -> str:
|
def _update(self, todo_id: Optional[Any], title: str) -> str:
|
||||||
"""Update a todo's title by ID."""
|
"""Update a todo's title by ID."""
|
||||||
@@ -277,13 +292,19 @@ class TodoListTool(Tool):
|
|||||||
if not title:
|
if not title:
|
||||||
return "Error: Title is required."
|
return "Error: Title is required."
|
||||||
|
|
||||||
result = self.collection.update_one(
|
with db_session() as conn:
|
||||||
{"user_id": self.user_id, "tool_id": self.tool_id, "todo_id": parsed_todo_id},
|
repo = TodosRepository(conn)
|
||||||
{"$set": {"title": title, "updated_at": datetime.now()}}
|
existing = repo.get_by_tool_and_todo_id(
|
||||||
)
|
self.user_id, self.tool_id, parsed_todo_id
|
||||||
|
)
|
||||||
|
if not existing:
|
||||||
|
return f"Error: Todo with ID {parsed_todo_id} not found."
|
||||||
|
repo.update_title_by_tool_and_todo_id(
|
||||||
|
self.user_id, self.tool_id, parsed_todo_id, title
|
||||||
|
)
|
||||||
|
|
||||||
if result.matched_count == 0:
|
if existing.get("id") is not None:
|
||||||
return f"Error: Todo with ID {parsed_todo_id} not found."
|
self._last_artifact_id = str(existing.get("id"))
|
||||||
|
|
||||||
return f"Todo {parsed_todo_id} updated to: {title}"
|
return f"Todo {parsed_todo_id} updated to: {title}"
|
||||||
|
|
||||||
@@ -293,13 +314,17 @@ class TodoListTool(Tool):
|
|||||||
if parsed_todo_id is None:
|
if parsed_todo_id is None:
|
||||||
return "Error: todo_id must be a positive integer."
|
return "Error: todo_id must be a positive integer."
|
||||||
|
|
||||||
result = self.collection.update_one(
|
with db_session() as conn:
|
||||||
{"user_id": self.user_id, "tool_id": self.tool_id, "todo_id": parsed_todo_id},
|
repo = TodosRepository(conn)
|
||||||
{"$set": {"status": "completed", "updated_at": datetime.now()}}
|
existing = repo.get_by_tool_and_todo_id(
|
||||||
)
|
self.user_id, self.tool_id, parsed_todo_id
|
||||||
|
)
|
||||||
|
if not existing:
|
||||||
|
return f"Error: Todo with ID {parsed_todo_id} not found."
|
||||||
|
repo.set_completed(self.user_id, self.tool_id, parsed_todo_id, True)
|
||||||
|
|
||||||
if result.matched_count == 0:
|
if existing.get("id") is not None:
|
||||||
return f"Error: Todo with ID {parsed_todo_id} not found."
|
self._last_artifact_id = str(existing.get("id"))
|
||||||
|
|
||||||
return f"Todo {parsed_todo_id} marked as completed."
|
return f"Todo {parsed_todo_id} marked as completed."
|
||||||
|
|
||||||
@@ -309,13 +334,18 @@ class TodoListTool(Tool):
|
|||||||
if parsed_todo_id is None:
|
if parsed_todo_id is None:
|
||||||
return "Error: todo_id must be a positive integer."
|
return "Error: todo_id must be a positive integer."
|
||||||
|
|
||||||
result = self.collection.delete_one({
|
with db_session() as conn:
|
||||||
"user_id": self.user_id,
|
repo = TodosRepository(conn)
|
||||||
"tool_id": self.tool_id,
|
existing = repo.get_by_tool_and_todo_id(
|
||||||
"todo_id": parsed_todo_id
|
self.user_id, self.tool_id, parsed_todo_id
|
||||||
})
|
)
|
||||||
|
if not existing:
|
||||||
|
return f"Error: Todo with ID {parsed_todo_id} not found."
|
||||||
|
repo.delete_by_tool_and_todo_id(
|
||||||
|
self.user_id, self.tool_id, parsed_todo_id
|
||||||
|
)
|
||||||
|
|
||||||
if result.deleted_count == 0:
|
if existing.get("id") is not None:
|
||||||
return f"Error: Todo with ID {parsed_todo_id} not found."
|
self._last_artifact_id = str(existing.get("id"))
|
||||||
|
|
||||||
return f"Todo {parsed_todo_id} deleted."
|
return f"Todo {parsed_todo_id} deleted."
|
||||||
|
|||||||
@@ -5,8 +5,9 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class ToolActionParser:
|
class ToolActionParser:
|
||||||
def __init__(self, llm_type):
|
def __init__(self, llm_type, name_mapping=None):
|
||||||
self.llm_type = llm_type
|
self.llm_type = llm_type
|
||||||
|
self.name_mapping = name_mapping
|
||||||
self.parsers = {
|
self.parsers = {
|
||||||
"OpenAILLM": self._parse_openai_llm,
|
"OpenAILLM": self._parse_openai_llm,
|
||||||
"GoogleLLM": self._parse_google_llm,
|
"GoogleLLM": self._parse_google_llm,
|
||||||
@@ -16,22 +17,33 @@ class ToolActionParser:
|
|||||||
parser = self.parsers.get(self.llm_type, self._parse_openai_llm)
|
parser = self.parsers.get(self.llm_type, self._parse_openai_llm)
|
||||||
return parser(call)
|
return parser(call)
|
||||||
|
|
||||||
|
def _resolve_via_mapping(self, call_name):
|
||||||
|
"""Look up (tool_id, action_name) from the name mapping if available."""
|
||||||
|
if self.name_mapping and call_name in self.name_mapping:
|
||||||
|
return self.name_mapping[call_name]
|
||||||
|
return None
|
||||||
|
|
||||||
def _parse_openai_llm(self, call):
|
def _parse_openai_llm(self, call):
|
||||||
try:
|
try:
|
||||||
call_args = json.loads(call.arguments)
|
call_args = json.loads(call.arguments)
|
||||||
|
|
||||||
|
resolved = self._resolve_via_mapping(call.name)
|
||||||
|
if resolved:
|
||||||
|
return resolved[0], resolved[1], call_args
|
||||||
|
|
||||||
|
# Fallback: legacy split on "_" for backward compatibility
|
||||||
tool_parts = call.name.split("_")
|
tool_parts = call.name.split("_")
|
||||||
|
|
||||||
# If the tool name doesn't contain an underscore, it's likely a hallucinated tool
|
|
||||||
if len(tool_parts) < 2:
|
if len(tool_parts) < 2:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Invalid tool name format: {call.name}. Expected format: action_name_tool_id"
|
f"Invalid tool name format: {call.name}. "
|
||||||
|
"Could not resolve via mapping or legacy parsing."
|
||||||
)
|
)
|
||||||
return None, None, None
|
return None, None, None
|
||||||
|
|
||||||
tool_id = tool_parts[-1]
|
tool_id = tool_parts[-1]
|
||||||
action_name = "_".join(tool_parts[:-1])
|
action_name = "_".join(tool_parts[:-1])
|
||||||
|
|
||||||
# Validate that tool_id looks like a numerical ID
|
|
||||||
if not tool_id.isdigit():
|
if not tool_id.isdigit():
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Tool ID '{tool_id}' is not numerical. This might be a hallucinated tool call."
|
f"Tool ID '{tool_id}' is not numerical. This might be a hallucinated tool call."
|
||||||
@@ -45,19 +57,24 @@ class ToolActionParser:
|
|||||||
def _parse_google_llm(self, call):
|
def _parse_google_llm(self, call):
|
||||||
try:
|
try:
|
||||||
call_args = call.arguments
|
call_args = call.arguments
|
||||||
|
|
||||||
|
resolved = self._resolve_via_mapping(call.name)
|
||||||
|
if resolved:
|
||||||
|
return resolved[0], resolved[1], call_args
|
||||||
|
|
||||||
|
# Fallback: legacy split on "_" for backward compatibility
|
||||||
tool_parts = call.name.split("_")
|
tool_parts = call.name.split("_")
|
||||||
|
|
||||||
# If the tool name doesn't contain an underscore, it's likely a hallucinated tool
|
|
||||||
if len(tool_parts) < 2:
|
if len(tool_parts) < 2:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Invalid tool name format: {call.name}. Expected format: action_name_tool_id"
|
f"Invalid tool name format: {call.name}. "
|
||||||
|
"Could not resolve via mapping or legacy parsing."
|
||||||
)
|
)
|
||||||
return None, None, None
|
return None, None, None
|
||||||
|
|
||||||
tool_id = tool_parts[-1]
|
tool_id = tool_parts[-1]
|
||||||
action_name = "_".join(tool_parts[:-1])
|
action_name = "_".join(tool_parts[:-1])
|
||||||
|
|
||||||
# Validate that tool_id looks like a numerical ID
|
|
||||||
if not tool_id.isdigit():
|
if not tool_id.isdigit():
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Tool ID '{tool_id}' is not numerical. This might be a hallucinated tool call."
|
f"Tool ID '{tool_id}' is not numerical. This might be a hallucinated tool call."
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ class ToolManager:
|
|||||||
continue
|
continue
|
||||||
module = importlib.import_module(f"application.agents.tools.{name}")
|
module = importlib.import_module(f"application.agents.tools.{name}")
|
||||||
for member_name, obj in inspect.getmembers(module, inspect.isclass):
|
for member_name, obj in inspect.getmembers(module, inspect.isclass):
|
||||||
if issubclass(obj, Tool) and obj is not Tool:
|
if issubclass(obj, Tool) and obj is not Tool and not obj.internal:
|
||||||
tool_config = self.config.get(name, {})
|
tool_config = self.config.get(name, {})
|
||||||
self.tools[name] = obj(tool_config)
|
self.tools[name] = obj(tool_config)
|
||||||
|
|
||||||
@@ -36,7 +36,7 @@ class ToolManager:
|
|||||||
def execute_action(self, tool_name, action_name, user_id=None, **kwargs):
|
def execute_action(self, tool_name, action_name, user_id=None, **kwargs):
|
||||||
if tool_name not in self.tools:
|
if tool_name not in self.tools:
|
||||||
raise ValueError(f"Tool '{tool_name}' not loaded")
|
raise ValueError(f"Tool '{tool_name}' not loaded")
|
||||||
if tool_name in {"mcp_tool", "memory", "todo_list"} and user_id:
|
if tool_name in {"mcp_tool", "memory", "todo_list", "notes"} and user_id:
|
||||||
tool_config = self.config.get(tool_name, {})
|
tool_config = self.config.get(tool_name, {})
|
||||||
tool = self.load_tool(tool_name, tool_config, user_id)
|
tool = self.load_tool(tool_name, tool_config, user_id)
|
||||||
return tool.execute_action(action_name, **kwargs)
|
return tool.execute_action(action_name, **kwargs)
|
||||||
|
|||||||
254
application/agents/workflow_agent.py
Normal file
254
application/agents/workflow_agent.py
Normal file
@@ -0,0 +1,254 @@
|
|||||||
|
import logging
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from typing import Any, Dict, Generator, Optional
|
||||||
|
|
||||||
|
from application.agents.base import BaseAgent
|
||||||
|
from application.agents.workflows.schemas import (
|
||||||
|
ExecutionStatus,
|
||||||
|
Workflow,
|
||||||
|
WorkflowEdge,
|
||||||
|
WorkflowGraph,
|
||||||
|
WorkflowNode,
|
||||||
|
WorkflowRun,
|
||||||
|
)
|
||||||
|
from application.agents.workflows.workflow_engine import WorkflowEngine
|
||||||
|
from application.logging import log_activity, LogContext
|
||||||
|
from application.storage.db.base_repository import looks_like_uuid
|
||||||
|
from application.storage.db.repositories.workflow_edges import WorkflowEdgesRepository
|
||||||
|
from application.storage.db.repositories.workflow_nodes import WorkflowNodesRepository
|
||||||
|
from application.storage.db.repositories.workflow_runs import WorkflowRunsRepository
|
||||||
|
from application.storage.db.repositories.workflows import WorkflowsRepository
|
||||||
|
from application.storage.db.session import db_readonly, db_session
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class WorkflowAgent(BaseAgent):
|
||||||
|
"""A specialized agent that executes predefined workflows."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*args,
|
||||||
|
workflow_id: Optional[str] = None,
|
||||||
|
workflow: Optional[Dict[str, Any]] = None,
|
||||||
|
workflow_owner: Optional[str] = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.workflow_id = workflow_id
|
||||||
|
self.workflow_owner = workflow_owner
|
||||||
|
self._workflow_data = workflow
|
||||||
|
self._engine: Optional[WorkflowEngine] = None
|
||||||
|
|
||||||
|
@log_activity()
|
||||||
|
def gen(
|
||||||
|
self, query: str, log_context: LogContext = None
|
||||||
|
) -> Generator[Dict[str, str], None, None]:
|
||||||
|
yield from self._gen_inner(query, log_context)
|
||||||
|
|
||||||
|
def _gen_inner(
|
||||||
|
self, query: str, log_context: LogContext
|
||||||
|
) -> Generator[Dict[str, str], None, None]:
|
||||||
|
graph = self._load_workflow_graph()
|
||||||
|
if not graph:
|
||||||
|
yield {"type": "error", "error": "Failed to load workflow configuration."}
|
||||||
|
return
|
||||||
|
self._engine = WorkflowEngine(graph, self)
|
||||||
|
yield from self._engine.execute({}, query)
|
||||||
|
self._save_workflow_run(query)
|
||||||
|
|
||||||
|
def _load_workflow_graph(self) -> Optional[WorkflowGraph]:
|
||||||
|
if self._workflow_data:
|
||||||
|
return self._parse_embedded_workflow()
|
||||||
|
if self.workflow_id:
|
||||||
|
return self._load_from_database()
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _parse_embedded_workflow(self) -> Optional[WorkflowGraph]:
|
||||||
|
try:
|
||||||
|
nodes_data = self._workflow_data.get("nodes", [])
|
||||||
|
edges_data = self._workflow_data.get("edges", [])
|
||||||
|
|
||||||
|
workflow = Workflow(
|
||||||
|
name=self._workflow_data.get("name", "Embedded Workflow"),
|
||||||
|
description=self._workflow_data.get("description"),
|
||||||
|
)
|
||||||
|
|
||||||
|
nodes = []
|
||||||
|
for n in nodes_data:
|
||||||
|
node_config = n.get("data", {})
|
||||||
|
nodes.append(
|
||||||
|
WorkflowNode(
|
||||||
|
id=n["id"],
|
||||||
|
workflow_id=self.workflow_id or "embedded",
|
||||||
|
type=n["type"],
|
||||||
|
title=n.get("title", "Node"),
|
||||||
|
description=n.get("description"),
|
||||||
|
position=n.get("position", {"x": 0, "y": 0}),
|
||||||
|
config=node_config,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
edges = []
|
||||||
|
for e in edges_data:
|
||||||
|
edges.append(
|
||||||
|
WorkflowEdge(
|
||||||
|
id=e["id"],
|
||||||
|
workflow_id=self.workflow_id or "embedded",
|
||||||
|
source=e.get("source") or e.get("source_id"),
|
||||||
|
target=e.get("target") or e.get("target_id"),
|
||||||
|
sourceHandle=e.get("sourceHandle") or e.get("source_handle"),
|
||||||
|
targetHandle=e.get("targetHandle") or e.get("target_handle"),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return WorkflowGraph(workflow=workflow, nodes=nodes, edges=edges)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Invalid embedded workflow: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _load_from_database(self) -> Optional[WorkflowGraph]:
|
||||||
|
try:
|
||||||
|
if not self.workflow_id:
|
||||||
|
logger.error("Missing workflow ID for load")
|
||||||
|
return None
|
||||||
|
owner_id = self.workflow_owner
|
||||||
|
if not owner_id and isinstance(self.decoded_token, dict):
|
||||||
|
owner_id = self.decoded_token.get("sub")
|
||||||
|
if not owner_id:
|
||||||
|
logger.error(
|
||||||
|
f"Workflow owner not available for workflow load: {self.workflow_id}"
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
with db_readonly() as conn:
|
||||||
|
wf_repo = WorkflowsRepository(conn)
|
||||||
|
if looks_like_uuid(self.workflow_id):
|
||||||
|
workflow_row = wf_repo.get(self.workflow_id, owner_id)
|
||||||
|
else:
|
||||||
|
workflow_row = wf_repo.get_by_legacy_id(self.workflow_id, owner_id)
|
||||||
|
if workflow_row is None:
|
||||||
|
logger.error(
|
||||||
|
f"Workflow {self.workflow_id} not found or inaccessible "
|
||||||
|
f"for user {owner_id}"
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
pg_workflow_id = str(workflow_row["id"])
|
||||||
|
graph_version = workflow_row.get("current_graph_version", 1)
|
||||||
|
try:
|
||||||
|
graph_version = int(graph_version)
|
||||||
|
if graph_version <= 0:
|
||||||
|
graph_version = 1
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
graph_version = 1
|
||||||
|
|
||||||
|
node_rows = WorkflowNodesRepository(conn).find_by_version(
|
||||||
|
pg_workflow_id, graph_version,
|
||||||
|
)
|
||||||
|
edge_rows = WorkflowEdgesRepository(conn).find_by_version(
|
||||||
|
pg_workflow_id, graph_version,
|
||||||
|
)
|
||||||
|
|
||||||
|
workflow = Workflow(
|
||||||
|
name=workflow_row.get("name"),
|
||||||
|
description=workflow_row.get("description"),
|
||||||
|
)
|
||||||
|
nodes = [
|
||||||
|
WorkflowNode(
|
||||||
|
id=n["node_id"],
|
||||||
|
workflow_id=pg_workflow_id,
|
||||||
|
type=n["node_type"],
|
||||||
|
title=n.get("title") or "Node",
|
||||||
|
description=n.get("description"),
|
||||||
|
position=n.get("position") or {"x": 0, "y": 0},
|
||||||
|
config=n.get("config") or {},
|
||||||
|
)
|
||||||
|
for n in node_rows
|
||||||
|
]
|
||||||
|
edges = [
|
||||||
|
WorkflowEdge(
|
||||||
|
id=e["edge_id"],
|
||||||
|
workflow_id=pg_workflow_id,
|
||||||
|
source=e.get("source_id"),
|
||||||
|
target=e.get("target_id"),
|
||||||
|
sourceHandle=e.get("source_handle"),
|
||||||
|
targetHandle=e.get("target_handle"),
|
||||||
|
)
|
||||||
|
for e in edge_rows
|
||||||
|
]
|
||||||
|
|
||||||
|
return WorkflowGraph(workflow=workflow, nodes=nodes, edges=edges)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to load workflow from database: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _save_workflow_run(self, query: str) -> None:
|
||||||
|
if not self._engine:
|
||||||
|
return
|
||||||
|
owner_id = self.workflow_owner
|
||||||
|
if not owner_id and isinstance(self.decoded_token, dict):
|
||||||
|
owner_id = self.decoded_token.get("sub")
|
||||||
|
try:
|
||||||
|
run = WorkflowRun(
|
||||||
|
workflow_id=self.workflow_id or "unknown",
|
||||||
|
user=owner_id,
|
||||||
|
status=self._determine_run_status(),
|
||||||
|
inputs={"query": query},
|
||||||
|
outputs=self._serialize_state(self._engine.state),
|
||||||
|
steps=self._engine.get_execution_summary(),
|
||||||
|
created_at=datetime.now(timezone.utc),
|
||||||
|
completed_at=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
|
||||||
|
if not self.workflow_id or not owner_id:
|
||||||
|
return
|
||||||
|
with db_session() as conn:
|
||||||
|
wf_repo = WorkflowsRepository(conn)
|
||||||
|
if looks_like_uuid(self.workflow_id):
|
||||||
|
workflow_row = wf_repo.get(self.workflow_id, owner_id)
|
||||||
|
else:
|
||||||
|
workflow_row = wf_repo.get_by_legacy_id(
|
||||||
|
self.workflow_id, owner_id,
|
||||||
|
)
|
||||||
|
if workflow_row is None:
|
||||||
|
return
|
||||||
|
WorkflowRunsRepository(conn).create(
|
||||||
|
str(workflow_row["id"]),
|
||||||
|
owner_id,
|
||||||
|
run.status.value,
|
||||||
|
inputs=run.inputs,
|
||||||
|
result=run.outputs,
|
||||||
|
steps=[step.model_dump(mode="json") for step in run.steps],
|
||||||
|
started_at=run.created_at,
|
||||||
|
ended_at=run.completed_at,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to save workflow run: {e}")
|
||||||
|
|
||||||
|
def _determine_run_status(self) -> ExecutionStatus:
|
||||||
|
if not self._engine or not self._engine.execution_log:
|
||||||
|
return ExecutionStatus.COMPLETED
|
||||||
|
for log in self._engine.execution_log:
|
||||||
|
if log.get("status") == ExecutionStatus.FAILED.value:
|
||||||
|
return ExecutionStatus.FAILED
|
||||||
|
return ExecutionStatus.COMPLETED
|
||||||
|
|
||||||
|
def _serialize_state(self, state: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
serialized: Dict[str, Any] = {}
|
||||||
|
for key, value in state.items():
|
||||||
|
serialized[key] = self._serialize_state_value(value)
|
||||||
|
return serialized
|
||||||
|
|
||||||
|
def _serialize_state_value(self, value: Any) -> Any:
|
||||||
|
if isinstance(value, dict):
|
||||||
|
return {
|
||||||
|
str(dict_key): self._serialize_state_value(dict_value)
|
||||||
|
for dict_key, dict_value in value.items()
|
||||||
|
}
|
||||||
|
if isinstance(value, list):
|
||||||
|
return [self._serialize_state_value(item) for item in value]
|
||||||
|
if isinstance(value, tuple):
|
||||||
|
return [self._serialize_state_value(item) for item in value]
|
||||||
|
if isinstance(value, datetime):
|
||||||
|
return value.isoformat()
|
||||||
|
if isinstance(value, (str, int, float, bool, type(None))):
|
||||||
|
return value
|
||||||
|
return str(value)
|
||||||
64
application/agents/workflows/cel_evaluator.py
Normal file
64
application/agents/workflows/cel_evaluator.py
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
import celpy
|
||||||
|
import celpy.celtypes
|
||||||
|
|
||||||
|
|
||||||
|
class CelEvaluationError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_value(value: Any) -> Any:
|
||||||
|
if isinstance(value, bool):
|
||||||
|
return celpy.celtypes.BoolType(value)
|
||||||
|
if isinstance(value, int):
|
||||||
|
return celpy.celtypes.IntType(value)
|
||||||
|
if isinstance(value, float):
|
||||||
|
return celpy.celtypes.DoubleType(value)
|
||||||
|
if isinstance(value, str):
|
||||||
|
return celpy.celtypes.StringType(value)
|
||||||
|
if isinstance(value, list):
|
||||||
|
return celpy.celtypes.ListType([_convert_value(item) for item in value])
|
||||||
|
if isinstance(value, dict):
|
||||||
|
return celpy.celtypes.MapType(
|
||||||
|
{celpy.celtypes.StringType(k): _convert_value(v) for k, v in value.items()}
|
||||||
|
)
|
||||||
|
if value is None:
|
||||||
|
return celpy.celtypes.BoolType(False)
|
||||||
|
return celpy.celtypes.StringType(str(value))
|
||||||
|
|
||||||
|
|
||||||
|
def build_activation(state: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
return {k: _convert_value(v) for k, v in state.items()}
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate_cel(expression: str, state: Dict[str, Any]) -> Any:
|
||||||
|
if not expression or not expression.strip():
|
||||||
|
raise CelEvaluationError("Empty expression")
|
||||||
|
try:
|
||||||
|
env = celpy.Environment()
|
||||||
|
ast = env.compile(expression)
|
||||||
|
program = env.program(ast)
|
||||||
|
activation = build_activation(state)
|
||||||
|
result = program.evaluate(activation)
|
||||||
|
except celpy.CELEvalError as exc:
|
||||||
|
raise CelEvaluationError(f"CEL evaluation error: {exc}") from exc
|
||||||
|
except Exception as exc:
|
||||||
|
raise CelEvaluationError(f"CEL error: {exc}") from exc
|
||||||
|
return cel_to_python(result)
|
||||||
|
|
||||||
|
|
||||||
|
def cel_to_python(value: Any) -> Any:
|
||||||
|
if isinstance(value, celpy.celtypes.BoolType):
|
||||||
|
return bool(value)
|
||||||
|
if isinstance(value, celpy.celtypes.IntType):
|
||||||
|
return int(value)
|
||||||
|
if isinstance(value, celpy.celtypes.DoubleType):
|
||||||
|
return float(value)
|
||||||
|
if isinstance(value, celpy.celtypes.StringType):
|
||||||
|
return str(value)
|
||||||
|
if isinstance(value, celpy.celtypes.ListType):
|
||||||
|
return [cel_to_python(item) for item in value]
|
||||||
|
if isinstance(value, celpy.celtypes.MapType):
|
||||||
|
return {str(k): cel_to_python(v) for k, v in value.items()}
|
||||||
|
return value
|
||||||
104
application/agents/workflows/node_agent.py
Normal file
104
application/agents/workflows/node_agent.py
Normal file
@@ -0,0 +1,104 @@
|
|||||||
|
"""Workflow Node Agents - defines specialized agents for workflow nodes."""
|
||||||
|
|
||||||
|
from typing import Any, Dict, List, Optional, Type
|
||||||
|
|
||||||
|
from application.agents.agentic_agent import AgenticAgent
|
||||||
|
from application.agents.base import BaseAgent
|
||||||
|
from application.agents.classic_agent import ClassicAgent
|
||||||
|
from application.agents.research_agent import ResearchAgent
|
||||||
|
from application.agents.workflows.schemas import AgentType
|
||||||
|
|
||||||
|
|
||||||
|
class ToolFilterMixin:
|
||||||
|
"""Mixin that filters fetched tools to only those specified in tool_ids."""
|
||||||
|
|
||||||
|
_allowed_tool_ids: List[str]
|
||||||
|
|
||||||
|
def _get_user_tools(self, user: str = "local") -> Dict[str, Dict[str, Any]]:
|
||||||
|
all_tools = super()._get_user_tools(user)
|
||||||
|
if not self._allowed_tool_ids:
|
||||||
|
return {}
|
||||||
|
filtered_tools = {
|
||||||
|
tool_id: tool
|
||||||
|
for tool_id, tool in all_tools.items()
|
||||||
|
if str(tool.get("_id", "")) in self._allowed_tool_ids
|
||||||
|
}
|
||||||
|
return filtered_tools
|
||||||
|
|
||||||
|
def _get_tools(self, api_key: str = None) -> Dict[str, Dict[str, Any]]:
|
||||||
|
all_tools = super()._get_tools(api_key)
|
||||||
|
if not self._allowed_tool_ids:
|
||||||
|
return {}
|
||||||
|
filtered_tools = {
|
||||||
|
tool_id: tool
|
||||||
|
for tool_id, tool in all_tools.items()
|
||||||
|
if str(tool.get("_id", "")) in self._allowed_tool_ids
|
||||||
|
}
|
||||||
|
return filtered_tools
|
||||||
|
|
||||||
|
|
||||||
|
class _WorkflowNodeMixin:
|
||||||
|
"""Common __init__ for all workflow node agents."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
endpoint: str,
|
||||||
|
llm_name: str,
|
||||||
|
model_id: str,
|
||||||
|
api_key: str,
|
||||||
|
tool_ids: Optional[List[str]] = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
endpoint=endpoint,
|
||||||
|
llm_name=llm_name,
|
||||||
|
model_id=model_id,
|
||||||
|
api_key=api_key,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
self._allowed_tool_ids = tool_ids or []
|
||||||
|
|
||||||
|
|
||||||
|
class WorkflowNodeClassicAgent(ToolFilterMixin, _WorkflowNodeMixin, ClassicAgent):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class WorkflowNodeAgenticAgent(ToolFilterMixin, _WorkflowNodeMixin, AgenticAgent):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class WorkflowNodeResearchAgent(ToolFilterMixin, _WorkflowNodeMixin, ResearchAgent):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class WorkflowNodeAgentFactory:
|
||||||
|
|
||||||
|
_agents: Dict[AgentType, Type[BaseAgent]] = {
|
||||||
|
AgentType.CLASSIC: WorkflowNodeClassicAgent,
|
||||||
|
AgentType.REACT: WorkflowNodeClassicAgent, # backwards compat
|
||||||
|
AgentType.AGENTIC: WorkflowNodeAgenticAgent,
|
||||||
|
AgentType.RESEARCH: WorkflowNodeResearchAgent,
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create(
|
||||||
|
cls,
|
||||||
|
agent_type: AgentType,
|
||||||
|
endpoint: str,
|
||||||
|
llm_name: str,
|
||||||
|
model_id: str,
|
||||||
|
api_key: str,
|
||||||
|
tool_ids: Optional[List[str]] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> BaseAgent:
|
||||||
|
agent_class = cls._agents.get(agent_type)
|
||||||
|
if not agent_class:
|
||||||
|
raise ValueError(f"Unsupported agent type: {agent_type}")
|
||||||
|
return agent_class(
|
||||||
|
endpoint=endpoint,
|
||||||
|
llm_name=llm_name,
|
||||||
|
model_id=model_id,
|
||||||
|
api_key=api_key,
|
||||||
|
tool_ids=tool_ids,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
168
application/agents/workflows/schemas.py
Normal file
168
application/agents/workflows/schemas.py
Normal file
@@ -0,0 +1,168 @@
|
|||||||
|
from datetime import datetime, timezone
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any, Dict, List, Literal, Optional, Union
|
||||||
|
|
||||||
|
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||||
|
|
||||||
|
|
||||||
|
class NodeType(str, Enum):
|
||||||
|
START = "start"
|
||||||
|
END = "end"
|
||||||
|
AGENT = "agent"
|
||||||
|
NOTE = "note"
|
||||||
|
STATE = "state"
|
||||||
|
CONDITION = "condition"
|
||||||
|
|
||||||
|
|
||||||
|
class AgentType(str, Enum):
|
||||||
|
CLASSIC = "classic"
|
||||||
|
REACT = "react"
|
||||||
|
AGENTIC = "agentic"
|
||||||
|
RESEARCH = "research"
|
||||||
|
|
||||||
|
|
||||||
|
class ExecutionStatus(str, Enum):
|
||||||
|
PENDING = "pending"
|
||||||
|
RUNNING = "running"
|
||||||
|
COMPLETED = "completed"
|
||||||
|
FAILED = "failed"
|
||||||
|
|
||||||
|
|
||||||
|
class Position(BaseModel):
|
||||||
|
model_config = ConfigDict(extra="forbid")
|
||||||
|
x: float = 0.0
|
||||||
|
y: float = 0.0
|
||||||
|
|
||||||
|
|
||||||
|
class AgentNodeConfig(BaseModel):
|
||||||
|
model_config = ConfigDict(extra="allow")
|
||||||
|
agent_type: AgentType = AgentType.CLASSIC
|
||||||
|
llm_name: Optional[str] = None
|
||||||
|
system_prompt: str = "You are a helpful assistant."
|
||||||
|
prompt_template: str = ""
|
||||||
|
output_variable: Optional[str] = None
|
||||||
|
stream_to_user: bool = True
|
||||||
|
tools: List[str] = Field(default_factory=list)
|
||||||
|
sources: List[str] = Field(default_factory=list)
|
||||||
|
chunks: str = "2"
|
||||||
|
retriever: str = ""
|
||||||
|
model_id: Optional[str] = None
|
||||||
|
json_schema: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ConditionCase(BaseModel):
|
||||||
|
model_config = ConfigDict(extra="forbid", populate_by_name=True)
|
||||||
|
name: Optional[str] = None
|
||||||
|
expression: str = ""
|
||||||
|
source_handle: str = Field(..., alias="sourceHandle")
|
||||||
|
|
||||||
|
|
||||||
|
class ConditionNodeConfig(BaseModel):
|
||||||
|
model_config = ConfigDict(extra="allow")
|
||||||
|
mode: Literal["simple", "advanced"] = "simple"
|
||||||
|
cases: List[ConditionCase] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
class StateOperation(BaseModel):
|
||||||
|
model_config = ConfigDict(extra="forbid")
|
||||||
|
expression: str = ""
|
||||||
|
target_variable: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
class WorkflowEdgeCreate(BaseModel):
|
||||||
|
model_config = ConfigDict(populate_by_name=True)
|
||||||
|
id: str
|
||||||
|
workflow_id: str
|
||||||
|
source_id: str = Field(..., alias="source")
|
||||||
|
target_id: str = Field(..., alias="target")
|
||||||
|
source_handle: Optional[str] = Field(None, alias="sourceHandle")
|
||||||
|
target_handle: Optional[str] = Field(None, alias="targetHandle")
|
||||||
|
|
||||||
|
|
||||||
|
class WorkflowEdge(WorkflowEdgeCreate):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class WorkflowNodeCreate(BaseModel):
|
||||||
|
model_config = ConfigDict(extra="allow")
|
||||||
|
id: str
|
||||||
|
workflow_id: str
|
||||||
|
type: NodeType
|
||||||
|
title: str = "Node"
|
||||||
|
description: Optional[str] = None
|
||||||
|
position: Position = Field(default_factory=Position)
|
||||||
|
config: Dict[str, Any] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
@field_validator("position", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def parse_position(cls, v: Union[Dict[str, float], Position]) -> Position:
|
||||||
|
if isinstance(v, dict):
|
||||||
|
return Position(**v)
|
||||||
|
return v
|
||||||
|
|
||||||
|
|
||||||
|
class WorkflowNode(WorkflowNodeCreate):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class WorkflowCreate(BaseModel):
|
||||||
|
model_config = ConfigDict(extra="allow")
|
||||||
|
name: str = "New Workflow"
|
||||||
|
description: Optional[str] = None
|
||||||
|
user: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class Workflow(WorkflowCreate):
|
||||||
|
id: Optional[str] = None
|
||||||
|
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||||
|
updated_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||||
|
|
||||||
|
|
||||||
|
class WorkflowGraph(BaseModel):
|
||||||
|
workflow: Workflow
|
||||||
|
nodes: List[WorkflowNode] = Field(default_factory=list)
|
||||||
|
edges: List[WorkflowEdge] = Field(default_factory=list)
|
||||||
|
|
||||||
|
def get_node_by_id(self, node_id: str) -> Optional[WorkflowNode]:
|
||||||
|
for node in self.nodes:
|
||||||
|
if node.id == node_id:
|
||||||
|
return node
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_start_node(self) -> Optional[WorkflowNode]:
|
||||||
|
for node in self.nodes:
|
||||||
|
if node.type == NodeType.START:
|
||||||
|
return node
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_outgoing_edges(self, node_id: str) -> List[WorkflowEdge]:
|
||||||
|
return [edge for edge in self.edges if edge.source_id == node_id]
|
||||||
|
|
||||||
|
|
||||||
|
class NodeExecutionLog(BaseModel):
|
||||||
|
model_config = ConfigDict(extra="forbid")
|
||||||
|
node_id: str
|
||||||
|
node_type: str
|
||||||
|
status: ExecutionStatus
|
||||||
|
started_at: datetime
|
||||||
|
completed_at: Optional[datetime] = None
|
||||||
|
error: Optional[str] = None
|
||||||
|
state_snapshot: Dict[str, Any] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
class WorkflowRunCreate(BaseModel):
|
||||||
|
workflow_id: str
|
||||||
|
inputs: Dict[str, str] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
class WorkflowRun(BaseModel):
|
||||||
|
model_config = ConfigDict(extra="allow")
|
||||||
|
id: Optional[str] = None
|
||||||
|
workflow_id: str
|
||||||
|
user: Optional[str] = None
|
||||||
|
status: ExecutionStatus = ExecutionStatus.PENDING
|
||||||
|
inputs: Dict[str, str] = Field(default_factory=dict)
|
||||||
|
outputs: Dict[str, Any] = Field(default_factory=dict)
|
||||||
|
steps: List[NodeExecutionLog] = Field(default_factory=list)
|
||||||
|
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||||
|
completed_at: Optional[datetime] = None
|
||||||
496
application/agents/workflows/workflow_engine.py
Normal file
496
application/agents/workflows/workflow_engine.py
Normal file
@@ -0,0 +1,496 @@
|
|||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from typing import Any, Dict, Generator, List, Optional, TYPE_CHECKING
|
||||||
|
|
||||||
|
from application.agents.workflows.cel_evaluator import CelEvaluationError, evaluate_cel
|
||||||
|
from application.agents.workflows.node_agent import WorkflowNodeAgentFactory
|
||||||
|
from application.agents.workflows.schemas import (
|
||||||
|
AgentNodeConfig,
|
||||||
|
AgentType,
|
||||||
|
ConditionNodeConfig,
|
||||||
|
ExecutionStatus,
|
||||||
|
NodeExecutionLog,
|
||||||
|
NodeType,
|
||||||
|
WorkflowGraph,
|
||||||
|
WorkflowNode,
|
||||||
|
)
|
||||||
|
from application.core.json_schema_utils import (
|
||||||
|
JsonSchemaValidationError,
|
||||||
|
normalize_json_schema_payload,
|
||||||
|
)
|
||||||
|
from application.error import sanitize_api_error
|
||||||
|
from application.templates.namespaces import NamespaceManager
|
||||||
|
from application.templates.template_engine import TemplateEngine, TemplateRenderError
|
||||||
|
|
||||||
|
try:
|
||||||
|
import jsonschema
|
||||||
|
except ImportError: # pragma: no cover - optional dependency in some deployments.
|
||||||
|
jsonschema = None
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from application.agents.base import BaseAgent
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
StateValue = Any
|
||||||
|
WorkflowState = Dict[str, StateValue]
|
||||||
|
TEMPLATE_RESERVED_NAMESPACES = {"agent", "system", "source", "tools", "passthrough"}
|
||||||
|
|
||||||
|
|
||||||
|
class WorkflowEngine:
|
||||||
|
MAX_EXECUTION_STEPS = 50
|
||||||
|
|
||||||
|
def __init__(self, graph: WorkflowGraph, agent: "BaseAgent"):
|
||||||
|
self.graph = graph
|
||||||
|
self.agent = agent
|
||||||
|
self.state: WorkflowState = {}
|
||||||
|
self.execution_log: List[Dict[str, Any]] = []
|
||||||
|
self._condition_result: Optional[str] = None
|
||||||
|
self._template_engine = TemplateEngine()
|
||||||
|
self._namespace_manager = NamespaceManager()
|
||||||
|
|
||||||
|
def execute(
|
||||||
|
self, initial_inputs: WorkflowState, query: str
|
||||||
|
) -> Generator[Dict[str, str], None, None]:
|
||||||
|
self._initialize_state(initial_inputs, query)
|
||||||
|
|
||||||
|
start_node = self.graph.get_start_node()
|
||||||
|
if not start_node:
|
||||||
|
yield {"type": "error", "error": "No start node found in workflow."}
|
||||||
|
return
|
||||||
|
current_node_id: Optional[str] = start_node.id
|
||||||
|
steps = 0
|
||||||
|
|
||||||
|
while current_node_id and steps < self.MAX_EXECUTION_STEPS:
|
||||||
|
node = self.graph.get_node_by_id(current_node_id)
|
||||||
|
if not node:
|
||||||
|
yield {"type": "error", "error": f"Node {current_node_id} not found."}
|
||||||
|
break
|
||||||
|
log_entry = self._create_log_entry(node)
|
||||||
|
|
||||||
|
yield {
|
||||||
|
"type": "workflow_step",
|
||||||
|
"node_id": node.id,
|
||||||
|
"node_type": node.type.value,
|
||||||
|
"node_title": node.title,
|
||||||
|
"status": "running",
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
yield from self._execute_node(node)
|
||||||
|
log_entry["status"] = ExecutionStatus.COMPLETED.value
|
||||||
|
log_entry["completed_at"] = datetime.now(timezone.utc)
|
||||||
|
|
||||||
|
output_key = f"node_{node.id}_output"
|
||||||
|
node_output = self.state.get(output_key)
|
||||||
|
|
||||||
|
yield {
|
||||||
|
"type": "workflow_step",
|
||||||
|
"node_id": node.id,
|
||||||
|
"node_type": node.type.value,
|
||||||
|
"node_title": node.title,
|
||||||
|
"status": "completed",
|
||||||
|
"state_snapshot": dict(self.state),
|
||||||
|
"output": node_output,
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error executing node {node.id}: {e}", exc_info=True)
|
||||||
|
log_entry["status"] = ExecutionStatus.FAILED.value
|
||||||
|
log_entry["error"] = str(e)
|
||||||
|
log_entry["completed_at"] = datetime.now(timezone.utc)
|
||||||
|
log_entry["state_snapshot"] = dict(self.state)
|
||||||
|
self.execution_log.append(log_entry)
|
||||||
|
|
||||||
|
user_friendly_error = sanitize_api_error(e)
|
||||||
|
yield {
|
||||||
|
"type": "workflow_step",
|
||||||
|
"node_id": node.id,
|
||||||
|
"node_type": node.type.value,
|
||||||
|
"node_title": node.title,
|
||||||
|
"status": "failed",
|
||||||
|
"state_snapshot": dict(self.state),
|
||||||
|
"error": user_friendly_error,
|
||||||
|
}
|
||||||
|
yield {"type": "error", "error": user_friendly_error}
|
||||||
|
break
|
||||||
|
log_entry["state_snapshot"] = dict(self.state)
|
||||||
|
self.execution_log.append(log_entry)
|
||||||
|
|
||||||
|
if node.type == NodeType.END:
|
||||||
|
break
|
||||||
|
current_node_id = self._get_next_node_id(current_node_id)
|
||||||
|
if current_node_id is None and node.type != NodeType.END:
|
||||||
|
logger.warning(
|
||||||
|
f"Branch ended at node '{node.title}' ({node.id}) without reaching an end node"
|
||||||
|
)
|
||||||
|
steps += 1
|
||||||
|
if steps >= self.MAX_EXECUTION_STEPS:
|
||||||
|
logger.warning(
|
||||||
|
f"Workflow reached max steps limit ({self.MAX_EXECUTION_STEPS})"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _initialize_state(self, initial_inputs: WorkflowState, query: str) -> None:
|
||||||
|
self.state.update(initial_inputs)
|
||||||
|
self.state["query"] = query
|
||||||
|
self.state["chat_history"] = str(self.agent.chat_history)
|
||||||
|
|
||||||
|
def _create_log_entry(self, node: WorkflowNode) -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"node_id": node.id,
|
||||||
|
"node_type": node.type.value,
|
||||||
|
"started_at": datetime.now(timezone.utc),
|
||||||
|
"completed_at": None,
|
||||||
|
"status": ExecutionStatus.RUNNING.value,
|
||||||
|
"error": None,
|
||||||
|
"state_snapshot": {},
|
||||||
|
}
|
||||||
|
|
||||||
|
def _get_next_node_id(self, current_node_id: str) -> Optional[str]:
|
||||||
|
node = self.graph.get_node_by_id(current_node_id)
|
||||||
|
edges = self.graph.get_outgoing_edges(current_node_id)
|
||||||
|
if not edges:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if node and node.type == NodeType.CONDITION and self._condition_result:
|
||||||
|
target_handle = self._condition_result
|
||||||
|
self._condition_result = None
|
||||||
|
for edge in edges:
|
||||||
|
if edge.source_handle == target_handle:
|
||||||
|
return edge.target_id
|
||||||
|
return None
|
||||||
|
|
||||||
|
return edges[0].target_id
|
||||||
|
|
||||||
|
def _execute_node(
|
||||||
|
self, node: WorkflowNode
|
||||||
|
) -> Generator[Dict[str, str], None, None]:
|
||||||
|
logger.info(f"Executing node {node.id} ({node.type.value})")
|
||||||
|
|
||||||
|
node_handlers = {
|
||||||
|
NodeType.START: self._execute_start_node,
|
||||||
|
NodeType.NOTE: self._execute_note_node,
|
||||||
|
NodeType.AGENT: self._execute_agent_node,
|
||||||
|
NodeType.STATE: self._execute_state_node,
|
||||||
|
NodeType.CONDITION: self._execute_condition_node,
|
||||||
|
NodeType.END: self._execute_end_node,
|
||||||
|
}
|
||||||
|
|
||||||
|
handler = node_handlers.get(node.type)
|
||||||
|
if handler:
|
||||||
|
yield from handler(node)
|
||||||
|
|
||||||
|
def _execute_start_node(
|
||||||
|
self, node: WorkflowNode
|
||||||
|
) -> Generator[Dict[str, str], None, None]:
|
||||||
|
yield from ()
|
||||||
|
|
||||||
|
def _execute_note_node(
|
||||||
|
self, node: WorkflowNode
|
||||||
|
) -> Generator[Dict[str, str], None, None]:
|
||||||
|
yield from ()
|
||||||
|
|
||||||
|
def _execute_agent_node(
|
||||||
|
self, node: WorkflowNode
|
||||||
|
) -> Generator[Dict[str, str], None, None]:
|
||||||
|
from application.core.model_utils import (
|
||||||
|
get_api_key_for_provider,
|
||||||
|
get_model_capabilities,
|
||||||
|
get_provider_from_model_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
node_config = AgentNodeConfig(**node.config.get("config", node.config))
|
||||||
|
|
||||||
|
if node_config.sources:
|
||||||
|
self._retrieve_node_sources(node_config)
|
||||||
|
|
||||||
|
if node_config.prompt_template:
|
||||||
|
formatted_prompt = self._format_template(node_config.prompt_template)
|
||||||
|
else:
|
||||||
|
formatted_prompt = self.state.get("query", "")
|
||||||
|
node_json_schema = self._normalize_node_json_schema(
|
||||||
|
node_config.json_schema, node.title
|
||||||
|
)
|
||||||
|
node_model_id = node_config.model_id or self.agent.model_id
|
||||||
|
node_llm_name = (
|
||||||
|
node_config.llm_name
|
||||||
|
or get_provider_from_model_id(node_model_id or "")
|
||||||
|
or self.agent.llm_name
|
||||||
|
)
|
||||||
|
node_api_key = get_api_key_for_provider(node_llm_name) or self.agent.api_key
|
||||||
|
|
||||||
|
if node_json_schema and node_model_id:
|
||||||
|
model_capabilities = get_model_capabilities(node_model_id)
|
||||||
|
if model_capabilities and not model_capabilities.get(
|
||||||
|
"supports_structured_output", False
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
f'Model "{node_model_id}" does not support structured output for node "{node.title}"'
|
||||||
|
)
|
||||||
|
|
||||||
|
factory_kwargs = {
|
||||||
|
"agent_type": node_config.agent_type,
|
||||||
|
"endpoint": self.agent.endpoint,
|
||||||
|
"llm_name": node_llm_name,
|
||||||
|
"model_id": node_model_id,
|
||||||
|
"api_key": node_api_key,
|
||||||
|
"tool_ids": node_config.tools,
|
||||||
|
"prompt": node_config.system_prompt,
|
||||||
|
"chat_history": self.agent.chat_history,
|
||||||
|
"decoded_token": self.agent.decoded_token,
|
||||||
|
"json_schema": node_json_schema,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Agentic/research agents need retriever_config for on-demand search
|
||||||
|
if node_config.agent_type in (AgentType.AGENTIC, AgentType.RESEARCH):
|
||||||
|
factory_kwargs["retriever_config"] = {
|
||||||
|
"source": {"active_docs": node_config.sources} if node_config.sources else {},
|
||||||
|
"retriever_name": node_config.retriever or "classic",
|
||||||
|
"chunks": int(node_config.chunks) if node_config.chunks else 2,
|
||||||
|
"model_id": node_model_id,
|
||||||
|
"llm_name": node_llm_name,
|
||||||
|
"api_key": node_api_key,
|
||||||
|
"decoded_token": self.agent.decoded_token,
|
||||||
|
}
|
||||||
|
|
||||||
|
node_agent = WorkflowNodeAgentFactory.create(**factory_kwargs)
|
||||||
|
|
||||||
|
full_response_parts: List[str] = []
|
||||||
|
structured_response_parts: List[str] = []
|
||||||
|
has_structured_response = False
|
||||||
|
first_chunk = True
|
||||||
|
for event in node_agent.gen(formatted_prompt):
|
||||||
|
if "answer" in event:
|
||||||
|
chunk = str(event["answer"])
|
||||||
|
full_response_parts.append(chunk)
|
||||||
|
if event.get("structured"):
|
||||||
|
has_structured_response = True
|
||||||
|
structured_response_parts.append(chunk)
|
||||||
|
if node_config.stream_to_user:
|
||||||
|
if first_chunk and hasattr(self, "_has_streamed"):
|
||||||
|
yield {"answer": "\n\n"}
|
||||||
|
first_chunk = False
|
||||||
|
yield event
|
||||||
|
|
||||||
|
if node_config.stream_to_user:
|
||||||
|
self._has_streamed = True
|
||||||
|
|
||||||
|
full_response = "".join(full_response_parts).strip()
|
||||||
|
output_value: Any = full_response
|
||||||
|
if has_structured_response:
|
||||||
|
structured_response = "".join(structured_response_parts).strip()
|
||||||
|
response_to_parse = structured_response or full_response
|
||||||
|
parsed_success, parsed_structured = self._parse_structured_output(
|
||||||
|
response_to_parse
|
||||||
|
)
|
||||||
|
output_value = parsed_structured if parsed_success else response_to_parse
|
||||||
|
if node_json_schema:
|
||||||
|
self._validate_structured_output(node_json_schema, output_value)
|
||||||
|
elif node_json_schema:
|
||||||
|
parsed_success, parsed_structured = self._parse_structured_output(
|
||||||
|
full_response
|
||||||
|
)
|
||||||
|
if not parsed_success:
|
||||||
|
raise ValueError(
|
||||||
|
"Structured output was expected but response was not valid JSON"
|
||||||
|
)
|
||||||
|
output_value = parsed_structured
|
||||||
|
self._validate_structured_output(node_json_schema, output_value)
|
||||||
|
|
||||||
|
default_output_key = f"node_{node.id}_output"
|
||||||
|
self.state[default_output_key] = output_value
|
||||||
|
|
||||||
|
if node_config.output_variable:
|
||||||
|
self.state[node_config.output_variable] = output_value
|
||||||
|
|
||||||
|
def _execute_state_node(
|
||||||
|
self, node: WorkflowNode
|
||||||
|
) -> Generator[Dict[str, str], None, None]:
|
||||||
|
config = node.config.get("config", node.config)
|
||||||
|
for op in config.get("operations", []):
|
||||||
|
expression = op.get("expression", "")
|
||||||
|
target_variable = op.get("target_variable", "")
|
||||||
|
if expression and target_variable:
|
||||||
|
self.state[target_variable] = evaluate_cel(expression, self.state)
|
||||||
|
yield from ()
|
||||||
|
|
||||||
|
def _execute_condition_node(
|
||||||
|
self, node: WorkflowNode
|
||||||
|
) -> Generator[Dict[str, str], None, None]:
|
||||||
|
config = ConditionNodeConfig(**node.config.get("config", node.config))
|
||||||
|
matched_handle = None
|
||||||
|
|
||||||
|
for case in config.cases:
|
||||||
|
if not case.expression.strip():
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
if evaluate_cel(case.expression, self.state):
|
||||||
|
matched_handle = case.source_handle
|
||||||
|
break
|
||||||
|
except CelEvaluationError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
self._condition_result = matched_handle or "else"
|
||||||
|
yield from ()
|
||||||
|
|
||||||
|
def _execute_end_node(
|
||||||
|
self, node: WorkflowNode
|
||||||
|
) -> Generator[Dict[str, str], None, None]:
|
||||||
|
config = node.config.get("config", node.config)
|
||||||
|
output_template = str(config.get("output_template", ""))
|
||||||
|
if output_template:
|
||||||
|
formatted_output = self._format_template(output_template)
|
||||||
|
yield {"answer": formatted_output}
|
||||||
|
|
||||||
|
def _parse_structured_output(self, raw_response: str) -> tuple[bool, Optional[Any]]:
|
||||||
|
normalized_response = raw_response.strip()
|
||||||
|
if not normalized_response:
|
||||||
|
return False, None
|
||||||
|
|
||||||
|
try:
|
||||||
|
return True, json.loads(normalized_response)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
logger.warning(
|
||||||
|
"Workflow agent returned structured output that was not valid JSON"
|
||||||
|
)
|
||||||
|
return False, None
|
||||||
|
|
||||||
|
def _normalize_node_json_schema(
|
||||||
|
self, schema: Optional[Dict[str, Any]], node_title: str
|
||||||
|
) -> Optional[Dict[str, Any]]:
|
||||||
|
if schema is None:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
return normalize_json_schema_payload(schema)
|
||||||
|
except JsonSchemaValidationError as exc:
|
||||||
|
raise ValueError(
|
||||||
|
f'Invalid JSON schema for node "{node_title}": {exc}'
|
||||||
|
) from exc
|
||||||
|
|
||||||
|
def _validate_structured_output(self, schema: Dict[str, Any], output_value: Any) -> None:
|
||||||
|
if jsonschema is None:
|
||||||
|
logger.warning(
|
||||||
|
"jsonschema package is not available, skipping structured output validation"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
normalized_schema = normalize_json_schema_payload(schema)
|
||||||
|
except JsonSchemaValidationError as exc:
|
||||||
|
raise ValueError(f"Invalid JSON schema: {exc}") from exc
|
||||||
|
|
||||||
|
try:
|
||||||
|
jsonschema.validate(instance=output_value, schema=normalized_schema)
|
||||||
|
except jsonschema.exceptions.ValidationError as exc:
|
||||||
|
raise ValueError(f"Structured output did not match schema: {exc.message}") from exc
|
||||||
|
except jsonschema.exceptions.SchemaError as exc:
|
||||||
|
raise ValueError(f"Invalid JSON schema: {exc.message}") from exc
|
||||||
|
|
||||||
|
def _format_template(self, template: str) -> str:
|
||||||
|
context = self._build_template_context()
|
||||||
|
try:
|
||||||
|
return self._template_engine.render(template, context)
|
||||||
|
except TemplateRenderError as e:
|
||||||
|
logger.warning(
|
||||||
|
"Workflow template rendering failed, using raw template: %s", str(e)
|
||||||
|
)
|
||||||
|
return template
|
||||||
|
|
||||||
|
def _build_template_context(self) -> Dict[str, Any]:
|
||||||
|
docs, docs_together = self._get_source_template_data()
|
||||||
|
passthrough_data = (
|
||||||
|
self.state.get("passthrough")
|
||||||
|
if isinstance(self.state.get("passthrough"), dict)
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
tools_data = (
|
||||||
|
self.state.get("tools") if isinstance(self.state.get("tools"), dict) else None
|
||||||
|
)
|
||||||
|
|
||||||
|
context = self._namespace_manager.build_context(
|
||||||
|
user_id=getattr(self.agent, "user", None),
|
||||||
|
request_id=getattr(self.agent, "request_id", None),
|
||||||
|
passthrough_data=passthrough_data,
|
||||||
|
docs=docs,
|
||||||
|
docs_together=docs_together,
|
||||||
|
tools_data=tools_data,
|
||||||
|
)
|
||||||
|
|
||||||
|
agent_context: Dict[str, Any] = {}
|
||||||
|
for key, value in self.state.items():
|
||||||
|
if not isinstance(key, str):
|
||||||
|
continue
|
||||||
|
normalized_key = key.strip()
|
||||||
|
if not normalized_key:
|
||||||
|
continue
|
||||||
|
agent_context[normalized_key] = value
|
||||||
|
|
||||||
|
context["agent"] = agent_context
|
||||||
|
|
||||||
|
# Keep legacy top-level variables working while namespaced variables are adopted.
|
||||||
|
for key, value in agent_context.items():
|
||||||
|
if key in TEMPLATE_RESERVED_NAMESPACES:
|
||||||
|
context[f"agent_{key}"] = value
|
||||||
|
continue
|
||||||
|
if key not in context:
|
||||||
|
context[key] = value
|
||||||
|
|
||||||
|
return context
|
||||||
|
|
||||||
|
def _get_source_template_data(self) -> tuple[Optional[List[Dict[str, Any]]], Optional[str]]:
|
||||||
|
docs = getattr(self.agent, "retrieved_docs", None)
|
||||||
|
if not isinstance(docs, list) or len(docs) == 0:
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
docs_together_parts: List[str] = []
|
||||||
|
for doc in docs:
|
||||||
|
if not isinstance(doc, dict):
|
||||||
|
continue
|
||||||
|
text = doc.get("text")
|
||||||
|
if not isinstance(text, str):
|
||||||
|
continue
|
||||||
|
|
||||||
|
filename = doc.get("filename") or doc.get("title") or doc.get("source")
|
||||||
|
if isinstance(filename, str) and filename.strip():
|
||||||
|
docs_together_parts.append(f"{filename}\n{text}")
|
||||||
|
else:
|
||||||
|
docs_together_parts.append(text)
|
||||||
|
|
||||||
|
docs_together = "\n\n".join(docs_together_parts) if docs_together_parts else None
|
||||||
|
return docs, docs_together
|
||||||
|
|
||||||
|
def _retrieve_node_sources(self, node_config: AgentNodeConfig) -> None:
|
||||||
|
"""Retrieve documents from the node's sources for template resolution."""
|
||||||
|
from application.retriever.retriever_creator import RetrieverCreator
|
||||||
|
|
||||||
|
query = self.state.get("query", "")
|
||||||
|
if not query:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
retriever = RetrieverCreator.create_retriever(
|
||||||
|
node_config.retriever or "classic",
|
||||||
|
source={"active_docs": node_config.sources},
|
||||||
|
chat_history=[],
|
||||||
|
prompt="",
|
||||||
|
chunks=int(node_config.chunks) if node_config.chunks else 2,
|
||||||
|
decoded_token=self.agent.decoded_token,
|
||||||
|
)
|
||||||
|
docs = retriever.search(query)
|
||||||
|
if docs:
|
||||||
|
self.agent.retrieved_docs = docs
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to retrieve docs for workflow node")
|
||||||
|
|
||||||
|
def get_execution_summary(self) -> List[NodeExecutionLog]:
|
||||||
|
return [
|
||||||
|
NodeExecutionLog(
|
||||||
|
node_id=log["node_id"],
|
||||||
|
node_type=log["node_type"],
|
||||||
|
status=ExecutionStatus(log["status"]),
|
||||||
|
started_at=log["started_at"],
|
||||||
|
completed_at=log.get("completed_at"),
|
||||||
|
error=log.get("error"),
|
||||||
|
state_snapshot=log.get("state_snapshot", {}),
|
||||||
|
)
|
||||||
|
for log in self.execution_log
|
||||||
|
]
|
||||||
52
application/alembic.ini
Normal file
52
application/alembic.ini
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
# Alembic configuration for the DocsGPT user-data Postgres database.
|
||||||
|
#
|
||||||
|
# The SQLAlchemy URL is deliberately NOT set here — env.py reads it from
|
||||||
|
# ``application.core.settings.settings.POSTGRES_URI`` so the same config
|
||||||
|
# source serves the running app and migrations. To run from the project
|
||||||
|
# root::
|
||||||
|
#
|
||||||
|
# alembic -c application/alembic.ini upgrade head
|
||||||
|
|
||||||
|
[alembic]
|
||||||
|
script_location = %(here)s/alembic
|
||||||
|
prepend_sys_path = ..
|
||||||
|
version_path_separator = os
|
||||||
|
|
||||||
|
# sqlalchemy.url is intentionally left blank — env.py supplies it.
|
||||||
|
sqlalchemy.url =
|
||||||
|
|
||||||
|
[post_write_hooks]
|
||||||
|
|
||||||
|
[loggers]
|
||||||
|
keys = root,sqlalchemy,alembic
|
||||||
|
|
||||||
|
[handlers]
|
||||||
|
keys = console
|
||||||
|
|
||||||
|
[formatters]
|
||||||
|
keys = generic
|
||||||
|
|
||||||
|
[logger_root]
|
||||||
|
level = WARNING
|
||||||
|
handlers = console
|
||||||
|
qualname =
|
||||||
|
|
||||||
|
[logger_sqlalchemy]
|
||||||
|
level = WARNING
|
||||||
|
handlers =
|
||||||
|
qualname = sqlalchemy.engine
|
||||||
|
|
||||||
|
[logger_alembic]
|
||||||
|
level = INFO
|
||||||
|
handlers =
|
||||||
|
qualname = alembic
|
||||||
|
|
||||||
|
[handler_console]
|
||||||
|
class = StreamHandler
|
||||||
|
args = (sys.stderr,)
|
||||||
|
level = NOTSET
|
||||||
|
formatter = generic
|
||||||
|
|
||||||
|
[formatter_generic]
|
||||||
|
format = %(levelname)-5.5s [%(name)s] %(message)s
|
||||||
|
datefmt = %H:%M:%S
|
||||||
82
application/alembic/env.py
Normal file
82
application/alembic/env.py
Normal file
@@ -0,0 +1,82 @@
|
|||||||
|
"""Alembic environment for the DocsGPT user-data Postgres database.
|
||||||
|
|
||||||
|
The URL is pulled from ``application.core.settings`` rather than
|
||||||
|
``alembic.ini`` so that a single ``POSTGRES_URI`` env var drives both the
|
||||||
|
running app and ``alembic`` CLI invocations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
from logging.config import fileConfig
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# Make the project root importable regardless of cwd. env.py lives at
|
||||||
|
# <repo>/application/alembic/env.py, so parents[2] is the repo root.
|
||||||
|
_PROJECT_ROOT = Path(__file__).resolve().parents[2]
|
||||||
|
if str(_PROJECT_ROOT) not in sys.path:
|
||||||
|
sys.path.insert(0, str(_PROJECT_ROOT))
|
||||||
|
|
||||||
|
from alembic import context # noqa: E402
|
||||||
|
from sqlalchemy import engine_from_config, pool # noqa: E402
|
||||||
|
|
||||||
|
from application.core.settings import settings # noqa: E402
|
||||||
|
from application.storage.db.models import metadata as target_metadata # noqa: E402
|
||||||
|
|
||||||
|
config = context.config
|
||||||
|
|
||||||
|
# Populate the runtime URL from settings.
|
||||||
|
if settings.POSTGRES_URI:
|
||||||
|
config.set_main_option("sqlalchemy.url", settings.POSTGRES_URI)
|
||||||
|
|
||||||
|
if config.config_file_name is not None:
|
||||||
|
fileConfig(config.config_file_name)
|
||||||
|
|
||||||
|
|
||||||
|
def run_migrations_offline() -> None:
|
||||||
|
"""Run migrations in 'offline' mode (emits SQL without a live DB)."""
|
||||||
|
url = config.get_main_option("sqlalchemy.url")
|
||||||
|
if not url:
|
||||||
|
raise RuntimeError(
|
||||||
|
"POSTGRES_URI is not configured. Set it in your .env to a "
|
||||||
|
"psycopg3 URI such as "
|
||||||
|
"'postgresql+psycopg://user:pass@host:5432/docsgpt'."
|
||||||
|
)
|
||||||
|
context.configure(
|
||||||
|
url=url,
|
||||||
|
target_metadata=target_metadata,
|
||||||
|
literal_binds=True,
|
||||||
|
dialect_opts={"paramstyle": "named"},
|
||||||
|
compare_type=True,
|
||||||
|
)
|
||||||
|
with context.begin_transaction():
|
||||||
|
context.run_migrations()
|
||||||
|
|
||||||
|
|
||||||
|
def run_migrations_online() -> None:
|
||||||
|
"""Run migrations in 'online' mode against a live connection."""
|
||||||
|
if not config.get_main_option("sqlalchemy.url"):
|
||||||
|
raise RuntimeError(
|
||||||
|
"POSTGRES_URI is not configured. Set it in your .env to a "
|
||||||
|
"psycopg3 URI such as "
|
||||||
|
"'postgresql+psycopg://user:pass@host:5432/docsgpt'."
|
||||||
|
)
|
||||||
|
connectable = engine_from_config(
|
||||||
|
config.get_section(config.config_ini_section, {}),
|
||||||
|
prefix="sqlalchemy.",
|
||||||
|
poolclass=pool.NullPool,
|
||||||
|
future=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
with connectable.connect() as connection:
|
||||||
|
context.configure(
|
||||||
|
connection=connection,
|
||||||
|
target_metadata=target_metadata,
|
||||||
|
compare_type=True,
|
||||||
|
)
|
||||||
|
with context.begin_transaction():
|
||||||
|
context.run_migrations()
|
||||||
|
|
||||||
|
|
||||||
|
if context.is_offline_mode():
|
||||||
|
run_migrations_offline()
|
||||||
|
else:
|
||||||
|
run_migrations_online()
|
||||||
26
application/alembic/script.py.mako
Normal file
26
application/alembic/script.py.mako
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
"""${message}
|
||||||
|
|
||||||
|
Revision ID: ${up_revision}
|
||||||
|
Revises: ${down_revision | comma,n}
|
||||||
|
Create Date: ${create_date}
|
||||||
|
|
||||||
|
"""
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
${imports if imports else ""}
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = ${repr(up_revision)}
|
||||||
|
down_revision: Union[str, None] = ${repr(down_revision)}
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
|
||||||
|
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
${upgrades if upgrades else "pass"}
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
${downgrades if downgrades else "pass"}
|
||||||
927
application/alembic/versions/0001_initial.py
Normal file
927
application/alembic/versions/0001_initial.py
Normal file
@@ -0,0 +1,927 @@
|
|||||||
|
"""0001 initial schema — consolidated Phase-1..3 baseline.
|
||||||
|
|
||||||
|
Revision ID: 0001_initial
|
||||||
|
Revises:
|
||||||
|
Create Date: 2026-04-13
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
|
||||||
|
revision: str = "0001_initial"
|
||||||
|
down_revision: Union[str, None] = None
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Extensions
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
op.execute('CREATE EXTENSION IF NOT EXISTS "pgcrypto";')
|
||||||
|
op.execute('CREATE EXTENSION IF NOT EXISTS "citext";')
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Trigger functions
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
op.execute(
|
||||||
|
"""
|
||||||
|
CREATE FUNCTION set_updated_at() RETURNS trigger
|
||||||
|
LANGUAGE plpgsql AS $$
|
||||||
|
BEGIN
|
||||||
|
NEW.updated_at = now();
|
||||||
|
RETURN NEW;
|
||||||
|
END;
|
||||||
|
$$;
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
op.execute(
|
||||||
|
"""
|
||||||
|
CREATE FUNCTION ensure_user_exists() RETURNS trigger
|
||||||
|
LANGUAGE plpgsql AS $$
|
||||||
|
BEGIN
|
||||||
|
IF NEW.user_id IS NOT NULL THEN
|
||||||
|
INSERT INTO users (user_id) VALUES (NEW.user_id)
|
||||||
|
ON CONFLICT (user_id) DO NOTHING;
|
||||||
|
END IF;
|
||||||
|
RETURN NEW;
|
||||||
|
END;
|
||||||
|
$$;
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
op.execute(
|
||||||
|
"""
|
||||||
|
CREATE FUNCTION cleanup_message_attachment_refs() RETURNS trigger
|
||||||
|
LANGUAGE plpgsql AS $$
|
||||||
|
BEGIN
|
||||||
|
UPDATE conversation_messages
|
||||||
|
SET attachments = array_remove(attachments, OLD.id)
|
||||||
|
WHERE OLD.id = ANY(attachments);
|
||||||
|
RETURN OLD;
|
||||||
|
END;
|
||||||
|
$$;
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
op.execute(
|
||||||
|
"""
|
||||||
|
CREATE FUNCTION cleanup_agent_extra_source_refs() RETURNS trigger
|
||||||
|
LANGUAGE plpgsql AS $$
|
||||||
|
BEGIN
|
||||||
|
UPDATE agents
|
||||||
|
SET extra_source_ids = array_remove(extra_source_ids, OLD.id)
|
||||||
|
WHERE OLD.id = ANY(extra_source_ids);
|
||||||
|
RETURN OLD;
|
||||||
|
END;
|
||||||
|
$$;
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
op.execute(
|
||||||
|
"""
|
||||||
|
CREATE FUNCTION cleanup_user_agent_prefs() RETURNS trigger
|
||||||
|
LANGUAGE plpgsql AS $$
|
||||||
|
DECLARE
|
||||||
|
agent_id_text text := OLD.id::text;
|
||||||
|
BEGIN
|
||||||
|
UPDATE users
|
||||||
|
SET agent_preferences = jsonb_set(
|
||||||
|
jsonb_set(
|
||||||
|
agent_preferences,
|
||||||
|
'{pinned}',
|
||||||
|
COALESCE((
|
||||||
|
SELECT jsonb_agg(e)
|
||||||
|
FROM jsonb_array_elements(
|
||||||
|
COALESCE(agent_preferences->'pinned', '[]'::jsonb)
|
||||||
|
) e
|
||||||
|
WHERE (e #>> '{}') <> agent_id_text
|
||||||
|
), '[]'::jsonb)
|
||||||
|
),
|
||||||
|
'{shared_with_me}',
|
||||||
|
COALESCE((
|
||||||
|
SELECT jsonb_agg(e)
|
||||||
|
FROM jsonb_array_elements(
|
||||||
|
COALESCE(agent_preferences->'shared_with_me', '[]'::jsonb)
|
||||||
|
) e
|
||||||
|
WHERE (e #>> '{}') <> agent_id_text
|
||||||
|
), '[]'::jsonb)
|
||||||
|
)
|
||||||
|
WHERE agent_preferences->'pinned' @> to_jsonb(agent_id_text)
|
||||||
|
OR agent_preferences->'shared_with_me' @> to_jsonb(agent_id_text);
|
||||||
|
RETURN OLD;
|
||||||
|
END;
|
||||||
|
$$;
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
op.execute(
|
||||||
|
"""
|
||||||
|
CREATE FUNCTION conversation_messages_fill_user_id() RETURNS trigger
|
||||||
|
LANGUAGE plpgsql AS $$
|
||||||
|
BEGIN
|
||||||
|
IF NEW.user_id IS NULL THEN
|
||||||
|
SELECT user_id INTO NEW.user_id
|
||||||
|
FROM conversations
|
||||||
|
WHERE id = NEW.conversation_id;
|
||||||
|
END IF;
|
||||||
|
RETURN NEW;
|
||||||
|
END;
|
||||||
|
$$;
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Tables
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
op.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE users (
|
||||||
|
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||||
|
user_id TEXT NOT NULL UNIQUE,
|
||||||
|
agent_preferences JSONB NOT NULL
|
||||||
|
DEFAULT '{"pinned": [], "shared_with_me": []}'::jsonb,
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||||
|
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
op.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE prompts (
|
||||||
|
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||||
|
user_id TEXT NOT NULL,
|
||||||
|
name TEXT NOT NULL,
|
||||||
|
content TEXT NOT NULL,
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||||
|
updated_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||||
|
legacy_mongo_id TEXT
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
op.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE user_tools (
|
||||||
|
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||||
|
user_id TEXT NOT NULL,
|
||||||
|
name TEXT NOT NULL,
|
||||||
|
custom_name TEXT,
|
||||||
|
display_name TEXT,
|
||||||
|
description TEXT,
|
||||||
|
config JSONB NOT NULL DEFAULT '{}'::jsonb,
|
||||||
|
config_requirements JSONB NOT NULL DEFAULT '{}'::jsonb,
|
||||||
|
actions JSONB NOT NULL DEFAULT '[]'::jsonb,
|
||||||
|
status BOOLEAN NOT NULL DEFAULT true,
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||||
|
updated_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||||
|
legacy_mongo_id TEXT
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
op.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE token_usage (
|
||||||
|
id BIGSERIAL PRIMARY KEY,
|
||||||
|
user_id TEXT,
|
||||||
|
api_key TEXT,
|
||||||
|
agent_id UUID,
|
||||||
|
prompt_tokens INTEGER NOT NULL DEFAULT 0,
|
||||||
|
generated_tokens INTEGER NOT NULL DEFAULT 0,
|
||||||
|
timestamp TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||||
|
mongo_id TEXT
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
op.execute(
|
||||||
|
"ALTER TABLE token_usage ADD CONSTRAINT token_usage_attribution_chk "
|
||||||
|
"CHECK (user_id IS NOT NULL OR api_key IS NOT NULL) NOT VALID;"
|
||||||
|
)
|
||||||
|
|
||||||
|
op.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE user_logs (
|
||||||
|
id BIGSERIAL PRIMARY KEY,
|
||||||
|
user_id TEXT,
|
||||||
|
endpoint TEXT,
|
||||||
|
timestamp TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||||
|
data JSONB,
|
||||||
|
mongo_id TEXT
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
op.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE stack_logs (
|
||||||
|
id BIGSERIAL PRIMARY KEY,
|
||||||
|
activity_id TEXT NOT NULL,
|
||||||
|
endpoint TEXT,
|
||||||
|
level TEXT,
|
||||||
|
user_id TEXT,
|
||||||
|
api_key TEXT,
|
||||||
|
query TEXT,
|
||||||
|
stacks JSONB NOT NULL DEFAULT '[]'::jsonb,
|
||||||
|
timestamp TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||||
|
mongo_id TEXT
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
op.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE agent_folders (
|
||||||
|
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||||
|
user_id TEXT NOT NULL,
|
||||||
|
name TEXT NOT NULL,
|
||||||
|
description TEXT,
|
||||||
|
parent_id UUID REFERENCES agent_folders(id) ON DELETE SET NULL,
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||||
|
updated_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||||
|
legacy_mongo_id TEXT
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
op.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE sources (
|
||||||
|
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||||
|
user_id TEXT NOT NULL,
|
||||||
|
name TEXT NOT NULL,
|
||||||
|
language TEXT,
|
||||||
|
date TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||||
|
model TEXT,
|
||||||
|
type TEXT,
|
||||||
|
metadata JSONB NOT NULL DEFAULT '{}'::jsonb,
|
||||||
|
retriever TEXT,
|
||||||
|
sync_frequency TEXT,
|
||||||
|
tokens TEXT,
|
||||||
|
file_path TEXT,
|
||||||
|
remote_data JSONB,
|
||||||
|
directory_structure JSONB,
|
||||||
|
file_name_map JSONB,
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||||
|
updated_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||||
|
legacy_mongo_id TEXT
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
op.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE agents (
|
||||||
|
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||||
|
user_id TEXT NOT NULL,
|
||||||
|
name TEXT NOT NULL,
|
||||||
|
description TEXT,
|
||||||
|
agent_type TEXT,
|
||||||
|
status TEXT NOT NULL,
|
||||||
|
key CITEXT UNIQUE,
|
||||||
|
image TEXT,
|
||||||
|
source_id UUID REFERENCES sources(id) ON DELETE SET NULL,
|
||||||
|
extra_source_ids UUID[] NOT NULL DEFAULT '{}',
|
||||||
|
chunks INTEGER,
|
||||||
|
retriever TEXT,
|
||||||
|
prompt_id UUID REFERENCES prompts(id) ON DELETE SET NULL,
|
||||||
|
tools JSONB NOT NULL DEFAULT '[]'::jsonb,
|
||||||
|
json_schema JSONB,
|
||||||
|
models JSONB,
|
||||||
|
default_model_id TEXT,
|
||||||
|
folder_id UUID REFERENCES agent_folders(id) ON DELETE SET NULL,
|
||||||
|
workflow_id UUID,
|
||||||
|
limited_token_mode BOOLEAN NOT NULL DEFAULT false,
|
||||||
|
token_limit INTEGER,
|
||||||
|
limited_request_mode BOOLEAN NOT NULL DEFAULT false,
|
||||||
|
request_limit INTEGER,
|
||||||
|
allow_system_prompt_override BOOLEAN NOT NULL DEFAULT false,
|
||||||
|
shared BOOLEAN NOT NULL DEFAULT false,
|
||||||
|
shared_token CITEXT UNIQUE,
|
||||||
|
shared_metadata JSONB,
|
||||||
|
incoming_webhook_token CITEXT UNIQUE,
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||||
|
updated_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||||
|
last_used_at TIMESTAMPTZ,
|
||||||
|
legacy_mongo_id TEXT
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
op.execute(
|
||||||
|
"ALTER TABLE token_usage ADD CONSTRAINT token_usage_agent_fk "
|
||||||
|
"FOREIGN KEY (agent_id) REFERENCES agents(id) ON DELETE SET NULL;"
|
||||||
|
)
|
||||||
|
|
||||||
|
op.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE attachments (
|
||||||
|
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||||
|
user_id TEXT NOT NULL,
|
||||||
|
filename TEXT NOT NULL,
|
||||||
|
upload_path TEXT NOT NULL,
|
||||||
|
mime_type TEXT,
|
||||||
|
size BIGINT,
|
||||||
|
content TEXT,
|
||||||
|
token_count INTEGER,
|
||||||
|
openai_file_id TEXT,
|
||||||
|
google_file_uri TEXT,
|
||||||
|
metadata JSONB,
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||||
|
legacy_mongo_id TEXT
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
op.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE memories (
|
||||||
|
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||||
|
user_id TEXT NOT NULL,
|
||||||
|
tool_id UUID REFERENCES user_tools(id) ON DELETE CASCADE,
|
||||||
|
path TEXT NOT NULL,
|
||||||
|
content TEXT NOT NULL,
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||||
|
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
op.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE todos (
|
||||||
|
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||||
|
user_id TEXT NOT NULL,
|
||||||
|
tool_id UUID REFERENCES user_tools(id) ON DELETE CASCADE,
|
||||||
|
todo_id INTEGER,
|
||||||
|
title TEXT NOT NULL,
|
||||||
|
completed BOOLEAN NOT NULL DEFAULT false,
|
||||||
|
metadata JSONB NOT NULL DEFAULT '{}'::jsonb,
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||||
|
updated_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||||
|
legacy_mongo_id TEXT
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
op.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE notes (
|
||||||
|
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||||
|
user_id TEXT NOT NULL,
|
||||||
|
tool_id UUID REFERENCES user_tools(id) ON DELETE CASCADE,
|
||||||
|
title TEXT NOT NULL,
|
||||||
|
content TEXT NOT NULL,
|
||||||
|
metadata JSONB NOT NULL DEFAULT '{}'::jsonb,
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||||
|
updated_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||||
|
legacy_mongo_id TEXT
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
op.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE connector_sessions (
|
||||||
|
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||||
|
user_id TEXT NOT NULL,
|
||||||
|
provider TEXT NOT NULL,
|
||||||
|
server_url TEXT,
|
||||||
|
session_token TEXT UNIQUE,
|
||||||
|
user_email TEXT,
|
||||||
|
status TEXT,
|
||||||
|
token_info JSONB,
|
||||||
|
session_data JSONB NOT NULL DEFAULT '{}'::jsonb,
|
||||||
|
expires_at TIMESTAMPTZ,
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||||
|
legacy_mongo_id TEXT
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
op.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE conversations (
|
||||||
|
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||||
|
user_id TEXT NOT NULL,
|
||||||
|
agent_id UUID REFERENCES agents(id) ON DELETE SET NULL,
|
||||||
|
name TEXT,
|
||||||
|
api_key TEXT,
|
||||||
|
is_shared_usage BOOLEAN NOT NULL DEFAULT false,
|
||||||
|
shared_token TEXT,
|
||||||
|
date TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||||
|
updated_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||||
|
shared_with TEXT[] NOT NULL DEFAULT '{}'::text[],
|
||||||
|
compression_metadata JSONB,
|
||||||
|
legacy_mongo_id TEXT,
|
||||||
|
CONSTRAINT conversations_api_key_nonempty_chk
|
||||||
|
CHECK (api_key IS NULL OR api_key <> '')
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
op.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE conversation_messages (
|
||||||
|
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||||
|
conversation_id UUID NOT NULL REFERENCES conversations(id) ON DELETE CASCADE,
|
||||||
|
position INTEGER NOT NULL,
|
||||||
|
prompt TEXT,
|
||||||
|
response TEXT,
|
||||||
|
thought TEXT,
|
||||||
|
sources JSONB NOT NULL DEFAULT '[]'::jsonb,
|
||||||
|
tool_calls JSONB NOT NULL DEFAULT '[]'::jsonb,
|
||||||
|
attachments UUID[] NOT NULL DEFAULT '{}'::uuid[],
|
||||||
|
model_id TEXT,
|
||||||
|
message_metadata JSONB NOT NULL DEFAULT '{}'::jsonb,
|
||||||
|
feedback JSONB,
|
||||||
|
timestamp TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||||
|
user_id TEXT NOT NULL,
|
||||||
|
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
op.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE shared_conversations (
|
||||||
|
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||||
|
conversation_id UUID NOT NULL REFERENCES conversations(id) ON DELETE CASCADE,
|
||||||
|
user_id TEXT NOT NULL,
|
||||||
|
is_promptable BOOLEAN NOT NULL DEFAULT false,
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||||
|
uuid UUID NOT NULL,
|
||||||
|
first_n_queries INTEGER NOT NULL DEFAULT 0,
|
||||||
|
api_key TEXT,
|
||||||
|
prompt_id UUID REFERENCES prompts(id) ON DELETE SET NULL,
|
||||||
|
chunks INTEGER,
|
||||||
|
CONSTRAINT shared_conversations_api_key_nonempty_chk
|
||||||
|
CHECK (api_key IS NULL OR api_key <> '')
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
op.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE pending_tool_state (
|
||||||
|
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||||
|
conversation_id UUID NOT NULL REFERENCES conversations(id) ON DELETE CASCADE,
|
||||||
|
user_id TEXT NOT NULL,
|
||||||
|
messages JSONB NOT NULL,
|
||||||
|
pending_tool_calls JSONB NOT NULL,
|
||||||
|
tools_dict JSONB NOT NULL,
|
||||||
|
tool_schemas JSONB NOT NULL,
|
||||||
|
agent_config JSONB NOT NULL,
|
||||||
|
client_tools JSONB,
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||||
|
expires_at TIMESTAMPTZ NOT NULL
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
op.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE workflows (
|
||||||
|
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||||
|
user_id TEXT NOT NULL,
|
||||||
|
name TEXT NOT NULL,
|
||||||
|
description TEXT,
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||||
|
updated_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||||
|
current_graph_version INTEGER NOT NULL DEFAULT 1,
|
||||||
|
legacy_mongo_id TEXT
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
# Backfill the agents.workflow_id FK now that workflows exists.
|
||||||
|
# The column was created without a FK (forward reference to a table
|
||||||
|
# that hadn't been declared yet); add the constraint here so workflow
|
||||||
|
# deletion still cascades through to agent unset.
|
||||||
|
op.execute(
|
||||||
|
"ALTER TABLE agents ADD CONSTRAINT agents_workflow_fk "
|
||||||
|
"FOREIGN KEY (workflow_id) REFERENCES workflows(id) ON DELETE SET NULL;"
|
||||||
|
)
|
||||||
|
|
||||||
|
op.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE workflow_nodes (
|
||||||
|
id UUID DEFAULT gen_random_uuid() NOT NULL,
|
||||||
|
workflow_id UUID NOT NULL REFERENCES workflows(id) ON DELETE CASCADE,
|
||||||
|
graph_version INTEGER NOT NULL,
|
||||||
|
node_type TEXT NOT NULL,
|
||||||
|
config JSONB NOT NULL DEFAULT '{}'::jsonb,
|
||||||
|
node_id TEXT NOT NULL,
|
||||||
|
title TEXT,
|
||||||
|
description TEXT,
|
||||||
|
position JSONB NOT NULL DEFAULT '{"x": 0, "y": 0}'::jsonb,
|
||||||
|
legacy_mongo_id TEXT,
|
||||||
|
PRIMARY KEY (id),
|
||||||
|
CONSTRAINT workflow_nodes_id_wf_ver_key
|
||||||
|
UNIQUE (id, workflow_id, graph_version)
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
op.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE workflow_edges (
|
||||||
|
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||||
|
workflow_id UUID NOT NULL REFERENCES workflows(id) ON DELETE CASCADE,
|
||||||
|
graph_version INTEGER NOT NULL,
|
||||||
|
from_node_id UUID NOT NULL,
|
||||||
|
to_node_id UUID NOT NULL,
|
||||||
|
config JSONB NOT NULL DEFAULT '{}'::jsonb,
|
||||||
|
edge_id TEXT NOT NULL,
|
||||||
|
source_handle TEXT,
|
||||||
|
target_handle TEXT,
|
||||||
|
CONSTRAINT workflow_edges_from_node_fk
|
||||||
|
FOREIGN KEY (from_node_id, workflow_id, graph_version)
|
||||||
|
REFERENCES workflow_nodes(id, workflow_id, graph_version) ON DELETE CASCADE,
|
||||||
|
CONSTRAINT workflow_edges_to_node_fk
|
||||||
|
FOREIGN KEY (to_node_id, workflow_id, graph_version)
|
||||||
|
REFERENCES workflow_nodes(id, workflow_id, graph_version) ON DELETE CASCADE
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
op.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE workflow_runs (
|
||||||
|
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||||
|
workflow_id UUID NOT NULL REFERENCES workflows(id) ON DELETE CASCADE,
|
||||||
|
user_id TEXT NOT NULL,
|
||||||
|
status TEXT NOT NULL,
|
||||||
|
started_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||||
|
ended_at TIMESTAMPTZ,
|
||||||
|
result JSONB,
|
||||||
|
inputs JSONB,
|
||||||
|
steps JSONB NOT NULL DEFAULT '[]'::jsonb,
|
||||||
|
legacy_mongo_id TEXT,
|
||||||
|
CONSTRAINT workflow_runs_status_chk
|
||||||
|
CHECK (status IN ('pending', 'running', 'completed', 'failed'))
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Indexes
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
op.execute("CREATE INDEX agent_folders_user_idx ON agent_folders (user_id);")
|
||||||
|
|
||||||
|
op.execute("CREATE INDEX agents_user_idx ON agents (user_id);")
|
||||||
|
op.execute("CREATE INDEX agents_shared_idx ON agents (shared) WHERE shared = true;")
|
||||||
|
op.execute("CREATE INDEX agents_status_idx ON agents (status);")
|
||||||
|
op.execute("CREATE INDEX agents_source_id_idx ON agents (source_id);")
|
||||||
|
op.execute("CREATE INDEX agents_prompt_id_idx ON agents (prompt_id);")
|
||||||
|
op.execute("CREATE INDEX agents_folder_id_idx ON agents (folder_id);")
|
||||||
|
op.execute(
|
||||||
|
"CREATE UNIQUE INDEX agents_legacy_mongo_id_uidx "
|
||||||
|
"ON agents (legacy_mongo_id) WHERE legacy_mongo_id IS NOT NULL;"
|
||||||
|
)
|
||||||
|
|
||||||
|
op.execute("CREATE INDEX attachments_user_idx ON attachments (user_id);")
|
||||||
|
op.execute(
|
||||||
|
"CREATE UNIQUE INDEX attachments_legacy_mongo_id_uidx "
|
||||||
|
"ON attachments (legacy_mongo_id) WHERE legacy_mongo_id IS NOT NULL;"
|
||||||
|
)
|
||||||
|
|
||||||
|
op.execute(
|
||||||
|
# MCP and OAuth connectors share the ``provider`` slot, so the
|
||||||
|
# dedup key is ``(user_id, server_url, provider)``: MCP rows
|
||||||
|
# differentiate by server_url (one per MCP server), OAuth rows
|
||||||
|
# have server_url = NULL and differentiate by provider alone.
|
||||||
|
# COALESCE lets NULL server_url participate in the constraint.
|
||||||
|
"CREATE UNIQUE INDEX connector_sessions_user_endpoint_uidx "
|
||||||
|
"ON connector_sessions (user_id, COALESCE(server_url, ''), provider);"
|
||||||
|
)
|
||||||
|
op.execute(
|
||||||
|
"CREATE INDEX connector_sessions_expiry_idx "
|
||||||
|
"ON connector_sessions (expires_at) WHERE expires_at IS NOT NULL;"
|
||||||
|
)
|
||||||
|
op.execute(
|
||||||
|
"CREATE INDEX connector_sessions_server_url_idx "
|
||||||
|
"ON connector_sessions (server_url) WHERE server_url IS NOT NULL;"
|
||||||
|
)
|
||||||
|
op.execute(
|
||||||
|
"CREATE UNIQUE INDEX connector_sessions_legacy_mongo_id_uidx "
|
||||||
|
"ON connector_sessions (legacy_mongo_id) WHERE legacy_mongo_id IS NOT NULL;"
|
||||||
|
)
|
||||||
|
|
||||||
|
op.execute(
|
||||||
|
"CREATE UNIQUE INDEX conversation_messages_conv_pos_uidx "
|
||||||
|
"ON conversation_messages (conversation_id, position);"
|
||||||
|
)
|
||||||
|
op.execute(
|
||||||
|
"CREATE INDEX conversation_messages_user_ts_idx "
|
||||||
|
"ON conversation_messages (user_id, timestamp DESC);"
|
||||||
|
)
|
||||||
|
|
||||||
|
op.execute("CREATE INDEX conversations_user_date_idx ON conversations (user_id, date DESC);")
|
||||||
|
op.execute("CREATE INDEX conversations_agent_idx ON conversations (agent_id);")
|
||||||
|
op.execute(
|
||||||
|
"CREATE UNIQUE INDEX conversations_shared_token_uidx "
|
||||||
|
"ON conversations (shared_token) WHERE shared_token IS NOT NULL;"
|
||||||
|
)
|
||||||
|
op.execute(
|
||||||
|
"CREATE INDEX conversations_api_key_date_idx "
|
||||||
|
"ON conversations (api_key, date DESC) WHERE api_key IS NOT NULL;"
|
||||||
|
)
|
||||||
|
op.execute(
|
||||||
|
"CREATE UNIQUE INDEX conversations_legacy_mongo_id_uidx "
|
||||||
|
"ON conversations (legacy_mongo_id) WHERE legacy_mongo_id IS NOT NULL;"
|
||||||
|
)
|
||||||
|
|
||||||
|
op.execute(
|
||||||
|
"CREATE UNIQUE INDEX memories_user_tool_path_uidx "
|
||||||
|
"ON memories (user_id, tool_id, path);"
|
||||||
|
)
|
||||||
|
op.execute(
|
||||||
|
"CREATE UNIQUE INDEX memories_user_path_null_tool_uidx "
|
||||||
|
"ON memories (user_id, path) WHERE tool_id IS NULL;"
|
||||||
|
)
|
||||||
|
op.execute(
|
||||||
|
"CREATE INDEX memories_path_prefix_idx "
|
||||||
|
"ON memories (user_id, tool_id, path text_pattern_ops);"
|
||||||
|
)
|
||||||
|
op.execute("CREATE INDEX memories_tool_id_idx ON memories (tool_id);")
|
||||||
|
|
||||||
|
op.execute("CREATE UNIQUE INDEX notes_user_tool_uidx ON notes (user_id, tool_id);")
|
||||||
|
op.execute("CREATE INDEX notes_tool_id_idx ON notes (tool_id);")
|
||||||
|
op.execute(
|
||||||
|
"CREATE UNIQUE INDEX notes_legacy_mongo_id_uidx "
|
||||||
|
"ON notes (legacy_mongo_id) WHERE legacy_mongo_id IS NOT NULL;"
|
||||||
|
)
|
||||||
|
|
||||||
|
op.execute(
|
||||||
|
"CREATE UNIQUE INDEX pending_tool_state_conv_user_uidx "
|
||||||
|
"ON pending_tool_state (conversation_id, user_id);"
|
||||||
|
)
|
||||||
|
op.execute(
|
||||||
|
"CREATE INDEX pending_tool_state_expires_idx ON pending_tool_state (expires_at);"
|
||||||
|
)
|
||||||
|
|
||||||
|
op.execute("CREATE INDEX prompts_user_id_idx ON prompts (user_id);")
|
||||||
|
op.execute(
|
||||||
|
"CREATE UNIQUE INDEX prompts_legacy_mongo_id_uidx "
|
||||||
|
"ON prompts (legacy_mongo_id) WHERE legacy_mongo_id IS NOT NULL;"
|
||||||
|
)
|
||||||
|
|
||||||
|
op.execute("CREATE INDEX shared_conversations_user_idx ON shared_conversations (user_id);")
|
||||||
|
op.execute("CREATE INDEX shared_conversations_conv_idx ON shared_conversations (conversation_id);")
|
||||||
|
op.execute(
|
||||||
|
"CREATE INDEX shared_conversations_prompt_id_idx ON shared_conversations (prompt_id);"
|
||||||
|
)
|
||||||
|
op.execute(
|
||||||
|
"CREATE UNIQUE INDEX shared_conversations_uuid_uidx ON shared_conversations (uuid);"
|
||||||
|
)
|
||||||
|
op.execute(
|
||||||
|
"CREATE UNIQUE INDEX shared_conversations_dedup_uidx "
|
||||||
|
"ON shared_conversations (conversation_id, user_id, is_promptable, first_n_queries, COALESCE(api_key, ''));"
|
||||||
|
)
|
||||||
|
|
||||||
|
op.execute("CREATE INDEX sources_user_idx ON sources (user_id);")
|
||||||
|
op.execute(
|
||||||
|
"CREATE UNIQUE INDEX sources_legacy_mongo_id_uidx "
|
||||||
|
"ON sources (legacy_mongo_id) WHERE legacy_mongo_id IS NOT NULL;"
|
||||||
|
)
|
||||||
|
op.execute(
|
||||||
|
"CREATE UNIQUE INDEX user_tools_legacy_mongo_id_uidx "
|
||||||
|
"ON user_tools (legacy_mongo_id) WHERE legacy_mongo_id IS NOT NULL;"
|
||||||
|
)
|
||||||
|
op.execute(
|
||||||
|
"CREATE UNIQUE INDEX agent_folders_legacy_mongo_id_uidx "
|
||||||
|
"ON agent_folders (legacy_mongo_id) WHERE legacy_mongo_id IS NOT NULL;"
|
||||||
|
)
|
||||||
|
op.execute("CREATE INDEX agent_folders_parent_idx ON agent_folders (parent_id);")
|
||||||
|
op.execute("CREATE INDEX agents_workflow_idx ON agents (workflow_id);")
|
||||||
|
|
||||||
|
op.execute('CREATE INDEX stack_logs_timestamp_idx ON stack_logs ("timestamp" DESC);')
|
||||||
|
op.execute('CREATE INDEX stack_logs_user_ts_idx ON stack_logs (user_id, "timestamp" DESC);')
|
||||||
|
op.execute('CREATE INDEX stack_logs_level_ts_idx ON stack_logs (level, "timestamp" DESC);')
|
||||||
|
op.execute("CREATE INDEX stack_logs_activity_idx ON stack_logs (activity_id);")
|
||||||
|
op.execute(
|
||||||
|
"CREATE UNIQUE INDEX stack_logs_mongo_id_uidx "
|
||||||
|
"ON stack_logs (mongo_id) WHERE mongo_id IS NOT NULL;"
|
||||||
|
)
|
||||||
|
|
||||||
|
op.execute("CREATE INDEX todos_user_tool_idx ON todos (user_id, tool_id);")
|
||||||
|
op.execute("CREATE INDEX todos_tool_id_idx ON todos (tool_id);")
|
||||||
|
op.execute(
|
||||||
|
"CREATE UNIQUE INDEX todos_legacy_mongo_id_uidx "
|
||||||
|
"ON todos (legacy_mongo_id) WHERE legacy_mongo_id IS NOT NULL;"
|
||||||
|
)
|
||||||
|
op.execute(
|
||||||
|
"CREATE UNIQUE INDEX todos_tool_todo_id_uidx "
|
||||||
|
"ON todos (tool_id, todo_id) WHERE todo_id IS NOT NULL;"
|
||||||
|
)
|
||||||
|
|
||||||
|
op.execute('CREATE INDEX token_usage_user_ts_idx ON token_usage (user_id, "timestamp" DESC);')
|
||||||
|
op.execute('CREATE INDEX token_usage_key_ts_idx ON token_usage (api_key, "timestamp" DESC);')
|
||||||
|
op.execute('CREATE INDEX token_usage_agent_ts_idx ON token_usage (agent_id, "timestamp" DESC);')
|
||||||
|
op.execute(
|
||||||
|
"CREATE UNIQUE INDEX token_usage_mongo_id_uidx "
|
||||||
|
"ON token_usage (mongo_id) WHERE mongo_id IS NOT NULL;"
|
||||||
|
)
|
||||||
|
|
||||||
|
op.execute('CREATE INDEX user_logs_user_ts_idx ON user_logs (user_id, "timestamp" DESC);')
|
||||||
|
op.execute(
|
||||||
|
"CREATE UNIQUE INDEX user_logs_mongo_id_uidx "
|
||||||
|
"ON user_logs (mongo_id) WHERE mongo_id IS NOT NULL;"
|
||||||
|
)
|
||||||
|
|
||||||
|
op.execute("CREATE INDEX user_tools_user_id_idx ON user_tools (user_id);")
|
||||||
|
|
||||||
|
op.execute("CREATE INDEX workflow_edges_from_node_idx ON workflow_edges (from_node_id);")
|
||||||
|
op.execute("CREATE INDEX workflow_edges_to_node_idx ON workflow_edges (to_node_id);")
|
||||||
|
op.execute(
|
||||||
|
"CREATE UNIQUE INDEX workflow_edges_wf_ver_eid_uidx "
|
||||||
|
"ON workflow_edges (workflow_id, graph_version, edge_id);"
|
||||||
|
)
|
||||||
|
|
||||||
|
op.execute(
|
||||||
|
"CREATE UNIQUE INDEX workflow_nodes_wf_ver_nid_uidx "
|
||||||
|
"ON workflow_nodes (workflow_id, graph_version, node_id);"
|
||||||
|
)
|
||||||
|
op.execute(
|
||||||
|
"CREATE UNIQUE INDEX workflow_nodes_legacy_mongo_id_uidx "
|
||||||
|
"ON workflow_nodes (legacy_mongo_id) WHERE legacy_mongo_id IS NOT NULL;"
|
||||||
|
)
|
||||||
|
|
||||||
|
op.execute("CREATE INDEX workflow_runs_workflow_idx ON workflow_runs (workflow_id);")
|
||||||
|
op.execute("CREATE INDEX workflow_runs_user_idx ON workflow_runs (user_id);")
|
||||||
|
op.execute(
|
||||||
|
"CREATE INDEX workflow_runs_status_started_idx "
|
||||||
|
"ON workflow_runs (status, started_at DESC);"
|
||||||
|
)
|
||||||
|
op.execute(
|
||||||
|
"CREATE UNIQUE INDEX workflow_runs_legacy_mongo_id_uidx "
|
||||||
|
"ON workflow_runs (legacy_mongo_id) WHERE legacy_mongo_id IS NOT NULL;"
|
||||||
|
)
|
||||||
|
|
||||||
|
op.execute("CREATE INDEX workflows_user_idx ON workflows (user_id);")
|
||||||
|
op.execute(
|
||||||
|
"CREATE UNIQUE INDEX workflows_legacy_mongo_id_uidx "
|
||||||
|
"ON workflows (legacy_mongo_id) WHERE legacy_mongo_id IS NOT NULL;"
|
||||||
|
)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# user_id foreign keys (deferrable so backfills can stage rows)
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
user_fk_tables = (
|
||||||
|
"agent_folders",
|
||||||
|
"agents",
|
||||||
|
"attachments",
|
||||||
|
"connector_sessions",
|
||||||
|
"conversation_messages",
|
||||||
|
"conversations",
|
||||||
|
"memories",
|
||||||
|
"notes",
|
||||||
|
"pending_tool_state",
|
||||||
|
"prompts",
|
||||||
|
"shared_conversations",
|
||||||
|
"sources",
|
||||||
|
"stack_logs",
|
||||||
|
"todos",
|
||||||
|
"token_usage",
|
||||||
|
"user_logs",
|
||||||
|
"user_tools",
|
||||||
|
"workflow_runs",
|
||||||
|
"workflows",
|
||||||
|
)
|
||||||
|
for table in user_fk_tables:
|
||||||
|
op.execute(
|
||||||
|
f"ALTER TABLE {table} "
|
||||||
|
f"ADD CONSTRAINT {table}_user_id_fk "
|
||||||
|
f"FOREIGN KEY (user_id) REFERENCES users(user_id) "
|
||||||
|
f"ON DELETE RESTRICT DEFERRABLE INITIALLY IMMEDIATE;"
|
||||||
|
)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Triggers
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
updated_at_tables = (
|
||||||
|
"agent_folders",
|
||||||
|
"agents",
|
||||||
|
"conversation_messages",
|
||||||
|
"conversations",
|
||||||
|
"memories",
|
||||||
|
"notes",
|
||||||
|
"prompts",
|
||||||
|
"sources",
|
||||||
|
"todos",
|
||||||
|
"user_tools",
|
||||||
|
"users",
|
||||||
|
"workflows",
|
||||||
|
)
|
||||||
|
for table in updated_at_tables:
|
||||||
|
op.execute(
|
||||||
|
f"CREATE TRIGGER {table}_set_updated_at "
|
||||||
|
f"BEFORE UPDATE ON {table} "
|
||||||
|
f"FOR EACH ROW WHEN (OLD.* IS DISTINCT FROM NEW.*) "
|
||||||
|
f"EXECUTE FUNCTION set_updated_at();"
|
||||||
|
)
|
||||||
|
|
||||||
|
ensure_user_tables = (
|
||||||
|
"agent_folders",
|
||||||
|
"agents",
|
||||||
|
"attachments",
|
||||||
|
"connector_sessions",
|
||||||
|
"conversation_messages",
|
||||||
|
"conversations",
|
||||||
|
"memories",
|
||||||
|
"notes",
|
||||||
|
"pending_tool_state",
|
||||||
|
"prompts",
|
||||||
|
"shared_conversations",
|
||||||
|
"sources",
|
||||||
|
"stack_logs",
|
||||||
|
"todos",
|
||||||
|
"token_usage",
|
||||||
|
"user_logs",
|
||||||
|
"user_tools",
|
||||||
|
"workflow_runs",
|
||||||
|
"workflows",
|
||||||
|
)
|
||||||
|
for table in ensure_user_tables:
|
||||||
|
op.execute(
|
||||||
|
f"CREATE TRIGGER {table}_ensure_user "
|
||||||
|
f"BEFORE INSERT OR UPDATE OF user_id ON {table} "
|
||||||
|
f"FOR EACH ROW EXECUTE FUNCTION ensure_user_exists();"
|
||||||
|
)
|
||||||
|
|
||||||
|
op.execute(
|
||||||
|
"CREATE TRIGGER conversation_messages_fill_user "
|
||||||
|
"BEFORE INSERT ON conversation_messages "
|
||||||
|
"FOR EACH ROW EXECUTE FUNCTION conversation_messages_fill_user_id();"
|
||||||
|
)
|
||||||
|
|
||||||
|
op.execute(
|
||||||
|
"CREATE TRIGGER attachments_cleanup_message_refs "
|
||||||
|
"AFTER DELETE ON attachments "
|
||||||
|
"FOR EACH ROW EXECUTE FUNCTION cleanup_message_attachment_refs();"
|
||||||
|
)
|
||||||
|
op.execute(
|
||||||
|
"CREATE TRIGGER agents_cleanup_user_prefs "
|
||||||
|
"AFTER DELETE ON agents "
|
||||||
|
"FOR EACH ROW EXECUTE FUNCTION cleanup_user_agent_prefs();"
|
||||||
|
)
|
||||||
|
op.execute(
|
||||||
|
"CREATE TRIGGER sources_cleanup_agent_extra_refs "
|
||||||
|
"AFTER DELETE ON sources "
|
||||||
|
"FOR EACH ROW EXECUTE FUNCTION cleanup_agent_extra_source_refs();"
|
||||||
|
)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Seed sentinel __system__ user (system/template sources attribute here)
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
op.execute(
|
||||||
|
"INSERT INTO users (user_id) VALUES ('__system__') "
|
||||||
|
"ON CONFLICT (user_id) DO NOTHING;"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
# Nuclear downgrade: drop everything this migration created. The
|
||||||
|
# ordering drops FK-bearing children before parents; CASCADE would
|
||||||
|
# also work but explicit ordering is easier to reason about in code
|
||||||
|
# review.
|
||||||
|
tables_in_drop_order = (
|
||||||
|
"workflow_edges",
|
||||||
|
"workflow_runs",
|
||||||
|
"workflow_nodes",
|
||||||
|
"workflows",
|
||||||
|
"pending_tool_state",
|
||||||
|
"shared_conversations",
|
||||||
|
"conversation_messages",
|
||||||
|
"conversations",
|
||||||
|
"connector_sessions",
|
||||||
|
"notes",
|
||||||
|
"todos",
|
||||||
|
"memories",
|
||||||
|
"attachments",
|
||||||
|
"agents",
|
||||||
|
"sources",
|
||||||
|
"agent_folders",
|
||||||
|
"stack_logs",
|
||||||
|
"user_logs",
|
||||||
|
"token_usage",
|
||||||
|
"user_tools",
|
||||||
|
"prompts",
|
||||||
|
"users",
|
||||||
|
)
|
||||||
|
for table in tables_in_drop_order:
|
||||||
|
op.execute(f"DROP TABLE IF EXISTS {table} CASCADE;")
|
||||||
|
|
||||||
|
for fn in (
|
||||||
|
"conversation_messages_fill_user_id",
|
||||||
|
"cleanup_user_agent_prefs",
|
||||||
|
"cleanup_agent_extra_source_refs",
|
||||||
|
"cleanup_message_attachment_refs",
|
||||||
|
"ensure_user_exists",
|
||||||
|
"set_updated_at",
|
||||||
|
):
|
||||||
|
op.execute(f"DROP FUNCTION IF EXISTS {fn}();")
|
||||||
@@ -3,6 +3,7 @@ from flask import Blueprint
|
|||||||
from application.api import api
|
from application.api import api
|
||||||
from application.api.answer.routes.answer import AnswerResource
|
from application.api.answer.routes.answer import AnswerResource
|
||||||
from application.api.answer.routes.base import answer_ns
|
from application.api.answer.routes.base import answer_ns
|
||||||
|
from application.api.answer.routes.search import SearchResource
|
||||||
from application.api.answer.routes.stream import StreamResource
|
from application.api.answer.routes.stream import StreamResource
|
||||||
|
|
||||||
|
|
||||||
@@ -14,6 +15,7 @@ api.add_namespace(answer_ns)
|
|||||||
def init_answer_routes():
|
def init_answer_routes():
|
||||||
api.add_resource(StreamResource, "/stream")
|
api.add_resource(StreamResource, "/stream")
|
||||||
api.add_resource(AnswerResource, "/api/answer")
|
api.add_resource(AnswerResource, "/api/answer")
|
||||||
|
api.add_resource(SearchResource, "/api/search")
|
||||||
|
|
||||||
|
|
||||||
init_answer_routes()
|
init_answer_routes()
|
||||||
|
|||||||
@@ -40,9 +40,9 @@ class AnswerResource(Resource, BaseAnswerResource):
|
|||||||
"chunks": fields.Integer(
|
"chunks": fields.Integer(
|
||||||
required=False, default=2, description="Number of chunks"
|
required=False, default=2, description="Number of chunks"
|
||||||
),
|
),
|
||||||
"token_limit": fields.Integer(required=False, description="Token limit"),
|
|
||||||
"retriever": fields.String(required=False, description="Retriever type"),
|
"retriever": fields.String(required=False, description="Retriever type"),
|
||||||
"api_key": fields.String(required=False, description="API key"),
|
"api_key": fields.String(required=False, description="API key"),
|
||||||
|
"agent_id": fields.String(required=False, description="Agent ID"),
|
||||||
"active_docs": fields.String(
|
"active_docs": fields.String(
|
||||||
required=False, description="Active documents"
|
required=False, description="Active documents"
|
||||||
),
|
),
|
||||||
@@ -74,69 +74,80 @@ class AnswerResource(Resource, BaseAnswerResource):
|
|||||||
decoded_token = getattr(request, "decoded_token", None)
|
decoded_token = getattr(request, "decoded_token", None)
|
||||||
processor = StreamProcessor(data, decoded_token)
|
processor = StreamProcessor(data, decoded_token)
|
||||||
try:
|
try:
|
||||||
processor.initialize()
|
# ---- Continuation mode ----
|
||||||
if not processor.decoded_token:
|
if data.get("tool_actions"):
|
||||||
return make_response({"error": "Unauthorized"}, 401)
|
(
|
||||||
|
agent,
|
||||||
|
messages,
|
||||||
|
tools_dict,
|
||||||
|
pending_tool_calls,
|
||||||
|
tool_actions,
|
||||||
|
) = processor.resume_from_tool_actions(
|
||||||
|
data["tool_actions"], data["conversation_id"]
|
||||||
|
)
|
||||||
|
if not processor.decoded_token:
|
||||||
|
return make_response({"error": "Unauthorized"}, 401)
|
||||||
|
if error := self.check_usage(processor.agent_config):
|
||||||
|
return error
|
||||||
|
stream = self.complete_stream(
|
||||||
|
question="",
|
||||||
|
agent=agent,
|
||||||
|
conversation_id=processor.conversation_id,
|
||||||
|
user_api_key=processor.agent_config.get("user_api_key"),
|
||||||
|
decoded_token=processor.decoded_token,
|
||||||
|
agent_id=processor.agent_id,
|
||||||
|
model_id=processor.model_id,
|
||||||
|
_continuation={
|
||||||
|
"messages": messages,
|
||||||
|
"tools_dict": tools_dict,
|
||||||
|
"pending_tool_calls": pending_tool_calls,
|
||||||
|
"tool_actions": tool_actions,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# ---- Normal mode ----
|
||||||
|
agent = processor.build_agent(data.get("question", ""))
|
||||||
|
if not processor.decoded_token:
|
||||||
|
return make_response({"error": "Unauthorized"}, 401)
|
||||||
|
|
||||||
docs_together, docs_list = processor.pre_fetch_docs(
|
if error := self.check_usage(processor.agent_config):
|
||||||
data.get("question", "")
|
return error
|
||||||
)
|
|
||||||
tools_data = processor.pre_fetch_tools()
|
|
||||||
|
|
||||||
agent = processor.create_agent(
|
stream = self.complete_stream(
|
||||||
docs_together=docs_together,
|
question=data["question"],
|
||||||
docs=docs_list,
|
agent=agent,
|
||||||
tools_data=tools_data,
|
conversation_id=processor.conversation_id,
|
||||||
)
|
user_api_key=processor.agent_config.get("user_api_key"),
|
||||||
|
decoded_token=processor.decoded_token,
|
||||||
|
isNoneDoc=data.get("isNoneDoc"),
|
||||||
|
index=None,
|
||||||
|
should_save_conversation=data.get("save_conversation", True),
|
||||||
|
agent_id=processor.agent_id,
|
||||||
|
is_shared_usage=processor.is_shared_usage,
|
||||||
|
shared_token=processor.shared_token,
|
||||||
|
model_id=processor.model_id,
|
||||||
|
)
|
||||||
|
|
||||||
if error := self.check_usage(processor.agent_config):
|
|
||||||
return error
|
|
||||||
|
|
||||||
stream = self.complete_stream(
|
|
||||||
question=data["question"],
|
|
||||||
agent=agent,
|
|
||||||
conversation_id=processor.conversation_id,
|
|
||||||
user_api_key=processor.agent_config.get("user_api_key"),
|
|
||||||
decoded_token=processor.decoded_token,
|
|
||||||
isNoneDoc=data.get("isNoneDoc"),
|
|
||||||
index=None,
|
|
||||||
should_save_conversation=data.get("save_conversation", True),
|
|
||||||
model_id=processor.model_id,
|
|
||||||
)
|
|
||||||
stream_result = self.process_response_stream(stream)
|
stream_result = self.process_response_stream(stream)
|
||||||
|
|
||||||
if len(stream_result) == 7:
|
if stream_result["error"]:
|
||||||
(
|
return make_response({"error": stream_result["error"]}, 400)
|
||||||
conversation_id,
|
|
||||||
response,
|
|
||||||
sources,
|
|
||||||
tool_calls,
|
|
||||||
thought,
|
|
||||||
error,
|
|
||||||
structured_info,
|
|
||||||
) = stream_result
|
|
||||||
else:
|
|
||||||
conversation_id, response, sources, tool_calls, thought, error = (
|
|
||||||
stream_result
|
|
||||||
)
|
|
||||||
structured_info = None
|
|
||||||
|
|
||||||
if error:
|
|
||||||
return make_response({"error": error}, 400)
|
|
||||||
result = {
|
result = {
|
||||||
"conversation_id": conversation_id,
|
"conversation_id": stream_result["conversation_id"],
|
||||||
"answer": response,
|
"answer": stream_result["answer"],
|
||||||
"sources": sources,
|
"sources": stream_result["sources"],
|
||||||
"tool_calls": tool_calls,
|
"tool_calls": stream_result["tool_calls"],
|
||||||
"thought": thought,
|
"thought": stream_result["thought"],
|
||||||
}
|
}
|
||||||
|
|
||||||
if structured_info:
|
extra_info = stream_result.get("extra")
|
||||||
result.update(structured_info)
|
if extra_info:
|
||||||
|
result.update(extra_info)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
logger.error(
|
||||||
f"/api/answer - error: {str(e)} - traceback: {traceback.format_exc()}",
|
f"/api/answer - error: {str(e)} - traceback: {traceback.format_exc()}",
|
||||||
extra={"error": str(e), "traceback": traceback.format_exc()},
|
extra={"error": str(e), "traceback": traceback.format_exc()},
|
||||||
)
|
)
|
||||||
return make_response({"error": str(e)}, 500)
|
return make_response({"error": "An error occurred processing your request"}, 500)
|
||||||
return make_response(result, 200)
|
return make_response(result, 200)
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ from typing import Any, Dict, Generator, List, Optional
|
|||||||
from flask import jsonify, make_response, Response
|
from flask import jsonify, make_response, Response
|
||||||
from flask_restx import Namespace
|
from flask_restx import Namespace
|
||||||
|
|
||||||
|
from application.api.answer.services.continuation_service import ContinuationService
|
||||||
from application.api.answer.services.conversation_service import ConversationService
|
from application.api.answer.services.conversation_service import ConversationService
|
||||||
from application.core.model_utils import (
|
from application.core.model_utils import (
|
||||||
get_api_key_for_provider,
|
get_api_key_for_provider,
|
||||||
@@ -13,9 +14,13 @@ from application.core.model_utils import (
|
|||||||
get_provider_from_model_id,
|
get_provider_from_model_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
from application.core.mongo_db import MongoDB
|
|
||||||
from application.core.settings import settings
|
from application.core.settings import settings
|
||||||
|
from application.error import sanitize_api_error
|
||||||
from application.llm.llm_creator import LLMCreator
|
from application.llm.llm_creator import LLMCreator
|
||||||
|
from application.storage.db.repositories.agents import AgentsRepository
|
||||||
|
from application.storage.db.repositories.token_usage import TokenUsageRepository
|
||||||
|
from application.storage.db.repositories.user_logs import UserLogsRepository
|
||||||
|
from application.storage.db.session import db_readonly, db_session
|
||||||
from application.utils import check_required_fields
|
from application.utils import check_required_fields
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -28,17 +33,22 @@ class BaseAnswerResource:
|
|||||||
"""Shared base class for answer endpoints"""
|
"""Shared base class for answer endpoints"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
mongo = MongoDB.get_client()
|
|
||||||
db = mongo[settings.MONGO_DB_NAME]
|
|
||||||
self.db = db
|
|
||||||
self.user_logs_collection = db["user_logs"]
|
|
||||||
self.default_model_id = get_default_model_id()
|
self.default_model_id = get_default_model_id()
|
||||||
self.conversation_service = ConversationService()
|
self.conversation_service = ConversationService()
|
||||||
|
|
||||||
def validate_request(
|
def validate_request(
|
||||||
self, data: Dict[str, Any], require_conversation_id: bool = False
|
self, data: Dict[str, Any], require_conversation_id: bool = False
|
||||||
) -> Optional[Response]:
|
) -> Optional[Response]:
|
||||||
"""Common request validation"""
|
"""Common request validation.
|
||||||
|
|
||||||
|
Continuation requests (``tool_actions`` present) require
|
||||||
|
``conversation_id`` but not ``question``.
|
||||||
|
"""
|
||||||
|
if data.get("tool_actions"):
|
||||||
|
# Continuation mode — question is not required
|
||||||
|
if missing := check_required_fields(data, ["conversation_id"]):
|
||||||
|
return missing
|
||||||
|
return None
|
||||||
required_fields = ["question"]
|
required_fields = ["question"]
|
||||||
if require_conversation_id:
|
if require_conversation_id:
|
||||||
required_fields.append("conversation_id")
|
required_fields.append("conversation_id")
|
||||||
@@ -46,6 +56,27 @@ class BaseAnswerResource:
|
|||||||
return missing_fields
|
return missing_fields
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _prepare_tool_calls_for_logging(
|
||||||
|
tool_calls: Optional[List[Dict[str, Any]]], max_chars: int = 10000
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
if not tool_calls:
|
||||||
|
return []
|
||||||
|
|
||||||
|
prepared = []
|
||||||
|
for tool_call in tool_calls:
|
||||||
|
if not isinstance(tool_call, dict):
|
||||||
|
prepared.append({"result": str(tool_call)[:max_chars]})
|
||||||
|
continue
|
||||||
|
|
||||||
|
item = dict(tool_call)
|
||||||
|
for key in ("result", "result_full"):
|
||||||
|
value = item.get(key)
|
||||||
|
if isinstance(value, str) and len(value) > max_chars:
|
||||||
|
item[key] = value[:max_chars]
|
||||||
|
prepared.append(item)
|
||||||
|
return prepared
|
||||||
|
|
||||||
def check_usage(self, agent_config: Dict) -> Optional[Response]:
|
def check_usage(self, agent_config: Dict) -> Optional[Response]:
|
||||||
"""Check if there is a usage limit and if it is exceeded
|
"""Check if there is a usage limit and if it is exceeded
|
||||||
|
|
||||||
@@ -59,8 +90,8 @@ class BaseAnswerResource:
|
|||||||
api_key = agent_config.get("user_api_key")
|
api_key = agent_config.get("user_api_key")
|
||||||
if not api_key:
|
if not api_key:
|
||||||
return None
|
return None
|
||||||
agents_collection = self.db["agents"]
|
with db_readonly() as conn:
|
||||||
agent = agents_collection.find_one({"key": api_key})
|
agent = AgentsRepository(conn).find_by_key(api_key)
|
||||||
|
|
||||||
if not agent:
|
if not agent:
|
||||||
return make_response(
|
return make_response(
|
||||||
@@ -81,41 +112,32 @@ class BaseAnswerResource:
|
|||||||
)
|
)
|
||||||
|
|
||||||
token_limit = int(
|
token_limit = int(
|
||||||
agent.get("token_limit", settings.DEFAULT_AGENT_LIMITS["token_limit"])
|
agent.get("token_limit") or settings.DEFAULT_AGENT_LIMITS["token_limit"]
|
||||||
)
|
)
|
||||||
request_limit = int(
|
request_limit = int(
|
||||||
agent.get("request_limit", settings.DEFAULT_AGENT_LIMITS["request_limit"])
|
agent.get("request_limit") or settings.DEFAULT_AGENT_LIMITS["request_limit"]
|
||||||
)
|
)
|
||||||
|
|
||||||
token_usage_collection = self.db["token_usage"]
|
end_date = datetime.datetime.now(datetime.timezone.utc)
|
||||||
|
|
||||||
end_date = datetime.datetime.now()
|
|
||||||
start_date = end_date - datetime.timedelta(hours=24)
|
start_date = end_date - datetime.timedelta(hours=24)
|
||||||
|
|
||||||
match_query = {
|
if limited_token_mode or limited_request_mode:
|
||||||
"timestamp": {"$gte": start_date, "$lte": end_date},
|
with db_readonly() as conn:
|
||||||
"api_key": api_key,
|
token_repo = TokenUsageRepository(conn)
|
||||||
}
|
if limited_token_mode:
|
||||||
|
daily_token_usage = token_repo.sum_tokens_in_range(
|
||||||
if limited_token_mode:
|
start=start_date, end=end_date, api_key=api_key,
|
||||||
token_pipeline = [
|
)
|
||||||
{"$match": match_query},
|
else:
|
||||||
{
|
daily_token_usage = 0
|
||||||
"$group": {
|
if limited_request_mode:
|
||||||
"_id": None,
|
daily_request_usage = token_repo.count_in_range(
|
||||||
"total_tokens": {
|
start=start_date, end=end_date, api_key=api_key,
|
||||||
"$sum": {"$add": ["$prompt_tokens", "$generated_tokens"]}
|
)
|
||||||
},
|
else:
|
||||||
}
|
daily_request_usage = 0
|
||||||
},
|
|
||||||
]
|
|
||||||
token_result = list(token_usage_collection.aggregate(token_pipeline))
|
|
||||||
daily_token_usage = token_result[0]["total_tokens"] if token_result else 0
|
|
||||||
else:
|
else:
|
||||||
daily_token_usage = 0
|
daily_token_usage = 0
|
||||||
if limited_request_mode:
|
|
||||||
daily_request_usage = token_usage_collection.count_documents(match_query)
|
|
||||||
else:
|
|
||||||
daily_request_usage = 0
|
daily_request_usage = 0
|
||||||
if not limited_token_mode and not limited_request_mode:
|
if not limited_token_mode and not limited_request_mode:
|
||||||
return None
|
return None
|
||||||
@@ -155,6 +177,7 @@ class BaseAnswerResource:
|
|||||||
is_shared_usage: bool = False,
|
is_shared_usage: bool = False,
|
||||||
shared_token: Optional[str] = None,
|
shared_token: Optional[str] = None,
|
||||||
model_id: Optional[str] = None,
|
model_id: Optional[str] = None,
|
||||||
|
_continuation: Optional[Dict] = None,
|
||||||
) -> Generator[str, None, None]:
|
) -> Generator[str, None, None]:
|
||||||
"""
|
"""
|
||||||
Generator function that streams the complete conversation response.
|
Generator function that streams the complete conversation response.
|
||||||
@@ -184,9 +207,23 @@ class BaseAnswerResource:
|
|||||||
is_structured = False
|
is_structured = False
|
||||||
schema_info = None
|
schema_info = None
|
||||||
structured_chunks = []
|
structured_chunks = []
|
||||||
|
query_metadata = {}
|
||||||
|
paused = False
|
||||||
|
|
||||||
for line in agent.gen(query=question):
|
if _continuation:
|
||||||
if "answer" in line:
|
gen_iter = agent.gen_continuation(
|
||||||
|
messages=_continuation["messages"],
|
||||||
|
tools_dict=_continuation["tools_dict"],
|
||||||
|
pending_tool_calls=_continuation["pending_tool_calls"],
|
||||||
|
tool_actions=_continuation["tool_actions"],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
gen_iter = agent.gen(query=question)
|
||||||
|
|
||||||
|
for line in gen_iter:
|
||||||
|
if "metadata" in line:
|
||||||
|
query_metadata.update(line["metadata"])
|
||||||
|
elif "answer" in line:
|
||||||
response_full += str(line["answer"])
|
response_full += str(line["answer"])
|
||||||
if line.get("structured"):
|
if line.get("structured"):
|
||||||
is_structured = True
|
is_structured = True
|
||||||
@@ -219,8 +256,21 @@ class BaseAnswerResource:
|
|||||||
data = json.dumps({"type": "thought", "thought": line["thought"]})
|
data = json.dumps({"type": "thought", "thought": line["thought"]})
|
||||||
yield f"data: {data}\n\n"
|
yield f"data: {data}\n\n"
|
||||||
elif "type" in line:
|
elif "type" in line:
|
||||||
data = json.dumps(line)
|
if line.get("type") == "tool_calls_pending":
|
||||||
yield f"data: {data}\n\n"
|
# Save continuation state and end the stream
|
||||||
|
paused = True
|
||||||
|
data = json.dumps(line)
|
||||||
|
yield f"data: {data}\n\n"
|
||||||
|
elif line.get("type") == "error":
|
||||||
|
sanitized_error = {
|
||||||
|
"type": "error",
|
||||||
|
"error": sanitize_api_error(line.get("error", "An error occurred"))
|
||||||
|
}
|
||||||
|
data = json.dumps(sanitized_error)
|
||||||
|
yield f"data: {data}\n\n"
|
||||||
|
else:
|
||||||
|
data = json.dumps(line)
|
||||||
|
yield f"data: {data}\n\n"
|
||||||
if is_structured and structured_chunks:
|
if is_structured and structured_chunks:
|
||||||
structured_data = {
|
structured_data = {
|
||||||
"type": "structured_answer",
|
"type": "structured_answer",
|
||||||
@@ -230,6 +280,93 @@ class BaseAnswerResource:
|
|||||||
}
|
}
|
||||||
data = json.dumps(structured_data)
|
data = json.dumps(structured_data)
|
||||||
yield f"data: {data}\n\n"
|
yield f"data: {data}\n\n"
|
||||||
|
|
||||||
|
# ---- Paused: save continuation state and end stream early ----
|
||||||
|
if paused:
|
||||||
|
continuation = getattr(agent, "_pending_continuation", None)
|
||||||
|
if continuation:
|
||||||
|
# Ensure we have a conversation_id — create a partial
|
||||||
|
# conversation if this is the first turn.
|
||||||
|
if not conversation_id and should_save_conversation:
|
||||||
|
try:
|
||||||
|
provider = (
|
||||||
|
get_provider_from_model_id(model_id)
|
||||||
|
if model_id
|
||||||
|
else settings.LLM_PROVIDER
|
||||||
|
)
|
||||||
|
sys_api_key = get_api_key_for_provider(
|
||||||
|
provider or settings.LLM_PROVIDER
|
||||||
|
)
|
||||||
|
llm = LLMCreator.create_llm(
|
||||||
|
provider or settings.LLM_PROVIDER,
|
||||||
|
api_key=sys_api_key,
|
||||||
|
user_api_key=user_api_key,
|
||||||
|
decoded_token=decoded_token,
|
||||||
|
model_id=model_id,
|
||||||
|
agent_id=agent_id,
|
||||||
|
)
|
||||||
|
conversation_id = (
|
||||||
|
self.conversation_service.save_conversation(
|
||||||
|
None,
|
||||||
|
question,
|
||||||
|
response_full,
|
||||||
|
thought,
|
||||||
|
source_log_docs,
|
||||||
|
tool_calls,
|
||||||
|
llm,
|
||||||
|
model_id or self.default_model_id,
|
||||||
|
decoded_token,
|
||||||
|
api_key=user_api_key,
|
||||||
|
agent_id=agent_id,
|
||||||
|
is_shared_usage=is_shared_usage,
|
||||||
|
shared_token=shared_token,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Failed to create conversation for continuation: {e}",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
if conversation_id:
|
||||||
|
try:
|
||||||
|
cont_service = ContinuationService()
|
||||||
|
cont_service.save_state(
|
||||||
|
conversation_id=str(conversation_id),
|
||||||
|
user=decoded_token.get("sub", "local"),
|
||||||
|
messages=continuation["messages"],
|
||||||
|
pending_tool_calls=continuation["pending_tool_calls"],
|
||||||
|
tools_dict=continuation["tools_dict"],
|
||||||
|
tool_schemas=getattr(agent, "tools", []),
|
||||||
|
agent_config={
|
||||||
|
"model_id": model_id or self.default_model_id,
|
||||||
|
"llm_name": getattr(agent, "llm_name", settings.LLM_PROVIDER),
|
||||||
|
"api_key": getattr(agent, "api_key", None),
|
||||||
|
"user_api_key": user_api_key,
|
||||||
|
"agent_id": agent_id,
|
||||||
|
"agent_type": agent.__class__.__name__,
|
||||||
|
"prompt": getattr(agent, "prompt", ""),
|
||||||
|
"json_schema": getattr(agent, "json_schema", None),
|
||||||
|
"retriever_config": getattr(agent, "retriever_config", None),
|
||||||
|
},
|
||||||
|
client_tools=getattr(
|
||||||
|
agent.tool_executor, "client_tools", None
|
||||||
|
),
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Failed to save continuation state: {str(e)}",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
id_data = {"type": "id", "id": str(conversation_id)}
|
||||||
|
data = json.dumps(id_data)
|
||||||
|
yield f"data: {data}\n\n"
|
||||||
|
|
||||||
|
data = json.dumps({"type": "end"})
|
||||||
|
yield f"data: {data}\n\n"
|
||||||
|
return
|
||||||
|
|
||||||
if isNoneDoc:
|
if isNoneDoc:
|
||||||
for doc in source_log_docs:
|
for doc in source_log_docs:
|
||||||
doc["source"] = "None"
|
doc["source"] = "None"
|
||||||
@@ -246,6 +383,7 @@ class BaseAnswerResource:
|
|||||||
user_api_key=user_api_key,
|
user_api_key=user_api_key,
|
||||||
decoded_token=decoded_token,
|
decoded_token=decoded_token,
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
|
agent_id=agent_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
if should_save_conversation:
|
if should_save_conversation:
|
||||||
@@ -265,6 +403,7 @@ class BaseAnswerResource:
|
|||||||
is_shared_usage=is_shared_usage,
|
is_shared_usage=is_shared_usage,
|
||||||
shared_token=shared_token,
|
shared_token=shared_token,
|
||||||
attachment_ids=attachment_ids,
|
attachment_ids=attachment_ids,
|
||||||
|
metadata=query_metadata if query_metadata else None,
|
||||||
)
|
)
|
||||||
# Persist compression metadata/summary if it exists and wasn't saved mid-execution
|
# Persist compression metadata/summary if it exists and wasn't saved mid-execution
|
||||||
compression_meta = getattr(agent, "compression_metadata", None)
|
compression_meta = getattr(agent, "compression_metadata", None)
|
||||||
@@ -292,14 +431,20 @@ class BaseAnswerResource:
|
|||||||
data = json.dumps(id_data)
|
data = json.dumps(id_data)
|
||||||
yield f"data: {data}\n\n"
|
yield f"data: {data}\n\n"
|
||||||
|
|
||||||
|
tool_calls_for_logging = self._prepare_tool_calls_for_logging(
|
||||||
|
getattr(agent, "tool_calls", tool_calls) or tool_calls
|
||||||
|
)
|
||||||
|
|
||||||
log_data = {
|
log_data = {
|
||||||
"action": "stream_answer",
|
"action": "stream_answer",
|
||||||
"level": "info",
|
"level": "info",
|
||||||
"user": decoded_token.get("sub"),
|
"user": decoded_token.get("sub"),
|
||||||
"api_key": user_api_key,
|
"api_key": user_api_key,
|
||||||
|
"agent_id": agent_id,
|
||||||
"question": question,
|
"question": question,
|
||||||
"response": response_full,
|
"response": response_full,
|
||||||
"sources": source_log_docs,
|
"sources": source_log_docs,
|
||||||
|
"tool_calls": tool_calls_for_logging,
|
||||||
"attachments": attachment_ids,
|
"attachments": attachment_ids,
|
||||||
"timestamp": datetime.datetime.now(datetime.timezone.utc),
|
"timestamp": datetime.datetime.now(datetime.timezone.utc),
|
||||||
}
|
}
|
||||||
@@ -312,7 +457,18 @@ class BaseAnswerResource:
|
|||||||
for key, value in log_data.items():
|
for key, value in log_data.items():
|
||||||
if isinstance(value, str) and len(value) > 10000:
|
if isinstance(value, str) and len(value) > 10000:
|
||||||
log_data[key] = value[:10000]
|
log_data[key] = value[:10000]
|
||||||
self.user_logs_collection.insert_one(log_data)
|
try:
|
||||||
|
with db_session() as conn:
|
||||||
|
UserLogsRepository(conn).insert(
|
||||||
|
user_id=log_data.get("user"),
|
||||||
|
endpoint="stream_answer",
|
||||||
|
data=log_data,
|
||||||
|
)
|
||||||
|
except Exception as log_err:
|
||||||
|
logger.error(
|
||||||
|
f"Failed to persist stream_answer user log: {log_err}",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
|
||||||
data = json.dumps({"type": "end"})
|
data = json.dumps({"type": "end"})
|
||||||
yield f"data: {data}\n\n"
|
yield f"data: {data}\n\n"
|
||||||
@@ -330,6 +486,7 @@ class BaseAnswerResource:
|
|||||||
api_key=settings.API_KEY,
|
api_key=settings.API_KEY,
|
||||||
user_api_key=user_api_key,
|
user_api_key=user_api_key,
|
||||||
decoded_token=decoded_token,
|
decoded_token=decoded_token,
|
||||||
|
agent_id=agent_id,
|
||||||
)
|
)
|
||||||
self.conversation_service.save_conversation(
|
self.conversation_service.save_conversation(
|
||||||
conversation_id,
|
conversation_id,
|
||||||
@@ -347,6 +504,7 @@ class BaseAnswerResource:
|
|||||||
is_shared_usage=is_shared_usage,
|
is_shared_usage=is_shared_usage,
|
||||||
shared_token=shared_token,
|
shared_token=shared_token,
|
||||||
attachment_ids=attachment_ids,
|
attachment_ids=attachment_ids,
|
||||||
|
metadata=query_metadata if query_metadata else None,
|
||||||
)
|
)
|
||||||
compression_meta = getattr(agent, "compression_metadata", None)
|
compression_meta = getattr(agent, "compression_metadata", None)
|
||||||
compression_saved = getattr(agent, "compression_saved", False)
|
compression_saved = getattr(agent, "compression_saved", False)
|
||||||
@@ -383,8 +541,13 @@ class BaseAnswerResource:
|
|||||||
yield f"data: {data}\n\n"
|
yield f"data: {data}\n\n"
|
||||||
return
|
return
|
||||||
|
|
||||||
def process_response_stream(self, stream):
|
def process_response_stream(self, stream) -> Dict[str, Any]:
|
||||||
"""Process the stream response for non-streaming endpoint"""
|
"""Process the stream response for non-streaming endpoint.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict with keys: conversation_id, answer, sources, tool_calls,
|
||||||
|
thought, error, and optional extra.
|
||||||
|
"""
|
||||||
conversation_id = ""
|
conversation_id = ""
|
||||||
response_full = ""
|
response_full = ""
|
||||||
source_log_docs = []
|
source_log_docs = []
|
||||||
@@ -393,6 +556,7 @@ class BaseAnswerResource:
|
|||||||
stream_ended = False
|
stream_ended = False
|
||||||
is_structured = False
|
is_structured = False
|
||||||
schema_info = None
|
schema_info = None
|
||||||
|
pending_tool_calls = None
|
||||||
|
|
||||||
for line in stream:
|
for line in stream:
|
||||||
try:
|
try:
|
||||||
@@ -411,11 +575,22 @@ class BaseAnswerResource:
|
|||||||
source_log_docs = event["source"]
|
source_log_docs = event["source"]
|
||||||
elif event["type"] == "tool_calls":
|
elif event["type"] == "tool_calls":
|
||||||
tool_calls = event["tool_calls"]
|
tool_calls = event["tool_calls"]
|
||||||
|
elif event["type"] == "tool_calls_pending":
|
||||||
|
pending_tool_calls = event.get("data", {}).get(
|
||||||
|
"pending_tool_calls", []
|
||||||
|
)
|
||||||
elif event["type"] == "thought":
|
elif event["type"] == "thought":
|
||||||
thought = event["thought"]
|
thought = event["thought"]
|
||||||
elif event["type"] == "error":
|
elif event["type"] == "error":
|
||||||
logger.error(f"Error from stream: {event['error']}")
|
logger.error(f"Error from stream: {event['error']}")
|
||||||
return None, None, None, None, event["error"], None
|
return {
|
||||||
|
"conversation_id": None,
|
||||||
|
"answer": None,
|
||||||
|
"sources": None,
|
||||||
|
"tool_calls": None,
|
||||||
|
"thought": None,
|
||||||
|
"error": event["error"],
|
||||||
|
}
|
||||||
elif event["type"] == "end":
|
elif event["type"] == "end":
|
||||||
stream_ended = True
|
stream_ended = True
|
||||||
except (json.JSONDecodeError, KeyError) as e:
|
except (json.JSONDecodeError, KeyError) as e:
|
||||||
@@ -423,18 +598,30 @@ class BaseAnswerResource:
|
|||||||
continue
|
continue
|
||||||
if not stream_ended:
|
if not stream_ended:
|
||||||
logger.error("Stream ended unexpectedly without an 'end' event.")
|
logger.error("Stream ended unexpectedly without an 'end' event.")
|
||||||
return None, None, None, None, "Stream ended unexpectedly", None
|
return {
|
||||||
result = (
|
"conversation_id": None,
|
||||||
conversation_id,
|
"answer": None,
|
||||||
response_full,
|
"sources": None,
|
||||||
source_log_docs,
|
"tool_calls": None,
|
||||||
tool_calls,
|
"thought": None,
|
||||||
thought,
|
"error": "Stream ended unexpectedly",
|
||||||
None,
|
}
|
||||||
)
|
|
||||||
|
result: Dict[str, Any] = {
|
||||||
|
"conversation_id": conversation_id,
|
||||||
|
"answer": response_full,
|
||||||
|
"sources": source_log_docs,
|
||||||
|
"tool_calls": tool_calls,
|
||||||
|
"thought": thought,
|
||||||
|
"error": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
if pending_tool_calls is not None:
|
||||||
|
result["extra"] = {"pending_tool_calls": pending_tool_calls}
|
||||||
|
|
||||||
if is_structured:
|
if is_structured:
|
||||||
result = result + ({"structured": True, "schema": schema_info},)
|
result["extra"] = {"structured": True, "schema": schema_info}
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def error_stream_generate(self, err_response):
|
def error_stream_generate(self, err_response):
|
||||||
|
|||||||
166
application/api/answer/routes/search.py
Normal file
166
application/api/answer/routes/search.py
Normal file
@@ -0,0 +1,166 @@
|
|||||||
|
import logging
|
||||||
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
|
from flask import make_response, request
|
||||||
|
from flask_restx import fields, Resource
|
||||||
|
|
||||||
|
from application.api.answer.routes.base import answer_ns
|
||||||
|
from application.core.settings import settings
|
||||||
|
from application.storage.db.repositories.agents import AgentsRepository
|
||||||
|
from application.storage.db.session import db_readonly
|
||||||
|
from application.vectorstore.vector_creator import VectorCreator
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@answer_ns.route("/api/search")
|
||||||
|
class SearchResource(Resource):
|
||||||
|
"""Fast search endpoint for retrieving relevant documents"""
|
||||||
|
|
||||||
|
search_model = answer_ns.model(
|
||||||
|
"SearchModel",
|
||||||
|
{
|
||||||
|
"question": fields.String(
|
||||||
|
required=True, description="Search query"
|
||||||
|
),
|
||||||
|
"api_key": fields.String(
|
||||||
|
required=True, description="API key for authentication"
|
||||||
|
),
|
||||||
|
"chunks": fields.Integer(
|
||||||
|
required=False, default=5, description="Number of results to return"
|
||||||
|
),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_sources_from_api_key(self, api_key: str) -> List[str]:
|
||||||
|
"""Get source IDs connected to the API key/agent."""
|
||||||
|
with db_readonly() as conn:
|
||||||
|
agent_data = AgentsRepository(conn).find_by_key(api_key)
|
||||||
|
if not agent_data:
|
||||||
|
return []
|
||||||
|
|
||||||
|
source_ids: List[str] = []
|
||||||
|
# extra_source_ids is a PG ARRAY(UUID) of source UUIDs.
|
||||||
|
extra = agent_data.get("extra_source_ids") or []
|
||||||
|
for src in extra:
|
||||||
|
if src:
|
||||||
|
source_ids.append(str(src))
|
||||||
|
|
||||||
|
if not source_ids:
|
||||||
|
single = agent_data.get("source_id")
|
||||||
|
if single:
|
||||||
|
source_ids.append(str(single))
|
||||||
|
|
||||||
|
return source_ids
|
||||||
|
|
||||||
|
def _search_vectorstores(
|
||||||
|
self, query: str, source_ids: List[str], chunks: int
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
"""Search across vectorstores and return results"""
|
||||||
|
if not source_ids:
|
||||||
|
return []
|
||||||
|
|
||||||
|
results = []
|
||||||
|
chunks_per_source = max(1, chunks // len(source_ids))
|
||||||
|
seen_texts = set()
|
||||||
|
|
||||||
|
for source_id in source_ids:
|
||||||
|
if not source_id or not source_id.strip():
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
docsearch = VectorCreator.create_vectorstore(
|
||||||
|
settings.VECTOR_STORE, source_id, settings.EMBEDDINGS_KEY
|
||||||
|
)
|
||||||
|
docs = docsearch.search(query, k=chunks_per_source * 2)
|
||||||
|
|
||||||
|
for doc in docs:
|
||||||
|
if len(results) >= chunks:
|
||||||
|
break
|
||||||
|
|
||||||
|
if hasattr(doc, "page_content") and hasattr(doc, "metadata"):
|
||||||
|
page_content = doc.page_content
|
||||||
|
metadata = doc.metadata
|
||||||
|
else:
|
||||||
|
page_content = doc.get("text", doc.get("page_content", ""))
|
||||||
|
metadata = doc.get("metadata", {})
|
||||||
|
|
||||||
|
# Skip duplicates
|
||||||
|
text_hash = hash(page_content[:200])
|
||||||
|
if text_hash in seen_texts:
|
||||||
|
continue
|
||||||
|
seen_texts.add(text_hash)
|
||||||
|
|
||||||
|
title = metadata.get(
|
||||||
|
"title", metadata.get("post_title", "")
|
||||||
|
)
|
||||||
|
if not isinstance(title, str):
|
||||||
|
title = str(title) if title else ""
|
||||||
|
|
||||||
|
# Clean up title
|
||||||
|
if title:
|
||||||
|
title = title.split("/")[-1]
|
||||||
|
else:
|
||||||
|
# Use filename or first part of content as title
|
||||||
|
title = metadata.get("filename", page_content[:50] + "...")
|
||||||
|
|
||||||
|
source = metadata.get("source", source_id)
|
||||||
|
|
||||||
|
results.append({
|
||||||
|
"text": page_content,
|
||||||
|
"title": title,
|
||||||
|
"source": source,
|
||||||
|
})
|
||||||
|
|
||||||
|
if len(results) >= chunks:
|
||||||
|
break
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Error searching vectorstore {source_id}: {e}",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
return results[:chunks]
|
||||||
|
|
||||||
|
@answer_ns.expect(search_model)
|
||||||
|
@answer_ns.doc(description="Search for relevant documents based on query")
|
||||||
|
def post(self):
|
||||||
|
data = request.get_json()
|
||||||
|
|
||||||
|
question = data.get("question")
|
||||||
|
api_key = data.get("api_key")
|
||||||
|
chunks = data.get("chunks", 5)
|
||||||
|
|
||||||
|
if not question:
|
||||||
|
return make_response({"error": "question is required"}, 400)
|
||||||
|
|
||||||
|
if not api_key:
|
||||||
|
return make_response({"error": "api_key is required"}, 400)
|
||||||
|
|
||||||
|
# Validate API key
|
||||||
|
with db_readonly() as conn:
|
||||||
|
agent = AgentsRepository(conn).find_by_key(api_key)
|
||||||
|
if not agent:
|
||||||
|
return make_response({"error": "Invalid API key"}, 401)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Get sources connected to this API key
|
||||||
|
source_ids = self._get_sources_from_api_key(api_key)
|
||||||
|
|
||||||
|
if not source_ids:
|
||||||
|
return make_response([], 200)
|
||||||
|
|
||||||
|
# Perform search
|
||||||
|
results = self._search_vectorstores(question, source_ids, chunks)
|
||||||
|
|
||||||
|
return make_response(results, 200)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"/api/search - error: {str(e)}",
|
||||||
|
extra={"error": str(e)},
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
return make_response({"error": "Search failed"}, 500)
|
||||||
@@ -40,9 +40,9 @@ class StreamResource(Resource, BaseAnswerResource):
|
|||||||
"chunks": fields.Integer(
|
"chunks": fields.Integer(
|
||||||
required=False, default=2, description="Number of chunks"
|
required=False, default=2, description="Number of chunks"
|
||||||
),
|
),
|
||||||
"token_limit": fields.Integer(required=False, description="Token limit"),
|
|
||||||
"retriever": fields.String(required=False, description="Retriever type"),
|
"retriever": fields.String(required=False, description="Retriever type"),
|
||||||
"api_key": fields.String(required=False, description="API key"),
|
"api_key": fields.String(required=False, description="API key"),
|
||||||
|
"agent_id": fields.String(required=False, description="Agent ID"),
|
||||||
"active_docs": fields.String(
|
"active_docs": fields.String(
|
||||||
required=False, description="Active documents"
|
required=False, description="Active documents"
|
||||||
),
|
),
|
||||||
@@ -79,15 +79,54 @@ class StreamResource(Resource, BaseAnswerResource):
|
|||||||
return error
|
return error
|
||||||
decoded_token = getattr(request, "decoded_token", None)
|
decoded_token = getattr(request, "decoded_token", None)
|
||||||
processor = StreamProcessor(data, decoded_token)
|
processor = StreamProcessor(data, decoded_token)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
processor.initialize()
|
# ---- Continuation mode ----
|
||||||
|
if data.get("tool_actions"):
|
||||||
|
(
|
||||||
|
agent,
|
||||||
|
messages,
|
||||||
|
tools_dict,
|
||||||
|
pending_tool_calls,
|
||||||
|
tool_actions,
|
||||||
|
) = processor.resume_from_tool_actions(
|
||||||
|
data["tool_actions"], data["conversation_id"]
|
||||||
|
)
|
||||||
|
if not processor.decoded_token:
|
||||||
|
return Response(
|
||||||
|
self.error_stream_generate("Unauthorized"),
|
||||||
|
status=401,
|
||||||
|
mimetype="text/event-stream",
|
||||||
|
)
|
||||||
|
if error := self.check_usage(processor.agent_config):
|
||||||
|
return error
|
||||||
|
return Response(
|
||||||
|
self.complete_stream(
|
||||||
|
question="",
|
||||||
|
agent=agent,
|
||||||
|
conversation_id=processor.conversation_id,
|
||||||
|
user_api_key=processor.agent_config.get("user_api_key"),
|
||||||
|
decoded_token=processor.decoded_token,
|
||||||
|
agent_id=processor.agent_id,
|
||||||
|
model_id=processor.model_id,
|
||||||
|
_continuation={
|
||||||
|
"messages": messages,
|
||||||
|
"tools_dict": tools_dict,
|
||||||
|
"pending_tool_calls": pending_tool_calls,
|
||||||
|
"tool_actions": tool_actions,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
mimetype="text/event-stream",
|
||||||
|
)
|
||||||
|
|
||||||
docs_together, docs_list = processor.pre_fetch_docs(data["question"])
|
# ---- Normal mode ----
|
||||||
tools_data = processor.pre_fetch_tools()
|
agent = processor.build_agent(data["question"])
|
||||||
|
if not processor.decoded_token:
|
||||||
agent = processor.create_agent(
|
return Response(
|
||||||
docs_together=docs_together, docs=docs_list, tools_data=tools_data
|
self.error_stream_generate("Unauthorized"),
|
||||||
)
|
status=401,
|
||||||
|
mimetype="text/event-stream",
|
||||||
|
)
|
||||||
|
|
||||||
if error := self.check_usage(processor.agent_config):
|
if error := self.check_usage(processor.agent_config):
|
||||||
return error
|
return error
|
||||||
@@ -102,7 +141,7 @@ class StreamResource(Resource, BaseAnswerResource):
|
|||||||
index=data.get("index"),
|
index=data.get("index"),
|
||||||
should_save_conversation=data.get("save_conversation", True),
|
should_save_conversation=data.get("save_conversation", True),
|
||||||
attachment_ids=data.get("attachments", []),
|
attachment_ids=data.get("attachments", []),
|
||||||
agent_id=data.get("agent_id"),
|
agent_id=processor.agent_id,
|
||||||
is_shared_usage=processor.is_shared_usage,
|
is_shared_usage=processor.is_shared_usage,
|
||||||
shared_token=processor.shared_token,
|
shared_token=processor.shared_token,
|
||||||
model_id=processor.model_id,
|
model_id=processor.model_id,
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
"""Message reconstruction utilities for compression."""
|
"""Message reconstruction utilities for compression."""
|
||||||
|
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional
|
||||||
@@ -49,28 +50,35 @@ class MessageBuilder:
|
|||||||
if include_tool_calls and "tool_calls" in query:
|
if include_tool_calls and "tool_calls" in query:
|
||||||
for tool_call in query["tool_calls"]:
|
for tool_call in query["tool_calls"]:
|
||||||
call_id = tool_call.get("call_id") or str(uuid.uuid4())
|
call_id = tool_call.get("call_id") or str(uuid.uuid4())
|
||||||
|
args = tool_call.get("arguments")
|
||||||
function_call_dict = {
|
args_str = (
|
||||||
"function_call": {
|
json.dumps(args)
|
||||||
"name": tool_call.get("action_name"),
|
if isinstance(args, dict)
|
||||||
"args": tool_call.get("arguments"),
|
else (args or "{}")
|
||||||
"call_id": call_id,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
function_response_dict = {
|
|
||||||
"function_response": {
|
|
||||||
"name": tool_call.get("action_name"),
|
|
||||||
"response": {"result": tool_call.get("result")},
|
|
||||||
"call_id": call_id,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
messages.append(
|
|
||||||
{"role": "assistant", "content": [function_call_dict]}
|
|
||||||
)
|
)
|
||||||
messages.append(
|
messages.append({
|
||||||
{"role": "tool", "content": [function_response_dict]}
|
"role": "assistant",
|
||||||
|
"content": None,
|
||||||
|
"tool_calls": [{
|
||||||
|
"id": call_id,
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": tool_call.get("action_name", ""),
|
||||||
|
"arguments": args_str,
|
||||||
|
},
|
||||||
|
}],
|
||||||
|
})
|
||||||
|
result = tool_call.get("result")
|
||||||
|
result_str = (
|
||||||
|
json.dumps(result)
|
||||||
|
if not isinstance(result, str)
|
||||||
|
else (result or "")
|
||||||
)
|
)
|
||||||
|
messages.append({
|
||||||
|
"role": "tool",
|
||||||
|
"tool_call_id": call_id,
|
||||||
|
"content": result_str,
|
||||||
|
})
|
||||||
|
|
||||||
# If no recent queries (everything was compressed), add a continuation user message
|
# If no recent queries (everything was compressed), add a continuation user message
|
||||||
if len(recent_queries) == 0 and compressed_summary:
|
if len(recent_queries) == 0 and compressed_summary:
|
||||||
@@ -180,28 +188,35 @@ class MessageBuilder:
|
|||||||
if include_tool_calls and "tool_calls" in query:
|
if include_tool_calls and "tool_calls" in query:
|
||||||
for tool_call in query["tool_calls"]:
|
for tool_call in query["tool_calls"]:
|
||||||
call_id = tool_call.get("call_id") or str(uuid.uuid4())
|
call_id = tool_call.get("call_id") or str(uuid.uuid4())
|
||||||
|
args = tool_call.get("arguments")
|
||||||
function_call_dict = {
|
args_str = (
|
||||||
"function_call": {
|
json.dumps(args)
|
||||||
"name": tool_call.get("action_name"),
|
if isinstance(args, dict)
|
||||||
"args": tool_call.get("arguments"),
|
else (args or "{}")
|
||||||
"call_id": call_id,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
function_response_dict = {
|
|
||||||
"function_response": {
|
|
||||||
"name": tool_call.get("action_name"),
|
|
||||||
"response": {"result": tool_call.get("result")},
|
|
||||||
"call_id": call_id,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
rebuilt_messages.append(
|
|
||||||
{"role": "assistant", "content": [function_call_dict]}
|
|
||||||
)
|
)
|
||||||
rebuilt_messages.append(
|
rebuilt_messages.append({
|
||||||
{"role": "tool", "content": [function_response_dict]}
|
"role": "assistant",
|
||||||
|
"content": None,
|
||||||
|
"tool_calls": [{
|
||||||
|
"id": call_id,
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": tool_call.get("action_name", ""),
|
||||||
|
"arguments": args_str,
|
||||||
|
},
|
||||||
|
}],
|
||||||
|
})
|
||||||
|
result = tool_call.get("result")
|
||||||
|
result_str = (
|
||||||
|
json.dumps(result)
|
||||||
|
if not isinstance(result, str)
|
||||||
|
else (result or "")
|
||||||
)
|
)
|
||||||
|
rebuilt_messages.append({
|
||||||
|
"role": "tool",
|
||||||
|
"tool_call_id": call_id,
|
||||||
|
"content": result_str,
|
||||||
|
})
|
||||||
|
|
||||||
# If no recent queries (everything was compressed), add a continuation user message
|
# If no recent queries (everything was compressed), add a continuation user message
|
||||||
if len(recent_queries) == 0 and compressed_summary:
|
if len(recent_queries) == 0 and compressed_summary:
|
||||||
|
|||||||
@@ -134,6 +134,7 @@ class CompressionOrchestrator:
|
|||||||
user_api_key=None,
|
user_api_key=None,
|
||||||
decoded_token=decoded_token,
|
decoded_token=decoded_token,
|
||||||
model_id=compression_model,
|
model_id=compression_model,
|
||||||
|
agent_id=conversation.get("agent_id"),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create compression service with DB update capability
|
# Create compression service with DB update capability
|
||||||
|
|||||||
157
application/api/answer/services/continuation_service.py
Normal file
157
application/api/answer/services/continuation_service.py
Normal file
@@ -0,0 +1,157 @@
|
|||||||
|
"""Service for saving and restoring tool-call continuation state.
|
||||||
|
|
||||||
|
When a stream pauses (tool needs approval or client-side execution),
|
||||||
|
the full execution state is persisted to Postgres so the client can
|
||||||
|
resume later by sending tool_actions.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from application.storage.db.base_repository import looks_like_uuid
|
||||||
|
from application.storage.db.repositories.conversations import ConversationsRepository
|
||||||
|
from application.storage.db.repositories.pending_tool_state import (
|
||||||
|
PendingToolStateRepository,
|
||||||
|
)
|
||||||
|
from application.storage.db.session import db_readonly, db_session
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# TTL for pending states — auto-cleaned after this period
|
||||||
|
PENDING_STATE_TTL_SECONDS = 30 * 60 # 30 minutes
|
||||||
|
|
||||||
|
|
||||||
|
def _make_serializable(obj: Any) -> Any:
|
||||||
|
"""Recursively coerce non-JSON values into JSON-safe forms.
|
||||||
|
|
||||||
|
Handles ``uuid.UUID`` (from PG columns), ``bytes``, and recurses into
|
||||||
|
dicts/lists. Post-Mongo-cutover the ObjectId branch is gone — none of
|
||||||
|
our writers produce them anymore.
|
||||||
|
"""
|
||||||
|
if isinstance(obj, UUID):
|
||||||
|
return str(obj)
|
||||||
|
if isinstance(obj, dict):
|
||||||
|
return {str(k): _make_serializable(v) for k, v in obj.items()}
|
||||||
|
if isinstance(obj, list):
|
||||||
|
return [_make_serializable(v) for v in obj]
|
||||||
|
if isinstance(obj, bytes):
|
||||||
|
return obj.decode("utf-8", errors="replace")
|
||||||
|
return obj
|
||||||
|
|
||||||
|
|
||||||
|
class ContinuationService:
|
||||||
|
"""Manages pending tool-call state in Postgres."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
# No-op constructor retained for call-site compatibility. State
|
||||||
|
# lives in Postgres now; each operation opens its own short-lived
|
||||||
|
# session rather than holding a connection on the service.
|
||||||
|
pass
|
||||||
|
|
||||||
|
def save_state(
|
||||||
|
self,
|
||||||
|
conversation_id: str,
|
||||||
|
user: str,
|
||||||
|
messages: List[Dict],
|
||||||
|
pending_tool_calls: List[Dict],
|
||||||
|
tools_dict: Dict,
|
||||||
|
tool_schemas: List[Dict],
|
||||||
|
agent_config: Dict,
|
||||||
|
client_tools: Optional[List[Dict]] = None,
|
||||||
|
) -> str:
|
||||||
|
"""Save execution state for later continuation.
|
||||||
|
|
||||||
|
``conversation_id`` may be a Postgres UUID or the legacy Mongo
|
||||||
|
``ObjectId`` string — the latter is resolved via
|
||||||
|
``conversations.legacy_mongo_id`` to find the matching row.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
conversation_id: The conversation this state belongs to.
|
||||||
|
user: Owner user ID.
|
||||||
|
messages: Full messages array at the pause point.
|
||||||
|
pending_tool_calls: Tool calls awaiting client action.
|
||||||
|
tools_dict: Serializable tools configuration dict.
|
||||||
|
tool_schemas: LLM-formatted tool schemas (agent.tools).
|
||||||
|
agent_config: Config needed to recreate the agent on resume.
|
||||||
|
client_tools: Client-provided tool schemas for client-side execution.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The string ID (conversation_id as provided) of the saved state.
|
||||||
|
"""
|
||||||
|
with db_session() as conn:
|
||||||
|
conv = ConversationsRepository(conn).get_by_legacy_id(conversation_id)
|
||||||
|
if conv is not None:
|
||||||
|
pg_conv_id = conv["id"]
|
||||||
|
elif looks_like_uuid(conversation_id):
|
||||||
|
pg_conv_id = conversation_id
|
||||||
|
else:
|
||||||
|
# Unresolvable legacy ObjectId — downstream ``CAST AS uuid``
|
||||||
|
# would raise and poison the save. Surface the mismatch so
|
||||||
|
# the caller can decide (the stream loop in routes/base.py
|
||||||
|
# already wraps this in try/except).
|
||||||
|
raise ValueError(
|
||||||
|
f"Cannot save continuation state: conversation_id "
|
||||||
|
f"{conversation_id!r} is neither a PG UUID nor a "
|
||||||
|
f"backfilled legacy Mongo id."
|
||||||
|
)
|
||||||
|
PendingToolStateRepository(conn).save_state(
|
||||||
|
pg_conv_id,
|
||||||
|
user,
|
||||||
|
messages=_make_serializable(messages),
|
||||||
|
pending_tool_calls=_make_serializable(pending_tool_calls),
|
||||||
|
tools_dict=_make_serializable(tools_dict),
|
||||||
|
tool_schemas=_make_serializable(tool_schemas),
|
||||||
|
agent_config=_make_serializable(agent_config),
|
||||||
|
client_tools=_make_serializable(client_tools) if client_tools else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Saved continuation state for conversation {conversation_id} "
|
||||||
|
f"with {len(pending_tool_calls)} pending tool call(s)"
|
||||||
|
)
|
||||||
|
return conversation_id
|
||||||
|
|
||||||
|
def load_state(
|
||||||
|
self, conversation_id: str, user: str
|
||||||
|
) -> Optional[Dict[str, Any]]:
|
||||||
|
"""Load pending continuation state.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The state dict, or None if no pending state exists.
|
||||||
|
"""
|
||||||
|
with db_readonly() as conn:
|
||||||
|
conv = ConversationsRepository(conn).get_by_legacy_id(conversation_id)
|
||||||
|
if conv is not None:
|
||||||
|
pg_conv_id = conv["id"]
|
||||||
|
elif looks_like_uuid(conversation_id):
|
||||||
|
pg_conv_id = conversation_id
|
||||||
|
else:
|
||||||
|
# Unresolvable legacy ObjectId → no state can exist for it.
|
||||||
|
return None
|
||||||
|
doc = PendingToolStateRepository(conn).load_state(pg_conv_id, user)
|
||||||
|
if not doc:
|
||||||
|
return None
|
||||||
|
return doc
|
||||||
|
|
||||||
|
def delete_state(self, conversation_id: str, user: str) -> bool:
|
||||||
|
"""Delete pending state after successful resumption.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if a row was deleted.
|
||||||
|
"""
|
||||||
|
with db_session() as conn:
|
||||||
|
conv = ConversationsRepository(conn).get_by_legacy_id(conversation_id)
|
||||||
|
if conv is not None:
|
||||||
|
pg_conv_id = conv["id"]
|
||||||
|
elif looks_like_uuid(conversation_id):
|
||||||
|
pg_conv_id = conversation_id
|
||||||
|
else:
|
||||||
|
# Unresolvable legacy ObjectId → nothing to delete.
|
||||||
|
return False
|
||||||
|
deleted = PendingToolStateRepository(conn).delete_state(pg_conv_id, user)
|
||||||
|
if deleted:
|
||||||
|
logger.info(
|
||||||
|
f"Deleted continuation state for conversation {conversation_id}"
|
||||||
|
)
|
||||||
|
return deleted
|
||||||
@@ -1,44 +1,51 @@
|
|||||||
|
"""Conversation persistence service backed by Postgres.
|
||||||
|
|
||||||
|
Handles create / append / update / compression for conversations during
|
||||||
|
the answer-streaming path. Connections are opened per-operation rather
|
||||||
|
than held for the duration of a stream.
|
||||||
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
from application.core.mongo_db import MongoDB
|
from sqlalchemy import text as sql_text
|
||||||
|
|
||||||
from application.core.settings import settings
|
from application.core.settings import settings
|
||||||
from bson import ObjectId
|
from application.storage.db.base_repository import looks_like_uuid
|
||||||
|
from application.storage.db.repositories.agents import AgentsRepository
|
||||||
|
from application.storage.db.repositories.conversations import ConversationsRepository
|
||||||
|
from application.storage.db.session import db_readonly, db_session
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class ConversationService:
|
class ConversationService:
|
||||||
def __init__(self):
|
|
||||||
mongo = MongoDB.get_client()
|
|
||||||
db = mongo[settings.MONGO_DB_NAME]
|
|
||||||
self.conversations_collection = db["conversations"]
|
|
||||||
self.agents_collection = db["agents"]
|
|
||||||
|
|
||||||
def get_conversation(
|
def get_conversation(
|
||||||
self, conversation_id: str, user_id: str
|
self, conversation_id: str, user_id: str
|
||||||
) -> Optional[Dict[str, Any]]:
|
) -> Optional[Dict[str, Any]]:
|
||||||
"""Retrieve a conversation with proper access control"""
|
"""Retrieve a conversation with owner-or-shared access control.
|
||||||
|
|
||||||
|
Returns a dict in the legacy Mongo shape — ``queries`` is a list
|
||||||
|
of message dicts (prompt/response/...) — for compatibility with
|
||||||
|
the streaming pipeline that consumes this shape.
|
||||||
|
"""
|
||||||
if not conversation_id or not user_id:
|
if not conversation_id or not user_id:
|
||||||
return None
|
return None
|
||||||
try:
|
try:
|
||||||
conversation = self.conversations_collection.find_one(
|
with db_readonly() as conn:
|
||||||
{
|
repo = ConversationsRepository(conn)
|
||||||
"_id": ObjectId(conversation_id),
|
conv = repo.get_any(conversation_id, user_id)
|
||||||
"$or": [{"user": user_id}, {"shared_with": user_id}],
|
if conv is None:
|
||||||
}
|
logger.warning(
|
||||||
)
|
f"Conversation not found or unauthorized - ID: {conversation_id}, User: {user_id}"
|
||||||
|
)
|
||||||
if not conversation:
|
return None
|
||||||
logger.warning(
|
messages = repo.get_messages(str(conv["id"]))
|
||||||
f"Conversation not found or unauthorized - ID: {conversation_id}, User: {user_id}"
|
conv["queries"] = messages
|
||||||
)
|
conv["_id"] = str(conv["id"])
|
||||||
return None
|
return conv
|
||||||
conversation["_id"] = str(conversation["_id"])
|
|
||||||
return conversation
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error fetching conversation: {str(e)}", exc_info=True)
|
logger.error(f"Error fetching conversation: {str(e)}", exc_info=True)
|
||||||
return None
|
return None
|
||||||
@@ -60,79 +67,61 @@ class ConversationService:
|
|||||||
is_shared_usage: bool = False,
|
is_shared_usage: bool = False,
|
||||||
shared_token: Optional[str] = None,
|
shared_token: Optional[str] = None,
|
||||||
attachment_ids: Optional[List[str]] = None,
|
attachment_ids: Optional[List[str]] = None,
|
||||||
|
metadata: Optional[Dict[str, Any]] = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Save or update a conversation in the database"""
|
"""Save or update a conversation in Postgres.
|
||||||
|
|
||||||
|
Returns the string conversation id (PG UUID as string, or the
|
||||||
|
caller-provided id if it was already a UUID).
|
||||||
|
"""
|
||||||
|
if decoded_token is None:
|
||||||
|
raise ValueError("Invalid or missing authentication token")
|
||||||
user_id = decoded_token.get("sub")
|
user_id = decoded_token.get("sub")
|
||||||
if not user_id:
|
if not user_id:
|
||||||
raise ValueError("User ID not found in token")
|
raise ValueError("User ID not found in token")
|
||||||
current_time = datetime.now(timezone.utc)
|
current_time = datetime.now(timezone.utc)
|
||||||
|
|
||||||
# clean up in sources array such that we save max 1k characters for text part
|
# Trim huge inline source text to a reasonable max before persist.
|
||||||
for source in sources:
|
for source in sources:
|
||||||
if "text" in source and isinstance(source["text"], str):
|
if "text" in source and isinstance(source["text"], str):
|
||||||
source["text"] = source["text"][:1000]
|
source["text"] = source["text"][:1000]
|
||||||
|
|
||||||
|
message_payload = {
|
||||||
|
"prompt": question,
|
||||||
|
"response": response,
|
||||||
|
"thought": thought,
|
||||||
|
"sources": sources,
|
||||||
|
"tool_calls": tool_calls,
|
||||||
|
"attachments": attachment_ids,
|
||||||
|
"model_id": model_id,
|
||||||
|
"timestamp": current_time,
|
||||||
|
}
|
||||||
|
if metadata:
|
||||||
|
message_payload["metadata"] = metadata
|
||||||
|
|
||||||
if conversation_id is not None and index is not None:
|
if conversation_id is not None and index is not None:
|
||||||
# Update existing conversation with new query
|
with db_session() as conn:
|
||||||
|
repo = ConversationsRepository(conn)
|
||||||
result = self.conversations_collection.update_one(
|
conv = repo.get_any(conversation_id, user_id)
|
||||||
{
|
if conv is None:
|
||||||
"_id": ObjectId(conversation_id),
|
raise ValueError("Conversation not found or unauthorized")
|
||||||
"user": user_id,
|
conv_pg_id = str(conv["id"])
|
||||||
f"queries.{index}": {"$exists": True},
|
repo.update_message_at(conv_pg_id, index, message_payload)
|
||||||
},
|
repo.truncate_after(conv_pg_id, index)
|
||||||
{
|
|
||||||
"$set": {
|
|
||||||
f"queries.{index}.prompt": question,
|
|
||||||
f"queries.{index}.response": response,
|
|
||||||
f"queries.{index}.thought": thought,
|
|
||||||
f"queries.{index}.sources": sources,
|
|
||||||
f"queries.{index}.tool_calls": tool_calls,
|
|
||||||
f"queries.{index}.timestamp": current_time,
|
|
||||||
f"queries.{index}.attachments": attachment_ids,
|
|
||||||
f"queries.{index}.model_id": model_id,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
if result.matched_count == 0:
|
|
||||||
raise ValueError("Conversation not found or unauthorized")
|
|
||||||
self.conversations_collection.update_one(
|
|
||||||
{
|
|
||||||
"_id": ObjectId(conversation_id),
|
|
||||||
"user": user_id,
|
|
||||||
f"queries.{index}": {"$exists": True},
|
|
||||||
},
|
|
||||||
{"$push": {"queries": {"$each": [], "$slice": index + 1}}},
|
|
||||||
)
|
|
||||||
return conversation_id
|
return conversation_id
|
||||||
elif conversation_id:
|
elif conversation_id:
|
||||||
# Append new message to existing conversation
|
with db_session() as conn:
|
||||||
|
repo = ConversationsRepository(conn)
|
||||||
result = self.conversations_collection.update_one(
|
conv = repo.get_any(conversation_id, user_id)
|
||||||
{"_id": ObjectId(conversation_id), "user": user_id},
|
if conv is None:
|
||||||
{
|
raise ValueError("Conversation not found or unauthorized")
|
||||||
"$push": {
|
conv_pg_id = str(conv["id"])
|
||||||
"queries": {
|
# append_message expects 'metadata' key either way; normalise.
|
||||||
"prompt": question,
|
append_payload = dict(message_payload)
|
||||||
"response": response,
|
append_payload.setdefault("metadata", metadata or {})
|
||||||
"thought": thought,
|
repo.append_message(conv_pg_id, append_payload)
|
||||||
"sources": sources,
|
|
||||||
"tool_calls": tool_calls,
|
|
||||||
"timestamp": current_time,
|
|
||||||
"attachments": attachment_ids,
|
|
||||||
"model_id": model_id,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
if result.matched_count == 0:
|
|
||||||
raise ValueError("Conversation not found or unauthorized")
|
|
||||||
return conversation_id
|
return conversation_id
|
||||||
else:
|
else:
|
||||||
# Create new conversation
|
|
||||||
|
|
||||||
messages_summary = [
|
messages_summary = [
|
||||||
{
|
{
|
||||||
"role": "system",
|
"role": "system",
|
||||||
@@ -148,71 +137,70 @@ class ConversationService:
|
|||||||
]
|
]
|
||||||
|
|
||||||
completion = llm.gen(
|
completion = llm.gen(
|
||||||
model=model_id, messages=messages_summary, max_tokens=30
|
model=model_id, messages=messages_summary, max_tokens=500
|
||||||
)
|
)
|
||||||
|
|
||||||
conversation_data = {
|
if not completion or not completion.strip():
|
||||||
"user": user_id,
|
completion = question[:50] if question else "New Conversation"
|
||||||
"date": current_time,
|
|
||||||
"name": completion,
|
|
||||||
"queries": [
|
|
||||||
{
|
|
||||||
"prompt": question,
|
|
||||||
"response": response,
|
|
||||||
"thought": thought,
|
|
||||||
"sources": sources,
|
|
||||||
"tool_calls": tool_calls,
|
|
||||||
"timestamp": current_time,
|
|
||||||
"attachments": attachment_ids,
|
|
||||||
"model_id": model_id,
|
|
||||||
}
|
|
||||||
],
|
|
||||||
}
|
|
||||||
|
|
||||||
|
resolved_api_key: Optional[str] = None
|
||||||
|
resolved_agent_id: Optional[str] = None
|
||||||
if api_key:
|
if api_key:
|
||||||
if agent_id:
|
with db_readonly() as conn:
|
||||||
conversation_data["agent_id"] = agent_id
|
agent = AgentsRepository(conn).find_by_key(api_key)
|
||||||
if is_shared_usage:
|
|
||||||
conversation_data["is_shared_usage"] = is_shared_usage
|
|
||||||
conversation_data["shared_token"] = shared_token
|
|
||||||
agent = self.agents_collection.find_one({"key": api_key})
|
|
||||||
if agent:
|
if agent:
|
||||||
conversation_data["api_key"] = agent["key"]
|
resolved_api_key = agent.get("key")
|
||||||
result = self.conversations_collection.insert_one(conversation_data)
|
if agent_id:
|
||||||
return str(result.inserted_id)
|
resolved_agent_id = agent_id
|
||||||
|
|
||||||
|
with db_session() as conn:
|
||||||
|
repo = ConversationsRepository(conn)
|
||||||
|
conv = repo.create(
|
||||||
|
user_id,
|
||||||
|
completion,
|
||||||
|
agent_id=resolved_agent_id,
|
||||||
|
api_key=resolved_api_key,
|
||||||
|
is_shared_usage=bool(resolved_agent_id and is_shared_usage),
|
||||||
|
shared_token=(
|
||||||
|
shared_token
|
||||||
|
if (resolved_agent_id and is_shared_usage)
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
)
|
||||||
|
conv_pg_id = str(conv["id"])
|
||||||
|
append_payload = dict(message_payload)
|
||||||
|
append_payload.setdefault("metadata", metadata or {})
|
||||||
|
repo.append_message(conv_pg_id, append_payload)
|
||||||
|
return conv_pg_id
|
||||||
|
|
||||||
def update_compression_metadata(
|
def update_compression_metadata(
|
||||||
self, conversation_id: str, compression_metadata: Dict[str, Any]
|
self, conversation_id: str, compression_metadata: Dict[str, Any]
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""Persist compression flags and append a compression point.
|
||||||
Update conversation with compression metadata.
|
|
||||||
|
|
||||||
Uses $push with $slice to keep only the most recent compression points,
|
Mirrors the Mongo-era ``$set`` + ``$push $slice`` on
|
||||||
preventing unbounded array growth. Since each compression incorporates
|
``compression_metadata`` but goes through the PG repo API.
|
||||||
previous compressions, older points become redundant.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
conversation_id: Conversation ID
|
|
||||||
compression_metadata: Compression point data
|
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
self.conversations_collection.update_one(
|
with db_session() as conn:
|
||||||
{"_id": ObjectId(conversation_id)},
|
repo = ConversationsRepository(conn)
|
||||||
{
|
# conversation_id here comes from the streaming pipeline
|
||||||
"$set": {
|
# which has already resolved it; accept either UUID or
|
||||||
"compression_metadata.is_compressed": True,
|
# legacy id for safety.
|
||||||
"compression_metadata.last_compression_at": compression_metadata.get(
|
conv = repo.get_by_legacy_id(conversation_id)
|
||||||
"timestamp"
|
conv_pg_id = (
|
||||||
),
|
str(conv["id"]) if conv is not None else conversation_id
|
||||||
},
|
)
|
||||||
"$push": {
|
repo.set_compression_flags(
|
||||||
"compression_metadata.compression_points": {
|
conv_pg_id,
|
||||||
"$each": [compression_metadata],
|
is_compressed=True,
|
||||||
"$slice": -settings.COMPRESSION_MAX_HISTORY_POINTS,
|
last_compression_at=compression_metadata.get("timestamp"),
|
||||||
}
|
)
|
||||||
},
|
repo.append_compression_point(
|
||||||
},
|
conv_pg_id,
|
||||||
)
|
compression_metadata,
|
||||||
|
max_points=settings.COMPRESSION_MAX_HISTORY_POINTS,
|
||||||
|
)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Updated compression metadata for conversation {conversation_id}"
|
f"Updated compression metadata for conversation {conversation_id}"
|
||||||
)
|
)
|
||||||
@@ -225,34 +213,34 @@ class ConversationService:
|
|||||||
def append_compression_message(
|
def append_compression_message(
|
||||||
self, conversation_id: str, compression_metadata: Dict[str, Any]
|
self, conversation_id: str, compression_metadata: Dict[str, Any]
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""Append a synthetic compression summary message to the conversation."""
|
||||||
Append a synthetic compression summary entry into the conversation history.
|
|
||||||
This makes the summary visible in the DB alongside normal queries.
|
|
||||||
"""
|
|
||||||
try:
|
try:
|
||||||
summary = compression_metadata.get("compressed_summary", "")
|
summary = compression_metadata.get("compressed_summary", "")
|
||||||
if not summary:
|
if not summary:
|
||||||
return
|
return
|
||||||
timestamp = compression_metadata.get("timestamp", datetime.now(timezone.utc))
|
timestamp = compression_metadata.get(
|
||||||
|
"timestamp", datetime.now(timezone.utc)
|
||||||
self.conversations_collection.update_one(
|
)
|
||||||
{"_id": ObjectId(conversation_id)},
|
|
||||||
{
|
with db_session() as conn:
|
||||||
"$push": {
|
repo = ConversationsRepository(conn)
|
||||||
"queries": {
|
conv = repo.get_by_legacy_id(conversation_id)
|
||||||
"prompt": "[Context Compression Summary]",
|
conv_pg_id = (
|
||||||
"response": summary,
|
str(conv["id"]) if conv is not None else conversation_id
|
||||||
"thought": "",
|
)
|
||||||
"sources": [],
|
repo.append_message(conv_pg_id, {
|
||||||
"tool_calls": [],
|
"prompt": "[Context Compression Summary]",
|
||||||
"timestamp": timestamp,
|
"response": summary,
|
||||||
"attachments": [],
|
"thought": "",
|
||||||
"model_id": compression_metadata.get("model_used"),
|
"sources": [],
|
||||||
}
|
"tool_calls": [],
|
||||||
}
|
"attachments": [],
|
||||||
},
|
"model_id": compression_metadata.get("model_used"),
|
||||||
|
"timestamp": timestamp,
|
||||||
|
})
|
||||||
|
logger.info(
|
||||||
|
f"Appended compression summary to conversation {conversation_id}"
|
||||||
)
|
)
|
||||||
logger.info(f"Appended compression summary to conversation {conversation_id}")
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
logger.error(
|
||||||
f"Error appending compression summary: {str(e)}", exc_info=True
|
f"Error appending compression summary: {str(e)}", exc_info=True
|
||||||
@@ -261,20 +249,30 @@ class ConversationService:
|
|||||||
def get_compression_metadata(
|
def get_compression_metadata(
|
||||||
self, conversation_id: str
|
self, conversation_id: str
|
||||||
) -> Optional[Dict[str, Any]]:
|
) -> Optional[Dict[str, Any]]:
|
||||||
"""
|
"""Fetch the stored compression metadata JSONB blob for a conversation."""
|
||||||
Get compression metadata for a conversation.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
conversation_id: Conversation ID
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Compression metadata dict or None
|
|
||||||
"""
|
|
||||||
try:
|
try:
|
||||||
conversation = self.conversations_collection.find_one(
|
with db_readonly() as conn:
|
||||||
{"_id": ObjectId(conversation_id)}, {"compression_metadata": 1}
|
repo = ConversationsRepository(conn)
|
||||||
)
|
conv = repo.get_by_legacy_id(conversation_id)
|
||||||
return conversation.get("compression_metadata") if conversation else None
|
if conv is None:
|
||||||
|
# Fallback to UUID lookup without user scoping — the
|
||||||
|
# caller already holds an authenticated conversation
|
||||||
|
# id from the streaming path. Gate on id shape so a
|
||||||
|
# non-UUID (legacy ObjectId that wasn't backfilled)
|
||||||
|
# doesn't reach CAST — the cast raises and spams the
|
||||||
|
# logs with a stack trace on every call.
|
||||||
|
if not looks_like_uuid(conversation_id):
|
||||||
|
return None
|
||||||
|
result = conn.execute(
|
||||||
|
sql_text(
|
||||||
|
"SELECT compression_metadata FROM conversations "
|
||||||
|
"WHERE id = CAST(:id AS uuid)"
|
||||||
|
),
|
||||||
|
{"id": conversation_id},
|
||||||
|
)
|
||||||
|
row = result.fetchone()
|
||||||
|
return row[0] if row is not None else None
|
||||||
|
return conv.get("compression_metadata") if conv else None
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
logger.error(
|
||||||
f"Error getting compression metadata: {str(e)}", exc_info=True
|
f"Error getting compression metadata: {str(e)}", exc_info=True
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -1,10 +1,10 @@
|
|||||||
import base64
|
import base64
|
||||||
import datetime
|
import html
|
||||||
import json
|
import json
|
||||||
import uuid
|
import uuid
|
||||||
|
from urllib.parse import urlencode
|
||||||
|
|
||||||
|
|
||||||
from bson.objectid import ObjectId
|
|
||||||
from flask import (
|
from flask import (
|
||||||
Blueprint,
|
Blueprint,
|
||||||
current_app,
|
current_app,
|
||||||
@@ -15,26 +15,34 @@ from flask import (
|
|||||||
from flask_restx import fields, Namespace, Resource
|
from flask_restx import fields, Namespace, Resource
|
||||||
|
|
||||||
|
|
||||||
|
from application.api import api
|
||||||
from application.api.user.tasks import (
|
from application.api.user.tasks import (
|
||||||
ingest_connector_task,
|
ingest_connector_task,
|
||||||
)
|
)
|
||||||
from application.core.mongo_db import MongoDB
|
|
||||||
from application.core.settings import settings
|
|
||||||
from application.api import api
|
|
||||||
|
|
||||||
|
|
||||||
from application.parser.connectors.connector_creator import ConnectorCreator
|
from application.parser.connectors.connector_creator import ConnectorCreator
|
||||||
|
from application.storage.db.repositories.connector_sessions import (
|
||||||
|
ConnectorSessionsRepository,
|
||||||
|
)
|
||||||
|
from application.storage.db.repositories.sources import SourcesRepository
|
||||||
|
from application.storage.db.session import db_readonly, db_session
|
||||||
|
|
||||||
|
|
||||||
mongo = MongoDB.get_client()
|
|
||||||
db = mongo[settings.MONGO_DB_NAME]
|
|
||||||
sources_collection = db["sources"]
|
|
||||||
sessions_collection = db["connector_sessions"]
|
|
||||||
|
|
||||||
connector = Blueprint("connector", __name__)
|
connector = Blueprint("connector", __name__)
|
||||||
connectors_ns = Namespace("connectors", description="Connector operations", path="/")
|
connectors_ns = Namespace("connectors", description="Connector operations", path="/")
|
||||||
api.add_namespace(connectors_ns)
|
api.add_namespace(connectors_ns)
|
||||||
|
|
||||||
|
# Fixed callback status path to prevent open redirect
|
||||||
|
CALLBACK_STATUS_PATH = "/api/connectors/callback-status"
|
||||||
|
|
||||||
|
|
||||||
|
def build_callback_redirect(params: dict) -> str:
|
||||||
|
"""Build a safe redirect URL to the callback status page.
|
||||||
|
|
||||||
|
Uses a fixed path and properly URL-encodes all parameters
|
||||||
|
to prevent URL injection and open redirect vulnerabilities.
|
||||||
|
"""
|
||||||
|
return f"{CALLBACK_STATUS_PATH}?{urlencode(params)}"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@connectors_ns.route("/api/connectors/auth")
|
@connectors_ns.route("/api/connectors/auth")
|
||||||
@@ -54,16 +62,14 @@ class ConnectorAuth(Resource):
|
|||||||
return make_response(jsonify({"success": False, "error": "Unauthorized"}), 401)
|
return make_response(jsonify({"success": False, "error": "Unauthorized"}), 401)
|
||||||
user_id = decoded_token.get('sub')
|
user_id = decoded_token.get('sub')
|
||||||
|
|
||||||
now = datetime.datetime.now(datetime.timezone.utc)
|
with db_session() as conn:
|
||||||
result = sessions_collection.insert_one({
|
session_row = ConnectorSessionsRepository(conn).upsert(
|
||||||
"provider": provider,
|
user_id, provider, status="pending",
|
||||||
"user": user_id,
|
)
|
||||||
"status": "pending",
|
session_pg_id = str(session_row["id"])
|
||||||
"created_at": now
|
|
||||||
})
|
|
||||||
state_dict = {
|
state_dict = {
|
||||||
"provider": provider,
|
"provider": provider,
|
||||||
"object_id": str(result.inserted_id)
|
"object_id": session_pg_id,
|
||||||
}
|
}
|
||||||
state = base64.urlsafe_b64encode(json.dumps(state_dict).encode()).decode()
|
state = base64.urlsafe_b64encode(json.dumps(state_dict).encode()).decode()
|
||||||
|
|
||||||
@@ -75,8 +81,8 @@ class ConnectorAuth(Resource):
|
|||||||
"state": state
|
"state": state
|
||||||
}), 200)
|
}), 200)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
current_app.logger.error(f"Error generating connector auth URL: {e}")
|
current_app.logger.error(f"Error generating connector auth URL: {e}", exc_info=True)
|
||||||
return make_response(jsonify({"success": False, "error": str(e)}), 500)
|
return make_response(jsonify({"success": False, "error": "Failed to generate authorization URL"}), 500)
|
||||||
|
|
||||||
|
|
||||||
@connectors_ns.route("/api/connectors/callback")
|
@connectors_ns.route("/api/connectors/callback")
|
||||||
@@ -93,18 +99,37 @@ class ConnectorsCallback(Resource):
|
|||||||
error = request.args.get('error')
|
error = request.args.get('error')
|
||||||
|
|
||||||
state_dict = json.loads(base64.urlsafe_b64decode(state.encode()).decode())
|
state_dict = json.loads(base64.urlsafe_b64decode(state.encode()).decode())
|
||||||
provider = state_dict["provider"]
|
provider = state_dict.get("provider")
|
||||||
state_object_id = state_dict["object_id"]
|
state_object_id = state_dict.get("object_id")
|
||||||
|
|
||||||
|
# Validate provider
|
||||||
|
if not provider or not isinstance(provider, str) or not ConnectorCreator.is_supported(provider):
|
||||||
|
return redirect(build_callback_redirect({
|
||||||
|
"status": "error",
|
||||||
|
"message": "Invalid provider"
|
||||||
|
}))
|
||||||
|
|
||||||
if error:
|
if error:
|
||||||
if error == "access_denied":
|
if error == "access_denied":
|
||||||
return redirect(f"/api/connectors/callback-status?status=cancelled&message=Authentication+was+cancelled.+You+can+try+again+if+you'd+like+to+connect+your+account.&provider={provider}")
|
return redirect(build_callback_redirect({
|
||||||
|
"status": "cancelled",
|
||||||
|
"message": "Authentication was cancelled. You can try again if you'd like to connect your account.",
|
||||||
|
"provider": provider
|
||||||
|
}))
|
||||||
else:
|
else:
|
||||||
current_app.logger.warning(f"OAuth error in callback: {error}")
|
current_app.logger.warning(f"OAuth error in callback: {error}")
|
||||||
return redirect(f"/api/connectors/callback-status?status=error&message=Authentication+failed.+Please+try+again+and+make+sure+to+grant+all+requested+permissions.&provider={provider}")
|
return redirect(build_callback_redirect({
|
||||||
|
"status": "error",
|
||||||
|
"message": "Authentication failed. Please try again and make sure to grant all requested permissions.",
|
||||||
|
"provider": provider
|
||||||
|
}))
|
||||||
|
|
||||||
if not authorization_code:
|
if not authorization_code:
|
||||||
return redirect(f"/api/connectors/callback-status?status=error&message=Authentication+failed.+Please+try+again+and+make+sure+to+grant+all+requested+permissions.&provider={provider}")
|
return redirect(build_callback_redirect({
|
||||||
|
"status": "error",
|
||||||
|
"message": "Authentication failed. Please try again and make sure to grant all requested permissions.",
|
||||||
|
"provider": provider
|
||||||
|
}))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
auth = ConnectorCreator.create_auth(provider)
|
auth = ConnectorCreator.create_auth(provider)
|
||||||
@@ -113,54 +138,74 @@ class ConnectorsCallback(Resource):
|
|||||||
session_token = str(uuid.uuid4())
|
session_token = str(uuid.uuid4())
|
||||||
|
|
||||||
try:
|
try:
|
||||||
credentials = auth.create_credentials_from_token_info(token_info)
|
if provider == "google_drive":
|
||||||
service = auth.build_drive_service(credentials)
|
credentials = auth.create_credentials_from_token_info(token_info)
|
||||||
user_info = service.about().get(fields="user").execute()
|
service = auth.build_drive_service(credentials)
|
||||||
user_email = user_info.get('user', {}).get('emailAddress', 'Connected User')
|
user_info = service.about().get(fields="user").execute()
|
||||||
|
user_email = user_info.get('user', {}).get('emailAddress', 'Connected User')
|
||||||
|
else:
|
||||||
|
user_email = token_info.get('user_info', {}).get('email', 'Connected User')
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
current_app.logger.warning(f"Could not get user info: {e}")
|
current_app.logger.warning(f"Could not get user info: {e}")
|
||||||
user_email = 'Connected User'
|
user_email = 'Connected User'
|
||||||
|
|
||||||
sanitized_token_info = {
|
sanitized_token_info = auth.sanitize_token_info(token_info)
|
||||||
"access_token": token_info.get("access_token"),
|
|
||||||
"refresh_token": token_info.get("refresh_token"),
|
|
||||||
"token_uri": token_info.get("token_uri"),
|
|
||||||
"expiry": token_info.get("expiry")
|
|
||||||
}
|
|
||||||
|
|
||||||
sessions_collection.find_one_and_update(
|
# ``object_id`` in the OAuth state is the PG session row
|
||||||
{"_id": ObjectId(state_object_id), "provider": provider},
|
# UUID (new flow) or a legacy Mongo ObjectId (pre-cutover
|
||||||
{
|
# issued state). Try UUID update first; fall back to
|
||||||
"$set": {
|
# legacy id path.
|
||||||
"session_token": session_token,
|
patch = {
|
||||||
"token_info": sanitized_token_info,
|
"session_token": session_token,
|
||||||
"user_email": user_email,
|
"token_info": sanitized_token_info,
|
||||||
"status": "authorized"
|
"user_email": user_email,
|
||||||
}
|
"status": "authorized",
|
||||||
}
|
}
|
||||||
)
|
with db_session() as conn:
|
||||||
|
repo = ConnectorSessionsRepository(conn)
|
||||||
|
if state_object_id:
|
||||||
|
value = str(state_object_id)
|
||||||
|
updated = False
|
||||||
|
if len(value) == 36 and "-" in value:
|
||||||
|
updated = repo.update(value, patch)
|
||||||
|
if not updated:
|
||||||
|
repo.update_by_legacy_id(value, patch)
|
||||||
|
|
||||||
# Redirect to success page with session token and user email
|
# Redirect to success page with session token and user email
|
||||||
return redirect(f"/api/connectors/callback-status?status=success&message=Authentication+successful&provider={provider}&session_token={session_token}&user_email={user_email}")
|
return redirect(build_callback_redirect({
|
||||||
|
"status": "success",
|
||||||
|
"message": "Authentication successful",
|
||||||
|
"provider": provider,
|
||||||
|
"session_token": session_token,
|
||||||
|
"user_email": user_email
|
||||||
|
}))
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
current_app.logger.error(f"Error exchanging code for tokens: {str(e)}", exc_info=True)
|
current_app.logger.error(f"Error exchanging code for tokens: {str(e)}", exc_info=True)
|
||||||
return redirect(f"/api/connectors/callback-status?status=error&message=Authentication+failed.+Please+try+again+and+make+sure+to+grant+all+requested+permissions.&provider={provider}")
|
return redirect(build_callback_redirect({
|
||||||
|
"status": "error",
|
||||||
|
"message": "Authentication failed. Please try again and make sure to grant all requested permissions.",
|
||||||
|
"provider": provider
|
||||||
|
}))
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
current_app.logger.error(f"Error handling connector callback: {e}")
|
current_app.logger.error(f"Error handling connector callback: {e}")
|
||||||
return redirect("/api/connectors/callback-status?status=error&message=Authentication+failed.+Please+try+again+and+make+sure+to+grant+all+requested+permissions.")
|
return redirect(build_callback_redirect({
|
||||||
|
"status": "error",
|
||||||
|
"message": "Authentication failed. Please try again and make sure to grant all requested permissions."
|
||||||
|
}))
|
||||||
|
|
||||||
|
|
||||||
@connectors_ns.route("/api/connectors/files")
|
@connectors_ns.route("/api/connectors/files")
|
||||||
class ConnectorFiles(Resource):
|
class ConnectorFiles(Resource):
|
||||||
@api.expect(api.model("ConnectorFilesModel", {
|
@api.expect(api.model("ConnectorFilesModel", {
|
||||||
"provider": fields.String(required=True),
|
"provider": fields.String(required=True),
|
||||||
"session_token": fields.String(required=True),
|
"session_token": fields.String(required=True),
|
||||||
"folder_id": fields.String(required=False),
|
"folder_id": fields.String(required=False),
|
||||||
"limit": fields.Integer(required=False),
|
"limit": fields.Integer(required=False),
|
||||||
"page_token": fields.String(required=False),
|
"page_token": fields.String(required=False),
|
||||||
"search_query": fields.String(required=False)
|
"search_query": fields.String(required=False),
|
||||||
}))
|
}))
|
||||||
@api.doc(description="List files from a connector provider (supports pagination and search)")
|
@api.doc(description="List files from a connector provider (supports pagination and search)")
|
||||||
def post(self):
|
def post(self):
|
||||||
@@ -168,11 +213,8 @@ class ConnectorFiles(Resource):
|
|||||||
data = request.get_json()
|
data = request.get_json()
|
||||||
provider = data.get('provider')
|
provider = data.get('provider')
|
||||||
session_token = data.get('session_token')
|
session_token = data.get('session_token')
|
||||||
folder_id = data.get('folder_id')
|
|
||||||
limit = data.get('limit', 10)
|
limit = data.get('limit', 10)
|
||||||
page_token = data.get('page_token')
|
|
||||||
search_query = data.get('search_query')
|
|
||||||
|
|
||||||
if not provider or not session_token:
|
if not provider or not session_token:
|
||||||
return make_response(jsonify({"success": False, "error": "provider and session_token are required"}), 400)
|
return make_response(jsonify({"success": False, "error": "provider and session_token are required"}), 400)
|
||||||
|
|
||||||
@@ -180,20 +222,20 @@ class ConnectorFiles(Resource):
|
|||||||
if not decoded_token:
|
if not decoded_token:
|
||||||
return make_response(jsonify({"success": False, "error": "Unauthorized"}), 401)
|
return make_response(jsonify({"success": False, "error": "Unauthorized"}), 401)
|
||||||
user = decoded_token.get('sub')
|
user = decoded_token.get('sub')
|
||||||
session = sessions_collection.find_one({"session_token": session_token, "user": user})
|
with db_readonly() as conn:
|
||||||
if not session:
|
session = ConnectorSessionsRepository(conn).get_by_session_token(
|
||||||
|
session_token,
|
||||||
|
)
|
||||||
|
if not session or session.get("user_id") != user:
|
||||||
return make_response(jsonify({"success": False, "error": "Invalid or unauthorized session"}), 401)
|
return make_response(jsonify({"success": False, "error": "Invalid or unauthorized session"}), 401)
|
||||||
|
|
||||||
loader = ConnectorCreator.create_connector(provider, session_token)
|
loader = ConnectorCreator.create_connector(provider, session_token)
|
||||||
|
|
||||||
|
generic_keys = {'provider', 'session_token'}
|
||||||
input_config = {
|
input_config = {
|
||||||
'limit': limit,
|
k: v for k, v in data.items() if k not in generic_keys
|
||||||
'list_only': True,
|
|
||||||
'session_token': session_token,
|
|
||||||
'folder_id': folder_id,
|
|
||||||
'page_token': page_token
|
|
||||||
}
|
}
|
||||||
if search_query:
|
input_config['list_only'] = True
|
||||||
input_config['search_query'] = search_query
|
|
||||||
|
|
||||||
documents = loader.load_data(input_config)
|
documents = loader.load_data(input_config)
|
||||||
|
|
||||||
@@ -228,8 +270,8 @@ class ConnectorFiles(Resource):
|
|||||||
"has_more": has_more
|
"has_more": has_more
|
||||||
}), 200)
|
}), 200)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
current_app.logger.error(f"Error loading connector files: {e}")
|
current_app.logger.error(f"Error loading connector files: {e}", exc_info=True)
|
||||||
return make_response(jsonify({"success": False, "error": f"Failed to load files: {str(e)}"}), 500)
|
return make_response(jsonify({"success": False, "error": "Failed to load files"}), 500)
|
||||||
|
|
||||||
|
|
||||||
@connectors_ns.route("/api/connectors/validate-session")
|
@connectors_ns.route("/api/connectors/validate-session")
|
||||||
@@ -249,8 +291,11 @@ class ConnectorValidateSession(Resource):
|
|||||||
return make_response(jsonify({"success": False, "error": "Unauthorized"}), 401)
|
return make_response(jsonify({"success": False, "error": "Unauthorized"}), 401)
|
||||||
user = decoded_token.get('sub')
|
user = decoded_token.get('sub')
|
||||||
|
|
||||||
session = sessions_collection.find_one({"session_token": session_token, "user": user})
|
with db_readonly() as conn:
|
||||||
if not session or "token_info" not in session:
|
session = ConnectorSessionsRepository(conn).get_by_session_token(
|
||||||
|
session_token,
|
||||||
|
)
|
||||||
|
if not session or session.get("user_id") != user or not session.get("token_info"):
|
||||||
return make_response(jsonify({"success": False, "error": "Invalid or expired session"}), 401)
|
return make_response(jsonify({"success": False, "error": "Invalid or expired session"}), 401)
|
||||||
|
|
||||||
token_info = session["token_info"]
|
token_info = session["token_info"]
|
||||||
@@ -260,16 +305,12 @@ class ConnectorValidateSession(Resource):
|
|||||||
if is_expired and token_info.get('refresh_token'):
|
if is_expired and token_info.get('refresh_token'):
|
||||||
try:
|
try:
|
||||||
refreshed_token_info = auth.refresh_access_token(token_info.get('refresh_token'))
|
refreshed_token_info = auth.refresh_access_token(token_info.get('refresh_token'))
|
||||||
sanitized_token_info = {
|
sanitized_token_info = auth.sanitize_token_info(refreshed_token_info)
|
||||||
"access_token": refreshed_token_info.get("access_token"),
|
with db_session() as conn:
|
||||||
"refresh_token": refreshed_token_info.get("refresh_token"),
|
repo = ConnectorSessionsRepository(conn)
|
||||||
"token_uri": refreshed_token_info.get("token_uri"),
|
row = repo.get_by_session_token(session_token)
|
||||||
"expiry": refreshed_token_info.get("expiry")
|
if row:
|
||||||
}
|
repo.update(str(row["id"]), {"token_info": sanitized_token_info})
|
||||||
sessions_collection.update_one(
|
|
||||||
{"session_token": session_token},
|
|
||||||
{"$set": {"token_info": sanitized_token_info}}
|
|
||||||
)
|
|
||||||
token_info = sanitized_token_info
|
token_info = sanitized_token_info
|
||||||
is_expired = False
|
is_expired = False
|
||||||
except Exception as refresh_error:
|
except Exception as refresh_error:
|
||||||
@@ -282,15 +323,21 @@ class ConnectorValidateSession(Resource):
|
|||||||
"error": "Session token has expired. Please reconnect."
|
"error": "Session token has expired. Please reconnect."
|
||||||
}), 401)
|
}), 401)
|
||||||
|
|
||||||
return make_response(jsonify({
|
_base_fields = {"access_token", "refresh_token", "token_uri", "expiry"}
|
||||||
|
provider_extras = {k: v for k, v in token_info.items() if k not in _base_fields}
|
||||||
|
|
||||||
|
response_data = {
|
||||||
"success": True,
|
"success": True,
|
||||||
"expired": False,
|
"expired": False,
|
||||||
"user_email": session.get('user_email', 'Connected User'),
|
"user_email": session.get('user_email', 'Connected User'),
|
||||||
"access_token": token_info.get('access_token')
|
"access_token": token_info.get('access_token'),
|
||||||
}), 200)
|
**provider_extras,
|
||||||
|
}
|
||||||
|
|
||||||
|
return make_response(jsonify(response_data), 200)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
current_app.logger.error(f"Error validating connector session: {e}")
|
current_app.logger.error(f"Error validating connector session: {e}", exc_info=True)
|
||||||
return make_response(jsonify({"success": False, "error": str(e)}), 500)
|
return make_response(jsonify({"success": False, "error": "Failed to validate session"}), 500)
|
||||||
|
|
||||||
|
|
||||||
@connectors_ns.route("/api/connectors/disconnect")
|
@connectors_ns.route("/api/connectors/disconnect")
|
||||||
@@ -307,12 +354,15 @@ class ConnectorDisconnect(Resource):
|
|||||||
|
|
||||||
|
|
||||||
if session_token:
|
if session_token:
|
||||||
sessions_collection.delete_one({"session_token": session_token})
|
with db_session() as conn:
|
||||||
|
ConnectorSessionsRepository(conn).delete_by_session_token(
|
||||||
|
session_token,
|
||||||
|
)
|
||||||
|
|
||||||
return make_response(jsonify({"success": True}), 200)
|
return make_response(jsonify({"success": True}), 200)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
current_app.logger.error(f"Error disconnecting connector session: {e}")
|
current_app.logger.error(f"Error disconnecting connector session: {e}", exc_info=True)
|
||||||
return make_response(jsonify({"success": False, "error": str(e)}), 500)
|
return make_response(jsonify({"success": False, "error": "Failed to disconnect session"}), 500)
|
||||||
|
|
||||||
|
|
||||||
@connectors_ns.route("/api/connectors/sync")
|
@connectors_ns.route("/api/connectors/sync")
|
||||||
@@ -345,32 +395,28 @@ class ConnectorSync(Resource):
|
|||||||
}),
|
}),
|
||||||
400
|
400
|
||||||
)
|
)
|
||||||
source = sources_collection.find_one({"_id": ObjectId(source_id)})
|
user_id = decoded_token.get('sub')
|
||||||
|
with db_readonly() as conn:
|
||||||
|
source = SourcesRepository(conn).get_any(source_id, user_id)
|
||||||
if not source:
|
if not source:
|
||||||
return make_response(
|
return make_response(
|
||||||
jsonify({
|
jsonify({
|
||||||
"success": False,
|
"success": False,
|
||||||
"error": "Source not found"
|
"error": "Source not found"
|
||||||
}),
|
}),
|
||||||
404
|
404
|
||||||
)
|
)
|
||||||
|
|
||||||
if source.get('user') != decoded_token.get('sub'):
|
# ``get_any`` already scopes by ``user_id``; an extra guard
|
||||||
return make_response(
|
# here would be dead code.
|
||||||
jsonify({
|
|
||||||
"success": False,
|
|
||||||
"error": "Unauthorized access to source"
|
|
||||||
}),
|
|
||||||
403
|
|
||||||
)
|
|
||||||
|
|
||||||
remote_data = {}
|
remote_data = source.get('remote_data') or {}
|
||||||
try:
|
if isinstance(remote_data, str):
|
||||||
if source.get('remote_data'):
|
try:
|
||||||
remote_data = json.loads(source.get('remote_data'))
|
remote_data = json.loads(remote_data)
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
current_app.logger.error(f"Invalid remote_data format for source {source_id}")
|
current_app.logger.error(f"Invalid remote_data format for source {source_id}")
|
||||||
remote_data = {}
|
remote_data = {}
|
||||||
|
|
||||||
source_type = remote_data.get('provider')
|
source_type = remote_data.get('provider')
|
||||||
if not source_type:
|
if not source_type:
|
||||||
@@ -398,7 +444,7 @@ class ConnectorSync(Resource):
|
|||||||
recursive=recursive,
|
recursive=recursive,
|
||||||
retriever=source.get('retriever', 'classic'),
|
retriever=source.get('retriever', 'classic'),
|
||||||
operation_mode="sync",
|
operation_mode="sync",
|
||||||
doc_id=source_id,
|
doc_id=str(source.get('id') or source_id),
|
||||||
sync_frequency=source.get('sync_frequency', 'never')
|
sync_frequency=source.get('sync_frequency', 'never')
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -418,8 +464,8 @@ class ConnectorSync(Resource):
|
|||||||
return make_response(
|
return make_response(
|
||||||
jsonify({
|
jsonify({
|
||||||
"success": False,
|
"success": False,
|
||||||
"error": str(err)
|
"error": "Failed to sync connector source"
|
||||||
}),
|
}),
|
||||||
400
|
400
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -430,17 +476,32 @@ class ConnectorCallbackStatus(Resource):
|
|||||||
def get(self):
|
def get(self):
|
||||||
"""Return HTML page with connector authentication status"""
|
"""Return HTML page with connector authentication status"""
|
||||||
try:
|
try:
|
||||||
status = request.args.get('status', 'error')
|
# Validate and sanitize status to a known value
|
||||||
message = request.args.get('message', '')
|
status_raw = request.args.get('status', 'error')
|
||||||
provider = request.args.get('provider', 'connector')
|
status = status_raw if status_raw in ('success', 'error', 'cancelled') else 'error'
|
||||||
|
|
||||||
|
# Escape all user-controlled values for HTML context
|
||||||
|
message = html.escape(request.args.get('message', ''))
|
||||||
|
provider_raw = request.args.get('provider', 'connector')
|
||||||
|
provider = html.escape(provider_raw.replace('_', ' ').title())
|
||||||
session_token = request.args.get('session_token', '')
|
session_token = request.args.get('session_token', '')
|
||||||
user_email = request.args.get('user_email', '')
|
user_email = html.escape(request.args.get('user_email', ''))
|
||||||
|
|
||||||
|
def safe_js_string(value: str) -> str:
|
||||||
|
"""Safely encode a string for embedding in inline JavaScript."""
|
||||||
|
js_encoded = json.dumps(value)
|
||||||
|
return js_encoded.replace('</', '<\\/').replace('<!--', '<\\!--')
|
||||||
|
|
||||||
|
js_status = safe_js_string(status)
|
||||||
|
js_session_token = safe_js_string(session_token)
|
||||||
|
js_user_email = safe_js_string(user_email)
|
||||||
|
js_provider_type = safe_js_string(provider_raw)
|
||||||
|
|
||||||
html_content = f"""
|
html_content = f"""
|
||||||
<!DOCTYPE html>
|
<!DOCTYPE html>
|
||||||
<html>
|
<html>
|
||||||
<head>
|
<head>
|
||||||
<title>{provider.replace('_', ' ').title()} Authentication</title>
|
<title>{provider} Authentication</title>
|
||||||
<style>
|
<style>
|
||||||
body {{ font-family: Arial, sans-serif; text-align: center; padding: 40px; }}
|
body {{ font-family: Arial, sans-serif; text-align: center; padding: 40px; }}
|
||||||
.container {{ max-width: 600px; margin: 0 auto; }}
|
.container {{ max-width: 600px; margin: 0 auto; }}
|
||||||
@@ -450,13 +511,14 @@ class ConnectorCallbackStatus(Resource):
|
|||||||
</style>
|
</style>
|
||||||
<script>
|
<script>
|
||||||
window.onload = function() {{
|
window.onload = function() {{
|
||||||
const status = "{status}";
|
const status = {js_status};
|
||||||
const sessionToken = "{session_token}";
|
const sessionToken = {js_session_token};
|
||||||
const userEmail = "{user_email}";
|
const userEmail = {js_user_email};
|
||||||
|
const providerType = {js_provider_type};
|
||||||
|
|
||||||
if (status === "success" && window.opener) {{
|
if (status === "success" && window.opener) {{
|
||||||
window.opener.postMessage({{
|
window.opener.postMessage({{
|
||||||
type: '{provider}_auth_success',
|
type: providerType + '_auth_success',
|
||||||
session_token: sessionToken,
|
session_token: sessionToken,
|
||||||
user_email: userEmail
|
user_email: userEmail
|
||||||
}}, '*');
|
}}, '*');
|
||||||
@@ -470,17 +532,17 @@ class ConnectorCallbackStatus(Resource):
|
|||||||
</head>
|
</head>
|
||||||
<body>
|
<body>
|
||||||
<div class="container">
|
<div class="container">
|
||||||
<h2>{provider.replace('_', ' ').title()} Authentication</h2>
|
<h2>{provider} Authentication</h2>
|
||||||
<div class="{status}">
|
<div class="{status}">
|
||||||
<p>{message}</p>
|
<p>{message}</p>
|
||||||
{f'<p>Connected as: {user_email}</p>' if status == 'success' else ''}
|
{f'<p>Connected as: {user_email}</p>' if status == 'success' else ''}
|
||||||
</div>
|
</div>
|
||||||
<p><small>You can close this window. {f"Your {provider.replace('_', ' ').title()} is now connected and ready to use." if status == 'success' else "Feel free to close this window."}</small></p>
|
<p><small>You can close this window. {f"Your {provider} is now connected and ready to use." if status == 'success' else "Feel free to close this window."}</small></p>
|
||||||
</div>
|
</div>
|
||||||
</body>
|
</body>
|
||||||
</html>
|
</html>
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return make_response(html_content, 200, {'Content-Type': 'text/html'})
|
return make_response(html_content, 200, {'Content-Type': 'text/html'})
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
current_app.logger.error(f"Error rendering callback status page: {e}")
|
current_app.logger.error(f"Error rendering callback status page: {e}")
|
||||||
|
|||||||
@@ -3,18 +3,16 @@ import datetime
|
|||||||
import json
|
import json
|
||||||
from flask import Blueprint, request, send_from_directory, jsonify
|
from flask import Blueprint, request, send_from_directory, jsonify
|
||||||
from werkzeug.utils import secure_filename
|
from werkzeug.utils import secure_filename
|
||||||
from bson.objectid import ObjectId
|
|
||||||
import logging
|
import logging
|
||||||
from application.core.mongo_db import MongoDB
|
|
||||||
from application.core.settings import settings
|
from application.core.settings import settings
|
||||||
|
from application.storage.db.base_repository import looks_like_uuid
|
||||||
|
from application.storage.db.repositories.sources import SourcesRepository
|
||||||
|
from application.storage.db.session import db_session
|
||||||
from application.storage.storage_creator import StorageCreator
|
from application.storage.storage_creator import StorageCreator
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
mongo = MongoDB.get_client()
|
|
||||||
db = mongo[settings.MONGO_DB_NAME]
|
|
||||||
conversations_collection = db["conversations"]
|
|
||||||
sources_collection = db["sources"]
|
|
||||||
|
|
||||||
current_dir = os.path.dirname(
|
current_dir = os.path.dirname(
|
||||||
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||||
@@ -26,12 +24,20 @@ internal = Blueprint("internal", __name__)
|
|||||||
|
|
||||||
@internal.before_request
|
@internal.before_request
|
||||||
def verify_internal_key():
|
def verify_internal_key():
|
||||||
"""Verify INTERNAL_KEY for all internal endpoint requests."""
|
"""Verify INTERNAL_KEY for all internal endpoint requests.
|
||||||
if settings.INTERNAL_KEY:
|
|
||||||
internal_key = request.headers.get("X-Internal-Key")
|
Deny by default: if INTERNAL_KEY is not configured, reject all requests.
|
||||||
if not internal_key or internal_key != settings.INTERNAL_KEY:
|
"""
|
||||||
logger.warning(f"Unauthorized internal API access attempt from {request.remote_addr}")
|
if not settings.INTERNAL_KEY:
|
||||||
return jsonify({"error": "Unauthorized", "message": "Invalid or missing internal key"}), 401
|
logger.warning(
|
||||||
|
f"Internal API request rejected from {request.remote_addr}: "
|
||||||
|
"INTERNAL_KEY is not configured"
|
||||||
|
)
|
||||||
|
return jsonify({"error": "Unauthorized", "message": "Internal API is not configured"}), 401
|
||||||
|
internal_key = request.headers.get("X-Internal-Key")
|
||||||
|
if not internal_key or internal_key != settings.INTERNAL_KEY:
|
||||||
|
logger.warning(f"Unauthorized internal API access attempt from {request.remote_addr}")
|
||||||
|
return jsonify({"error": "Unauthorized", "message": "Invalid or missing internal key"}), 401
|
||||||
|
|
||||||
|
|
||||||
@internal.route("/api/download", methods=["get"])
|
@internal.route("/api/download", methods=["get"])
|
||||||
@@ -48,20 +54,21 @@ def upload_index_files():
|
|||||||
"""Upload two files(index.faiss, index.pkl) to the user's folder."""
|
"""Upload two files(index.faiss, index.pkl) to the user's folder."""
|
||||||
if "user" not in request.form:
|
if "user" not in request.form:
|
||||||
return {"status": "no user"}
|
return {"status": "no user"}
|
||||||
user = request.form["user"]
|
user = request.form["user"]
|
||||||
if "name" not in request.form:
|
if "name" not in request.form:
|
||||||
return {"status": "no name"}
|
return {"status": "no name"}
|
||||||
job_name = request.form["name"]
|
job_name = request.form["name"]
|
||||||
tokens = request.form["tokens"]
|
tokens = request.form["tokens"]
|
||||||
retriever = request.form["retriever"]
|
retriever = request.form["retriever"]
|
||||||
id = request.form["id"]
|
source_id = request.form["id"]
|
||||||
type = request.form["type"]
|
type = request.form["type"]
|
||||||
remote_data = request.form["remote_data"] if "remote_data" in request.form else None
|
remote_data = request.form["remote_data"] if "remote_data" in request.form else None
|
||||||
sync_frequency = request.form["sync_frequency"] if "sync_frequency" in request.form else None
|
sync_frequency = request.form["sync_frequency"] if "sync_frequency" in request.form else None
|
||||||
|
|
||||||
file_path = request.form.get("file_path")
|
file_path = request.form.get("file_path")
|
||||||
directory_structure = request.form.get("directory_structure")
|
directory_structure = request.form.get("directory_structure")
|
||||||
|
file_name_map = request.form.get("file_name_map")
|
||||||
|
|
||||||
if directory_structure:
|
if directory_structure:
|
||||||
try:
|
try:
|
||||||
directory_structure = json.loads(directory_structure)
|
directory_structure = json.loads(directory_structure)
|
||||||
@@ -70,10 +77,18 @@ def upload_index_files():
|
|||||||
directory_structure = {}
|
directory_structure = {}
|
||||||
else:
|
else:
|
||||||
directory_structure = {}
|
directory_structure = {}
|
||||||
|
if file_name_map:
|
||||||
|
try:
|
||||||
|
file_name_map = json.loads(file_name_map)
|
||||||
|
except Exception:
|
||||||
|
logger.error("Error parsing file_name_map")
|
||||||
|
file_name_map = None
|
||||||
|
else:
|
||||||
|
file_name_map = None
|
||||||
|
|
||||||
storage = StorageCreator.get_storage()
|
storage = StorageCreator.get_storage()
|
||||||
index_base_path = f"indexes/{id}"
|
index_base_path = f"indexes/{source_id}"
|
||||||
|
|
||||||
if settings.VECTOR_STORE == "faiss":
|
if settings.VECTOR_STORE == "faiss":
|
||||||
if "file_faiss" not in request.files:
|
if "file_faiss" not in request.files:
|
||||||
logger.error("No file_faiss part")
|
logger.error("No file_faiss part")
|
||||||
@@ -94,44 +109,48 @@ def upload_index_files():
|
|||||||
storage.save_file(file_faiss, faiss_storage_path)
|
storage.save_file(file_faiss, faiss_storage_path)
|
||||||
storage.save_file(file_pkl, pkl_storage_path)
|
storage.save_file(file_pkl, pkl_storage_path)
|
||||||
|
|
||||||
|
now = datetime.datetime.now(datetime.timezone.utc)
|
||||||
|
update_fields = {
|
||||||
|
"name": job_name,
|
||||||
|
"type": type,
|
||||||
|
"language": job_name,
|
||||||
|
"date": now,
|
||||||
|
"model": settings.EMBEDDINGS_NAME,
|
||||||
|
"tokens": tokens,
|
||||||
|
"retriever": retriever,
|
||||||
|
"remote_data": remote_data,
|
||||||
|
"sync_frequency": sync_frequency,
|
||||||
|
"file_path": file_path,
|
||||||
|
"directory_structure": directory_structure,
|
||||||
|
}
|
||||||
|
if file_name_map is not None:
|
||||||
|
update_fields["file_name_map"] = file_name_map
|
||||||
|
|
||||||
existing_entry = sources_collection.find_one({"_id": ObjectId(id)})
|
with db_session() as conn:
|
||||||
if existing_entry:
|
repo = SourcesRepository(conn)
|
||||||
sources_collection.update_one(
|
existing = None
|
||||||
{"_id": ObjectId(id)},
|
if looks_like_uuid(source_id):
|
||||||
{
|
existing = repo.get(source_id, user)
|
||||||
"$set": {
|
if existing is None:
|
||||||
"user": user,
|
existing = repo.get_by_legacy_id(source_id, user)
|
||||||
"name": job_name,
|
if existing is not None:
|
||||||
"language": job_name,
|
repo.update(str(existing["id"]), user, update_fields)
|
||||||
"date": datetime.datetime.now(),
|
else:
|
||||||
"model": settings.EMBEDDINGS_NAME,
|
repo.create(
|
||||||
"type": type,
|
job_name,
|
||||||
"tokens": tokens,
|
source_id=source_id if looks_like_uuid(source_id) else None,
|
||||||
"retriever": retriever,
|
user_id=user,
|
||||||
"remote_data": remote_data,
|
type=type,
|
||||||
"sync_frequency": sync_frequency,
|
tokens=tokens,
|
||||||
"file_path": file_path,
|
retriever=retriever,
|
||||||
"directory_structure": directory_structure,
|
remote_data=remote_data,
|
||||||
}
|
sync_frequency=sync_frequency,
|
||||||
},
|
file_path=file_path,
|
||||||
)
|
directory_structure=directory_structure,
|
||||||
else:
|
file_name_map=file_name_map,
|
||||||
sources_collection.insert_one(
|
language=job_name,
|
||||||
{
|
model=settings.EMBEDDINGS_NAME,
|
||||||
"_id": ObjectId(id),
|
date=now,
|
||||||
"user": user,
|
legacy_mongo_id=None if looks_like_uuid(source_id) else str(source_id),
|
||||||
"name": job_name,
|
)
|
||||||
"language": job_name,
|
|
||||||
"date": datetime.datetime.now(),
|
|
||||||
"model": settings.EMBEDDINGS_NAME,
|
|
||||||
"type": type,
|
|
||||||
"tokens": tokens,
|
|
||||||
"retriever": retriever,
|
|
||||||
"remote_data": remote_data,
|
|
||||||
"sync_frequency": sync_frequency,
|
|
||||||
"file_path": file_path,
|
|
||||||
"directory_structure": directory_structure,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
return {"status": "ok"}
|
return {"status": "ok"}
|
||||||
|
|||||||
@@ -3,5 +3,6 @@
|
|||||||
from .routes import agents_ns
|
from .routes import agents_ns
|
||||||
from .sharing import agents_sharing_ns
|
from .sharing import agents_sharing_ns
|
||||||
from .webhooks import agents_webhooks_ns
|
from .webhooks import agents_webhooks_ns
|
||||||
|
from .folders import agents_folders_ns
|
||||||
|
|
||||||
__all__ = ["agents_ns", "agents_sharing_ns", "agents_webhooks_ns"]
|
__all__ = ["agents_ns", "agents_sharing_ns", "agents_webhooks_ns", "agents_folders_ns"]
|
||||||
|
|||||||
366
application/api/user/agents/folders.py
Normal file
366
application/api/user/agents/folders.py
Normal file
@@ -0,0 +1,366 @@
|
|||||||
|
"""
|
||||||
|
Agent folders management routes.
|
||||||
|
Provides virtual folder organization for agents (Google Drive-like structure).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from flask import current_app, jsonify, make_response, request
|
||||||
|
from flask_restx import Namespace, Resource, fields
|
||||||
|
from sqlalchemy import text as _sql_text
|
||||||
|
|
||||||
|
from application.api import api
|
||||||
|
from application.storage.db.base_repository import looks_like_uuid
|
||||||
|
from application.storage.db.repositories.agent_folders import AgentFoldersRepository
|
||||||
|
from application.storage.db.repositories.agents import AgentsRepository
|
||||||
|
from application.storage.db.session import db_readonly, db_session
|
||||||
|
|
||||||
|
|
||||||
|
agents_folders_ns = Namespace(
|
||||||
|
"agents_folders", description="Agent folder management", path="/api/agents/folders"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_folder_id(repo: AgentFoldersRepository, folder_id: str, user: str):
|
||||||
|
"""Resolve a folder id that may be either a UUID or legacy Mongo ObjectId."""
|
||||||
|
if not folder_id:
|
||||||
|
return None
|
||||||
|
if looks_like_uuid(folder_id):
|
||||||
|
row = repo.get(folder_id, user)
|
||||||
|
if row is not None:
|
||||||
|
return row
|
||||||
|
return repo.get_by_legacy_id(folder_id, user)
|
||||||
|
|
||||||
|
|
||||||
|
def _folder_error_response(message: str, err: Exception):
|
||||||
|
current_app.logger.error(f"{message}: {err}", exc_info=True)
|
||||||
|
return make_response(jsonify({"success": False, "message": message}), 400)
|
||||||
|
|
||||||
|
|
||||||
|
def _serialize_folder(f: dict) -> dict:
|
||||||
|
created_at = f.get("created_at")
|
||||||
|
updated_at = f.get("updated_at")
|
||||||
|
return {
|
||||||
|
"id": str(f["id"]),
|
||||||
|
"name": f.get("name"),
|
||||||
|
"parent_id": str(f["parent_id"]) if f.get("parent_id") else None,
|
||||||
|
"created_at": created_at.isoformat() if hasattr(created_at, "isoformat") else created_at,
|
||||||
|
"updated_at": updated_at.isoformat() if hasattr(updated_at, "isoformat") else updated_at,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@agents_folders_ns.route("/")
|
||||||
|
class AgentFolders(Resource):
|
||||||
|
@api.doc(description="Get all folders for the user")
|
||||||
|
def get(self):
|
||||||
|
decoded_token = request.decoded_token
|
||||||
|
if not decoded_token:
|
||||||
|
return make_response(jsonify({"success": False}), 401)
|
||||||
|
user = decoded_token.get("sub")
|
||||||
|
try:
|
||||||
|
with db_readonly() as conn:
|
||||||
|
folders = AgentFoldersRepository(conn).list_for_user(user)
|
||||||
|
result = [_serialize_folder(f) for f in folders]
|
||||||
|
return make_response(jsonify({"folders": result}), 200)
|
||||||
|
except Exception as err:
|
||||||
|
return _folder_error_response("Failed to fetch folders", err)
|
||||||
|
|
||||||
|
@api.doc(description="Create a new folder")
|
||||||
|
@api.expect(
|
||||||
|
api.model(
|
||||||
|
"CreateFolder",
|
||||||
|
{
|
||||||
|
"name": fields.String(required=True, description="Folder name"),
|
||||||
|
"parent_id": fields.String(required=False, description="Parent folder ID"),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
def post(self):
|
||||||
|
decoded_token = request.decoded_token
|
||||||
|
if not decoded_token:
|
||||||
|
return make_response(jsonify({"success": False}), 401)
|
||||||
|
user = decoded_token.get("sub")
|
||||||
|
data = request.get_json()
|
||||||
|
if not data or not data.get("name"):
|
||||||
|
return make_response(jsonify({"success": False, "message": "Folder name is required"}), 400)
|
||||||
|
|
||||||
|
parent_id_input = data.get("parent_id")
|
||||||
|
description = data.get("description")
|
||||||
|
|
||||||
|
try:
|
||||||
|
with db_session() as conn:
|
||||||
|
repo = AgentFoldersRepository(conn)
|
||||||
|
pg_parent_id = None
|
||||||
|
if parent_id_input:
|
||||||
|
parent = _resolve_folder_id(repo, parent_id_input, user)
|
||||||
|
if not parent:
|
||||||
|
return make_response(
|
||||||
|
jsonify({"success": False, "message": "Parent folder not found"}),
|
||||||
|
404,
|
||||||
|
)
|
||||||
|
pg_parent_id = str(parent["id"])
|
||||||
|
folder = repo.create(
|
||||||
|
user, data["name"],
|
||||||
|
description=description,
|
||||||
|
parent_id=pg_parent_id,
|
||||||
|
)
|
||||||
|
return make_response(
|
||||||
|
jsonify(
|
||||||
|
{
|
||||||
|
"id": str(folder["id"]),
|
||||||
|
"name": folder["name"],
|
||||||
|
"parent_id": pg_parent_id,
|
||||||
|
}
|
||||||
|
),
|
||||||
|
201,
|
||||||
|
)
|
||||||
|
except Exception as err:
|
||||||
|
return _folder_error_response("Failed to create folder", err)
|
||||||
|
|
||||||
|
|
||||||
|
@agents_folders_ns.route("/<string:folder_id>")
|
||||||
|
class AgentFolder(Resource):
|
||||||
|
@api.doc(description="Get a specific folder with its agents")
|
||||||
|
def get(self, folder_id):
|
||||||
|
decoded_token = request.decoded_token
|
||||||
|
if not decoded_token:
|
||||||
|
return make_response(jsonify({"success": False}), 401)
|
||||||
|
user = decoded_token.get("sub")
|
||||||
|
try:
|
||||||
|
with db_readonly() as conn:
|
||||||
|
folders_repo = AgentFoldersRepository(conn)
|
||||||
|
folder = _resolve_folder_id(folders_repo, folder_id, user)
|
||||||
|
if not folder:
|
||||||
|
return make_response(
|
||||||
|
jsonify({"success": False, "message": "Folder not found"}),
|
||||||
|
404,
|
||||||
|
)
|
||||||
|
pg_folder_id = str(folder["id"])
|
||||||
|
|
||||||
|
agents_rows = conn.execute(
|
||||||
|
_sql_text(
|
||||||
|
"SELECT id, name, description FROM agents "
|
||||||
|
"WHERE user_id = :user_id AND folder_id = CAST(:fid AS uuid) "
|
||||||
|
"ORDER BY created_at DESC"
|
||||||
|
),
|
||||||
|
{"user_id": user, "fid": pg_folder_id},
|
||||||
|
).fetchall()
|
||||||
|
agents_list = [
|
||||||
|
{
|
||||||
|
"id": str(row._mapping["id"]),
|
||||||
|
"name": row._mapping["name"],
|
||||||
|
"description": row._mapping.get("description", "") or "",
|
||||||
|
}
|
||||||
|
for row in agents_rows
|
||||||
|
]
|
||||||
|
|
||||||
|
subfolders = folders_repo.list_children(pg_folder_id, user)
|
||||||
|
subfolders_list = [
|
||||||
|
{"id": str(sf["id"]), "name": sf["name"]}
|
||||||
|
for sf in subfolders
|
||||||
|
]
|
||||||
|
|
||||||
|
return make_response(
|
||||||
|
jsonify(
|
||||||
|
{
|
||||||
|
"id": pg_folder_id,
|
||||||
|
"name": folder["name"],
|
||||||
|
"parent_id": (
|
||||||
|
str(folder["parent_id"]) if folder.get("parent_id") else None
|
||||||
|
),
|
||||||
|
"agents": agents_list,
|
||||||
|
"subfolders": subfolders_list,
|
||||||
|
}
|
||||||
|
),
|
||||||
|
200,
|
||||||
|
)
|
||||||
|
except Exception as err:
|
||||||
|
return _folder_error_response("Failed to fetch folder", err)
|
||||||
|
|
||||||
|
@api.doc(description="Update a folder")
|
||||||
|
def put(self, folder_id):
|
||||||
|
decoded_token = request.decoded_token
|
||||||
|
if not decoded_token:
|
||||||
|
return make_response(jsonify({"success": False}), 401)
|
||||||
|
user = decoded_token.get("sub")
|
||||||
|
data = request.get_json()
|
||||||
|
if not data:
|
||||||
|
return make_response(jsonify({"success": False, "message": "No data provided"}), 400)
|
||||||
|
|
||||||
|
try:
|
||||||
|
with db_session() as conn:
|
||||||
|
repo = AgentFoldersRepository(conn)
|
||||||
|
folder = _resolve_folder_id(repo, folder_id, user)
|
||||||
|
if not folder:
|
||||||
|
return make_response(
|
||||||
|
jsonify({"success": False, "message": "Folder not found"}),
|
||||||
|
404,
|
||||||
|
)
|
||||||
|
pg_folder_id = str(folder["id"])
|
||||||
|
|
||||||
|
update_fields: dict = {}
|
||||||
|
if "name" in data:
|
||||||
|
update_fields["name"] = data["name"]
|
||||||
|
if "description" in data:
|
||||||
|
update_fields["description"] = data["description"]
|
||||||
|
if "parent_id" in data:
|
||||||
|
parent_input = data.get("parent_id")
|
||||||
|
if parent_input:
|
||||||
|
if parent_input == folder_id or parent_input == pg_folder_id:
|
||||||
|
return make_response(
|
||||||
|
jsonify(
|
||||||
|
{
|
||||||
|
"success": False,
|
||||||
|
"message": "Cannot set folder as its own parent",
|
||||||
|
}
|
||||||
|
),
|
||||||
|
400,
|
||||||
|
)
|
||||||
|
parent = _resolve_folder_id(repo, parent_input, user)
|
||||||
|
if not parent:
|
||||||
|
return make_response(
|
||||||
|
jsonify({"success": False, "message": "Parent folder not found"}),
|
||||||
|
404,
|
||||||
|
)
|
||||||
|
if str(parent["id"]) == pg_folder_id:
|
||||||
|
return make_response(
|
||||||
|
jsonify(
|
||||||
|
{
|
||||||
|
"success": False,
|
||||||
|
"message": "Cannot set folder as its own parent",
|
||||||
|
}
|
||||||
|
),
|
||||||
|
400,
|
||||||
|
)
|
||||||
|
update_fields["parent_id"] = str(parent["id"])
|
||||||
|
else:
|
||||||
|
update_fields["parent_id"] = None
|
||||||
|
|
||||||
|
if update_fields:
|
||||||
|
repo.update(pg_folder_id, user, update_fields)
|
||||||
|
|
||||||
|
return make_response(jsonify({"success": True}), 200)
|
||||||
|
except Exception as err:
|
||||||
|
return _folder_error_response("Failed to update folder", err)
|
||||||
|
|
||||||
|
@api.doc(description="Delete a folder")
|
||||||
|
def delete(self, folder_id):
|
||||||
|
decoded_token = request.decoded_token
|
||||||
|
if not decoded_token:
|
||||||
|
return make_response(jsonify({"success": False}), 401)
|
||||||
|
user = decoded_token.get("sub")
|
||||||
|
try:
|
||||||
|
with db_session() as conn:
|
||||||
|
repo = AgentFoldersRepository(conn)
|
||||||
|
folder = _resolve_folder_id(repo, folder_id, user)
|
||||||
|
if not folder:
|
||||||
|
return make_response(
|
||||||
|
jsonify({"success": False, "message": "Folder not found"}),
|
||||||
|
404,
|
||||||
|
)
|
||||||
|
pg_folder_id = str(folder["id"])
|
||||||
|
# Clear folder assignments from agents; self-FK
|
||||||
|
# ``ON DELETE SET NULL`` handles child folders.
|
||||||
|
AgentsRepository(conn).clear_folder_for_all(pg_folder_id, user)
|
||||||
|
deleted = repo.delete(pg_folder_id, user)
|
||||||
|
if not deleted:
|
||||||
|
return make_response(
|
||||||
|
jsonify({"success": False, "message": "Folder not found"}),
|
||||||
|
404,
|
||||||
|
)
|
||||||
|
return make_response(jsonify({"success": True}), 200)
|
||||||
|
except Exception as err:
|
||||||
|
return _folder_error_response("Failed to delete folder", err)
|
||||||
|
|
||||||
|
|
||||||
|
@agents_folders_ns.route("/move_agent")
|
||||||
|
class MoveAgentToFolder(Resource):
|
||||||
|
@api.doc(description="Move an agent to a folder or remove from folder")
|
||||||
|
@api.expect(
|
||||||
|
api.model(
|
||||||
|
"MoveAgent",
|
||||||
|
{
|
||||||
|
"agent_id": fields.String(required=True, description="Agent ID to move"),
|
||||||
|
"folder_id": fields.String(required=False, description="Target folder ID (null to remove from folder)"),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
def post(self):
|
||||||
|
decoded_token = request.decoded_token
|
||||||
|
if not decoded_token:
|
||||||
|
return make_response(jsonify({"success": False}), 401)
|
||||||
|
user = decoded_token.get("sub")
|
||||||
|
data = request.get_json()
|
||||||
|
if not data or not data.get("agent_id"):
|
||||||
|
return make_response(jsonify({"success": False, "message": "Agent ID is required"}), 400)
|
||||||
|
|
||||||
|
agent_id_input = data["agent_id"]
|
||||||
|
folder_id_input = data.get("folder_id")
|
||||||
|
|
||||||
|
try:
|
||||||
|
with db_session() as conn:
|
||||||
|
agents_repo = AgentsRepository(conn)
|
||||||
|
agent = agents_repo.get_any(agent_id_input, user)
|
||||||
|
if not agent:
|
||||||
|
return make_response(
|
||||||
|
jsonify({"success": False, "message": "Agent not found"}),
|
||||||
|
404,
|
||||||
|
)
|
||||||
|
pg_folder_id = None
|
||||||
|
if folder_id_input:
|
||||||
|
folders_repo = AgentFoldersRepository(conn)
|
||||||
|
folder = _resolve_folder_id(folders_repo, folder_id_input, user)
|
||||||
|
if not folder:
|
||||||
|
return make_response(
|
||||||
|
jsonify({"success": False, "message": "Folder not found"}),
|
||||||
|
404,
|
||||||
|
)
|
||||||
|
pg_folder_id = str(folder["id"])
|
||||||
|
agents_repo.set_folder(str(agent["id"]), user, pg_folder_id)
|
||||||
|
return make_response(jsonify({"success": True}), 200)
|
||||||
|
except Exception as err:
|
||||||
|
return _folder_error_response("Failed to move agent", err)
|
||||||
|
|
||||||
|
|
||||||
|
@agents_folders_ns.route("/bulk_move")
|
||||||
|
class BulkMoveAgents(Resource):
|
||||||
|
@api.doc(description="Move multiple agents to a folder")
|
||||||
|
@api.expect(
|
||||||
|
api.model(
|
||||||
|
"BulkMoveAgents",
|
||||||
|
{
|
||||||
|
"agent_ids": fields.List(fields.String, required=True, description="List of agent IDs"),
|
||||||
|
"folder_id": fields.String(required=False, description="Target folder ID"),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
def post(self):
|
||||||
|
decoded_token = request.decoded_token
|
||||||
|
if not decoded_token:
|
||||||
|
return make_response(jsonify({"success": False}), 401)
|
||||||
|
user = decoded_token.get("sub")
|
||||||
|
data = request.get_json()
|
||||||
|
if not data or not data.get("agent_ids"):
|
||||||
|
return make_response(jsonify({"success": False, "message": "Agent IDs are required"}), 400)
|
||||||
|
|
||||||
|
agent_ids = data["agent_ids"]
|
||||||
|
folder_id_input = data.get("folder_id")
|
||||||
|
|
||||||
|
try:
|
||||||
|
with db_session() as conn:
|
||||||
|
agents_repo = AgentsRepository(conn)
|
||||||
|
pg_folder_id = None
|
||||||
|
if folder_id_input:
|
||||||
|
folders_repo = AgentFoldersRepository(conn)
|
||||||
|
folder = _resolve_folder_id(folders_repo, folder_id_input, user)
|
||||||
|
if not folder:
|
||||||
|
return make_response(
|
||||||
|
jsonify({"success": False, "message": "Folder not found"}),
|
||||||
|
404,
|
||||||
|
)
|
||||||
|
pg_folder_id = str(folder["id"])
|
||||||
|
for agent_id_input in agent_ids:
|
||||||
|
agent = agents_repo.get_any(agent_id_input, user)
|
||||||
|
if agent is not None:
|
||||||
|
agents_repo.set_folder(str(agent["id"]), user, pg_folder_id)
|
||||||
|
return make_response(jsonify({"success": True}), 200)
|
||||||
|
except Exception as err:
|
||||||
|
return _folder_error_response("Failed to move agents", err)
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -3,21 +3,17 @@
|
|||||||
import datetime
|
import datetime
|
||||||
import secrets
|
import secrets
|
||||||
|
|
||||||
from bson import DBRef
|
|
||||||
from bson.objectid import ObjectId
|
|
||||||
from flask import current_app, jsonify, make_response, request
|
from flask import current_app, jsonify, make_response, request
|
||||||
from flask_restx import fields, Namespace, Resource
|
from flask_restx import fields, Namespace, Resource
|
||||||
|
from sqlalchemy import text as _sql_text
|
||||||
|
|
||||||
from application.api import api
|
from application.api import api
|
||||||
from application.core.settings import settings
|
from application.core.settings import settings
|
||||||
from application.api.user.base import (
|
from application.api.user.base import resolve_tool_details
|
||||||
agents_collection,
|
from application.storage.db.base_repository import looks_like_uuid
|
||||||
db,
|
from application.storage.db.repositories.agents import AgentsRepository
|
||||||
ensure_user_doc,
|
from application.storage.db.repositories.users import UsersRepository
|
||||||
resolve_tool_details,
|
from application.storage.db.session import db_readonly, db_session
|
||||||
user_tools_collection,
|
|
||||||
users_collection,
|
|
||||||
)
|
|
||||||
from application.utils import generate_image_url
|
from application.utils import generate_image_url
|
||||||
|
|
||||||
agents_sharing_ns = Namespace(
|
agents_sharing_ns = Namespace(
|
||||||
@@ -25,6 +21,38 @@ agents_sharing_ns = Namespace(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _serialize_agent_basic(agent: dict) -> dict:
|
||||||
|
"""Shape a PG agent row into the API response dict."""
|
||||||
|
source_id = agent.get("source_id")
|
||||||
|
return {
|
||||||
|
"id": str(agent["id"]),
|
||||||
|
"user": agent.get("user_id", ""),
|
||||||
|
"name": agent.get("name", ""),
|
||||||
|
"image": (
|
||||||
|
generate_image_url(agent["image"]) if agent.get("image") else ""
|
||||||
|
),
|
||||||
|
"description": agent.get("description", ""),
|
||||||
|
"source": str(source_id) if source_id else "",
|
||||||
|
"chunks": str(agent["chunks"]) if agent.get("chunks") is not None else "0",
|
||||||
|
"retriever": agent.get("retriever", "classic") or "classic",
|
||||||
|
"prompt_id": str(agent["prompt_id"]) if agent.get("prompt_id") else "default",
|
||||||
|
"tools": agent.get("tools", []) or [],
|
||||||
|
"tool_details": resolve_tool_details(agent.get("tools", []) or []),
|
||||||
|
"agent_type": agent.get("agent_type", "") or "",
|
||||||
|
"status": agent.get("status", "") or "",
|
||||||
|
"json_schema": agent.get("json_schema"),
|
||||||
|
"limited_token_mode": agent.get("limited_token_mode", False),
|
||||||
|
"token_limit": agent.get("token_limit") or settings.DEFAULT_AGENT_LIMITS["token_limit"],
|
||||||
|
"limited_request_mode": agent.get("limited_request_mode", False),
|
||||||
|
"request_limit": agent.get("request_limit") or settings.DEFAULT_AGENT_LIMITS["request_limit"],
|
||||||
|
"created_at": agent.get("created_at", ""),
|
||||||
|
"updated_at": agent.get("updated_at", ""),
|
||||||
|
"shared": bool(agent.get("shared", False)),
|
||||||
|
"shared_token": agent.get("shared_token", "") or "",
|
||||||
|
"shared_metadata": agent.get("shared_metadata", {}) or {},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@agents_sharing_ns.route("/shared_agent")
|
@agents_sharing_ns.route("/shared_agent")
|
||||||
class SharedAgent(Resource):
|
class SharedAgent(Resource):
|
||||||
@api.doc(
|
@api.doc(
|
||||||
@@ -41,70 +69,33 @@ class SharedAgent(Resource):
|
|||||||
jsonify({"success": False, "message": "Token or ID is required"}), 400
|
jsonify({"success": False, "message": "Token or ID is required"}), 400
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
query = {
|
with db_readonly() as conn:
|
||||||
"shared_publicly": True,
|
shared_agent = AgentsRepository(conn).find_by_shared_token(
|
||||||
"shared_token": shared_token,
|
shared_token,
|
||||||
}
|
)
|
||||||
shared_agent = agents_collection.find_one(query)
|
|
||||||
if not shared_agent:
|
if not shared_agent:
|
||||||
return make_response(
|
return make_response(
|
||||||
jsonify({"success": False, "message": "Shared agent not found"}),
|
jsonify({"success": False, "message": "Shared agent not found"}),
|
||||||
404,
|
404,
|
||||||
)
|
)
|
||||||
agent_id = str(shared_agent["_id"])
|
agent_id = str(shared_agent["id"])
|
||||||
data = {
|
data = _serialize_agent_basic(shared_agent)
|
||||||
"id": agent_id,
|
|
||||||
"user": shared_agent.get("user", ""),
|
|
||||||
"name": shared_agent.get("name", ""),
|
|
||||||
"image": (
|
|
||||||
generate_image_url(shared_agent["image"])
|
|
||||||
if shared_agent.get("image")
|
|
||||||
else ""
|
|
||||||
),
|
|
||||||
"description": shared_agent.get("description", ""),
|
|
||||||
"source": (
|
|
||||||
str(source_doc["_id"])
|
|
||||||
if isinstance(shared_agent.get("source"), DBRef)
|
|
||||||
and (source_doc := db.dereference(shared_agent.get("source")))
|
|
||||||
else ""
|
|
||||||
),
|
|
||||||
"chunks": shared_agent.get("chunks", "0"),
|
|
||||||
"retriever": shared_agent.get("retriever", "classic"),
|
|
||||||
"prompt_id": shared_agent.get("prompt_id", "default"),
|
|
||||||
"tools": shared_agent.get("tools", []),
|
|
||||||
"tool_details": resolve_tool_details(shared_agent.get("tools", [])),
|
|
||||||
"agent_type": shared_agent.get("agent_type", ""),
|
|
||||||
"status": shared_agent.get("status", ""),
|
|
||||||
"json_schema": shared_agent.get("json_schema"),
|
|
||||||
"limited_token_mode": shared_agent.get("limited_token_mode", False),
|
|
||||||
"token_limit": shared_agent.get("token_limit", settings.DEFAULT_AGENT_LIMITS["token_limit"]),
|
|
||||||
"limited_request_mode": shared_agent.get("limited_request_mode", False),
|
|
||||||
"request_limit": shared_agent.get("request_limit", settings.DEFAULT_AGENT_LIMITS["request_limit"]),
|
|
||||||
"created_at": shared_agent.get("createdAt", ""),
|
|
||||||
"updated_at": shared_agent.get("updatedAt", ""),
|
|
||||||
"shared": shared_agent.get("shared_publicly", False),
|
|
||||||
"shared_token": shared_agent.get("shared_token", ""),
|
|
||||||
"shared_metadata": shared_agent.get("shared_metadata", {}),
|
|
||||||
}
|
|
||||||
|
|
||||||
if data["tools"]:
|
if data["tools"]:
|
||||||
enriched_tools = []
|
enriched_tools = []
|
||||||
for tool in data["tools"]:
|
for detail in data["tool_details"]:
|
||||||
tool_data = user_tools_collection.find_one({"_id": ObjectId(tool)})
|
enriched_tools.append(detail.get("name", ""))
|
||||||
if tool_data:
|
|
||||||
enriched_tools.append(tool_data.get("name", ""))
|
|
||||||
data["tools"] = enriched_tools
|
data["tools"] = enriched_tools
|
||||||
decoded_token = getattr(request, "decoded_token", None)
|
decoded_token = getattr(request, "decoded_token", None)
|
||||||
if decoded_token:
|
if decoded_token:
|
||||||
user_id = decoded_token.get("sub")
|
user_id = decoded_token.get("sub")
|
||||||
owner_id = shared_agent.get("user")
|
owner_id = shared_agent.get("user_id")
|
||||||
|
|
||||||
if user_id != owner_id:
|
if user_id != owner_id:
|
||||||
ensure_user_doc(user_id)
|
with db_session() as conn:
|
||||||
users_collection.update_one(
|
users_repo = UsersRepository(conn)
|
||||||
{"user_id": user_id},
|
users_repo.upsert(user_id)
|
||||||
{"$addToSet": {"agent_preferences.shared_with_me": agent_id}},
|
users_repo.add_shared(user_id, agent_id)
|
||||||
)
|
|
||||||
return make_response(jsonify(data), 200)
|
return make_response(jsonify(data), 200)
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
current_app.logger.error(f"Error retrieving shared agent: {err}")
|
current_app.logger.error(f"Error retrieving shared agent: {err}")
|
||||||
@@ -121,52 +112,73 @@ class SharedAgents(Resource):
|
|||||||
return make_response(jsonify({"success": False}), 401)
|
return make_response(jsonify({"success": False}), 401)
|
||||||
user_id = decoded_token.get("sub")
|
user_id = decoded_token.get("sub")
|
||||||
|
|
||||||
user_doc = ensure_user_doc(user_id)
|
with db_session() as conn:
|
||||||
shared_with_ids = user_doc.get("agent_preferences", {}).get(
|
users_repo = UsersRepository(conn)
|
||||||
"shared_with_me", []
|
user_doc = users_repo.upsert(user_id)
|
||||||
)
|
shared_with_ids = (
|
||||||
shared_object_ids = [ObjectId(id) for id in shared_with_ids]
|
user_doc.get("agent_preferences", {}).get("shared_with_me", [])
|
||||||
|
if isinstance(user_doc.get("agent_preferences"), dict)
|
||||||
shared_agents_cursor = agents_collection.find(
|
else []
|
||||||
{"_id": {"$in": shared_object_ids}, "shared_publicly": True}
|
|
||||||
)
|
|
||||||
shared_agents = list(shared_agents_cursor)
|
|
||||||
|
|
||||||
found_ids_set = {str(agent["_id"]) for agent in shared_agents}
|
|
||||||
stale_ids = [id for id in shared_with_ids if id not in found_ids_set]
|
|
||||||
if stale_ids:
|
|
||||||
users_collection.update_one(
|
|
||||||
{"user_id": user_id},
|
|
||||||
{"$pullAll": {"agent_preferences.shared_with_me": stale_ids}},
|
|
||||||
)
|
)
|
||||||
pinned_ids = set(user_doc.get("agent_preferences", {}).get("pinned", []))
|
# Keep only UUID-shaped ids; ObjectId leftovers are stripped below.
|
||||||
|
uuid_ids = [sid for sid in shared_with_ids if looks_like_uuid(sid)]
|
||||||
|
non_uuid_ids = [sid for sid in shared_with_ids if not looks_like_uuid(sid)]
|
||||||
|
|
||||||
list_shared_agents = [
|
if uuid_ids:
|
||||||
{
|
result = conn.execute(
|
||||||
"id": str(agent["_id"]),
|
_sql_text(
|
||||||
"name": agent.get("name", ""),
|
"SELECT * FROM agents "
|
||||||
"description": agent.get("description", ""),
|
"WHERE id = ANY(CAST(:ids AS uuid[])) "
|
||||||
"image": (
|
"AND shared = true"
|
||||||
generate_image_url(agent["image"]) if agent.get("image") else ""
|
),
|
||||||
),
|
{"ids": uuid_ids},
|
||||||
"tools": agent.get("tools", []),
|
)
|
||||||
"tool_details": resolve_tool_details(agent.get("tools", [])),
|
shared_agents = [dict(row._mapping) for row in result.fetchall()]
|
||||||
"agent_type": agent.get("agent_type", ""),
|
else:
|
||||||
"status": agent.get("status", ""),
|
shared_agents = []
|
||||||
"json_schema": agent.get("json_schema"),
|
|
||||||
"limited_token_mode": agent.get("limited_token_mode", False),
|
found_ids_set = {str(agent["id"]) for agent in shared_agents}
|
||||||
"token_limit": agent.get("token_limit", settings.DEFAULT_AGENT_LIMITS["token_limit"]),
|
stale_ids = [sid for sid in uuid_ids if sid not in found_ids_set]
|
||||||
"limited_request_mode": agent.get("limited_request_mode", False),
|
stale_ids.extend(non_uuid_ids)
|
||||||
"request_limit": agent.get("request_limit", settings.DEFAULT_AGENT_LIMITS["request_limit"]),
|
if stale_ids:
|
||||||
"created_at": agent.get("createdAt", ""),
|
users_repo.remove_shared_bulk(user_id, stale_ids)
|
||||||
"updated_at": agent.get("updatedAt", ""),
|
|
||||||
"pinned": str(agent["_id"]) in pinned_ids,
|
pinned_ids = set(
|
||||||
"shared": agent.get("shared_publicly", False),
|
user_doc.get("agent_preferences", {}).get("pinned", [])
|
||||||
"shared_token": agent.get("shared_token", ""),
|
if isinstance(user_doc.get("agent_preferences"), dict)
|
||||||
"shared_metadata": agent.get("shared_metadata", {}),
|
else []
|
||||||
}
|
)
|
||||||
for agent in shared_agents
|
|
||||||
]
|
list_shared_agents = []
|
||||||
|
for agent in shared_agents:
|
||||||
|
agent_id_str = str(agent["id"])
|
||||||
|
list_shared_agents.append(
|
||||||
|
{
|
||||||
|
"id": agent_id_str,
|
||||||
|
"name": agent.get("name", ""),
|
||||||
|
"description": agent.get("description", ""),
|
||||||
|
"image": (
|
||||||
|
generate_image_url(agent["image"]) if agent.get("image") else ""
|
||||||
|
),
|
||||||
|
"tools": agent.get("tools", []) or [],
|
||||||
|
"tool_details": resolve_tool_details(
|
||||||
|
agent.get("tools", []) or []
|
||||||
|
),
|
||||||
|
"agent_type": agent.get("agent_type", "") or "",
|
||||||
|
"status": agent.get("status", "") or "",
|
||||||
|
"json_schema": agent.get("json_schema"),
|
||||||
|
"limited_token_mode": agent.get("limited_token_mode", False),
|
||||||
|
"token_limit": agent.get("token_limit") or settings.DEFAULT_AGENT_LIMITS["token_limit"],
|
||||||
|
"limited_request_mode": agent.get("limited_request_mode", False),
|
||||||
|
"request_limit": agent.get("request_limit") or settings.DEFAULT_AGENT_LIMITS["request_limit"],
|
||||||
|
"created_at": agent.get("created_at", ""),
|
||||||
|
"updated_at": agent.get("updated_at", ""),
|
||||||
|
"pinned": agent_id_str in pinned_ids,
|
||||||
|
"shared": bool(agent.get("shared", False)),
|
||||||
|
"shared_token": agent.get("shared_token", "") or "",
|
||||||
|
"shared_metadata": agent.get("shared_metadata", {}) or {},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
return make_response(jsonify(list_shared_agents), 200)
|
return make_response(jsonify(list_shared_agents), 200)
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
@@ -220,44 +232,43 @@ class ShareAgent(Resource):
|
|||||||
),
|
),
|
||||||
400,
|
400,
|
||||||
)
|
)
|
||||||
|
shared_token = None
|
||||||
try:
|
try:
|
||||||
try:
|
with db_session() as conn:
|
||||||
agent_oid = ObjectId(agent_id)
|
repo = AgentsRepository(conn)
|
||||||
except Exception:
|
agent = repo.get_any(agent_id, user)
|
||||||
return make_response(
|
if not agent:
|
||||||
jsonify({"success": False, "message": "Invalid agent ID"}), 400
|
return make_response(
|
||||||
)
|
jsonify({"success": False, "message": "Agent not found"}), 404
|
||||||
agent = agents_collection.find_one({"_id": agent_oid, "user": user})
|
)
|
||||||
if not agent:
|
if shared:
|
||||||
return make_response(
|
shared_metadata = {
|
||||||
jsonify({"success": False, "message": "Agent not found"}), 404
|
"shared_by": username,
|
||||||
)
|
"shared_at": datetime.datetime.now(
|
||||||
if shared:
|
datetime.timezone.utc
|
||||||
shared_metadata = {
|
).isoformat(),
|
||||||
"shared_by": username,
|
}
|
||||||
"shared_at": datetime.datetime.now(datetime.timezone.utc),
|
shared_token = secrets.token_urlsafe(32)
|
||||||
}
|
repo.update(
|
||||||
shared_token = secrets.token_urlsafe(32)
|
str(agent["id"]), user,
|
||||||
agents_collection.update_one(
|
{
|
||||||
{"_id": agent_oid, "user": user},
|
"shared": True,
|
||||||
{
|
|
||||||
"$set": {
|
|
||||||
"shared_publicly": shared,
|
|
||||||
"shared_metadata": shared_metadata,
|
|
||||||
"shared_token": shared_token,
|
"shared_token": shared_token,
|
||||||
}
|
"shared_metadata": shared_metadata,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
agents_collection.update_one(
|
repo.update(
|
||||||
{"_id": agent_oid, "user": user},
|
str(agent["id"]), user,
|
||||||
{"$set": {"shared_publicly": shared, "shared_token": None}},
|
{
|
||||||
{"$unset": {"shared_metadata": ""}},
|
"shared": False,
|
||||||
)
|
"shared_token": None,
|
||||||
|
"shared_metadata": None,
|
||||||
|
},
|
||||||
|
)
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
current_app.logger.error(f"Error sharing/unsharing agent: {err}")
|
current_app.logger.error(f"Error sharing/unsharing agent: {err}", exc_info=True)
|
||||||
return make_response(jsonify({"success": False, "error": str(err)}), 400)
|
return make_response(jsonify({"success": False, "error": "Failed to update agent sharing status"}), 400)
|
||||||
shared_token = shared_token if shared else None
|
|
||||||
return make_response(
|
return make_response(
|
||||||
jsonify({"success": True, "shared_token": shared_token}), 200
|
jsonify({"success": True, "shared_token": shared_token}), 200
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -2,14 +2,15 @@
|
|||||||
|
|
||||||
import secrets
|
import secrets
|
||||||
|
|
||||||
from bson.objectid import ObjectId
|
|
||||||
from flask import current_app, jsonify, make_response, request
|
from flask import current_app, jsonify, make_response, request
|
||||||
from flask_restx import Namespace, Resource
|
from flask_restx import Namespace, Resource
|
||||||
|
|
||||||
from application.api import api
|
from application.api import api
|
||||||
from application.api.user.base import agents_collection, require_agent
|
from application.api.user.base import require_agent
|
||||||
from application.api.user.tasks import process_agent_webhook
|
from application.api.user.tasks import process_agent_webhook
|
||||||
from application.core.settings import settings
|
from application.core.settings import settings
|
||||||
|
from application.storage.db.repositories.agents import AgentsRepository
|
||||||
|
from application.storage.db.session import db_readonly, db_session
|
||||||
|
|
||||||
|
|
||||||
agents_webhooks_ns = Namespace(
|
agents_webhooks_ns = Namespace(
|
||||||
@@ -34,9 +35,8 @@ class AgentWebhook(Resource):
|
|||||||
jsonify({"success": False, "message": "ID is required"}), 400
|
jsonify({"success": False, "message": "ID is required"}), 400
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
agent = agents_collection.find_one(
|
with db_readonly() as conn:
|
||||||
{"_id": ObjectId(agent_id), "user": user}
|
agent = AgentsRepository(conn).get_any(agent_id, user)
|
||||||
)
|
|
||||||
if not agent:
|
if not agent:
|
||||||
return make_response(
|
return make_response(
|
||||||
jsonify({"success": False, "message": "Agent not found"}), 404
|
jsonify({"success": False, "message": "Agent not found"}), 404
|
||||||
@@ -44,10 +44,11 @@ class AgentWebhook(Resource):
|
|||||||
webhook_token = agent.get("incoming_webhook_token")
|
webhook_token = agent.get("incoming_webhook_token")
|
||||||
if not webhook_token:
|
if not webhook_token:
|
||||||
webhook_token = secrets.token_urlsafe(32)
|
webhook_token = secrets.token_urlsafe(32)
|
||||||
agents_collection.update_one(
|
with db_session() as conn:
|
||||||
{"_id": ObjectId(agent_id), "user": user},
|
AgentsRepository(conn).update(
|
||||||
{"$set": {"incoming_webhook_token": webhook_token}},
|
str(agent["id"]), user,
|
||||||
)
|
{"incoming_webhook_token": webhook_token},
|
||||||
|
)
|
||||||
base_url = settings.API_URL.rstrip("/")
|
base_url = settings.API_URL.rstrip("/")
|
||||||
full_webhook_url = f"{base_url}/api/webhooks/agents/{webhook_token}"
|
full_webhook_url = f"{base_url}/api/webhooks/agents/{webhook_token}"
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
|
|||||||
@@ -2,26 +2,84 @@
|
|||||||
|
|
||||||
import datetime
|
import datetime
|
||||||
|
|
||||||
from bson.objectid import ObjectId
|
|
||||||
from flask import current_app, jsonify, make_response, request
|
from flask import current_app, jsonify, make_response, request
|
||||||
from flask_restx import fields, Namespace, Resource
|
from flask_restx import fields, Namespace, Resource
|
||||||
|
from sqlalchemy import text as _sql_text
|
||||||
|
|
||||||
from application.api import api
|
from application.api import api
|
||||||
from application.api.user.base import (
|
from application.api.user.base import (
|
||||||
agents_collection,
|
|
||||||
conversations_collection,
|
|
||||||
generate_date_range,
|
generate_date_range,
|
||||||
generate_hourly_range,
|
generate_hourly_range,
|
||||||
generate_minute_range,
|
generate_minute_range,
|
||||||
token_usage_collection,
|
|
||||||
user_logs_collection,
|
|
||||||
)
|
)
|
||||||
|
from application.storage.db.repositories.agents import AgentsRepository
|
||||||
|
from application.storage.db.repositories.token_usage import TokenUsageRepository
|
||||||
|
from application.storage.db.repositories.user_logs import UserLogsRepository
|
||||||
|
from application.storage.db.session import db_readonly
|
||||||
|
|
||||||
|
|
||||||
analytics_ns = Namespace(
|
analytics_ns = Namespace(
|
||||||
"analytics", description="Analytics and reporting operations", path="/api"
|
"analytics", description="Analytics and reporting operations", path="/api"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_FILTER_BUCKETS = {
|
||||||
|
"last_hour": ("minute", "%Y-%m-%d %H:%M:00", "YYYY-MM-DD HH24:MI:00"),
|
||||||
|
"last_24_hour": ("hour", "%Y-%m-%d %H:00", "YYYY-MM-DD HH24:00"),
|
||||||
|
"last_7_days": ("day", "%Y-%m-%d", "YYYY-MM-DD"),
|
||||||
|
"last_15_days": ("day", "%Y-%m-%d", "YYYY-MM-DD"),
|
||||||
|
"last_30_days": ("day", "%Y-%m-%d", "YYYY-MM-DD"),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _range_for_filter(filter_option: str):
|
||||||
|
"""Return ``(start_date, end_date, bucket_unit, pg_fmt)`` for the filter.
|
||||||
|
|
||||||
|
Returns ``None`` on invalid filter.
|
||||||
|
"""
|
||||||
|
if filter_option not in _FILTER_BUCKETS:
|
||||||
|
return None
|
||||||
|
end_date = datetime.datetime.now(datetime.timezone.utc)
|
||||||
|
bucket_unit, _py_fmt, pg_fmt = _FILTER_BUCKETS[filter_option]
|
||||||
|
|
||||||
|
if filter_option == "last_hour":
|
||||||
|
start_date = end_date - datetime.timedelta(hours=1)
|
||||||
|
elif filter_option == "last_24_hour":
|
||||||
|
start_date = end_date - datetime.timedelta(hours=24)
|
||||||
|
else:
|
||||||
|
days = {
|
||||||
|
"last_7_days": 6,
|
||||||
|
"last_15_days": 14,
|
||||||
|
"last_30_days": 29,
|
||||||
|
}[filter_option]
|
||||||
|
start_date = end_date - datetime.timedelta(days=days)
|
||||||
|
start_date = start_date.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||||
|
end_date = end_date.replace(
|
||||||
|
hour=23, minute=59, second=59, microsecond=999999
|
||||||
|
)
|
||||||
|
return start_date, end_date, bucket_unit, pg_fmt
|
||||||
|
|
||||||
|
|
||||||
|
def _intervals_for_filter(filter_option, start_date, end_date):
|
||||||
|
if filter_option == "last_hour":
|
||||||
|
return generate_minute_range(start_date, end_date)
|
||||||
|
if filter_option == "last_24_hour":
|
||||||
|
return generate_hourly_range(start_date, end_date)
|
||||||
|
return generate_date_range(start_date, end_date)
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_api_key(conn, api_key_id, user_id):
|
||||||
|
"""Look up the ``agents.key`` value for a given agent id.
|
||||||
|
|
||||||
|
Scoped by ``user_id`` so an authenticated caller can't probe another
|
||||||
|
user's agents. Accepts either UUID or legacy Mongo ObjectId shape.
|
||||||
|
"""
|
||||||
|
if not api_key_id:
|
||||||
|
return None
|
||||||
|
agent = AgentsRepository(conn).get_any(api_key_id, user_id)
|
||||||
|
return (agent or {}).get("key") if agent else None
|
||||||
|
|
||||||
|
|
||||||
@analytics_ns.route("/get_message_analytics")
|
@analytics_ns.route("/get_message_analytics")
|
||||||
class GetMessageAnalytics(Resource):
|
class GetMessageAnalytics(Resource):
|
||||||
get_message_analytics_model = api.model(
|
get_message_analytics_model = api.model(
|
||||||
@@ -32,13 +90,7 @@ class GetMessageAnalytics(Resource):
|
|||||||
required=False,
|
required=False,
|
||||||
description="Filter option for analytics",
|
description="Filter option for analytics",
|
||||||
default="last_30_days",
|
default="last_30_days",
|
||||||
enum=[
|
enum=list(_FILTER_BUCKETS.keys()),
|
||||||
"last_hour",
|
|
||||||
"last_24_hour",
|
|
||||||
"last_7_days",
|
|
||||||
"last_15_days",
|
|
||||||
"last_30_days",
|
|
||||||
],
|
|
||||||
),
|
),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
@@ -50,88 +102,54 @@ class GetMessageAnalytics(Resource):
|
|||||||
if not decoded_token:
|
if not decoded_token:
|
||||||
return make_response(jsonify({"success": False}), 401)
|
return make_response(jsonify({"success": False}), 401)
|
||||||
user = decoded_token.get("sub")
|
user = decoded_token.get("sub")
|
||||||
data = request.get_json()
|
data = request.get_json() or {}
|
||||||
api_key_id = data.get("api_key_id")
|
api_key_id = data.get("api_key_id")
|
||||||
filter_option = data.get("filter_option", "last_30_days")
|
filter_option = data.get("filter_option", "last_30_days")
|
||||||
|
|
||||||
|
window = _range_for_filter(filter_option)
|
||||||
|
if window is None:
|
||||||
|
return make_response(
|
||||||
|
jsonify({"success": False, "message": "Invalid option"}), 400
|
||||||
|
)
|
||||||
|
start_date, end_date, _bucket_unit, pg_fmt = window
|
||||||
|
|
||||||
try:
|
try:
|
||||||
api_key = (
|
with db_readonly() as conn:
|
||||||
agents_collection.find_one({"_id": ObjectId(api_key_id), "user": user})[
|
api_key = _resolve_api_key(conn, api_key_id, user)
|
||||||
"key"
|
|
||||||
|
# Count messages per bucket, filtered by the conversation's
|
||||||
|
# owner (user_id) and optionally the agent api_key. The
|
||||||
|
# ``user_id`` filter is always applied post-cutover to
|
||||||
|
# prevent cross-tenant leakage on admin dashboards.
|
||||||
|
clauses = [
|
||||||
|
"c.user_id = :user_id",
|
||||||
|
"m.timestamp >= :start",
|
||||||
|
"m.timestamp <= :end",
|
||||||
]
|
]
|
||||||
if api_key_id
|
params: dict = {
|
||||||
else None
|
"user_id": user,
|
||||||
)
|
"start": start_date,
|
||||||
except Exception as err:
|
"end": end_date,
|
||||||
current_app.logger.error(f"Error getting API key: {err}", exc_info=True)
|
"fmt": pg_fmt,
|
||||||
return make_response(jsonify({"success": False}), 400)
|
|
||||||
end_date = datetime.datetime.now(datetime.timezone.utc)
|
|
||||||
|
|
||||||
if filter_option == "last_hour":
|
|
||||||
start_date = end_date - datetime.timedelta(hours=1)
|
|
||||||
group_format = "%Y-%m-%d %H:%M:00"
|
|
||||||
elif filter_option == "last_24_hour":
|
|
||||||
start_date = end_date - datetime.timedelta(hours=24)
|
|
||||||
group_format = "%Y-%m-%d %H:00"
|
|
||||||
else:
|
|
||||||
if filter_option in ["last_7_days", "last_15_days", "last_30_days"]:
|
|
||||||
filter_days = (
|
|
||||||
6
|
|
||||||
if filter_option == "last_7_days"
|
|
||||||
else 14 if filter_option == "last_15_days" else 29
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return make_response(
|
|
||||||
jsonify({"success": False, "message": "Invalid option"}), 400
|
|
||||||
)
|
|
||||||
start_date = end_date - datetime.timedelta(days=filter_days)
|
|
||||||
start_date = start_date.replace(hour=0, minute=0, second=0, microsecond=0)
|
|
||||||
end_date = end_date.replace(
|
|
||||||
hour=23, minute=59, second=59, microsecond=999999
|
|
||||||
)
|
|
||||||
group_format = "%Y-%m-%d"
|
|
||||||
try:
|
|
||||||
match_stage = {
|
|
||||||
"$match": {
|
|
||||||
"user": user,
|
|
||||||
}
|
}
|
||||||
}
|
if api_key:
|
||||||
if api_key:
|
clauses.append("c.api_key = :api_key")
|
||||||
match_stage["$match"]["api_key"] = api_key
|
params["api_key"] = api_key
|
||||||
pipeline = [
|
where = " AND ".join(clauses)
|
||||||
match_stage,
|
sql = (
|
||||||
{"$unwind": "$queries"},
|
"SELECT to_char(m.timestamp AT TIME ZONE 'UTC', :fmt) AS bucket, "
|
||||||
{
|
"COUNT(*) AS count "
|
||||||
"$match": {
|
"FROM conversation_messages m "
|
||||||
"queries.timestamp": {"$gte": start_date, "$lte": end_date}
|
"JOIN conversations c ON c.id = m.conversation_id "
|
||||||
}
|
f"WHERE {where} "
|
||||||
},
|
"GROUP BY bucket ORDER BY bucket ASC"
|
||||||
{
|
)
|
||||||
"$group": {
|
rows = conn.execute(_sql_text(sql), params).fetchall()
|
||||||
"_id": {
|
|
||||||
"$dateToString": {
|
|
||||||
"format": group_format,
|
|
||||||
"date": "$queries.timestamp",
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"count": {"$sum": 1},
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{"$sort": {"_id": 1}},
|
|
||||||
]
|
|
||||||
|
|
||||||
message_data = conversations_collection.aggregate(pipeline)
|
intervals = _intervals_for_filter(filter_option, start_date, end_date)
|
||||||
|
|
||||||
if filter_option == "last_hour":
|
|
||||||
intervals = generate_minute_range(start_date, end_date)
|
|
||||||
elif filter_option == "last_24_hour":
|
|
||||||
intervals = generate_hourly_range(start_date, end_date)
|
|
||||||
else:
|
|
||||||
intervals = generate_date_range(start_date, end_date)
|
|
||||||
daily_messages = {interval: 0 for interval in intervals}
|
daily_messages = {interval: 0 for interval in intervals}
|
||||||
|
for row in rows:
|
||||||
for entry in message_data:
|
daily_messages[row._mapping["bucket"]] = int(row._mapping["count"])
|
||||||
daily_messages[entry["_id"]] = entry["count"]
|
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
current_app.logger.error(
|
current_app.logger.error(
|
||||||
f"Error getting message analytics: {err}", exc_info=True
|
f"Error getting message analytics: {err}", exc_info=True
|
||||||
@@ -152,13 +170,7 @@ class GetTokenAnalytics(Resource):
|
|||||||
required=False,
|
required=False,
|
||||||
description="Filter option for analytics",
|
description="Filter option for analytics",
|
||||||
default="last_30_days",
|
default="last_30_days",
|
||||||
enum=[
|
enum=list(_FILTER_BUCKETS.keys()),
|
||||||
"last_hour",
|
|
||||||
"last_24_hour",
|
|
||||||
"last_7_days",
|
|
||||||
"last_15_days",
|
|
||||||
"last_30_days",
|
|
||||||
],
|
|
||||||
),
|
),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
@@ -170,123 +182,36 @@ class GetTokenAnalytics(Resource):
|
|||||||
if not decoded_token:
|
if not decoded_token:
|
||||||
return make_response(jsonify({"success": False}), 401)
|
return make_response(jsonify({"success": False}), 401)
|
||||||
user = decoded_token.get("sub")
|
user = decoded_token.get("sub")
|
||||||
data = request.get_json()
|
data = request.get_json() or {}
|
||||||
api_key_id = data.get("api_key_id")
|
api_key_id = data.get("api_key_id")
|
||||||
filter_option = data.get("filter_option", "last_30_days")
|
filter_option = data.get("filter_option", "last_30_days")
|
||||||
|
|
||||||
try:
|
window = _range_for_filter(filter_option)
|
||||||
api_key = (
|
if window is None:
|
||||||
agents_collection.find_one({"_id": ObjectId(api_key_id), "user": user})[
|
return make_response(
|
||||||
"key"
|
jsonify({"success": False, "message": "Invalid option"}), 400
|
||||||
]
|
|
||||||
if api_key_id
|
|
||||||
else None
|
|
||||||
)
|
)
|
||||||
except Exception as err:
|
start_date, end_date, bucket_unit, _pg_fmt = window
|
||||||
current_app.logger.error(f"Error getting API key: {err}", exc_info=True)
|
|
||||||
return make_response(jsonify({"success": False}), 400)
|
|
||||||
end_date = datetime.datetime.now(datetime.timezone.utc)
|
|
||||||
|
|
||||||
if filter_option == "last_hour":
|
|
||||||
start_date = end_date - datetime.timedelta(hours=1)
|
|
||||||
group_format = "%Y-%m-%d %H:%M:00"
|
|
||||||
group_stage = {
|
|
||||||
"$group": {
|
|
||||||
"_id": {
|
|
||||||
"minute": {
|
|
||||||
"$dateToString": {
|
|
||||||
"format": group_format,
|
|
||||||
"date": "$timestamp",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"total_tokens": {
|
|
||||||
"$sum": {"$add": ["$prompt_tokens", "$generated_tokens"]}
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
elif filter_option == "last_24_hour":
|
|
||||||
start_date = end_date - datetime.timedelta(hours=24)
|
|
||||||
group_format = "%Y-%m-%d %H:00"
|
|
||||||
group_stage = {
|
|
||||||
"$group": {
|
|
||||||
"_id": {
|
|
||||||
"hour": {
|
|
||||||
"$dateToString": {
|
|
||||||
"format": group_format,
|
|
||||||
"date": "$timestamp",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"total_tokens": {
|
|
||||||
"$sum": {"$add": ["$prompt_tokens", "$generated_tokens"]}
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
if filter_option in ["last_7_days", "last_15_days", "last_30_days"]:
|
|
||||||
filter_days = (
|
|
||||||
6
|
|
||||||
if filter_option == "last_7_days"
|
|
||||||
else (14 if filter_option == "last_15_days" else 29)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return make_response(
|
|
||||||
jsonify({"success": False, "message": "Invalid option"}), 400
|
|
||||||
)
|
|
||||||
start_date = end_date - datetime.timedelta(days=filter_days)
|
|
||||||
start_date = start_date.replace(hour=0, minute=0, second=0, microsecond=0)
|
|
||||||
end_date = end_date.replace(
|
|
||||||
hour=23, minute=59, second=59, microsecond=999999
|
|
||||||
)
|
|
||||||
group_format = "%Y-%m-%d"
|
|
||||||
group_stage = {
|
|
||||||
"$group": {
|
|
||||||
"_id": {
|
|
||||||
"day": {
|
|
||||||
"$dateToString": {
|
|
||||||
"format": group_format,
|
|
||||||
"date": "$timestamp",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"total_tokens": {
|
|
||||||
"$sum": {"$add": ["$prompt_tokens", "$generated_tokens"]}
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
try:
|
try:
|
||||||
match_stage = {
|
with db_readonly() as conn:
|
||||||
"$match": {
|
api_key = _resolve_api_key(conn, api_key_id, user)
|
||||||
"user_id": user,
|
# ``bucketed_totals`` applies user_id / api_key filters
|
||||||
"timestamp": {"$gte": start_date, "$lte": end_date},
|
# directly — no need to reshape a Mongo pipeline.
|
||||||
}
|
rows = TokenUsageRepository(conn).bucketed_totals(
|
||||||
}
|
bucket_unit=bucket_unit,
|
||||||
if api_key:
|
user_id=user,
|
||||||
match_stage["$match"]["api_key"] = api_key
|
api_key=api_key,
|
||||||
token_usage_data = token_usage_collection.aggregate(
|
timestamp_gte=start_date,
|
||||||
[
|
timestamp_lt=end_date,
|
||||||
match_stage,
|
)
|
||||||
group_stage,
|
|
||||||
{"$sort": {"_id": 1}},
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
if filter_option == "last_hour":
|
intervals = _intervals_for_filter(filter_option, start_date, end_date)
|
||||||
intervals = generate_minute_range(start_date, end_date)
|
|
||||||
elif filter_option == "last_24_hour":
|
|
||||||
intervals = generate_hourly_range(start_date, end_date)
|
|
||||||
else:
|
|
||||||
intervals = generate_date_range(start_date, end_date)
|
|
||||||
daily_token_usage = {interval: 0 for interval in intervals}
|
daily_token_usage = {interval: 0 for interval in intervals}
|
||||||
|
for entry in rows:
|
||||||
for entry in token_usage_data:
|
daily_token_usage[entry["bucket"]] = int(
|
||||||
if filter_option == "last_hour":
|
entry["prompt_tokens"] + entry["generated_tokens"]
|
||||||
daily_token_usage[entry["_id"]["minute"]] = entry["total_tokens"]
|
)
|
||||||
elif filter_option == "last_24_hour":
|
|
||||||
daily_token_usage[entry["_id"]["hour"]] = entry["total_tokens"]
|
|
||||||
else:
|
|
||||||
daily_token_usage[entry["_id"]["day"]] = entry["total_tokens"]
|
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
current_app.logger.error(
|
current_app.logger.error(
|
||||||
f"Error getting token analytics: {err}", exc_info=True
|
f"Error getting token analytics: {err}", exc_info=True
|
||||||
@@ -307,13 +232,7 @@ class GetFeedbackAnalytics(Resource):
|
|||||||
required=False,
|
required=False,
|
||||||
description="Filter option for analytics",
|
description="Filter option for analytics",
|
||||||
default="last_30_days",
|
default="last_30_days",
|
||||||
enum=[
|
enum=list(_FILTER_BUCKETS.keys()),
|
||||||
"last_hour",
|
|
||||||
"last_24_hour",
|
|
||||||
"last_7_days",
|
|
||||||
"last_15_days",
|
|
||||||
"last_30_days",
|
|
||||||
],
|
|
||||||
),
|
),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
@@ -325,128 +244,64 @@ class GetFeedbackAnalytics(Resource):
|
|||||||
if not decoded_token:
|
if not decoded_token:
|
||||||
return make_response(jsonify({"success": False}), 401)
|
return make_response(jsonify({"success": False}), 401)
|
||||||
user = decoded_token.get("sub")
|
user = decoded_token.get("sub")
|
||||||
data = request.get_json()
|
data = request.get_json() or {}
|
||||||
api_key_id = data.get("api_key_id")
|
api_key_id = data.get("api_key_id")
|
||||||
filter_option = data.get("filter_option", "last_30_days")
|
filter_option = data.get("filter_option", "last_30_days")
|
||||||
|
|
||||||
|
window = _range_for_filter(filter_option)
|
||||||
|
if window is None:
|
||||||
|
return make_response(
|
||||||
|
jsonify({"success": False, "message": "Invalid option"}), 400
|
||||||
|
)
|
||||||
|
start_date, end_date, _bucket_unit, pg_fmt = window
|
||||||
|
|
||||||
try:
|
try:
|
||||||
api_key = (
|
with db_readonly() as conn:
|
||||||
agents_collection.find_one({"_id": ObjectId(api_key_id), "user": user})[
|
api_key = _resolve_api_key(conn, api_key_id, user)
|
||||||
"key"
|
|
||||||
|
# Feedback lives inside the ``conversation_messages.feedback``
|
||||||
|
# JSONB as ``{"text": "like"|"dislike", "timestamp": "..."}``.
|
||||||
|
# There is no scalar ``feedback_timestamp`` column — extract
|
||||||
|
# the timestamp from the JSONB and cast it to timestamptz for
|
||||||
|
# the range filter + bucket grouping.
|
||||||
|
clauses = [
|
||||||
|
"c.user_id = :user_id",
|
||||||
|
"m.feedback IS NOT NULL",
|
||||||
|
"(m.feedback->>'timestamp')::timestamptz >= :start",
|
||||||
|
"(m.feedback->>'timestamp')::timestamptz <= :end",
|
||||||
]
|
]
|
||||||
if api_key_id
|
params: dict = {
|
||||||
else None
|
"user_id": user,
|
||||||
)
|
"start": start_date,
|
||||||
except Exception as err:
|
"end": end_date,
|
||||||
current_app.logger.error(f"Error getting API key: {err}", exc_info=True)
|
"fmt": pg_fmt,
|
||||||
return make_response(jsonify({"success": False}), 400)
|
|
||||||
end_date = datetime.datetime.now(datetime.timezone.utc)
|
|
||||||
|
|
||||||
if filter_option == "last_hour":
|
|
||||||
start_date = end_date - datetime.timedelta(hours=1)
|
|
||||||
group_format = "%Y-%m-%d %H:%M:00"
|
|
||||||
date_field = {
|
|
||||||
"$dateToString": {
|
|
||||||
"format": group_format,
|
|
||||||
"date": "$queries.feedback_timestamp",
|
|
||||||
}
|
}
|
||||||
}
|
if api_key:
|
||||||
elif filter_option == "last_24_hour":
|
clauses.append("c.api_key = :api_key")
|
||||||
start_date = end_date - datetime.timedelta(hours=24)
|
params["api_key"] = api_key
|
||||||
group_format = "%Y-%m-%d %H:00"
|
where = " AND ".join(clauses)
|
||||||
date_field = {
|
sql = (
|
||||||
"$dateToString": {
|
"SELECT to_char("
|
||||||
"format": group_format,
|
"(m.feedback->>'timestamp')::timestamptz AT TIME ZONE 'UTC', :fmt"
|
||||||
"date": "$queries.feedback_timestamp",
|
") AS bucket, "
|
||||||
}
|
"SUM(CASE WHEN m.feedback->>'text' = 'like' THEN 1 ELSE 0 END) AS positive, "
|
||||||
}
|
"SUM(CASE WHEN m.feedback->>'text' = 'dislike' THEN 1 ELSE 0 END) AS negative "
|
||||||
else:
|
"FROM conversation_messages m "
|
||||||
if filter_option in ["last_7_days", "last_15_days", "last_30_days"]:
|
"JOIN conversations c ON c.id = m.conversation_id "
|
||||||
filter_days = (
|
f"WHERE {where} "
|
||||||
6
|
"GROUP BY bucket ORDER BY bucket ASC"
|
||||||
if filter_option == "last_7_days"
|
|
||||||
else (14 if filter_option == "last_15_days" else 29)
|
|
||||||
)
|
)
|
||||||
else:
|
rows = conn.execute(_sql_text(sql), params).fetchall()
|
||||||
return make_response(
|
|
||||||
jsonify({"success": False, "message": "Invalid option"}), 400
|
|
||||||
)
|
|
||||||
start_date = end_date - datetime.timedelta(days=filter_days)
|
|
||||||
start_date = start_date.replace(hour=0, minute=0, second=0, microsecond=0)
|
|
||||||
end_date = end_date.replace(
|
|
||||||
hour=23, minute=59, second=59, microsecond=999999
|
|
||||||
)
|
|
||||||
group_format = "%Y-%m-%d"
|
|
||||||
date_field = {
|
|
||||||
"$dateToString": {
|
|
||||||
"format": group_format,
|
|
||||||
"date": "$queries.feedback_timestamp",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
try:
|
|
||||||
match_stage = {
|
|
||||||
"$match": {
|
|
||||||
"queries.feedback_timestamp": {
|
|
||||||
"$gte": start_date,
|
|
||||||
"$lte": end_date,
|
|
||||||
},
|
|
||||||
"queries.feedback": {"$exists": True},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if api_key:
|
|
||||||
match_stage["$match"]["api_key"] = api_key
|
|
||||||
pipeline = [
|
|
||||||
match_stage,
|
|
||||||
{"$unwind": "$queries"},
|
|
||||||
{"$match": {"queries.feedback": {"$exists": True}}},
|
|
||||||
{
|
|
||||||
"$group": {
|
|
||||||
"_id": {"time": date_field, "feedback": "$queries.feedback"},
|
|
||||||
"count": {"$sum": 1},
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"$group": {
|
|
||||||
"_id": "$_id.time",
|
|
||||||
"positive": {
|
|
||||||
"$sum": {
|
|
||||||
"$cond": [
|
|
||||||
{"$eq": ["$_id.feedback", "LIKE"]},
|
|
||||||
"$count",
|
|
||||||
0,
|
|
||||||
]
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"negative": {
|
|
||||||
"$sum": {
|
|
||||||
"$cond": [
|
|
||||||
{"$eq": ["$_id.feedback", "DISLIKE"]},
|
|
||||||
"$count",
|
|
||||||
0,
|
|
||||||
]
|
|
||||||
}
|
|
||||||
},
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{"$sort": {"_id": 1}},
|
|
||||||
]
|
|
||||||
|
|
||||||
feedback_data = conversations_collection.aggregate(pipeline)
|
intervals = _intervals_for_filter(filter_option, start_date, end_date)
|
||||||
|
|
||||||
if filter_option == "last_hour":
|
|
||||||
intervals = generate_minute_range(start_date, end_date)
|
|
||||||
elif filter_option == "last_24_hour":
|
|
||||||
intervals = generate_hourly_range(start_date, end_date)
|
|
||||||
else:
|
|
||||||
intervals = generate_date_range(start_date, end_date)
|
|
||||||
daily_feedback = {
|
daily_feedback = {
|
||||||
interval: {"positive": 0, "negative": 0} for interval in intervals
|
interval: {"positive": 0, "negative": 0} for interval in intervals
|
||||||
}
|
}
|
||||||
|
for row in rows:
|
||||||
for entry in feedback_data:
|
bucket = row._mapping["bucket"]
|
||||||
daily_feedback[entry["_id"]] = {
|
daily_feedback[bucket] = {
|
||||||
"positive": entry["positive"],
|
"positive": int(row._mapping["positive"] or 0),
|
||||||
"negative": entry["negative"],
|
"negative": int(row._mapping["negative"] or 0),
|
||||||
}
|
}
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
current_app.logger.error(
|
current_app.logger.error(
|
||||||
@@ -484,47 +339,89 @@ class GetUserLogs(Resource):
|
|||||||
if not decoded_token:
|
if not decoded_token:
|
||||||
return make_response(jsonify({"success": False}), 401)
|
return make_response(jsonify({"success": False}), 401)
|
||||||
user = decoded_token.get("sub")
|
user = decoded_token.get("sub")
|
||||||
data = request.get_json()
|
data = request.get_json() or {}
|
||||||
page = int(data.get("page", 1))
|
page = int(data.get("page", 1))
|
||||||
api_key_id = data.get("api_key_id")
|
api_key_id = data.get("api_key_id")
|
||||||
page_size = int(data.get("page_size", 10))
|
page_size = int(data.get("page_size", 10))
|
||||||
skip = (page - 1) * page_size
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
api_key = (
|
with db_readonly() as conn:
|
||||||
agents_collection.find_one({"_id": ObjectId(api_key_id)})["key"]
|
api_key = _resolve_api_key(conn, api_key_id, user)
|
||||||
if api_key_id
|
logs_repo = UserLogsRepository(conn)
|
||||||
else None
|
if api_key:
|
||||||
)
|
# ``find_by_api_key`` filters on ``data->>'api_key'``
|
||||||
|
# — the PG shape of the legacy top-level ``api_key``
|
||||||
|
# filter. Paginate client-side using offset/limit.
|
||||||
|
all_rows = logs_repo.find_by_api_key(api_key)
|
||||||
|
offset = (page - 1) * page_size
|
||||||
|
window = all_rows[offset: offset + page_size + 1]
|
||||||
|
items = window
|
||||||
|
else:
|
||||||
|
items, has_more_flag = logs_repo.list_paginated(
|
||||||
|
user_id=user,
|
||||||
|
page=page,
|
||||||
|
page_size=page_size,
|
||||||
|
)
|
||||||
|
# list_paginated already trims to page_size and
|
||||||
|
# returns has_more separately.
|
||||||
|
results = [
|
||||||
|
{
|
||||||
|
"id": str(item.get("id") or item.get("_id")),
|
||||||
|
"action": (item.get("data") or {}).get("action"),
|
||||||
|
"level": (item.get("data") or {}).get("level"),
|
||||||
|
"user": item.get("user_id"),
|
||||||
|
"question": (item.get("data") or {}).get("question"),
|
||||||
|
"sources": (item.get("data") or {}).get("sources"),
|
||||||
|
"retriever_params": (item.get("data") or {}).get(
|
||||||
|
"retriever_params"
|
||||||
|
),
|
||||||
|
"timestamp": (
|
||||||
|
item["timestamp"].isoformat()
|
||||||
|
if hasattr(item.get("timestamp"), "isoformat")
|
||||||
|
else item.get("timestamp")
|
||||||
|
),
|
||||||
|
}
|
||||||
|
for item in items
|
||||||
|
]
|
||||||
|
return make_response(
|
||||||
|
jsonify(
|
||||||
|
{
|
||||||
|
"success": True,
|
||||||
|
"logs": results,
|
||||||
|
"page": page,
|
||||||
|
"page_size": page_size,
|
||||||
|
"has_more": has_more_flag,
|
||||||
|
}
|
||||||
|
),
|
||||||
|
200,
|
||||||
|
)
|
||||||
|
|
||||||
|
has_more = len(items) > page_size
|
||||||
|
items = items[:page_size]
|
||||||
|
results = [
|
||||||
|
{
|
||||||
|
"id": str(item.get("id") or item.get("_id")),
|
||||||
|
"action": (item.get("data") or {}).get("action"),
|
||||||
|
"level": (item.get("data") or {}).get("level"),
|
||||||
|
"user": item.get("user_id"),
|
||||||
|
"question": (item.get("data") or {}).get("question"),
|
||||||
|
"sources": (item.get("data") or {}).get("sources"),
|
||||||
|
"retriever_params": (item.get("data") or {}).get(
|
||||||
|
"retriever_params"
|
||||||
|
),
|
||||||
|
"timestamp": (
|
||||||
|
item["timestamp"].isoformat()
|
||||||
|
if hasattr(item.get("timestamp"), "isoformat")
|
||||||
|
else item.get("timestamp")
|
||||||
|
),
|
||||||
|
}
|
||||||
|
for item in items
|
||||||
|
]
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
current_app.logger.error(f"Error getting API key: {err}", exc_info=True)
|
current_app.logger.error(
|
||||||
|
f"Error getting user logs: {err}", exc_info=True
|
||||||
|
)
|
||||||
return make_response(jsonify({"success": False}), 400)
|
return make_response(jsonify({"success": False}), 400)
|
||||||
query = {"user": user}
|
|
||||||
if api_key:
|
|
||||||
query = {"api_key": api_key}
|
|
||||||
items_cursor = (
|
|
||||||
user_logs_collection.find(query)
|
|
||||||
.sort("timestamp", -1)
|
|
||||||
.skip(skip)
|
|
||||||
.limit(page_size + 1)
|
|
||||||
)
|
|
||||||
items = list(items_cursor)
|
|
||||||
|
|
||||||
results = [
|
|
||||||
{
|
|
||||||
"id": str(item.get("_id")),
|
|
||||||
"action": item.get("action"),
|
|
||||||
"level": item.get("level"),
|
|
||||||
"user": item.get("user"),
|
|
||||||
"question": item.get("question"),
|
|
||||||
"sources": item.get("sources"),
|
|
||||||
"retriever_params": item.get("retriever_params"),
|
|
||||||
"timestamp": item.get("timestamp"),
|
|
||||||
}
|
|
||||||
for item in items[:page_size]
|
|
||||||
]
|
|
||||||
|
|
||||||
has_more = len(items) > page_size
|
|
||||||
|
|
||||||
return make_response(
|
return make_response(
|
||||||
jsonify(
|
jsonify(
|
||||||
|
|||||||
@@ -1,15 +1,39 @@
|
|||||||
"""File attachments and media routes."""
|
"""File attachments and media routes."""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
import tempfile
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
|
||||||
from bson.objectid import ObjectId
|
|
||||||
from flask import current_app, jsonify, make_response, request
|
from flask import current_app, jsonify, make_response, request
|
||||||
from flask_restx import fields, Namespace, Resource
|
from flask_restx import fields, Namespace, Resource
|
||||||
|
|
||||||
from application.api import api
|
from application.api import api
|
||||||
from application.api.user.base import agents_collection, storage
|
from application.cache import get_redis_instance
|
||||||
from application.api.user.tasks import store_attachment
|
|
||||||
from application.core.settings import settings
|
from application.core.settings import settings
|
||||||
|
from application.storage.db.repositories.agents import AgentsRepository
|
||||||
|
from application.storage.db.session import db_readonly
|
||||||
|
from application.stt.constants import (
|
||||||
|
SUPPORTED_AUDIO_EXTENSIONS,
|
||||||
|
SUPPORTED_AUDIO_MIME_TYPES,
|
||||||
|
)
|
||||||
|
from application.stt.upload_limits import (
|
||||||
|
AudioFileTooLargeError,
|
||||||
|
build_stt_file_size_limit_message,
|
||||||
|
enforce_audio_file_size_limit,
|
||||||
|
is_audio_filename,
|
||||||
|
)
|
||||||
|
from application.stt.live_session import (
|
||||||
|
apply_live_stt_hypothesis,
|
||||||
|
create_live_stt_session,
|
||||||
|
delete_live_stt_session,
|
||||||
|
finalize_live_stt_session,
|
||||||
|
get_live_stt_transcript_text,
|
||||||
|
load_live_stt_session,
|
||||||
|
save_live_stt_session,
|
||||||
|
)
|
||||||
|
from application.stt.stt_creator import STTCreator
|
||||||
from application.tts.tts_creator import TTSCreator
|
from application.tts.tts_creator import TTSCreator
|
||||||
from application.utils import safe_filename
|
from application.utils import safe_filename
|
||||||
|
|
||||||
@@ -19,6 +43,73 @@ attachments_ns = Namespace(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_authenticated_user():
|
||||||
|
decoded_token = getattr(request, "decoded_token", None)
|
||||||
|
api_key = request.form.get("api_key") or request.args.get("api_key")
|
||||||
|
|
||||||
|
if decoded_token:
|
||||||
|
return safe_filename(decoded_token.get("sub"))
|
||||||
|
|
||||||
|
if api_key:
|
||||||
|
with db_readonly() as conn:
|
||||||
|
agent = AgentsRepository(conn).find_by_key(api_key)
|
||||||
|
if not agent:
|
||||||
|
return make_response(
|
||||||
|
jsonify({"success": False, "message": "Invalid API key"}), 401
|
||||||
|
)
|
||||||
|
return safe_filename(agent.get("user_id"))
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _get_uploaded_file_size(file) -> int:
|
||||||
|
try:
|
||||||
|
current_position = file.stream.tell()
|
||||||
|
file.stream.seek(0, os.SEEK_END)
|
||||||
|
size_bytes = file.stream.tell()
|
||||||
|
file.stream.seek(current_position)
|
||||||
|
return size_bytes
|
||||||
|
except Exception:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
def _is_supported_audio_mimetype(mimetype: str) -> bool:
|
||||||
|
if not mimetype:
|
||||||
|
return True
|
||||||
|
normalized = mimetype.split(";")[0].strip().lower()
|
||||||
|
return normalized.startswith("audio/") or normalized in SUPPORTED_AUDIO_MIME_TYPES
|
||||||
|
|
||||||
|
|
||||||
|
def _enforce_uploaded_audio_size_limit(file, filename: str) -> None:
|
||||||
|
if not is_audio_filename(filename):
|
||||||
|
return
|
||||||
|
size_bytes = _get_uploaded_file_size(file)
|
||||||
|
if size_bytes:
|
||||||
|
enforce_audio_file_size_limit(size_bytes)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_store_attachment_user_error(exc: Exception) -> str:
|
||||||
|
if isinstance(exc, AudioFileTooLargeError):
|
||||||
|
return build_stt_file_size_limit_message()
|
||||||
|
return "Failed to process file"
|
||||||
|
|
||||||
|
|
||||||
|
def _require_live_stt_redis():
|
||||||
|
redis_client = get_redis_instance()
|
||||||
|
if redis_client:
|
||||||
|
return redis_client
|
||||||
|
return make_response(
|
||||||
|
jsonify({"success": False, "message": "Live transcription is unavailable"}),
|
||||||
|
503,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_bool_form_value(value: str | None) -> bool:
|
||||||
|
if value is None:
|
||||||
|
return False
|
||||||
|
return value.strip().lower() in {"1", "true", "yes", "on"}
|
||||||
|
|
||||||
|
|
||||||
@attachments_ns.route("/store_attachment")
|
@attachments_ns.route("/store_attachment")
|
||||||
class StoreAttachment(Resource):
|
class StoreAttachment(Resource):
|
||||||
@api.expect(
|
@api.expect(
|
||||||
@@ -36,8 +127,9 @@ class StoreAttachment(Resource):
|
|||||||
description="Stores one or multiple attachments without vectorization or training. Supports user or API key authentication."
|
description="Stores one or multiple attachments without vectorization or training. Supports user or API key authentication."
|
||||||
)
|
)
|
||||||
def post(self):
|
def post(self):
|
||||||
decoded_token = getattr(request, "decoded_token", None)
|
auth_user = _resolve_authenticated_user()
|
||||||
api_key = request.form.get("api_key") or request.args.get("api_key")
|
if hasattr(auth_user, "status_code"):
|
||||||
|
return auth_user
|
||||||
|
|
||||||
files = request.files.getlist("file")
|
files = request.files.getlist("file")
|
||||||
if not files:
|
if not files:
|
||||||
@@ -51,30 +143,25 @@ class StoreAttachment(Resource):
|
|||||||
400,
|
400,
|
||||||
)
|
)
|
||||||
|
|
||||||
user = None
|
user = auth_user
|
||||||
if decoded_token:
|
if not user:
|
||||||
user = safe_filename(decoded_token.get("sub"))
|
|
||||||
elif api_key:
|
|
||||||
agent = agents_collection.find_one({"key": api_key})
|
|
||||||
if not agent:
|
|
||||||
return make_response(
|
|
||||||
jsonify({"success": False, "message": "Invalid API key"}), 401
|
|
||||||
)
|
|
||||||
user = safe_filename(agent.get("user"))
|
|
||||||
else:
|
|
||||||
return make_response(
|
return make_response(
|
||||||
jsonify({"success": False, "message": "Authentication required"}), 401
|
jsonify({"success": False, "message": "Authentication required"}), 401
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
from application.api.user.tasks import store_attachment
|
||||||
|
from application.api.user.base import storage
|
||||||
|
|
||||||
tasks = []
|
tasks = []
|
||||||
errors = []
|
errors = []
|
||||||
original_file_count = len(files)
|
original_file_count = len(files)
|
||||||
|
|
||||||
for idx, file in enumerate(files):
|
for idx, file in enumerate(files):
|
||||||
try:
|
try:
|
||||||
attachment_id = ObjectId()
|
attachment_id = uuid.uuid4()
|
||||||
original_filename = safe_filename(os.path.basename(file.filename))
|
original_filename = safe_filename(os.path.basename(file.filename))
|
||||||
|
_enforce_uploaded_audio_size_limit(file, original_filename)
|
||||||
relative_path = f"{settings.UPLOAD_FOLDER}/{user}/attachments/{str(attachment_id)}/{original_filename}"
|
relative_path = f"{settings.UPLOAD_FOLDER}/{user}/attachments/{str(attachment_id)}/{original_filename}"
|
||||||
|
|
||||||
metadata = storage.save_file(file, relative_path)
|
metadata = storage.save_file(file, relative_path)
|
||||||
@@ -90,20 +177,33 @@ class StoreAttachment(Resource):
|
|||||||
"task_id": task.id,
|
"task_id": task.id,
|
||||||
"filename": original_filename,
|
"filename": original_filename,
|
||||||
"attachment_id": str(attachment_id),
|
"attachment_id": str(attachment_id),
|
||||||
|
"upload_index": idx,
|
||||||
})
|
})
|
||||||
except Exception as file_err:
|
except Exception as file_err:
|
||||||
current_app.logger.error(f"Error processing file {idx} ({file.filename}): {file_err}", exc_info=True)
|
current_app.logger.error(f"Error processing file {idx} ({file.filename}): {file_err}", exc_info=True)
|
||||||
errors.append({
|
errors.append({
|
||||||
|
"upload_index": idx,
|
||||||
"filename": file.filename,
|
"filename": file.filename,
|
||||||
"error": str(file_err)
|
"error": _get_store_attachment_user_error(file_err),
|
||||||
})
|
})
|
||||||
|
|
||||||
if not tasks:
|
if not tasks:
|
||||||
error_msg = "No valid files to upload"
|
if errors and all(
|
||||||
if errors:
|
error.get("error") == build_stt_file_size_limit_message()
|
||||||
error_msg += f". Errors: {errors}"
|
for error in errors
|
||||||
|
):
|
||||||
|
return make_response(
|
||||||
|
jsonify(
|
||||||
|
{
|
||||||
|
"success": False,
|
||||||
|
"message": build_stt_file_size_limit_message(),
|
||||||
|
"errors": errors,
|
||||||
|
}
|
||||||
|
),
|
||||||
|
413,
|
||||||
|
)
|
||||||
return make_response(
|
return make_response(
|
||||||
jsonify({"status": "error", "message": error_msg, "errors": errors}),
|
jsonify({"status": "error", "message": "No valid files to upload"}),
|
||||||
400,
|
400,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -135,14 +235,392 @@ class StoreAttachment(Resource):
|
|||||||
)
|
)
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
current_app.logger.error(f"Error storing attachment: {err}", exc_info=True)
|
current_app.logger.error(f"Error storing attachment: {err}", exc_info=True)
|
||||||
return make_response(jsonify({"success": False, "error": str(err)}), 400)
|
return make_response(jsonify({"success": False, "error": "Failed to store attachment"}), 400)
|
||||||
|
|
||||||
|
|
||||||
|
@attachments_ns.route("/stt")
|
||||||
|
class SpeechToText(Resource):
|
||||||
|
@api.expect(
|
||||||
|
api.model(
|
||||||
|
"SpeechToTextModel",
|
||||||
|
{
|
||||||
|
"file": fields.Raw(required=True, description="Audio file"),
|
||||||
|
"language": fields.String(
|
||||||
|
required=False, description="Optional transcription language hint"
|
||||||
|
),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
@api.doc(description="Transcribe an uploaded audio file")
|
||||||
|
def post(self):
|
||||||
|
auth_user = _resolve_authenticated_user()
|
||||||
|
if hasattr(auth_user, "status_code"):
|
||||||
|
return auth_user
|
||||||
|
if not auth_user:
|
||||||
|
return make_response(
|
||||||
|
jsonify({"success": False, "message": "Authentication required"}),
|
||||||
|
401,
|
||||||
|
)
|
||||||
|
|
||||||
|
file = request.files.get("file")
|
||||||
|
if not file or file.filename == "":
|
||||||
|
return make_response(
|
||||||
|
jsonify({"success": False, "message": "Missing file"}),
|
||||||
|
400,
|
||||||
|
)
|
||||||
|
|
||||||
|
filename = safe_filename(os.path.basename(file.filename))
|
||||||
|
suffix = Path(filename).suffix.lower()
|
||||||
|
if suffix not in SUPPORTED_AUDIO_EXTENSIONS:
|
||||||
|
return make_response(
|
||||||
|
jsonify({"success": False, "message": "Unsupported audio format"}),
|
||||||
|
400,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not _is_supported_audio_mimetype(file.mimetype or ""):
|
||||||
|
return make_response(
|
||||||
|
jsonify({"success": False, "message": "Unsupported audio MIME type"}),
|
||||||
|
400,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
_enforce_uploaded_audio_size_limit(file, filename)
|
||||||
|
except AudioFileTooLargeError:
|
||||||
|
return make_response(
|
||||||
|
jsonify(
|
||||||
|
{
|
||||||
|
"success": False,
|
||||||
|
"message": build_stt_file_size_limit_message(),
|
||||||
|
}
|
||||||
|
),
|
||||||
|
413,
|
||||||
|
)
|
||||||
|
|
||||||
|
temp_path = None
|
||||||
|
try:
|
||||||
|
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as temp_file:
|
||||||
|
file.save(temp_file.name)
|
||||||
|
temp_path = Path(temp_file.name)
|
||||||
|
|
||||||
|
stt_instance = STTCreator.create_stt(settings.STT_PROVIDER)
|
||||||
|
transcript = stt_instance.transcribe(
|
||||||
|
temp_path,
|
||||||
|
language=request.form.get("language") or settings.STT_LANGUAGE,
|
||||||
|
timestamps=settings.STT_ENABLE_TIMESTAMPS,
|
||||||
|
diarize=settings.STT_ENABLE_DIARIZATION,
|
||||||
|
)
|
||||||
|
return make_response(jsonify({"success": True, **transcript}), 200)
|
||||||
|
except Exception as err:
|
||||||
|
current_app.logger.error(f"Error transcribing audio: {err}", exc_info=True)
|
||||||
|
return make_response(
|
||||||
|
jsonify({"success": False, "message": "Failed to transcribe audio"}),
|
||||||
|
400,
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
if temp_path and temp_path.exists():
|
||||||
|
temp_path.unlink()
|
||||||
|
|
||||||
|
|
||||||
|
@attachments_ns.route("/stt/live/start")
|
||||||
|
class LiveSpeechToTextStart(Resource):
|
||||||
|
@api.doc(description="Start a live speech-to-text session")
|
||||||
|
def post(self):
|
||||||
|
auth_user = _resolve_authenticated_user()
|
||||||
|
if hasattr(auth_user, "status_code"):
|
||||||
|
return auth_user
|
||||||
|
if not auth_user:
|
||||||
|
return make_response(
|
||||||
|
jsonify({"success": False, "message": "Authentication required"}),
|
||||||
|
401,
|
||||||
|
)
|
||||||
|
|
||||||
|
redis_client = _require_live_stt_redis()
|
||||||
|
if hasattr(redis_client, "status_code"):
|
||||||
|
return redis_client
|
||||||
|
|
||||||
|
payload = request.get_json(silent=True) or {}
|
||||||
|
session_state = create_live_stt_session(
|
||||||
|
user=auth_user,
|
||||||
|
language=payload.get("language") or settings.STT_LANGUAGE,
|
||||||
|
)
|
||||||
|
save_live_stt_session(redis_client, session_state)
|
||||||
|
|
||||||
|
return make_response(
|
||||||
|
jsonify(
|
||||||
|
{
|
||||||
|
"success": True,
|
||||||
|
"session_id": session_state["session_id"],
|
||||||
|
"language": session_state.get("language"),
|
||||||
|
"committed_text": "",
|
||||||
|
"mutable_text": "",
|
||||||
|
"previous_hypothesis": "",
|
||||||
|
"latest_hypothesis": "",
|
||||||
|
"finalized_text": "",
|
||||||
|
"pending_text": "",
|
||||||
|
"transcript_text": "",
|
||||||
|
}
|
||||||
|
),
|
||||||
|
200,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@attachments_ns.route("/stt/live/chunk")
|
||||||
|
class LiveSpeechToTextChunk(Resource):
|
||||||
|
@api.expect(
|
||||||
|
api.model(
|
||||||
|
"LiveSpeechToTextChunkModel",
|
||||||
|
{
|
||||||
|
"session_id": fields.String(
|
||||||
|
required=True, description="Live transcription session ID"
|
||||||
|
),
|
||||||
|
"chunk_index": fields.Integer(
|
||||||
|
required=True, description="Sequential chunk index"
|
||||||
|
),
|
||||||
|
"is_silence": fields.Boolean(
|
||||||
|
required=False,
|
||||||
|
description="Whether the latest capture window was mostly silence",
|
||||||
|
),
|
||||||
|
"file": fields.Raw(required=True, description="Audio chunk"),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
@api.doc(description="Transcribe a chunk for a live speech-to-text session")
|
||||||
|
def post(self):
|
||||||
|
auth_user = _resolve_authenticated_user()
|
||||||
|
if hasattr(auth_user, "status_code"):
|
||||||
|
return auth_user
|
||||||
|
if not auth_user:
|
||||||
|
return make_response(
|
||||||
|
jsonify({"success": False, "message": "Authentication required"}),
|
||||||
|
401,
|
||||||
|
)
|
||||||
|
|
||||||
|
redis_client = _require_live_stt_redis()
|
||||||
|
if hasattr(redis_client, "status_code"):
|
||||||
|
return redis_client
|
||||||
|
|
||||||
|
session_id = request.form.get("session_id", "").strip()
|
||||||
|
if not session_id:
|
||||||
|
return make_response(
|
||||||
|
jsonify({"success": False, "message": "Missing session_id"}),
|
||||||
|
400,
|
||||||
|
)
|
||||||
|
|
||||||
|
session_state = load_live_stt_session(redis_client, session_id)
|
||||||
|
if not session_state:
|
||||||
|
return make_response(
|
||||||
|
jsonify(
|
||||||
|
{
|
||||||
|
"success": False,
|
||||||
|
"message": "Live transcription session not found",
|
||||||
|
}
|
||||||
|
),
|
||||||
|
404,
|
||||||
|
)
|
||||||
|
|
||||||
|
if safe_filename(str(session_state.get("user", ""))) != auth_user:
|
||||||
|
return make_response(
|
||||||
|
jsonify({"success": False, "message": "Forbidden"}),
|
||||||
|
403,
|
||||||
|
)
|
||||||
|
|
||||||
|
chunk_index_raw = request.form.get("chunk_index", "").strip()
|
||||||
|
if chunk_index_raw == "":
|
||||||
|
return make_response(
|
||||||
|
jsonify({"success": False, "message": "Missing chunk_index"}),
|
||||||
|
400,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
chunk_index = int(chunk_index_raw)
|
||||||
|
except ValueError:
|
||||||
|
return make_response(
|
||||||
|
jsonify({"success": False, "message": "Invalid chunk_index"}),
|
||||||
|
400,
|
||||||
|
)
|
||||||
|
is_silence = _parse_bool_form_value(request.form.get("is_silence"))
|
||||||
|
|
||||||
|
file = request.files.get("file")
|
||||||
|
if not file or file.filename == "":
|
||||||
|
return make_response(
|
||||||
|
jsonify({"success": False, "message": "Missing file"}),
|
||||||
|
400,
|
||||||
|
)
|
||||||
|
|
||||||
|
filename = safe_filename(os.path.basename(file.filename))
|
||||||
|
suffix = Path(filename).suffix.lower()
|
||||||
|
if suffix not in SUPPORTED_AUDIO_EXTENSIONS:
|
||||||
|
return make_response(
|
||||||
|
jsonify({"success": False, "message": "Unsupported audio format"}),
|
||||||
|
400,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not _is_supported_audio_mimetype(file.mimetype or ""):
|
||||||
|
return make_response(
|
||||||
|
jsonify({"success": False, "message": "Unsupported audio MIME type"}),
|
||||||
|
400,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
_enforce_uploaded_audio_size_limit(file, filename)
|
||||||
|
except AudioFileTooLargeError:
|
||||||
|
return make_response(
|
||||||
|
jsonify(
|
||||||
|
{
|
||||||
|
"success": False,
|
||||||
|
"message": build_stt_file_size_limit_message(),
|
||||||
|
}
|
||||||
|
),
|
||||||
|
413,
|
||||||
|
)
|
||||||
|
|
||||||
|
temp_path = None
|
||||||
|
try:
|
||||||
|
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as temp_file:
|
||||||
|
file.save(temp_file.name)
|
||||||
|
temp_path = Path(temp_file.name)
|
||||||
|
|
||||||
|
session_language = session_state.get("language") or settings.STT_LANGUAGE
|
||||||
|
stt_instance = STTCreator.create_stt(settings.STT_PROVIDER)
|
||||||
|
transcript = stt_instance.transcribe(
|
||||||
|
temp_path,
|
||||||
|
language=session_language,
|
||||||
|
timestamps=False,
|
||||||
|
diarize=False,
|
||||||
|
)
|
||||||
|
if not session_state.get("language") and transcript.get("language"):
|
||||||
|
session_state["language"] = transcript["language"]
|
||||||
|
|
||||||
|
try:
|
||||||
|
apply_live_stt_hypothesis(
|
||||||
|
session_state,
|
||||||
|
str(transcript.get("text", "")),
|
||||||
|
chunk_index,
|
||||||
|
is_silence=is_silence,
|
||||||
|
)
|
||||||
|
except ValueError:
|
||||||
|
current_app.logger.warning(
|
||||||
|
"Invalid live transcription chunk",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
return make_response(
|
||||||
|
jsonify(
|
||||||
|
{
|
||||||
|
"success": False,
|
||||||
|
"message": "Invalid live transcription chunk",
|
||||||
|
}
|
||||||
|
),
|
||||||
|
409,
|
||||||
|
)
|
||||||
|
save_live_stt_session(redis_client, session_state)
|
||||||
|
|
||||||
|
return make_response(
|
||||||
|
jsonify(
|
||||||
|
{
|
||||||
|
"success": True,
|
||||||
|
"session_id": session_id,
|
||||||
|
"chunk_index": chunk_index,
|
||||||
|
"chunk_text": transcript.get("text", ""),
|
||||||
|
"is_silence": is_silence,
|
||||||
|
"language": session_state.get("language"),
|
||||||
|
"committed_text": session_state.get("committed_text", ""),
|
||||||
|
"mutable_text": session_state.get("mutable_text", ""),
|
||||||
|
"previous_hypothesis": session_state.get(
|
||||||
|
"previous_hypothesis", ""
|
||||||
|
),
|
||||||
|
"latest_hypothesis": session_state.get(
|
||||||
|
"latest_hypothesis", ""
|
||||||
|
),
|
||||||
|
"finalized_text": session_state.get("committed_text", ""),
|
||||||
|
"pending_text": session_state.get("mutable_text", ""),
|
||||||
|
"transcript_text": get_live_stt_transcript_text(session_state),
|
||||||
|
}
|
||||||
|
),
|
||||||
|
200,
|
||||||
|
)
|
||||||
|
except Exception as err:
|
||||||
|
current_app.logger.error(
|
||||||
|
f"Error transcribing live audio chunk: {err}", exc_info=True
|
||||||
|
)
|
||||||
|
return make_response(
|
||||||
|
jsonify({"success": False, "message": "Failed to transcribe audio"}),
|
||||||
|
400,
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
if temp_path and temp_path.exists():
|
||||||
|
temp_path.unlink()
|
||||||
|
|
||||||
|
|
||||||
|
@attachments_ns.route("/stt/live/finish")
|
||||||
|
class LiveSpeechToTextFinish(Resource):
|
||||||
|
@api.doc(description="Finish a live speech-to-text session")
|
||||||
|
def post(self):
|
||||||
|
auth_user = _resolve_authenticated_user()
|
||||||
|
if hasattr(auth_user, "status_code"):
|
||||||
|
return auth_user
|
||||||
|
if not auth_user:
|
||||||
|
return make_response(
|
||||||
|
jsonify({"success": False, "message": "Authentication required"}),
|
||||||
|
401,
|
||||||
|
)
|
||||||
|
|
||||||
|
redis_client = _require_live_stt_redis()
|
||||||
|
if hasattr(redis_client, "status_code"):
|
||||||
|
return redis_client
|
||||||
|
|
||||||
|
payload = request.get_json(silent=True) or {}
|
||||||
|
session_id = str(payload.get("session_id", "")).strip()
|
||||||
|
if not session_id:
|
||||||
|
return make_response(
|
||||||
|
jsonify({"success": False, "message": "Missing session_id"}),
|
||||||
|
400,
|
||||||
|
)
|
||||||
|
|
||||||
|
session_state = load_live_stt_session(redis_client, session_id)
|
||||||
|
if not session_state:
|
||||||
|
return make_response(
|
||||||
|
jsonify(
|
||||||
|
{
|
||||||
|
"success": False,
|
||||||
|
"message": "Live transcription session not found",
|
||||||
|
}
|
||||||
|
),
|
||||||
|
404,
|
||||||
|
)
|
||||||
|
|
||||||
|
if safe_filename(str(session_state.get("user", ""))) != auth_user:
|
||||||
|
return make_response(
|
||||||
|
jsonify({"success": False, "message": "Forbidden"}),
|
||||||
|
403,
|
||||||
|
)
|
||||||
|
|
||||||
|
final_text = finalize_live_stt_session(session_state)
|
||||||
|
delete_live_stt_session(redis_client, session_id)
|
||||||
|
|
||||||
|
return make_response(
|
||||||
|
jsonify(
|
||||||
|
{
|
||||||
|
"success": True,
|
||||||
|
"session_id": session_id,
|
||||||
|
"language": session_state.get("language"),
|
||||||
|
"text": final_text,
|
||||||
|
}
|
||||||
|
),
|
||||||
|
200,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@attachments_ns.route("/images/<path:image_path>")
|
@attachments_ns.route("/images/<path:image_path>")
|
||||||
class ServeImage(Resource):
|
class ServeImage(Resource):
|
||||||
@api.doc(description="Serve an image from storage")
|
@api.doc(description="Serve an image from storage")
|
||||||
def get(self, image_path):
|
def get(self, image_path):
|
||||||
|
if ".." in image_path or image_path.startswith("/") or "\x00" in image_path:
|
||||||
|
return make_response(
|
||||||
|
jsonify({"success": False, "message": "Invalid image path"}), 400
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
|
from application.api.user.base import storage
|
||||||
|
|
||||||
file_obj = storage.get_file(image_path)
|
file_obj = storage.get_file(image_path)
|
||||||
extension = image_path.split(".")[-1].lower()
|
extension = image_path.split(".")[-1].lower()
|
||||||
content_type = f"image/{extension}"
|
content_type = f"image/{extension}"
|
||||||
@@ -157,6 +635,10 @@ class ServeImage(Resource):
|
|||||||
return make_response(
|
return make_response(
|
||||||
jsonify({"success": False, "message": "Image not found"}), 404
|
jsonify({"success": False, "message": "Image not found"}), 404
|
||||||
)
|
)
|
||||||
|
except ValueError:
|
||||||
|
return make_response(
|
||||||
|
jsonify({"success": False, "message": "Invalid image path"}), 400
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
current_app.logger.error(f"Error serving image: {e}")
|
current_app.logger.error(f"Error serving image: {e}")
|
||||||
return make_response(
|
return make_response(
|
||||||
|
|||||||
@@ -8,13 +8,15 @@ import uuid
|
|||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
from bson.objectid import ObjectId
|
|
||||||
from flask import current_app, jsonify, make_response, Response
|
from flask import current_app, jsonify, make_response, Response
|
||||||
from pymongo import ReturnDocument
|
|
||||||
from werkzeug.utils import secure_filename
|
from werkzeug.utils import secure_filename
|
||||||
|
|
||||||
from application.core.mongo_db import MongoDB
|
from sqlalchemy import text as _sql_text
|
||||||
|
|
||||||
from application.core.settings import settings
|
from application.core.settings import settings
|
||||||
|
from application.storage.db.base_repository import looks_like_uuid, row_to_dict
|
||||||
|
from application.storage.db.repositories.users import UsersRepository
|
||||||
|
from application.storage.db.session import db_readonly, db_session
|
||||||
from application.storage.storage_creator import StorageCreator
|
from application.storage.storage_creator import StorageCreator
|
||||||
from application.vectorstore.vector_creator import VectorCreator
|
from application.vectorstore.vector_creator import VectorCreator
|
||||||
|
|
||||||
@@ -22,32 +24,6 @@ from application.vectorstore.vector_creator import VectorCreator
|
|||||||
storage = StorageCreator.get_storage()
|
storage = StorageCreator.get_storage()
|
||||||
|
|
||||||
|
|
||||||
mongo = MongoDB.get_client()
|
|
||||||
db = mongo[settings.MONGO_DB_NAME]
|
|
||||||
|
|
||||||
|
|
||||||
conversations_collection = db["conversations"]
|
|
||||||
sources_collection = db["sources"]
|
|
||||||
prompts_collection = db["prompts"]
|
|
||||||
feedback_collection = db["feedback"]
|
|
||||||
agents_collection = db["agents"]
|
|
||||||
token_usage_collection = db["token_usage"]
|
|
||||||
shared_conversations_collections = db["shared_conversations"]
|
|
||||||
users_collection = db["users"]
|
|
||||||
user_logs_collection = db["user_logs"]
|
|
||||||
user_tools_collection = db["user_tools"]
|
|
||||||
attachments_collection = db["attachments"]
|
|
||||||
|
|
||||||
|
|
||||||
try:
|
|
||||||
agents_collection.create_index(
|
|
||||||
[("shared", 1)],
|
|
||||||
name="shared_index",
|
|
||||||
background=True,
|
|
||||||
)
|
|
||||||
users_collection.create_index("user_id", unique=True)
|
|
||||||
except Exception as e:
|
|
||||||
print("Error creating indexes:", e)
|
|
||||||
current_dir = os.path.dirname(
|
current_dir = os.path.dirname(
|
||||||
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||||
)
|
)
|
||||||
@@ -79,58 +55,95 @@ def generate_date_range(start_date, end_date):
|
|||||||
|
|
||||||
def ensure_user_doc(user_id):
|
def ensure_user_doc(user_id):
|
||||||
"""
|
"""
|
||||||
Ensure user document exists with proper agent preferences structure.
|
Ensure a Postgres ``users`` row exists for ``user_id``.
|
||||||
|
|
||||||
|
Returns the row as a dict with the shape legacy callers expect — in
|
||||||
|
particular ``user_id`` and ``agent_preferences`` (with ``pinned`` and
|
||||||
|
``shared_with_me`` list keys always present).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id: The user ID to ensure
|
user_id: The user ID to ensure
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The user document
|
The user document as a dict.
|
||||||
"""
|
"""
|
||||||
default_prefs = {
|
with db_session() as conn:
|
||||||
"pinned": [],
|
user_doc = UsersRepository(conn).upsert(user_id)
|
||||||
"shared_with_me": [],
|
|
||||||
}
|
|
||||||
|
|
||||||
user_doc = users_collection.find_one_and_update(
|
prefs = user_doc.get("agent_preferences") or {}
|
||||||
{"user_id": user_id},
|
if not isinstance(prefs, dict):
|
||||||
{"$setOnInsert": {"agent_preferences": default_prefs}},
|
prefs = {}
|
||||||
upsert=True,
|
prefs.setdefault("pinned", [])
|
||||||
return_document=ReturnDocument.AFTER,
|
prefs.setdefault("shared_with_me", [])
|
||||||
)
|
user_doc["agent_preferences"] = prefs
|
||||||
|
|
||||||
prefs = user_doc.get("agent_preferences", {})
|
|
||||||
updates = {}
|
|
||||||
if "pinned" not in prefs:
|
|
||||||
updates["agent_preferences.pinned"] = []
|
|
||||||
if "shared_with_me" not in prefs:
|
|
||||||
updates["agent_preferences.shared_with_me"] = []
|
|
||||||
if updates:
|
|
||||||
users_collection.update_one({"user_id": user_id}, {"$set": updates})
|
|
||||||
user_doc = users_collection.find_one({"user_id": user_id})
|
|
||||||
return user_doc
|
return user_doc
|
||||||
|
|
||||||
|
|
||||||
def resolve_tool_details(tool_ids):
|
def resolve_tool_details(tool_ids):
|
||||||
"""
|
"""
|
||||||
Resolve tool IDs to their details.
|
Resolve tool IDs to their display details.
|
||||||
|
|
||||||
|
Accepts either Postgres UUIDs or legacy Mongo ObjectId strings (mixed
|
||||||
|
lists are supported — each id is looked up via ``get_any``, which
|
||||||
|
resolves to whichever column matches). Unknown ids are silently
|
||||||
|
skipped.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
tool_ids: List of tool IDs
|
tool_ids: List of tool IDs (UUIDs or legacy Mongo ObjectId strings).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of tool details with id, name, and display_name
|
List of tool details with ``id``, ``name``, and ``display_name``.
|
||||||
"""
|
"""
|
||||||
tools = user_tools_collection.find(
|
if not tool_ids:
|
||||||
{"_id": {"$in": [ObjectId(tid) for tid in tool_ids]}}
|
return []
|
||||||
)
|
|
||||||
|
uuid_ids: list[str] = []
|
||||||
|
legacy_ids: list[str] = []
|
||||||
|
for tid in tool_ids:
|
||||||
|
if not tid:
|
||||||
|
continue
|
||||||
|
tid_str = str(tid)
|
||||||
|
if looks_like_uuid(tid_str):
|
||||||
|
uuid_ids.append(tid_str)
|
||||||
|
else:
|
||||||
|
legacy_ids.append(tid_str)
|
||||||
|
|
||||||
|
if not uuid_ids and not legacy_ids:
|
||||||
|
return []
|
||||||
|
|
||||||
|
rows: list[dict] = []
|
||||||
|
with db_readonly() as conn:
|
||||||
|
if uuid_ids:
|
||||||
|
result = conn.execute(
|
||||||
|
_sql_text(
|
||||||
|
"SELECT * FROM user_tools "
|
||||||
|
"WHERE id = ANY(CAST(:ids AS uuid[]))"
|
||||||
|
),
|
||||||
|
{"ids": uuid_ids},
|
||||||
|
)
|
||||||
|
rows.extend(row_to_dict(r) for r in result.fetchall())
|
||||||
|
if legacy_ids:
|
||||||
|
result = conn.execute(
|
||||||
|
_sql_text(
|
||||||
|
"SELECT * FROM user_tools "
|
||||||
|
"WHERE legacy_mongo_id = ANY(:ids)"
|
||||||
|
),
|
||||||
|
{"ids": legacy_ids},
|
||||||
|
)
|
||||||
|
rows.extend(row_to_dict(r) for r in result.fetchall())
|
||||||
|
|
||||||
return [
|
return [
|
||||||
{
|
{
|
||||||
"id": str(tool["_id"]),
|
"id": str(tool.get("id") or tool.get("legacy_mongo_id") or ""),
|
||||||
"name": tool.get("name", ""),
|
"name": tool.get("name", "") or "",
|
||||||
"display_name": tool.get("displayName", tool.get("name", "")),
|
"display_name": (
|
||||||
|
tool.get("custom_name")
|
||||||
|
or tool.get("display_name")
|
||||||
|
or tool.get("name", "")
|
||||||
|
or ""
|
||||||
|
),
|
||||||
}
|
}
|
||||||
for tool in tools
|
for tool in rows
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@@ -200,14 +213,15 @@ def require_agent(func):
|
|||||||
|
|
||||||
@wraps(func)
|
@wraps(func)
|
||||||
def wrapper(*args, **kwargs):
|
def wrapper(*args, **kwargs):
|
||||||
|
from application.storage.db.repositories.agents import AgentsRepository
|
||||||
|
|
||||||
webhook_token = kwargs.get("webhook_token")
|
webhook_token = kwargs.get("webhook_token")
|
||||||
if not webhook_token:
|
if not webhook_token:
|
||||||
return make_response(
|
return make_response(
|
||||||
jsonify({"success": False, "message": "Webhook token missing"}), 400
|
jsonify({"success": False, "message": "Webhook token missing"}), 400
|
||||||
)
|
)
|
||||||
agent = agents_collection.find_one(
|
with db_readonly() as conn:
|
||||||
{"incoming_webhook_token": webhook_token}, {"_id": 1}
|
agent = AgentsRepository(conn).find_by_webhook_token(webhook_token)
|
||||||
)
|
|
||||||
if not agent:
|
if not agent:
|
||||||
current_app.logger.warning(
|
current_app.logger.warning(
|
||||||
f"Webhook attempt with invalid token: {webhook_token}"
|
f"Webhook attempt with invalid token: {webhook_token}"
|
||||||
@@ -216,7 +230,7 @@ def require_agent(func):
|
|||||||
jsonify({"success": False, "message": "Agent not found"}), 404
|
jsonify({"success": False, "message": "Agent not found"}), 404
|
||||||
)
|
)
|
||||||
kwargs["agent"] = agent
|
kwargs["agent"] = agent
|
||||||
kwargs["agent_id_str"] = str(agent["_id"])
|
kwargs["agent_id_str"] = str(agent["id"])
|
||||||
return func(*args, **kwargs)
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|||||||
@@ -2,12 +2,13 @@
|
|||||||
|
|
||||||
import datetime
|
import datetime
|
||||||
|
|
||||||
from bson.objectid import ObjectId
|
|
||||||
from flask import current_app, jsonify, make_response, request
|
from flask import current_app, jsonify, make_response, request
|
||||||
from flask_restx import fields, Namespace, Resource
|
from flask_restx import fields, Namespace, Resource
|
||||||
|
|
||||||
from application.api import api
|
from application.api import api
|
||||||
from application.api.user.base import attachments_collection, conversations_collection
|
from application.storage.db.repositories.attachments import AttachmentsRepository
|
||||||
|
from application.storage.db.repositories.conversations import ConversationsRepository
|
||||||
|
from application.storage.db.session import db_readonly, db_session
|
||||||
from application.utils import check_required_fields
|
from application.utils import check_required_fields
|
||||||
|
|
||||||
conversations_ns = Namespace(
|
conversations_ns = Namespace(
|
||||||
@@ -30,10 +31,13 @@ class DeleteConversation(Resource):
|
|||||||
return make_response(
|
return make_response(
|
||||||
jsonify({"success": False, "message": "ID is required"}), 400
|
jsonify({"success": False, "message": "ID is required"}), 400
|
||||||
)
|
)
|
||||||
|
user_id = decoded_token["sub"]
|
||||||
try:
|
try:
|
||||||
conversations_collection.delete_one(
|
with db_session() as conn:
|
||||||
{"_id": ObjectId(conversation_id), "user": decoded_token["sub"]}
|
repo = ConversationsRepository(conn)
|
||||||
)
|
conv = repo.get_any(conversation_id, user_id)
|
||||||
|
if conv is not None:
|
||||||
|
repo.delete(str(conv["id"]), user_id)
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
current_app.logger.error(
|
current_app.logger.error(
|
||||||
f"Error deleting conversation: {err}", exc_info=True
|
f"Error deleting conversation: {err}", exc_info=True
|
||||||
@@ -53,7 +57,8 @@ class DeleteAllConversations(Resource):
|
|||||||
return make_response(jsonify({"success": False}), 401)
|
return make_response(jsonify({"success": False}), 401)
|
||||||
user_id = decoded_token.get("sub")
|
user_id = decoded_token.get("sub")
|
||||||
try:
|
try:
|
||||||
conversations_collection.delete_many({"user": user_id})
|
with db_session() as conn:
|
||||||
|
ConversationsRepository(conn).delete_all_for_user(user_id)
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
current_app.logger.error(
|
current_app.logger.error(
|
||||||
f"Error deleting all conversations: {err}", exc_info=True
|
f"Error deleting all conversations: {err}", exc_info=True
|
||||||
@@ -71,26 +76,21 @@ class GetConversations(Resource):
|
|||||||
decoded_token = request.decoded_token
|
decoded_token = request.decoded_token
|
||||||
if not decoded_token:
|
if not decoded_token:
|
||||||
return make_response(jsonify({"success": False}), 401)
|
return make_response(jsonify({"success": False}), 401)
|
||||||
|
user_id = decoded_token.get("sub")
|
||||||
try:
|
try:
|
||||||
conversations = (
|
with db_readonly() as conn:
|
||||||
conversations_collection.find(
|
conversations = ConversationsRepository(conn).list_for_user(
|
||||||
{
|
user_id, limit=30
|
||||||
"$or": [
|
|
||||||
{"api_key": {"$exists": False}},
|
|
||||||
{"agent_id": {"$exists": True}},
|
|
||||||
],
|
|
||||||
"user": decoded_token.get("sub"),
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
.sort("date", -1)
|
|
||||||
.limit(30)
|
|
||||||
)
|
|
||||||
|
|
||||||
list_conversations = [
|
list_conversations = [
|
||||||
{
|
{
|
||||||
"id": str(conversation["_id"]),
|
"id": str(conversation["id"]),
|
||||||
"name": conversation["name"],
|
"name": conversation["name"],
|
||||||
"agent_id": conversation.get("agent_id", None),
|
"agent_id": (
|
||||||
|
str(conversation["agent_id"])
|
||||||
|
if conversation.get("agent_id")
|
||||||
|
else None
|
||||||
|
),
|
||||||
"is_shared_usage": conversation.get("is_shared_usage", False),
|
"is_shared_usage": conversation.get("is_shared_usage", False),
|
||||||
"shared_token": conversation.get("shared_token", None),
|
"shared_token": conversation.get("shared_token", None),
|
||||||
}
|
}
|
||||||
@@ -119,38 +119,67 @@ class GetSingleConversation(Resource):
|
|||||||
return make_response(
|
return make_response(
|
||||||
jsonify({"success": False, "message": "ID is required"}), 400
|
jsonify({"success": False, "message": "ID is required"}), 400
|
||||||
)
|
)
|
||||||
|
user_id = decoded_token.get("sub")
|
||||||
try:
|
try:
|
||||||
conversation = conversations_collection.find_one(
|
with db_readonly() as conn:
|
||||||
{"_id": ObjectId(conversation_id), "user": decoded_token.get("sub")}
|
repo = ConversationsRepository(conn)
|
||||||
)
|
conversation = repo.get_any(conversation_id, user_id)
|
||||||
if not conversation:
|
if not conversation:
|
||||||
return make_response(jsonify({"status": "not found"}), 404)
|
return make_response(jsonify({"status": "not found"}), 404)
|
||||||
# Process queries to include attachment names
|
conv_pg_id = str(conversation["id"])
|
||||||
|
messages = repo.get_messages(conv_pg_id)
|
||||||
|
|
||||||
queries = conversation["queries"]
|
# Resolve attachment details (id, fileName) for each message.
|
||||||
for query in queries:
|
attachments_repo = AttachmentsRepository(conn)
|
||||||
if "attachments" in query and query["attachments"]:
|
queries = []
|
||||||
attachment_details = []
|
for msg in messages:
|
||||||
for attachment_id in query["attachments"]:
|
query = {
|
||||||
try:
|
"prompt": msg.get("prompt"),
|
||||||
attachment = attachments_collection.find_one(
|
"response": msg.get("response"),
|
||||||
{"_id": ObjectId(attachment_id)}
|
"thought": msg.get("thought"),
|
||||||
)
|
"sources": msg.get("sources") or [],
|
||||||
if attachment:
|
"tool_calls": msg.get("tool_calls") or [],
|
||||||
attachment_details.append(
|
"timestamp": msg.get("timestamp"),
|
||||||
{
|
"model_id": msg.get("model_id"),
|
||||||
"id": str(attachment["_id"]),
|
}
|
||||||
"fileName": attachment.get(
|
if msg.get("metadata"):
|
||||||
"filename", "Unknown file"
|
query["metadata"] = msg["metadata"]
|
||||||
),
|
# Feedback on conversation_messages is a JSONB blob with
|
||||||
}
|
# shape {"text": <str>, "timestamp": <iso>}. The legacy
|
||||||
|
# frontend consumed a flat scalar feedback string, so
|
||||||
|
# unwrap the ``text`` field for compat.
|
||||||
|
feedback = msg.get("feedback")
|
||||||
|
if feedback is not None:
|
||||||
|
if isinstance(feedback, dict):
|
||||||
|
query["feedback"] = feedback.get("text")
|
||||||
|
if feedback.get("timestamp"):
|
||||||
|
query["feedback_timestamp"] = feedback["timestamp"]
|
||||||
|
else:
|
||||||
|
query["feedback"] = feedback
|
||||||
|
attachments = msg.get("attachments") or []
|
||||||
|
if attachments:
|
||||||
|
attachment_details = []
|
||||||
|
for attachment_id in attachments:
|
||||||
|
try:
|
||||||
|
att = attachments_repo.get_any(
|
||||||
|
str(attachment_id), user_id
|
||||||
)
|
)
|
||||||
except Exception as e:
|
if att:
|
||||||
current_app.logger.error(
|
attachment_details.append(
|
||||||
f"Error retrieving attachment {attachment_id}: {e}",
|
{
|
||||||
exc_info=True,
|
"id": str(att["id"]),
|
||||||
)
|
"fileName": att.get(
|
||||||
query["attachments"] = attachment_details
|
"filename", "Unknown file"
|
||||||
|
),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
current_app.logger.error(
|
||||||
|
f"Error retrieving attachment {attachment_id}: {e}",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
query["attachments"] = attachment_details
|
||||||
|
queries.append(query)
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
current_app.logger.error(
|
current_app.logger.error(
|
||||||
f"Error retrieving conversation: {err}", exc_info=True
|
f"Error retrieving conversation: {err}", exc_info=True
|
||||||
@@ -158,7 +187,9 @@ class GetSingleConversation(Resource):
|
|||||||
return make_response(jsonify({"success": False}), 400)
|
return make_response(jsonify({"success": False}), 400)
|
||||||
data = {
|
data = {
|
||||||
"queries": queries,
|
"queries": queries,
|
||||||
"agent_id": conversation.get("agent_id"),
|
"agent_id": (
|
||||||
|
str(conversation["agent_id"]) if conversation.get("agent_id") else None
|
||||||
|
),
|
||||||
"is_shared_usage": conversation.get("is_shared_usage", False),
|
"is_shared_usage": conversation.get("is_shared_usage", False),
|
||||||
"shared_token": conversation.get("shared_token", None),
|
"shared_token": conversation.get("shared_token", None),
|
||||||
}
|
}
|
||||||
@@ -190,11 +221,13 @@ class UpdateConversationName(Resource):
|
|||||||
missing_fields = check_required_fields(data, required_fields)
|
missing_fields = check_required_fields(data, required_fields)
|
||||||
if missing_fields:
|
if missing_fields:
|
||||||
return missing_fields
|
return missing_fields
|
||||||
|
user_id = decoded_token.get("sub")
|
||||||
try:
|
try:
|
||||||
conversations_collection.update_one(
|
with db_session() as conn:
|
||||||
{"_id": ObjectId(data["id"]), "user": decoded_token.get("sub")},
|
repo = ConversationsRepository(conn)
|
||||||
{"$set": {"name": data["name"]}},
|
conv = repo.get_any(data["id"], user_id)
|
||||||
)
|
if conv is not None:
|
||||||
|
repo.rename(str(conv["id"]), user_id, data["name"])
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
current_app.logger.error(
|
current_app.logger.error(
|
||||||
f"Error updating conversation name: {err}", exc_info=True
|
f"Error updating conversation name: {err}", exc_info=True
|
||||||
@@ -237,43 +270,33 @@ class SubmitFeedback(Resource):
|
|||||||
missing_fields = check_required_fields(data, required_fields)
|
missing_fields = check_required_fields(data, required_fields)
|
||||||
if missing_fields:
|
if missing_fields:
|
||||||
return missing_fields
|
return missing_fields
|
||||||
|
user_id = decoded_token.get("sub")
|
||||||
|
feedback_value = data["feedback"]
|
||||||
|
question_index = int(data["question_index"])
|
||||||
|
# Normalize string feedback to lowercase so analytics queries
|
||||||
|
# (which match 'like'/'dislike') count rows correctly. Tolerate
|
||||||
|
# legacy uppercase clients on ingest. Non-string values pass through.
|
||||||
|
if isinstance(feedback_value, str):
|
||||||
|
feedback_value = feedback_value.lower()
|
||||||
|
feedback_payload = (
|
||||||
|
None
|
||||||
|
if feedback_value is None
|
||||||
|
else {
|
||||||
|
"text": feedback_value,
|
||||||
|
"timestamp": datetime.datetime.now(
|
||||||
|
datetime.timezone.utc
|
||||||
|
).isoformat(),
|
||||||
|
}
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
if data["feedback"] is None:
|
with db_session() as conn:
|
||||||
# Remove feedback and feedback_timestamp if feedback is null
|
repo = ConversationsRepository(conn)
|
||||||
|
conv = repo.get_any(data["conversation_id"], user_id)
|
||||||
conversations_collection.update_one(
|
if conv is None:
|
||||||
{
|
return make_response(
|
||||||
"_id": ObjectId(data["conversation_id"]),
|
jsonify({"success": False, "message": "Not found"}), 404
|
||||||
"user": decoded_token.get("sub"),
|
)
|
||||||
f"queries.{data['question_index']}": {"$exists": True},
|
repo.set_feedback(str(conv["id"]), question_index, feedback_payload)
|
||||||
},
|
|
||||||
{
|
|
||||||
"$unset": {
|
|
||||||
f"queries.{data['question_index']}.feedback": "",
|
|
||||||
f"queries.{data['question_index']}.feedback_timestamp": "",
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Set feedback and feedback_timestamp if feedback has a value
|
|
||||||
|
|
||||||
conversations_collection.update_one(
|
|
||||||
{
|
|
||||||
"_id": ObjectId(data["conversation_id"]),
|
|
||||||
"user": decoded_token.get("sub"),
|
|
||||||
f"queries.{data['question_index']}": {"$exists": True},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"$set": {
|
|
||||||
f"queries.{data['question_index']}.feedback": data[
|
|
||||||
"feedback"
|
|
||||||
],
|
|
||||||
f"queries.{data['question_index']}.feedback_timestamp": datetime.datetime.now(
|
|
||||||
datetime.timezone.utc
|
|
||||||
),
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
current_app.logger.error(f"Error submitting feedback: {err}", exc_info=True)
|
current_app.logger.error(f"Error submitting feedback: {err}", exc_info=True)
|
||||||
return make_response(jsonify({"success": False}), 400)
|
return make_response(jsonify({"success": False}), 400)
|
||||||
|
|||||||
@@ -2,12 +2,13 @@
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from bson.objectid import ObjectId
|
|
||||||
from flask import current_app, jsonify, make_response, request
|
from flask import current_app, jsonify, make_response, request
|
||||||
from flask_restx import fields, Namespace, Resource
|
from flask_restx import fields, Namespace, Resource
|
||||||
|
|
||||||
from application.api import api
|
from application.api import api
|
||||||
from application.api.user.base import current_dir, prompts_collection
|
from application.api.user.base import current_dir
|
||||||
|
from application.storage.db.repositories.prompts import PromptsRepository
|
||||||
|
from application.storage.db.session import db_readonly, db_session
|
||||||
from application.utils import check_required_fields
|
from application.utils import check_required_fields
|
||||||
|
|
||||||
prompts_ns = Namespace(
|
prompts_ns = Namespace(
|
||||||
@@ -40,15 +41,9 @@ class CreatePrompt(Resource):
|
|||||||
return missing_fields
|
return missing_fields
|
||||||
user = decoded_token.get("sub")
|
user = decoded_token.get("sub")
|
||||||
try:
|
try:
|
||||||
|
with db_session() as conn:
|
||||||
resp = prompts_collection.insert_one(
|
prompt = PromptsRepository(conn).create(user, data["name"], data["content"])
|
||||||
{
|
new_id = str(prompt["id"])
|
||||||
"name": data["name"],
|
|
||||||
"content": data["content"],
|
|
||||||
"user": user,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
new_id = str(resp.inserted_id)
|
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
current_app.logger.error(f"Error creating prompt: {err}", exc_info=True)
|
current_app.logger.error(f"Error creating prompt: {err}", exc_info=True)
|
||||||
return make_response(jsonify({"success": False}), 400)
|
return make_response(jsonify({"success": False}), 400)
|
||||||
@@ -64,17 +59,17 @@ class GetPrompts(Resource):
|
|||||||
return make_response(jsonify({"success": False}), 401)
|
return make_response(jsonify({"success": False}), 401)
|
||||||
user = decoded_token.get("sub")
|
user = decoded_token.get("sub")
|
||||||
try:
|
try:
|
||||||
prompts = prompts_collection.find({"user": user})
|
with db_readonly() as conn:
|
||||||
|
prompts = PromptsRepository(conn).list_for_user(user)
|
||||||
list_prompts = [
|
list_prompts = [
|
||||||
{"id": "default", "name": "default", "type": "public"},
|
{"id": "default", "name": "default", "type": "public"},
|
||||||
{"id": "creative", "name": "creative", "type": "public"},
|
{"id": "creative", "name": "creative", "type": "public"},
|
||||||
{"id": "strict", "name": "strict", "type": "public"},
|
{"id": "strict", "name": "strict", "type": "public"},
|
||||||
]
|
]
|
||||||
|
|
||||||
for prompt in prompts:
|
for prompt in prompts:
|
||||||
list_prompts.append(
|
list_prompts.append(
|
||||||
{
|
{
|
||||||
"id": str(prompt["_id"]),
|
"id": str(prompt["id"]),
|
||||||
"name": prompt["name"],
|
"name": prompt["name"],
|
||||||
"type": "private",
|
"type": "private",
|
||||||
}
|
}
|
||||||
@@ -119,9 +114,12 @@ class GetSinglePrompt(Resource):
|
|||||||
) as f:
|
) as f:
|
||||||
chat_reduce_strict = f.read()
|
chat_reduce_strict = f.read()
|
||||||
return make_response(jsonify({"content": chat_reduce_strict}), 200)
|
return make_response(jsonify({"content": chat_reduce_strict}), 200)
|
||||||
prompt = prompts_collection.find_one(
|
with db_readonly() as conn:
|
||||||
{"_id": ObjectId(prompt_id), "user": user}
|
prompt = PromptsRepository(conn).get_any(prompt_id, user)
|
||||||
)
|
if not prompt:
|
||||||
|
return make_response(
|
||||||
|
jsonify({"success": False, "message": "Prompt not found"}), 404
|
||||||
|
)
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
current_app.logger.error(f"Error retrieving prompt: {err}", exc_info=True)
|
current_app.logger.error(f"Error retrieving prompt: {err}", exc_info=True)
|
||||||
return make_response(jsonify({"success": False}), 400)
|
return make_response(jsonify({"success": False}), 400)
|
||||||
@@ -148,7 +146,15 @@ class DeletePrompt(Resource):
|
|||||||
if missing_fields:
|
if missing_fields:
|
||||||
return missing_fields
|
return missing_fields
|
||||||
try:
|
try:
|
||||||
prompts_collection.delete_one({"_id": ObjectId(data["id"]), "user": user})
|
with db_session() as conn:
|
||||||
|
repo = PromptsRepository(conn)
|
||||||
|
prompt = repo.get_any(data["id"], user)
|
||||||
|
if not prompt:
|
||||||
|
return make_response(
|
||||||
|
jsonify({"success": False, "message": "Prompt not found"}),
|
||||||
|
404,
|
||||||
|
)
|
||||||
|
repo.delete(str(prompt["id"]), user)
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
current_app.logger.error(f"Error deleting prompt: {err}", exc_info=True)
|
current_app.logger.error(f"Error deleting prompt: {err}", exc_info=True)
|
||||||
return make_response(jsonify({"success": False}), 400)
|
return make_response(jsonify({"success": False}), 400)
|
||||||
@@ -181,10 +187,15 @@ class UpdatePrompt(Resource):
|
|||||||
if missing_fields:
|
if missing_fields:
|
||||||
return missing_fields
|
return missing_fields
|
||||||
try:
|
try:
|
||||||
prompts_collection.update_one(
|
with db_session() as conn:
|
||||||
{"_id": ObjectId(data["id"]), "user": user},
|
repo = PromptsRepository(conn)
|
||||||
{"$set": {"name": data["name"], "content": data["content"]}},
|
prompt = repo.get_any(data["id"], user)
|
||||||
)
|
if not prompt:
|
||||||
|
return make_response(
|
||||||
|
jsonify({"success": False, "message": "Prompt not found"}),
|
||||||
|
404,
|
||||||
|
)
|
||||||
|
repo.update(str(prompt["id"]), user, data["name"], data["content"])
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
current_app.logger.error(f"Error updating prompt: {err}", exc_info=True)
|
current_app.logger.error(f"Error updating prompt: {err}", exc_info=True)
|
||||||
return make_response(jsonify({"success": False}), 400)
|
return make_response(jsonify({"success": False}), 400)
|
||||||
|
|||||||
@@ -5,8 +5,7 @@ Main user API routes - registers all namespace modules.
|
|||||||
from flask import Blueprint
|
from flask import Blueprint
|
||||||
|
|
||||||
from application.api import api
|
from application.api import api
|
||||||
from .agents import agents_ns, agents_sharing_ns, agents_webhooks_ns
|
from .agents import agents_ns, agents_sharing_ns, agents_webhooks_ns, agents_folders_ns
|
||||||
|
|
||||||
from .analytics import analytics_ns
|
from .analytics import analytics_ns
|
||||||
from .attachments import attachments_ns
|
from .attachments import attachments_ns
|
||||||
from .conversations import conversations_ns
|
from .conversations import conversations_ns
|
||||||
@@ -15,6 +14,7 @@ from .prompts import prompts_ns
|
|||||||
from .sharing import sharing_ns
|
from .sharing import sharing_ns
|
||||||
from .sources import sources_chunks_ns, sources_ns, sources_upload_ns
|
from .sources import sources_chunks_ns, sources_ns, sources_upload_ns
|
||||||
from .tools import tools_mcp_ns, tools_ns
|
from .tools import tools_mcp_ns, tools_ns
|
||||||
|
from .workflows import workflows_ns
|
||||||
|
|
||||||
|
|
||||||
user = Blueprint("user", __name__)
|
user = Blueprint("user", __name__)
|
||||||
@@ -31,10 +31,11 @@ api.add_namespace(conversations_ns)
|
|||||||
# Models
|
# Models
|
||||||
api.add_namespace(models_ns)
|
api.add_namespace(models_ns)
|
||||||
|
|
||||||
# Agents (main, sharing, webhooks)
|
# Agents (main, sharing, webhooks, folders)
|
||||||
api.add_namespace(agents_ns)
|
api.add_namespace(agents_ns)
|
||||||
api.add_namespace(agents_sharing_ns)
|
api.add_namespace(agents_sharing_ns)
|
||||||
api.add_namespace(agents_webhooks_ns)
|
api.add_namespace(agents_webhooks_ns)
|
||||||
|
api.add_namespace(agents_folders_ns)
|
||||||
|
|
||||||
# Prompts
|
# Prompts
|
||||||
api.add_namespace(prompts_ns)
|
api.add_namespace(prompts_ns)
|
||||||
@@ -50,3 +51,6 @@ api.add_namespace(sources_upload_ns)
|
|||||||
# Tools (main, MCP)
|
# Tools (main, MCP)
|
||||||
api.add_namespace(tools_ns)
|
api.add_namespace(tools_ns)
|
||||||
api.add_namespace(tools_mcp_ns)
|
api.add_namespace(tools_mcp_ns)
|
||||||
|
|
||||||
|
# Workflows
|
||||||
|
api.add_namespace(workflows_ns)
|
||||||
|
|||||||
@@ -2,26 +2,126 @@
|
|||||||
|
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
from bson.binary import Binary, UuidRepresentation
|
|
||||||
from bson.dbref import DBRef
|
|
||||||
from bson.objectid import ObjectId
|
|
||||||
from flask import current_app, jsonify, make_response, request
|
from flask import current_app, jsonify, make_response, request
|
||||||
from flask_restx import fields, inputs, Namespace, Resource
|
from flask_restx import fields, inputs, Namespace, Resource
|
||||||
|
from sqlalchemy import text as _sql_text
|
||||||
|
|
||||||
from application.api import api
|
from application.api import api
|
||||||
from application.api.user.base import (
|
from application.storage.db.base_repository import looks_like_uuid
|
||||||
agents_collection,
|
from application.storage.db.repositories.agents import AgentsRepository
|
||||||
attachments_collection,
|
from application.storage.db.repositories.attachments import AttachmentsRepository
|
||||||
conversations_collection,
|
from application.storage.db.repositories.conversations import ConversationsRepository
|
||||||
shared_conversations_collections,
|
from application.storage.db.repositories.shared_conversations import (
|
||||||
|
SharedConversationsRepository,
|
||||||
)
|
)
|
||||||
|
from application.storage.db.session import db_readonly, db_session
|
||||||
from application.utils import check_required_fields
|
from application.utils import check_required_fields
|
||||||
|
|
||||||
|
|
||||||
sharing_ns = Namespace(
|
sharing_ns = Namespace(
|
||||||
"sharing", description="Conversation sharing operations", path="/api"
|
"sharing", description="Conversation sharing operations", path="/api"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_prompt_pg_id(conn, prompt_id_raw, user_id):
|
||||||
|
"""Translate an incoming prompt id (UUID or legacy Mongo ObjectId) to a PG UUID.
|
||||||
|
|
||||||
|
Scoped by ``user_id`` so a caller can't link another user's prompt
|
||||||
|
into their share record. Returns ``None`` for sentinel values
|
||||||
|
(``"default"``) or unresolved ids.
|
||||||
|
"""
|
||||||
|
if not prompt_id_raw or prompt_id_raw == "default":
|
||||||
|
return None
|
||||||
|
value = str(prompt_id_raw)
|
||||||
|
# Already UUID — trust it but still require ownership. A shape-gate
|
||||||
|
# (rather than a loose ``len == 36 and '-' in value`` check) keeps
|
||||||
|
# non-UUID input out of ``CAST(:pid AS uuid)``; the cast would raise
|
||||||
|
# and poison the readonly transaction otherwise.
|
||||||
|
if looks_like_uuid(value):
|
||||||
|
row = conn.execute(
|
||||||
|
_sql_text(
|
||||||
|
"SELECT id FROM prompts WHERE id = CAST(:pid AS uuid) "
|
||||||
|
"AND user_id = :uid"
|
||||||
|
),
|
||||||
|
{"pid": value, "uid": user_id},
|
||||||
|
).fetchone()
|
||||||
|
return str(row[0]) if row else None
|
||||||
|
# Legacy Mongo ObjectId fallback.
|
||||||
|
row = conn.execute(
|
||||||
|
_sql_text(
|
||||||
|
"SELECT id FROM prompts WHERE legacy_mongo_id = :pid "
|
||||||
|
"AND user_id = :uid"
|
||||||
|
),
|
||||||
|
{"pid": value, "uid": user_id},
|
||||||
|
).fetchone()
|
||||||
|
return str(row[0]) if row else None
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_source_pg_id(conn, source_raw):
|
||||||
|
"""Translate a source id (UUID or legacy Mongo ObjectId) to a PG UUID."""
|
||||||
|
if not source_raw:
|
||||||
|
return None
|
||||||
|
value = str(source_raw)
|
||||||
|
# See ``_resolve_prompt_pg_id`` for the shape-gate rationale.
|
||||||
|
if looks_like_uuid(value):
|
||||||
|
row = conn.execute(
|
||||||
|
_sql_text(
|
||||||
|
"SELECT id FROM sources WHERE id = CAST(:sid AS uuid)"
|
||||||
|
),
|
||||||
|
{"sid": value},
|
||||||
|
).fetchone()
|
||||||
|
return str(row[0]) if row else None
|
||||||
|
row = conn.execute(
|
||||||
|
_sql_text("SELECT id FROM sources WHERE legacy_mongo_id = :sid"),
|
||||||
|
{"sid": value},
|
||||||
|
).fetchone()
|
||||||
|
return str(row[0]) if row else None
|
||||||
|
|
||||||
|
|
||||||
|
def _find_reusable_share_agent(
|
||||||
|
conn, user_id, *, prompt_pg_id, chunks, source_pg_id, retriever,
|
||||||
|
):
|
||||||
|
"""Find an existing share-as-agent key row matching these parameters.
|
||||||
|
|
||||||
|
Mirrors the legacy Mongo ``agents_collection.find_one`` pre-existence
|
||||||
|
check. Used to reuse an api key across repeated shares of the same
|
||||||
|
conversation with the same prompt/chunks/source/retriever.
|
||||||
|
"""
|
||||||
|
clauses = ["user_id = :uid", "key IS NOT NULL"]
|
||||||
|
params: dict = {"uid": user_id}
|
||||||
|
if prompt_pg_id is None:
|
||||||
|
clauses.append("prompt_id IS NULL")
|
||||||
|
else:
|
||||||
|
clauses.append("prompt_id = CAST(:pid AS uuid)")
|
||||||
|
params["pid"] = prompt_pg_id
|
||||||
|
if chunks is None:
|
||||||
|
clauses.append("chunks IS NULL")
|
||||||
|
else:
|
||||||
|
clauses.append("chunks = :chunks")
|
||||||
|
params["chunks"] = int(chunks)
|
||||||
|
if source_pg_id is None:
|
||||||
|
clauses.append("source_id IS NULL")
|
||||||
|
else:
|
||||||
|
clauses.append("source_id = CAST(:sid AS uuid)")
|
||||||
|
params["sid"] = source_pg_id
|
||||||
|
if retriever is None:
|
||||||
|
clauses.append("retriever IS NULL")
|
||||||
|
else:
|
||||||
|
clauses.append("retriever = :retr")
|
||||||
|
params["retr"] = retriever
|
||||||
|
sql = (
|
||||||
|
"SELECT * FROM agents WHERE "
|
||||||
|
+ " AND ".join(clauses)
|
||||||
|
+ " LIMIT 1"
|
||||||
|
)
|
||||||
|
row = conn.execute(_sql_text(sql), params).fetchone()
|
||||||
|
if row is None:
|
||||||
|
return None
|
||||||
|
mapping = dict(row._mapping)
|
||||||
|
mapping["id"] = str(mapping["id"]) if mapping.get("id") else None
|
||||||
|
return mapping
|
||||||
|
|
||||||
|
|
||||||
@sharing_ns.route("/share")
|
@sharing_ns.route("/share")
|
||||||
class ShareConversation(Resource):
|
class ShareConversation(Resource):
|
||||||
share_conversation_model = api.model(
|
share_conversation_model = api.model(
|
||||||
@@ -56,146 +156,94 @@ class ShareConversation(Resource):
|
|||||||
conversation_id = data["conversation_id"]
|
conversation_id = data["conversation_id"]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
conversation = conversations_collection.find_one(
|
with db_session() as conn:
|
||||||
{"_id": ObjectId(conversation_id)}
|
conv_repo = ConversationsRepository(conn)
|
||||||
)
|
shared_repo = SharedConversationsRepository(conn)
|
||||||
if conversation is None:
|
agents_repo = AgentsRepository(conn)
|
||||||
return make_response(
|
|
||||||
jsonify(
|
|
||||||
{
|
|
||||||
"status": "error",
|
|
||||||
"message": "Conversation does not exist",
|
|
||||||
}
|
|
||||||
),
|
|
||||||
404,
|
|
||||||
)
|
|
||||||
current_n_queries = len(conversation["queries"])
|
|
||||||
explicit_binary = Binary.from_uuid(
|
|
||||||
uuid.uuid4(), UuidRepresentation.STANDARD
|
|
||||||
)
|
|
||||||
|
|
||||||
if is_promptable:
|
conversation = conv_repo.get_any(conversation_id, user)
|
||||||
prompt_id = data.get("prompt_id", "default")
|
if conversation is None:
|
||||||
chunks = data.get("chunks", "2")
|
return make_response(
|
||||||
|
jsonify(
|
||||||
name = conversation["name"] + "(shared)"
|
|
||||||
new_api_key_data = {
|
|
||||||
"prompt_id": prompt_id,
|
|
||||||
"chunks": chunks,
|
|
||||||
"user": user,
|
|
||||||
}
|
|
||||||
|
|
||||||
if "source" in data and ObjectId.is_valid(data["source"]):
|
|
||||||
new_api_key_data["source"] = DBRef(
|
|
||||||
"sources", ObjectId(data["source"])
|
|
||||||
)
|
|
||||||
if "retriever" in data:
|
|
||||||
new_api_key_data["retriever"] = data["retriever"]
|
|
||||||
pre_existing_api_document = agents_collection.find_one(new_api_key_data)
|
|
||||||
if pre_existing_api_document:
|
|
||||||
api_uuid = pre_existing_api_document["key"]
|
|
||||||
pre_existing = shared_conversations_collections.find_one(
|
|
||||||
{
|
|
||||||
"conversation_id": ObjectId(conversation_id),
|
|
||||||
"isPromptable": is_promptable,
|
|
||||||
"first_n_queries": current_n_queries,
|
|
||||||
"user": user,
|
|
||||||
"api_key": api_uuid,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
if pre_existing is not None:
|
|
||||||
return make_response(
|
|
||||||
jsonify(
|
|
||||||
{
|
|
||||||
"success": True,
|
|
||||||
"identifier": str(pre_existing["uuid"].as_uuid()),
|
|
||||||
}
|
|
||||||
),
|
|
||||||
200,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
shared_conversations_collections.insert_one(
|
|
||||||
{
|
{
|
||||||
"uuid": explicit_binary,
|
"status": "error",
|
||||||
"conversation_id": ObjectId(conversation_id),
|
"message": "Conversation does not exist",
|
||||||
"isPromptable": is_promptable,
|
|
||||||
"first_n_queries": current_n_queries,
|
|
||||||
"user": user,
|
|
||||||
"api_key": api_uuid,
|
|
||||||
}
|
}
|
||||||
)
|
),
|
||||||
return make_response(
|
404,
|
||||||
jsonify(
|
)
|
||||||
{
|
conv_pg_id = str(conversation["id"])
|
||||||
"success": True,
|
current_n_queries = conv_repo.message_count(conv_pg_id)
|
||||||
"identifier": str(explicit_binary.as_uuid()),
|
|
||||||
}
|
|
||||||
),
|
|
||||||
201,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
api_uuid = str(uuid.uuid4())
|
|
||||||
new_api_key_data["key"] = api_uuid
|
|
||||||
new_api_key_data["name"] = name
|
|
||||||
|
|
||||||
if "source" in data and ObjectId.is_valid(data["source"]):
|
if is_promptable:
|
||||||
new_api_key_data["source"] = DBRef(
|
prompt_id_raw = data.get("prompt_id", "default")
|
||||||
"sources", ObjectId(data["source"])
|
chunks_raw = data.get("chunks", "2")
|
||||||
|
try:
|
||||||
|
chunks_int = int(chunks_raw) if chunks_raw not in (None, "") else None
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
chunks_int = None
|
||||||
|
|
||||||
|
prompt_pg_id = _resolve_prompt_pg_id(conn, prompt_id_raw, user)
|
||||||
|
source_pg_id = _resolve_source_pg_id(conn, data.get("source"))
|
||||||
|
retriever = data.get("retriever")
|
||||||
|
|
||||||
|
reusable = _find_reusable_share_agent(
|
||||||
|
conn, user,
|
||||||
|
prompt_pg_id=prompt_pg_id,
|
||||||
|
chunks=chunks_int,
|
||||||
|
source_pg_id=source_pg_id,
|
||||||
|
retriever=retriever,
|
||||||
|
)
|
||||||
|
if reusable:
|
||||||
|
api_uuid = reusable.get("key")
|
||||||
|
else:
|
||||||
|
api_uuid = str(uuid.uuid4())
|
||||||
|
name = (conversation.get("name") or "") + "(shared)"
|
||||||
|
agents_repo.create(
|
||||||
|
user,
|
||||||
|
name,
|
||||||
|
"published",
|
||||||
|
key=api_uuid,
|
||||||
|
retriever=retriever,
|
||||||
|
chunks=chunks_int,
|
||||||
|
prompt_id=prompt_pg_id,
|
||||||
|
source_id=source_pg_id,
|
||||||
)
|
)
|
||||||
if "retriever" in data:
|
|
||||||
new_api_key_data["retriever"] = data["retriever"]
|
share = shared_repo.get_or_create(
|
||||||
agents_collection.insert_one(new_api_key_data)
|
conv_pg_id,
|
||||||
shared_conversations_collections.insert_one(
|
user,
|
||||||
{
|
is_promptable=True,
|
||||||
"uuid": explicit_binary,
|
first_n_queries=current_n_queries,
|
||||||
"conversation_id": ObjectId(conversation_id),
|
api_key=api_uuid,
|
||||||
"isPromptable": is_promptable,
|
prompt_id=prompt_pg_id,
|
||||||
"first_n_queries": current_n_queries,
|
chunks=chunks_int,
|
||||||
"user": user,
|
|
||||||
"api_key": api_uuid,
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
return make_response(
|
return make_response(
|
||||||
jsonify(
|
jsonify(
|
||||||
{
|
{
|
||||||
"success": True,
|
"success": True,
|
||||||
"identifier": str(explicit_binary.as_uuid()),
|
"identifier": str(share["uuid"]),
|
||||||
}
|
}
|
||||||
),
|
),
|
||||||
201,
|
201 if reusable is None else 200,
|
||||||
)
|
)
|
||||||
pre_existing = shared_conversations_collections.find_one(
|
|
||||||
{
|
# Non-promptable share path.
|
||||||
"conversation_id": ObjectId(conversation_id),
|
share = shared_repo.get_or_create(
|
||||||
"isPromptable": is_promptable,
|
conv_pg_id,
|
||||||
"first_n_queries": current_n_queries,
|
user,
|
||||||
"user": user,
|
is_promptable=False,
|
||||||
}
|
first_n_queries=current_n_queries,
|
||||||
)
|
api_key=None,
|
||||||
if pre_existing is not None:
|
)
|
||||||
return make_response(
|
return make_response(
|
||||||
jsonify(
|
jsonify(
|
||||||
{
|
{
|
||||||
"success": True,
|
"success": True,
|
||||||
"identifier": str(pre_existing["uuid"].as_uuid()),
|
"identifier": str(share["uuid"]),
|
||||||
}
|
}
|
||||||
),
|
),
|
||||||
200,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
shared_conversations_collections.insert_one(
|
|
||||||
{
|
|
||||||
"uuid": explicit_binary,
|
|
||||||
"conversation_id": ObjectId(conversation_id),
|
|
||||||
"isPromptable": is_promptable,
|
|
||||||
"first_n_queries": current_n_queries,
|
|
||||||
"user": user,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
return make_response(
|
|
||||||
jsonify(
|
|
||||||
{"success": True, "identifier": str(explicit_binary.as_uuid())}
|
|
||||||
),
|
|
||||||
201,
|
201,
|
||||||
)
|
)
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
@@ -210,22 +258,13 @@ class GetPubliclySharedConversations(Resource):
|
|||||||
@api.doc(description="Get publicly shared conversations by identifier")
|
@api.doc(description="Get publicly shared conversations by identifier")
|
||||||
def get(self, identifier: str):
|
def get(self, identifier: str):
|
||||||
try:
|
try:
|
||||||
query_uuid = Binary.from_uuid(
|
with db_readonly() as conn:
|
||||||
uuid.UUID(identifier), UuidRepresentation.STANDARD
|
shared_repo = SharedConversationsRepository(conn)
|
||||||
)
|
conv_repo = ConversationsRepository(conn)
|
||||||
shared = shared_conversations_collections.find_one({"uuid": query_uuid})
|
attach_repo = AttachmentsRepository(conn)
|
||||||
conversation_queries = []
|
|
||||||
|
|
||||||
if (
|
shared = shared_repo.find_by_uuid(identifier)
|
||||||
shared
|
if not shared or not shared.get("conversation_id"):
|
||||||
and "conversation_id" in shared
|
|
||||||
):
|
|
||||||
# conversation_id is now stored as an ObjectId, not a DBRef
|
|
||||||
conversation_id = shared["conversation_id"]
|
|
||||||
conversation = conversations_collection.find_one(
|
|
||||||
{"_id": conversation_id}
|
|
||||||
)
|
|
||||||
if conversation is None:
|
|
||||||
return make_response(
|
return make_response(
|
||||||
jsonify(
|
jsonify(
|
||||||
{
|
{
|
||||||
@@ -235,22 +274,60 @@ class GetPubliclySharedConversations(Resource):
|
|||||||
),
|
),
|
||||||
404,
|
404,
|
||||||
)
|
)
|
||||||
conversation_queries = conversation["queries"][
|
conv_pg_id = str(shared["conversation_id"])
|
||||||
: (shared["first_n_queries"])
|
owner_user = shared.get("user_id")
|
||||||
]
|
|
||||||
|
|
||||||
for query in conversation_queries:
|
conversation = conv_repo.get_owned(conv_pg_id, owner_user) if owner_user else None
|
||||||
if "attachments" in query and query["attachments"]:
|
if conversation is None:
|
||||||
|
# Fall back to any-user lookup in case shared row's
|
||||||
|
# user_id is missing — still keyed by PG UUID.
|
||||||
|
row = conn.execute(
|
||||||
|
_sql_text(
|
||||||
|
"SELECT * FROM conversations WHERE id = CAST(:id AS uuid)"
|
||||||
|
),
|
||||||
|
{"id": conv_pg_id},
|
||||||
|
).fetchone()
|
||||||
|
if row is None:
|
||||||
|
return make_response(
|
||||||
|
jsonify(
|
||||||
|
{
|
||||||
|
"success": False,
|
||||||
|
"error": "might have broken url or the conversation does not exist",
|
||||||
|
}
|
||||||
|
),
|
||||||
|
404,
|
||||||
|
)
|
||||||
|
conversation = dict(row._mapping)
|
||||||
|
|
||||||
|
messages = conv_repo.get_messages(conv_pg_id)
|
||||||
|
first_n = shared.get("first_n_queries") or 0
|
||||||
|
conversation_queries = []
|
||||||
|
for msg in messages[:first_n]:
|
||||||
|
query = {
|
||||||
|
"prompt": msg.get("prompt"),
|
||||||
|
"response": msg.get("response"),
|
||||||
|
"thought": msg.get("thought"),
|
||||||
|
"sources": msg.get("sources") or [],
|
||||||
|
"tool_calls": msg.get("tool_calls") or [],
|
||||||
|
"timestamp": (
|
||||||
|
msg["timestamp"].isoformat()
|
||||||
|
if hasattr(msg.get("timestamp"), "isoformat")
|
||||||
|
else msg.get("timestamp")
|
||||||
|
),
|
||||||
|
"feedback": msg.get("feedback"),
|
||||||
|
}
|
||||||
|
attachments = msg.get("attachments") or []
|
||||||
|
if attachments:
|
||||||
attachment_details = []
|
attachment_details = []
|
||||||
for attachment_id in query["attachments"]:
|
for attachment_id in attachments:
|
||||||
try:
|
try:
|
||||||
attachment = attachments_collection.find_one(
|
attachment = attach_repo.get_any(
|
||||||
{"_id": ObjectId(attachment_id)}
|
str(attachment_id), owner_user,
|
||||||
)
|
) if owner_user else None
|
||||||
if attachment:
|
if attachment:
|
||||||
attachment_details.append(
|
attachment_details.append(
|
||||||
{
|
{
|
||||||
"id": str(attachment["_id"]),
|
"id": str(attachment["id"]),
|
||||||
"fileName": attachment.get(
|
"fileName": attachment.get(
|
||||||
"filename", "Unknown file"
|
"filename", "Unknown file"
|
||||||
),
|
),
|
||||||
@@ -262,26 +339,23 @@ class GetPubliclySharedConversations(Resource):
|
|||||||
exc_info=True,
|
exc_info=True,
|
||||||
)
|
)
|
||||||
query["attachments"] = attachment_details
|
query["attachments"] = attachment_details
|
||||||
else:
|
conversation_queries.append(query)
|
||||||
return make_response(
|
|
||||||
jsonify(
|
created = conversation.get("created_at") or conversation.get("date")
|
||||||
{
|
date_iso = (
|
||||||
"success": False,
|
created.isoformat()
|
||||||
"error": "might have broken url or the conversation does not exist",
|
if hasattr(created, "isoformat")
|
||||||
}
|
else (str(created) if created is not None else None)
|
||||||
),
|
|
||||||
404,
|
|
||||||
)
|
)
|
||||||
date = conversation["_id"].generation_time.isoformat()
|
res = {
|
||||||
res = {
|
"success": True,
|
||||||
"success": True,
|
"queries": conversation_queries,
|
||||||
"queries": conversation_queries,
|
"title": conversation.get("name"),
|
||||||
"title": conversation["name"],
|
"timestamp": date_iso,
|
||||||
"timestamp": date,
|
}
|
||||||
}
|
if shared.get("is_promptable") and shared.get("api_key"):
|
||||||
if shared["isPromptable"] and "api_key" in shared:
|
res["api_key"] = shared["api_key"]
|
||||||
res["api_key"] = shared["api_key"]
|
return make_response(jsonify(res), 200)
|
||||||
return make_response(jsonify(res), 200)
|
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
current_app.logger.error(
|
current_app.logger.error(
|
||||||
f"Error getting shared conversation: {err}", exc_info=True
|
f"Error getting shared conversation: {err}", exc_info=True
|
||||||
|
|||||||
@@ -1,11 +1,12 @@
|
|||||||
"""Source document management chunk management."""
|
"""Source document management chunk management."""
|
||||||
|
|
||||||
from bson.objectid import ObjectId
|
|
||||||
from flask import current_app, jsonify, make_response, request
|
from flask import current_app, jsonify, make_response, request
|
||||||
from flask_restx import fields, Namespace, Resource
|
from flask_restx import fields, Namespace, Resource
|
||||||
|
|
||||||
from application.api import api
|
from application.api import api
|
||||||
from application.api.user.base import get_vector_store, sources_collection
|
from application.api.user.base import get_vector_store
|
||||||
|
from application.storage.db.repositories.sources import SourcesRepository
|
||||||
|
from application.storage.db.session import db_readonly
|
||||||
from application.utils import check_required_fields, num_tokens_from_string
|
from application.utils import check_required_fields, num_tokens_from_string
|
||||||
|
|
||||||
sources_chunks_ns = Namespace(
|
sources_chunks_ns = Namespace(
|
||||||
@@ -13,6 +14,15 @@ sources_chunks_ns = Namespace(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_source(doc_id: str, user: str):
|
||||||
|
"""Resolve a source (UUID or legacy ObjectId) for the caller.
|
||||||
|
|
||||||
|
Returns the row dict (with PG UUID in ``id``) or ``None`` if missing.
|
||||||
|
"""
|
||||||
|
with db_readonly() as conn:
|
||||||
|
return SourcesRepository(conn).get_any(doc_id, user)
|
||||||
|
|
||||||
|
|
||||||
@sources_chunks_ns.route("/get_chunks")
|
@sources_chunks_ns.route("/get_chunks")
|
||||||
class GetChunks(Resource):
|
class GetChunks(Resource):
|
||||||
@api.doc(
|
@api.doc(
|
||||||
@@ -36,31 +46,34 @@ class GetChunks(Resource):
|
|||||||
path = request.args.get("path")
|
path = request.args.get("path")
|
||||||
search_term = request.args.get("search", "").strip().lower()
|
search_term = request.args.get("search", "").strip().lower()
|
||||||
|
|
||||||
if not ObjectId.is_valid(doc_id):
|
if not doc_id:
|
||||||
|
return make_response(jsonify({"error": "Invalid doc_id"}), 400)
|
||||||
|
try:
|
||||||
|
doc = _resolve_source(doc_id, user)
|
||||||
|
except Exception as e:
|
||||||
|
current_app.logger.error(f"Error resolving source: {e}", exc_info=True)
|
||||||
return make_response(jsonify({"error": "Invalid doc_id"}), 400)
|
return make_response(jsonify({"error": "Invalid doc_id"}), 400)
|
||||||
doc = sources_collection.find_one({"_id": ObjectId(doc_id), "user": user})
|
|
||||||
if not doc:
|
if not doc:
|
||||||
return make_response(
|
return make_response(
|
||||||
jsonify({"error": "Document not found or access denied"}), 404
|
jsonify({"error": "Document not found or access denied"}), 404
|
||||||
)
|
)
|
||||||
|
resolved_id = str(doc["id"])
|
||||||
try:
|
try:
|
||||||
store = get_vector_store(doc_id)
|
store = get_vector_store(resolved_id)
|
||||||
chunks = store.get_chunks()
|
chunks = store.get_chunks()
|
||||||
|
|
||||||
filtered_chunks = []
|
filtered_chunks = []
|
||||||
for chunk in chunks:
|
for chunk in chunks:
|
||||||
metadata = chunk.get("metadata", {})
|
metadata = chunk.get("metadata", {})
|
||||||
|
|
||||||
# Filter by path if provided
|
|
||||||
|
|
||||||
if path:
|
if path:
|
||||||
chunk_source = metadata.get("source", "")
|
chunk_source = metadata.get("source", "")
|
||||||
# Check if the chunk's source matches the requested path
|
chunk_file_path = metadata.get("file_path", "")
|
||||||
|
source_match = chunk_source and chunk_source.endswith(path)
|
||||||
|
file_path_match = chunk_file_path and chunk_file_path.endswith(path)
|
||||||
|
|
||||||
if not chunk_source or not chunk_source.endswith(path):
|
if not (source_match or file_path_match):
|
||||||
continue
|
continue
|
||||||
# Filter by search term if provided
|
|
||||||
|
|
||||||
if search_term:
|
if search_term:
|
||||||
text_match = search_term in chunk.get("text", "").lower()
|
text_match = search_term in chunk.get("text", "").lower()
|
||||||
title_match = search_term in metadata.get("title", "").lower()
|
title_match = search_term in metadata.get("title", "").lower()
|
||||||
@@ -127,15 +140,17 @@ class AddChunk(Resource):
|
|||||||
token_count = num_tokens_from_string(text)
|
token_count = num_tokens_from_string(text)
|
||||||
metadata["token_count"] = token_count
|
metadata["token_count"] = token_count
|
||||||
|
|
||||||
if not ObjectId.is_valid(doc_id):
|
try:
|
||||||
|
doc = _resolve_source(doc_id, user)
|
||||||
|
except Exception as e:
|
||||||
|
current_app.logger.error(f"Error resolving source: {e}", exc_info=True)
|
||||||
return make_response(jsonify({"error": "Invalid doc_id"}), 400)
|
return make_response(jsonify({"error": "Invalid doc_id"}), 400)
|
||||||
doc = sources_collection.find_one({"_id": ObjectId(doc_id), "user": user})
|
|
||||||
if not doc:
|
if not doc:
|
||||||
return make_response(
|
return make_response(
|
||||||
jsonify({"error": "Document not found or access denied"}), 404
|
jsonify({"error": "Document not found or access denied"}), 404
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
store = get_vector_store(doc_id)
|
store = get_vector_store(str(doc["id"]))
|
||||||
chunk_id = store.add_chunk(text, metadata)
|
chunk_id = store.add_chunk(text, metadata)
|
||||||
return make_response(
|
return make_response(
|
||||||
jsonify({"message": "Chunk added successfully", "chunk_id": chunk_id}),
|
jsonify({"message": "Chunk added successfully", "chunk_id": chunk_id}),
|
||||||
@@ -160,15 +175,17 @@ class DeleteChunk(Resource):
|
|||||||
doc_id = request.args.get("id")
|
doc_id = request.args.get("id")
|
||||||
chunk_id = request.args.get("chunk_id")
|
chunk_id = request.args.get("chunk_id")
|
||||||
|
|
||||||
if not ObjectId.is_valid(doc_id):
|
try:
|
||||||
|
doc = _resolve_source(doc_id, user)
|
||||||
|
except Exception as e:
|
||||||
|
current_app.logger.error(f"Error resolving source: {e}", exc_info=True)
|
||||||
return make_response(jsonify({"error": "Invalid doc_id"}), 400)
|
return make_response(jsonify({"error": "Invalid doc_id"}), 400)
|
||||||
doc = sources_collection.find_one({"_id": ObjectId(doc_id), "user": user})
|
|
||||||
if not doc:
|
if not doc:
|
||||||
return make_response(
|
return make_response(
|
||||||
jsonify({"error": "Document not found or access denied"}), 404
|
jsonify({"error": "Document not found or access denied"}), 404
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
store = get_vector_store(doc_id)
|
store = get_vector_store(str(doc["id"]))
|
||||||
deleted = store.delete_chunk(chunk_id)
|
deleted = store.delete_chunk(chunk_id)
|
||||||
if deleted:
|
if deleted:
|
||||||
return make_response(
|
return make_response(
|
||||||
@@ -227,15 +244,17 @@ class UpdateChunk(Resource):
|
|||||||
if metadata is None:
|
if metadata is None:
|
||||||
metadata = {}
|
metadata = {}
|
||||||
metadata["token_count"] = token_count
|
metadata["token_count"] = token_count
|
||||||
if not ObjectId.is_valid(doc_id):
|
try:
|
||||||
|
doc = _resolve_source(doc_id, user)
|
||||||
|
except Exception as e:
|
||||||
|
current_app.logger.error(f"Error resolving source: {e}", exc_info=True)
|
||||||
return make_response(jsonify({"error": "Invalid doc_id"}), 400)
|
return make_response(jsonify({"error": "Invalid doc_id"}), 400)
|
||||||
doc = sources_collection.find_one({"_id": ObjectId(doc_id), "user": user})
|
|
||||||
if not doc:
|
if not doc:
|
||||||
return make_response(
|
return make_response(
|
||||||
jsonify({"error": "Document not found or access denied"}), 404
|
jsonify({"error": "Document not found or access denied"}), 404
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
store = get_vector_store(doc_id)
|
store = get_vector_store(str(doc["id"]))
|
||||||
|
|
||||||
chunks = store.get_chunks()
|
chunks = store.get_chunks()
|
||||||
existing_chunk = next((c for c in chunks if c["doc_id"] == chunk_id), None)
|
existing_chunk = next((c for c in chunks if c["doc_id"] == chunk_id), None)
|
||||||
|
|||||||
@@ -3,13 +3,14 @@
|
|||||||
import json
|
import json
|
||||||
import math
|
import math
|
||||||
|
|
||||||
from bson.objectid import ObjectId
|
|
||||||
from flask import current_app, jsonify, make_response, redirect, request
|
from flask import current_app, jsonify, make_response, redirect, request
|
||||||
from flask_restx import fields, Namespace, Resource
|
from flask_restx import fields, Namespace, Resource
|
||||||
|
|
||||||
from application.api import api
|
from application.api import api
|
||||||
from application.api.user.base import sources_collection
|
from application.api.user.tasks import sync_source
|
||||||
from application.core.settings import settings
|
from application.core.settings import settings
|
||||||
|
from application.storage.db.repositories.sources import SourcesRepository
|
||||||
|
from application.storage.db.session import db_readonly, db_session
|
||||||
from application.storage.storage_creator import StorageCreator
|
from application.storage.storage_creator import StorageCreator
|
||||||
from application.utils import check_required_fields
|
from application.utils import check_required_fields
|
||||||
from application.vectorstore.vector_creator import VectorCreator
|
from application.vectorstore.vector_creator import VectorCreator
|
||||||
@@ -20,6 +21,21 @@ sources_ns = Namespace(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_provider_from_remote_data(remote_data):
|
||||||
|
if not remote_data:
|
||||||
|
return None
|
||||||
|
if isinstance(remote_data, dict):
|
||||||
|
return remote_data.get("provider")
|
||||||
|
if isinstance(remote_data, str):
|
||||||
|
try:
|
||||||
|
remote_data_obj = json.loads(remote_data)
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
if isinstance(remote_data_obj, dict):
|
||||||
|
return remote_data_obj.get("provider")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
@sources_ns.route("/sources")
|
@sources_ns.route("/sources")
|
||||||
class CombinedJson(Resource):
|
class CombinedJson(Resource):
|
||||||
@api.doc(description="Provide JSON file with combined available indexes")
|
@api.doc(description="Provide JSON file with combined available indexes")
|
||||||
@@ -40,10 +56,20 @@ class CombinedJson(Resource):
|
|||||||
]
|
]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
for index in sources_collection.find({"user": user}).sort("date", -1):
|
with db_readonly() as conn:
|
||||||
|
indexes = SourcesRepository(conn).list_for_user(user)
|
||||||
|
# list_for_user sorts by created_at DESC; legacy shape sorted by
|
||||||
|
# "date" DESC. Both are monotonic on creation so the ordering is
|
||||||
|
# equivalent for dev; re-sort defensively.
|
||||||
|
indexes = sorted(
|
||||||
|
indexes, key=lambda r: r.get("date") or r.get("created_at") or "",
|
||||||
|
reverse=True,
|
||||||
|
)
|
||||||
|
for index in indexes:
|
||||||
|
provider = _get_provider_from_remote_data(index.get("remote_data"))
|
||||||
data.append(
|
data.append(
|
||||||
{
|
{
|
||||||
"id": str(index["_id"]),
|
"id": str(index["id"]),
|
||||||
"name": index.get("name"),
|
"name": index.get("name"),
|
||||||
"date": index.get("date"),
|
"date": index.get("date"),
|
||||||
"model": settings.EMBEDDINGS_NAME,
|
"model": settings.EMBEDDINGS_NAME,
|
||||||
@@ -51,10 +77,9 @@ class CombinedJson(Resource):
|
|||||||
"tokens": index.get("tokens", ""),
|
"tokens": index.get("tokens", ""),
|
||||||
"retriever": index.get("retriever", "classic"),
|
"retriever": index.get("retriever", "classic"),
|
||||||
"syncFrequency": index.get("sync_frequency", ""),
|
"syncFrequency": index.get("sync_frequency", ""),
|
||||||
|
"provider": provider,
|
||||||
"is_nested": bool(index.get("directory_structure")),
|
"is_nested": bool(index.get("directory_structure")),
|
||||||
"type": index.get(
|
"type": index.get("type", "file"),
|
||||||
"type", "file"
|
|
||||||
), # Add type field with default "file"
|
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
@@ -71,59 +96,55 @@ class PaginatedSources(Resource):
|
|||||||
if not decoded_token:
|
if not decoded_token:
|
||||||
return make_response(jsonify({"success": False}), 401)
|
return make_response(jsonify({"success": False}), 401)
|
||||||
user = decoded_token.get("sub")
|
user = decoded_token.get("sub")
|
||||||
sort_field = request.args.get("sort", "date") # Default to 'date'
|
sort_field = request.args.get("sort", "date")
|
||||||
sort_order = request.args.get("order", "desc") # Default to 'desc'
|
sort_order = request.args.get("order", "desc")
|
||||||
page = int(request.args.get("page", 1)) # Default to 1
|
page = max(1, int(request.args.get("page", 1)))
|
||||||
rows_per_page = int(request.args.get("rows", 10)) # Default to 10
|
rows_per_page = max(1, int(request.args.get("rows", 10)))
|
||||||
# add .strip() to remove leading and trailing whitespaces
|
search_term = request.args.get("search", "").strip() or None
|
||||||
|
|
||||||
search_term = request.args.get(
|
|
||||||
"search", ""
|
|
||||||
).strip() # add search for filter documents
|
|
||||||
|
|
||||||
# Prepare query for filtering
|
|
||||||
|
|
||||||
query = {"user": user}
|
|
||||||
if search_term:
|
|
||||||
query["name"] = {
|
|
||||||
"$regex": search_term,
|
|
||||||
"$options": "i", # using case-insensitive search
|
|
||||||
}
|
|
||||||
total_documents = sources_collection.count_documents(query)
|
|
||||||
total_pages = max(1, math.ceil(total_documents / rows_per_page))
|
|
||||||
page = min(
|
|
||||||
max(1, page), total_pages
|
|
||||||
) # add this to make sure page inbound is within the range
|
|
||||||
sort_order = 1 if sort_order == "asc" else -1
|
|
||||||
skip = (page - 1) * rows_per_page
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
documents = (
|
with db_readonly() as conn:
|
||||||
sources_collection.find(query)
|
repo = SourcesRepository(conn)
|
||||||
.sort(sort_field, sort_order)
|
total_documents = repo.count_for_user(
|
||||||
.skip(skip)
|
user, search_term=search_term,
|
||||||
.limit(rows_per_page)
|
)
|
||||||
)
|
# Prior in-Python implementation returned ``totalPages = 1``
|
||||||
|
# for empty result sets (``max(1, ceil(0/rows))``); we
|
||||||
|
# preserve that contract so the frontend pager stays stable.
|
||||||
|
total_pages = max(1, math.ceil(total_documents / rows_per_page))
|
||||||
|
effective_page = min(page, total_pages)
|
||||||
|
offset = (effective_page - 1) * rows_per_page
|
||||||
|
window = repo.list_for_user(
|
||||||
|
user,
|
||||||
|
limit=rows_per_page,
|
||||||
|
offset=offset,
|
||||||
|
search_term=search_term,
|
||||||
|
sort_field=sort_field,
|
||||||
|
sort_order=sort_order,
|
||||||
|
)
|
||||||
|
|
||||||
paginated_docs = []
|
paginated_docs = []
|
||||||
for doc in documents:
|
for doc in window:
|
||||||
doc_data = {
|
provider = _get_provider_from_remote_data(doc.get("remote_data"))
|
||||||
"id": str(doc["_id"]),
|
paginated_docs.append(
|
||||||
"name": doc.get("name", ""),
|
{
|
||||||
"date": doc.get("date", ""),
|
"id": str(doc["id"]),
|
||||||
"model": settings.EMBEDDINGS_NAME,
|
"name": doc.get("name", ""),
|
||||||
"location": "local",
|
"date": doc.get("date", ""),
|
||||||
"tokens": doc.get("tokens", ""),
|
"model": settings.EMBEDDINGS_NAME,
|
||||||
"retriever": doc.get("retriever", "classic"),
|
"location": "local",
|
||||||
"syncFrequency": doc.get("sync_frequency", ""),
|
"tokens": doc.get("tokens", ""),
|
||||||
"isNested": bool(doc.get("directory_structure")),
|
"retriever": doc.get("retriever", "classic"),
|
||||||
"type": doc.get("type", "file"),
|
"syncFrequency": doc.get("sync_frequency", ""),
|
||||||
}
|
"provider": provider,
|
||||||
paginated_docs.append(doc_data)
|
"isNested": bool(doc.get("directory_structure")),
|
||||||
|
"type": doc.get("type", "file"),
|
||||||
|
}
|
||||||
|
)
|
||||||
response = {
|
response = {
|
||||||
"total": total_documents,
|
"total": total_documents,
|
||||||
"totalPages": total_pages,
|
"totalPages": total_pages,
|
||||||
"currentPage": page,
|
"currentPage": effective_page,
|
||||||
"paginated": paginated_docs,
|
"paginated": paginated_docs,
|
||||||
}
|
}
|
||||||
return make_response(jsonify(response), 200)
|
return make_response(jsonify(response), 200)
|
||||||
@@ -134,28 +155,6 @@ class PaginatedSources(Resource):
|
|||||||
return make_response(jsonify({"success": False}), 400)
|
return make_response(jsonify({"success": False}), 400)
|
||||||
|
|
||||||
|
|
||||||
@sources_ns.route("/delete_by_ids")
|
|
||||||
class DeleteByIds(Resource):
|
|
||||||
@api.doc(
|
|
||||||
description="Deletes documents from the vector store by IDs",
|
|
||||||
params={"path": "Comma-separated list of IDs"},
|
|
||||||
)
|
|
||||||
def get(self):
|
|
||||||
ids = request.args.get("path")
|
|
||||||
if not ids:
|
|
||||||
return make_response(
|
|
||||||
jsonify({"success": False, "message": "Missing required fields"}), 400
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
result = sources_collection.delete_index(ids=ids)
|
|
||||||
if result:
|
|
||||||
return make_response(jsonify({"success": True}), 200)
|
|
||||||
except Exception as err:
|
|
||||||
current_app.logger.error(f"Error deleting indexes: {err}", exc_info=True)
|
|
||||||
return make_response(jsonify({"success": False}), 400)
|
|
||||||
return make_response(jsonify({"success": False}), 400)
|
|
||||||
|
|
||||||
|
|
||||||
@sources_ns.route("/delete_old")
|
@sources_ns.route("/delete_old")
|
||||||
class DeleteOldIndexes(Resource):
|
class DeleteOldIndexes(Resource):
|
||||||
@api.doc(
|
@api.doc(
|
||||||
@@ -166,30 +165,33 @@ class DeleteOldIndexes(Resource):
|
|||||||
decoded_token = request.decoded_token
|
decoded_token = request.decoded_token
|
||||||
if not decoded_token:
|
if not decoded_token:
|
||||||
return make_response(jsonify({"success": False}), 401)
|
return make_response(jsonify({"success": False}), 401)
|
||||||
|
user = decoded_token.get("sub")
|
||||||
source_id = request.args.get("source_id")
|
source_id = request.args.get("source_id")
|
||||||
if not source_id:
|
if not source_id:
|
||||||
return make_response(
|
return make_response(
|
||||||
jsonify({"success": False, "message": "Missing required fields"}), 400
|
jsonify({"success": False, "message": "Missing required fields"}), 400
|
||||||
)
|
)
|
||||||
doc = sources_collection.find_one(
|
try:
|
||||||
{"_id": ObjectId(source_id), "user": decoded_token.get("sub")}
|
with db_readonly() as conn:
|
||||||
)
|
doc = SourcesRepository(conn).get_any(source_id, user)
|
||||||
|
except Exception as err:
|
||||||
|
current_app.logger.error(f"Error looking up source: {err}", exc_info=True)
|
||||||
|
return make_response(jsonify({"success": False}), 400)
|
||||||
if not doc:
|
if not doc:
|
||||||
return make_response(jsonify({"status": "not found"}), 404)
|
return make_response(jsonify({"status": "not found"}), 404)
|
||||||
storage = StorageCreator.get_storage()
|
storage = StorageCreator.get_storage()
|
||||||
|
resolved_id = str(doc["id"])
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Delete vector index
|
|
||||||
|
|
||||||
if settings.VECTOR_STORE == "faiss":
|
if settings.VECTOR_STORE == "faiss":
|
||||||
index_path = f"indexes/{str(doc['_id'])}"
|
index_path = f"indexes/{resolved_id}"
|
||||||
if storage.file_exists(f"{index_path}/index.faiss"):
|
if storage.file_exists(f"{index_path}/index.faiss"):
|
||||||
storage.delete_file(f"{index_path}/index.faiss")
|
storage.delete_file(f"{index_path}/index.faiss")
|
||||||
if storage.file_exists(f"{index_path}/index.pkl"):
|
if storage.file_exists(f"{index_path}/index.pkl"):
|
||||||
storage.delete_file(f"{index_path}/index.pkl")
|
storage.delete_file(f"{index_path}/index.pkl")
|
||||||
else:
|
else:
|
||||||
vectorstore = VectorCreator.create_vectorstore(
|
vectorstore = VectorCreator.create_vectorstore(
|
||||||
settings.VECTOR_STORE, source_id=str(doc["_id"])
|
settings.VECTOR_STORE, source_id=resolved_id
|
||||||
)
|
)
|
||||||
vectorstore.delete_index()
|
vectorstore.delete_index()
|
||||||
if "file_path" in doc and doc["file_path"]:
|
if "file_path" in doc and doc["file_path"]:
|
||||||
@@ -207,7 +209,14 @@ class DeleteOldIndexes(Resource):
|
|||||||
f"Error deleting files and indexes: {err}", exc_info=True
|
f"Error deleting files and indexes: {err}", exc_info=True
|
||||||
)
|
)
|
||||||
return make_response(jsonify({"success": False}), 400)
|
return make_response(jsonify({"success": False}), 400)
|
||||||
sources_collection.delete_one({"_id": ObjectId(source_id)})
|
try:
|
||||||
|
with db_session() as conn:
|
||||||
|
SourcesRepository(conn).delete(resolved_id, user)
|
||||||
|
except Exception as err:
|
||||||
|
current_app.logger.error(
|
||||||
|
f"Error deleting source row: {err}", exc_info=True
|
||||||
|
)
|
||||||
|
return make_response(jsonify({"success": False}), 400)
|
||||||
return make_response(jsonify({"success": True}), 200)
|
return make_response(jsonify({"success": True}), 200)
|
||||||
|
|
||||||
|
|
||||||
@@ -240,7 +249,7 @@ class ManageSync(Resource):
|
|||||||
if not decoded_token:
|
if not decoded_token:
|
||||||
return make_response(jsonify({"success": False}), 401)
|
return make_response(jsonify({"success": False}), 401)
|
||||||
user = decoded_token.get("sub")
|
user = decoded_token.get("sub")
|
||||||
data = request.get_json()
|
data = request.get_json() or {}
|
||||||
required_fields = ["source_id", "sync_frequency"]
|
required_fields = ["source_id", "sync_frequency"]
|
||||||
missing_fields = check_required_fields(data, required_fields)
|
missing_fields = check_required_fields(data, required_fields)
|
||||||
if missing_fields:
|
if missing_fields:
|
||||||
@@ -252,15 +261,16 @@ class ManageSync(Resource):
|
|||||||
return make_response(
|
return make_response(
|
||||||
jsonify({"success": False, "message": "Invalid frequency"}), 400
|
jsonify({"success": False, "message": "Invalid frequency"}), 400
|
||||||
)
|
)
|
||||||
update_data = {"$set": {"sync_frequency": sync_frequency}}
|
|
||||||
try:
|
try:
|
||||||
sources_collection.update_one(
|
with db_session() as conn:
|
||||||
{
|
repo = SourcesRepository(conn)
|
||||||
"_id": ObjectId(source_id),
|
doc = repo.get_any(source_id, user)
|
||||||
"user": user,
|
if doc is None:
|
||||||
},
|
return make_response(
|
||||||
update_data,
|
jsonify({"success": False, "message": "Source not found"}),
|
||||||
)
|
404,
|
||||||
|
)
|
||||||
|
repo.update(str(doc["id"]), user, {"sync_frequency": sync_frequency})
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
current_app.logger.error(
|
current_app.logger.error(
|
||||||
f"Error updating sync frequency: {err}", exc_info=True
|
f"Error updating sync frequency: {err}", exc_info=True
|
||||||
@@ -269,6 +279,73 @@ class ManageSync(Resource):
|
|||||||
return make_response(jsonify({"success": True}), 200)
|
return make_response(jsonify({"success": True}), 200)
|
||||||
|
|
||||||
|
|
||||||
|
@sources_ns.route("/sync_source")
|
||||||
|
class SyncSource(Resource):
|
||||||
|
sync_source_model = api.model(
|
||||||
|
"SyncSourceModel",
|
||||||
|
{"source_id": fields.String(required=True, description="Source ID")},
|
||||||
|
)
|
||||||
|
|
||||||
|
@api.expect(sync_source_model)
|
||||||
|
@api.doc(description="Trigger an immediate sync for a source")
|
||||||
|
def post(self):
|
||||||
|
decoded_token = request.decoded_token
|
||||||
|
if not decoded_token:
|
||||||
|
return make_response(jsonify({"success": False}), 401)
|
||||||
|
user = decoded_token.get("sub")
|
||||||
|
data = request.get_json()
|
||||||
|
required_fields = ["source_id"]
|
||||||
|
missing_fields = check_required_fields(data, required_fields)
|
||||||
|
if missing_fields:
|
||||||
|
return missing_fields
|
||||||
|
source_id = data["source_id"]
|
||||||
|
try:
|
||||||
|
with db_readonly() as conn:
|
||||||
|
doc = SourcesRepository(conn).get_any(source_id, user)
|
||||||
|
except Exception as err:
|
||||||
|
current_app.logger.error(f"Error looking up source: {err}", exc_info=True)
|
||||||
|
return make_response(
|
||||||
|
jsonify({"success": False, "message": "Invalid source ID"}), 400
|
||||||
|
)
|
||||||
|
if not doc:
|
||||||
|
return make_response(
|
||||||
|
jsonify({"success": False, "message": "Source not found"}), 404
|
||||||
|
)
|
||||||
|
source_type = doc.get("type", "")
|
||||||
|
if source_type and source_type.startswith("connector"):
|
||||||
|
return make_response(
|
||||||
|
jsonify(
|
||||||
|
{
|
||||||
|
"success": False,
|
||||||
|
"message": "Connector sources must be synced via /api/connectors/sync",
|
||||||
|
}
|
||||||
|
),
|
||||||
|
400,
|
||||||
|
)
|
||||||
|
source_data = doc.get("remote_data")
|
||||||
|
if not source_data:
|
||||||
|
return make_response(
|
||||||
|
jsonify({"success": False, "message": "Source is not syncable"}), 400
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
task = sync_source.delay(
|
||||||
|
source_data=source_data,
|
||||||
|
job_name=doc.get("name", ""),
|
||||||
|
user=user,
|
||||||
|
loader=source_type,
|
||||||
|
sync_frequency=doc.get("sync_frequency", "never"),
|
||||||
|
retriever=doc.get("retriever", "classic"),
|
||||||
|
doc_id=str(doc["id"]),
|
||||||
|
)
|
||||||
|
except Exception as err:
|
||||||
|
current_app.logger.error(
|
||||||
|
f"Error starting sync for source {source_id}: {err}",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
return make_response(jsonify({"success": False}), 400)
|
||||||
|
return make_response(jsonify({"success": True, "task_id": task.id}), 200)
|
||||||
|
|
||||||
|
|
||||||
@sources_ns.route("/directory_structure")
|
@sources_ns.route("/directory_structure")
|
||||||
class DirectoryStructure(Resource):
|
class DirectoryStructure(Resource):
|
||||||
@api.doc(
|
@api.doc(
|
||||||
@@ -284,10 +361,9 @@ class DirectoryStructure(Resource):
|
|||||||
|
|
||||||
if not doc_id:
|
if not doc_id:
|
||||||
return make_response(jsonify({"error": "Document ID is required"}), 400)
|
return make_response(jsonify({"error": "Document ID is required"}), 400)
|
||||||
if not ObjectId.is_valid(doc_id):
|
|
||||||
return make_response(jsonify({"error": "Invalid document ID"}), 400)
|
|
||||||
try:
|
try:
|
||||||
doc = sources_collection.find_one({"_id": ObjectId(doc_id), "user": user})
|
with db_readonly() as conn:
|
||||||
|
doc = SourcesRepository(conn).get_any(doc_id, user)
|
||||||
if not doc:
|
if not doc:
|
||||||
return make_response(
|
return make_response(
|
||||||
jsonify({"error": "Document not found or access denied"}), 404
|
jsonify({"error": "Document not found or access denied"}), 404
|
||||||
@@ -301,6 +377,8 @@ class DirectoryStructure(Resource):
|
|||||||
if isinstance(remote_data, str) and remote_data:
|
if isinstance(remote_data, str) and remote_data:
|
||||||
remote_data_obj = json.loads(remote_data)
|
remote_data_obj = json.loads(remote_data)
|
||||||
provider = remote_data_obj.get("provider")
|
provider = remote_data_obj.get("provider")
|
||||||
|
elif isinstance(remote_data, dict):
|
||||||
|
provider = remote_data.get("provider")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
current_app.logger.warning(
|
current_app.logger.warning(
|
||||||
f"Failed to parse remote_data for doc {doc_id}: {e}"
|
f"Failed to parse remote_data for doc {doc_id}: {e}"
|
||||||
@@ -320,4 +398,7 @@ class DirectoryStructure(Resource):
|
|||||||
current_app.logger.error(
|
current_app.logger.error(
|
||||||
f"Error retrieving directory structure: {e}", exc_info=True
|
f"Error retrieving directory structure: {e}", exc_info=True
|
||||||
)
|
)
|
||||||
return make_response(jsonify({"success": False, "error": str(e)}), 500)
|
return make_response(
|
||||||
|
jsonify({"success": False, "error": "Failed to retrieve directory structure"}),
|
||||||
|
500,
|
||||||
|
)
|
||||||
|
|||||||
@@ -5,16 +5,23 @@ import os
|
|||||||
import tempfile
|
import tempfile
|
||||||
import zipfile
|
import zipfile
|
||||||
|
|
||||||
from bson.objectid import ObjectId
|
|
||||||
from flask import current_app, jsonify, make_response, request
|
from flask import current_app, jsonify, make_response, request
|
||||||
from flask_restx import fields, Namespace, Resource
|
from flask_restx import fields, Namespace, Resource
|
||||||
|
|
||||||
from application.api import api
|
from application.api import api
|
||||||
from application.api.user.base import sources_collection
|
|
||||||
from application.api.user.tasks import ingest, ingest_connector_task, ingest_remote
|
from application.api.user.tasks import ingest, ingest_connector_task, ingest_remote
|
||||||
from application.core.settings import settings
|
from application.core.settings import settings
|
||||||
from application.parser.connectors.connector_creator import ConnectorCreator
|
from application.parser.connectors.connector_creator import ConnectorCreator
|
||||||
|
from application.parser.file.constants import SUPPORTED_SOURCE_EXTENSIONS
|
||||||
|
from application.storage.db.repositories.sources import SourcesRepository
|
||||||
|
from application.storage.db.session import db_readonly, db_session
|
||||||
from application.storage.storage_creator import StorageCreator
|
from application.storage.storage_creator import StorageCreator
|
||||||
|
from application.stt.upload_limits import (
|
||||||
|
AudioFileTooLargeError,
|
||||||
|
build_stt_file_size_limit_message,
|
||||||
|
enforce_audio_file_size_limit,
|
||||||
|
is_audio_filename,
|
||||||
|
)
|
||||||
from application.utils import check_required_fields, safe_filename
|
from application.utils import check_required_fields, safe_filename
|
||||||
|
|
||||||
|
|
||||||
@@ -23,6 +30,12 @@ sources_upload_ns = Namespace(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _enforce_audio_path_size_limit(file_path: str, filename: str) -> None:
|
||||||
|
if not is_audio_filename(filename):
|
||||||
|
return
|
||||||
|
enforce_audio_file_size_limit(os.path.getsize(file_path))
|
||||||
|
|
||||||
|
|
||||||
@sources_upload_ns.route("/upload")
|
@sources_upload_ns.route("/upload")
|
||||||
class UploadFile(Resource):
|
class UploadFile(Resource):
|
||||||
@api.expect(
|
@api.expect(
|
||||||
@@ -64,19 +77,28 @@ class UploadFile(Resource):
|
|||||||
safe_user = safe_filename(user)
|
safe_user = safe_filename(user)
|
||||||
dir_name = safe_filename(job_name)
|
dir_name = safe_filename(job_name)
|
||||||
base_path = f"{settings.UPLOAD_FOLDER}/{safe_user}/{dir_name}"
|
base_path = f"{settings.UPLOAD_FOLDER}/{safe_user}/{dir_name}"
|
||||||
|
file_name_map = {}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
storage = StorageCreator.get_storage()
|
storage = StorageCreator.get_storage()
|
||||||
|
|
||||||
for file in files:
|
for file in files:
|
||||||
original_filename = file.filename
|
original_filename = os.path.basename(file.filename)
|
||||||
safe_file = safe_filename(original_filename)
|
safe_file = safe_filename(original_filename)
|
||||||
|
if original_filename:
|
||||||
|
file_name_map[safe_file] = original_filename
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as temp_dir:
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
temp_file_path = os.path.join(temp_dir, safe_file)
|
temp_file_path = os.path.join(temp_dir, safe_file)
|
||||||
file.save(temp_file_path)
|
file.save(temp_file_path)
|
||||||
|
_enforce_audio_path_size_limit(temp_file_path, safe_file)
|
||||||
|
|
||||||
if zipfile.is_zipfile(temp_file_path):
|
# Only extract actual .zip files, not Office formats (.docx, .xlsx, .pptx)
|
||||||
|
# which are technically zip archives but should be processed as-is
|
||||||
|
is_office_format = safe_file.lower().endswith(
|
||||||
|
(".docx", ".xlsx", ".pptx", ".odt", ".ods", ".odp", ".epub")
|
||||||
|
)
|
||||||
|
if zipfile.is_zipfile(temp_file_path) and not is_office_format:
|
||||||
try:
|
try:
|
||||||
with zipfile.ZipFile(temp_file_path, "r") as zip_ref:
|
with zipfile.ZipFile(temp_file_path, "r") as zip_ref:
|
||||||
zip_ref.extractall(path=temp_dir)
|
zip_ref.extractall(path=temp_dir)
|
||||||
@@ -94,6 +116,10 @@ class UploadFile(Resource):
|
|||||||
os.path.join(root, extracted_file), temp_dir
|
os.path.join(root, extracted_file), temp_dir
|
||||||
)
|
)
|
||||||
storage_path = f"{base_path}/{rel_path}"
|
storage_path = f"{base_path}/{rel_path}"
|
||||||
|
_enforce_audio_path_size_limit(
|
||||||
|
os.path.join(root, extracted_file),
|
||||||
|
extracted_file,
|
||||||
|
)
|
||||||
|
|
||||||
with open(
|
with open(
|
||||||
os.path.join(root, extracted_file), "rb"
|
os.path.join(root, extracted_file), "rb"
|
||||||
@@ -116,27 +142,22 @@ class UploadFile(Resource):
|
|||||||
storage.save_file(f, file_path)
|
storage.save_file(f, file_path)
|
||||||
task = ingest.delay(
|
task = ingest.delay(
|
||||||
settings.UPLOAD_FOLDER,
|
settings.UPLOAD_FOLDER,
|
||||||
[
|
list(SUPPORTED_SOURCE_EXTENSIONS),
|
||||||
".rst",
|
|
||||||
".md",
|
|
||||||
".pdf",
|
|
||||||
".txt",
|
|
||||||
".docx",
|
|
||||||
".csv",
|
|
||||||
".epub",
|
|
||||||
".html",
|
|
||||||
".mdx",
|
|
||||||
".json",
|
|
||||||
".xlsx",
|
|
||||||
".pptx",
|
|
||||||
".png",
|
|
||||||
".jpg",
|
|
||||||
".jpeg",
|
|
||||||
],
|
|
||||||
job_name,
|
job_name,
|
||||||
user,
|
user,
|
||||||
file_path=base_path,
|
file_path=base_path,
|
||||||
filename=dir_name,
|
filename=dir_name,
|
||||||
|
file_name_map=file_name_map,
|
||||||
|
)
|
||||||
|
except AudioFileTooLargeError:
|
||||||
|
return make_response(
|
||||||
|
jsonify(
|
||||||
|
{
|
||||||
|
"success": False,
|
||||||
|
"message": build_stt_file_size_limit_message(),
|
||||||
|
}
|
||||||
|
),
|
||||||
|
413,
|
||||||
)
|
)
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
current_app.logger.error(f"Error uploading file: {err}", exc_info=True)
|
current_app.logger.error(f"Error uploading file: {err}", exc_info=True)
|
||||||
@@ -182,6 +203,8 @@ class UploadRemote(Resource):
|
|||||||
source_data = config.get("url")
|
source_data = config.get("url")
|
||||||
elif data["source"] == "reddit":
|
elif data["source"] == "reddit":
|
||||||
source_data = config
|
source_data = config
|
||||||
|
elif data["source"] == "s3":
|
||||||
|
source_data = config
|
||||||
elif data["source"] in ConnectorCreator.get_supported_connectors():
|
elif data["source"] in ConnectorCreator.get_supported_connectors():
|
||||||
session_token = config.get("session_token")
|
session_token = config.get("session_token")
|
||||||
if not session_token:
|
if not session_token:
|
||||||
@@ -306,15 +329,8 @@ class ManageSourceFiles(Resource):
|
|||||||
400,
|
400,
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
ObjectId(source_id)
|
with db_readonly() as conn:
|
||||||
except Exception:
|
source = SourcesRepository(conn).get_any(source_id, user)
|
||||||
return make_response(
|
|
||||||
jsonify({"success": False, "message": "Invalid source ID format"}), 400
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
source = sources_collection.find_one(
|
|
||||||
{"_id": ObjectId(source_id), "user": user}
|
|
||||||
)
|
|
||||||
if not source:
|
if not source:
|
||||||
return make_response(
|
return make_response(
|
||||||
jsonify(
|
jsonify(
|
||||||
@@ -330,10 +346,19 @@ class ManageSourceFiles(Resource):
|
|||||||
return make_response(
|
return make_response(
|
||||||
jsonify({"success": False, "message": "Database error"}), 500
|
jsonify({"success": False, "message": "Database error"}), 500
|
||||||
)
|
)
|
||||||
|
resolved_source_id = str(source["id"])
|
||||||
try:
|
try:
|
||||||
storage = StorageCreator.get_storage()
|
storage = StorageCreator.get_storage()
|
||||||
source_file_path = source.get("file_path", "")
|
source_file_path = source.get("file_path", "")
|
||||||
parent_dir = request.form.get("parent_dir", "")
|
parent_dir = request.form.get("parent_dir", "")
|
||||||
|
file_name_map = source.get("file_name_map") or {}
|
||||||
|
if isinstance(file_name_map, str):
|
||||||
|
try:
|
||||||
|
file_name_map = json.loads(file_name_map)
|
||||||
|
except Exception:
|
||||||
|
file_name_map = {}
|
||||||
|
if not isinstance(file_name_map, dict):
|
||||||
|
file_name_map = {}
|
||||||
|
|
||||||
if parent_dir and (parent_dir.startswith("/") or ".." in parent_dir):
|
if parent_dir and (parent_dir.startswith("/") or ".." in parent_dir):
|
||||||
return make_response(
|
return make_response(
|
||||||
@@ -355,24 +380,43 @@ class ManageSourceFiles(Resource):
|
|||||||
400,
|
400,
|
||||||
)
|
)
|
||||||
added_files = []
|
added_files = []
|
||||||
|
map_updated = False
|
||||||
|
|
||||||
target_dir = source_file_path
|
target_dir = source_file_path
|
||||||
if parent_dir:
|
if parent_dir:
|
||||||
target_dir = f"{source_file_path}/{parent_dir}"
|
target_dir = f"{source_file_path}/{parent_dir}"
|
||||||
for file in files:
|
for file in files:
|
||||||
if file.filename:
|
if file.filename:
|
||||||
safe_filename_str = safe_filename(file.filename)
|
original_filename = os.path.basename(file.filename)
|
||||||
|
safe_filename_str = safe_filename(original_filename)
|
||||||
file_path = f"{target_dir}/{safe_filename_str}"
|
file_path = f"{target_dir}/{safe_filename_str}"
|
||||||
|
|
||||||
# Save file to storage
|
# Save file to storage
|
||||||
|
|
||||||
storage.save_file(file, file_path)
|
storage.save_file(file, file_path)
|
||||||
added_files.append(safe_filename_str)
|
added_files.append(safe_filename_str)
|
||||||
|
if original_filename:
|
||||||
|
relative_key = (
|
||||||
|
f"{parent_dir}/{safe_filename_str}"
|
||||||
|
if parent_dir
|
||||||
|
else safe_filename_str
|
||||||
|
)
|
||||||
|
file_name_map[relative_key] = original_filename
|
||||||
|
map_updated = True
|
||||||
|
|
||||||
|
if map_updated:
|
||||||
|
with db_session() as conn:
|
||||||
|
SourcesRepository(conn).update(
|
||||||
|
resolved_source_id, user,
|
||||||
|
{"file_name_map": dict(file_name_map)},
|
||||||
|
)
|
||||||
# Trigger re-ingestion pipeline
|
# Trigger re-ingestion pipeline
|
||||||
|
|
||||||
from application.api.user.tasks import reingest_source_task
|
from application.api.user.tasks import reingest_source_task
|
||||||
|
|
||||||
task = reingest_source_task.delay(source_id=source_id, user=user)
|
task = reingest_source_task.delay(
|
||||||
|
source_id=resolved_source_id, user=user
|
||||||
|
)
|
||||||
|
|
||||||
return make_response(
|
return make_response(
|
||||||
jsonify(
|
jsonify(
|
||||||
@@ -414,7 +458,18 @@ class ManageSourceFiles(Resource):
|
|||||||
# Remove files from storage and directory structure
|
# Remove files from storage and directory structure
|
||||||
|
|
||||||
removed_files = []
|
removed_files = []
|
||||||
|
map_updated = False
|
||||||
for file_path in file_paths:
|
for file_path in file_paths:
|
||||||
|
if ".." in str(file_path) or str(file_path).startswith("/"):
|
||||||
|
return make_response(
|
||||||
|
jsonify(
|
||||||
|
{
|
||||||
|
"success": False,
|
||||||
|
"message": "Invalid file path",
|
||||||
|
}
|
||||||
|
),
|
||||||
|
400,
|
||||||
|
)
|
||||||
full_path = f"{source_file_path}/{file_path}"
|
full_path = f"{source_file_path}/{file_path}"
|
||||||
|
|
||||||
# Remove from storage
|
# Remove from storage
|
||||||
@@ -422,11 +477,23 @@ class ManageSourceFiles(Resource):
|
|||||||
if storage.file_exists(full_path):
|
if storage.file_exists(full_path):
|
||||||
storage.delete_file(full_path)
|
storage.delete_file(full_path)
|
||||||
removed_files.append(file_path)
|
removed_files.append(file_path)
|
||||||
|
if file_path in file_name_map:
|
||||||
|
file_name_map.pop(file_path, None)
|
||||||
|
map_updated = True
|
||||||
|
|
||||||
|
if map_updated and isinstance(file_name_map, dict):
|
||||||
|
with db_session() as conn:
|
||||||
|
SourcesRepository(conn).update(
|
||||||
|
resolved_source_id, user,
|
||||||
|
{"file_name_map": dict(file_name_map)},
|
||||||
|
)
|
||||||
# Trigger re-ingestion pipeline
|
# Trigger re-ingestion pipeline
|
||||||
|
|
||||||
from application.api.user.tasks import reingest_source_task
|
from application.api.user.tasks import reingest_source_task
|
||||||
|
|
||||||
task = reingest_source_task.delay(source_id=source_id, user=user)
|
task = reingest_source_task.delay(
|
||||||
|
source_id=resolved_source_id, user=user
|
||||||
|
)
|
||||||
|
|
||||||
return make_response(
|
return make_response(
|
||||||
jsonify(
|
jsonify(
|
||||||
@@ -504,12 +571,29 @@ class ManageSourceFiles(Resource):
|
|||||||
f"User: {user}, Source ID: {source_id}, Directory path: {directory_path}, "
|
f"User: {user}, Source ID: {source_id}, Directory path: {directory_path}, "
|
||||||
f"Full path: {full_directory_path}"
|
f"Full path: {full_directory_path}"
|
||||||
)
|
)
|
||||||
|
if directory_path and file_name_map:
|
||||||
|
prefix = f"{directory_path.rstrip('/')}/"
|
||||||
|
keys_to_remove = [
|
||||||
|
key
|
||||||
|
for key in file_name_map.keys()
|
||||||
|
if key == directory_path or key.startswith(prefix)
|
||||||
|
]
|
||||||
|
if keys_to_remove:
|
||||||
|
for key in keys_to_remove:
|
||||||
|
file_name_map.pop(key, None)
|
||||||
|
with db_session() as conn:
|
||||||
|
SourcesRepository(conn).update(
|
||||||
|
resolved_source_id, user,
|
||||||
|
{"file_name_map": dict(file_name_map)},
|
||||||
|
)
|
||||||
|
|
||||||
# Trigger re-ingestion pipeline
|
# Trigger re-ingestion pipeline
|
||||||
|
|
||||||
from application.api.user.tasks import reingest_source_task
|
from application.api.user.tasks import reingest_source_task
|
||||||
|
|
||||||
task = reingest_source_task.delay(source_id=source_id, user=user)
|
task = reingest_source_task.delay(
|
||||||
|
source_id=resolved_source_id, user=user
|
||||||
|
)
|
||||||
|
|
||||||
return make_response(
|
return make_response(
|
||||||
jsonify(
|
jsonify(
|
||||||
@@ -574,8 +658,9 @@ class TaskStatus(Resource):
|
|||||||
):
|
):
|
||||||
task_meta = str(task_meta) # Convert to a string representation
|
task_meta = str(task_meta) # Convert to a string representation
|
||||||
except ConnectionError as err:
|
except ConnectionError as err:
|
||||||
|
current_app.logger.error(f"Connection error getting task status: {err}")
|
||||||
return make_response(
|
return make_response(
|
||||||
jsonify({"success": False, "message": str(err)}), 503
|
jsonify({"success": False, "message": "Service unavailable"}), 503
|
||||||
)
|
)
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
current_app.logger.error(f"Error getting task status: {err}", exc_info=True)
|
current_app.logger.error(f"Error getting task status: {err}", exc_info=True)
|
||||||
|
|||||||
@@ -8,13 +8,25 @@ from application.worker import (
|
|||||||
mcp_oauth,
|
mcp_oauth,
|
||||||
mcp_oauth_status,
|
mcp_oauth_status,
|
||||||
remote_worker,
|
remote_worker,
|
||||||
|
sync,
|
||||||
sync_worker,
|
sync_worker,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@celery.task(bind=True)
|
@celery.task(bind=True)
|
||||||
def ingest(self, directory, formats, job_name, user, file_path, filename):
|
def ingest(
|
||||||
resp = ingest_worker(self, directory, formats, job_name, file_path, filename, user)
|
self, directory, formats, job_name, user, file_path, filename, file_name_map=None
|
||||||
|
):
|
||||||
|
resp = ingest_worker(
|
||||||
|
self,
|
||||||
|
directory,
|
||||||
|
formats,
|
||||||
|
job_name,
|
||||||
|
file_path,
|
||||||
|
filename,
|
||||||
|
user,
|
||||||
|
file_name_map=file_name_map,
|
||||||
|
)
|
||||||
return resp
|
return resp
|
||||||
|
|
||||||
|
|
||||||
@@ -38,6 +50,30 @@ def schedule_syncs(self, frequency):
|
|||||||
return resp
|
return resp
|
||||||
|
|
||||||
|
|
||||||
|
@celery.task(bind=True)
|
||||||
|
def sync_source(
|
||||||
|
self,
|
||||||
|
source_data,
|
||||||
|
job_name,
|
||||||
|
user,
|
||||||
|
loader,
|
||||||
|
sync_frequency,
|
||||||
|
retriever,
|
||||||
|
doc_id,
|
||||||
|
):
|
||||||
|
resp = sync(
|
||||||
|
self,
|
||||||
|
source_data,
|
||||||
|
job_name,
|
||||||
|
user,
|
||||||
|
loader,
|
||||||
|
sync_frequency,
|
||||||
|
retriever,
|
||||||
|
doc_id,
|
||||||
|
)
|
||||||
|
return resp
|
||||||
|
|
||||||
|
|
||||||
@celery.task(bind=True)
|
@celery.task(bind=True)
|
||||||
def store_attachment(self, file_info, user):
|
def store_attachment(self, file_info, user):
|
||||||
resp = attachment_worker(self, file_info, user)
|
resp = attachment_worker(self, file_info, user)
|
||||||
@@ -98,6 +134,12 @@ def setup_periodic_tasks(sender, **kwargs):
|
|||||||
timedelta(days=30),
|
timedelta(days=30),
|
||||||
schedule_syncs.s("monthly"),
|
schedule_syncs.s("monthly"),
|
||||||
)
|
)
|
||||||
|
# Replaces Mongo's TTL index on pending_tool_state.expires_at.
|
||||||
|
sender.add_periodic_task(
|
||||||
|
timedelta(seconds=60),
|
||||||
|
cleanup_pending_tool_state.s(),
|
||||||
|
name="cleanup-pending-tool-state",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@celery.task(bind=True)
|
@celery.task(bind=True)
|
||||||
@@ -110,3 +152,27 @@ def mcp_oauth_task(self, config, user):
|
|||||||
def mcp_oauth_status_task(self, task_id):
|
def mcp_oauth_status_task(self, task_id):
|
||||||
resp = mcp_oauth_status(self, task_id)
|
resp = mcp_oauth_status(self, task_id)
|
||||||
return resp
|
return resp
|
||||||
|
|
||||||
|
|
||||||
|
@celery.task(bind=True)
|
||||||
|
def cleanup_pending_tool_state(self):
|
||||||
|
"""Delete pending_tool_state rows past their TTL.
|
||||||
|
|
||||||
|
Replaces Mongo's ``expireAfterSeconds=0`` TTL index — Postgres has
|
||||||
|
no native TTL, so this task runs every 60 seconds to keep
|
||||||
|
``pending_tool_state`` bounded. No-ops if ``POSTGRES_URI`` isn't
|
||||||
|
configured (keeps the task runnable in Mongo-only environments).
|
||||||
|
"""
|
||||||
|
from application.core.settings import settings
|
||||||
|
if not settings.POSTGRES_URI:
|
||||||
|
return {"deleted": 0, "skipped": "POSTGRES_URI not set"}
|
||||||
|
|
||||||
|
from application.storage.db.engine import get_engine
|
||||||
|
from application.storage.db.repositories.pending_tool_state import (
|
||||||
|
PendingToolStateRepository,
|
||||||
|
)
|
||||||
|
|
||||||
|
engine = get_engine()
|
||||||
|
with engine.begin() as conn:
|
||||||
|
deleted = PendingToolStateRepository(conn).cleanup_expired()
|
||||||
|
return {"deleted": deleted}
|
||||||
|
|||||||
@@ -1,21 +1,80 @@
|
|||||||
"""Tool management MCP server integration."""
|
"""Tool management MCP server integration."""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
from email.quoprimime import unquote
|
from urllib.parse import urlencode, urlparse
|
||||||
|
|
||||||
from bson.objectid import ObjectId
|
|
||||||
from flask import current_app, jsonify, make_response, redirect, request
|
from flask import current_app, jsonify, make_response, redirect, request
|
||||||
from flask_restx import fields, Namespace, Resource
|
from flask_restx import Namespace, Resource, fields
|
||||||
|
|
||||||
from application.agents.tools.mcp_tool import MCPOAuthManager, MCPTool
|
from application.agents.tools.mcp_tool import MCPOAuthManager, MCPTool
|
||||||
from application.api import api
|
from application.api import api
|
||||||
from application.api.user.base import user_tools_collection
|
from application.api.user.tools.routes import transform_actions
|
||||||
from application.cache import get_redis_instance
|
from application.cache import get_redis_instance
|
||||||
from application.security.encryption import encrypt_credentials
|
from application.core.url_validation import SSRFError, validate_url
|
||||||
|
from application.security.encryption import decrypt_credentials, encrypt_credentials
|
||||||
|
from application.storage.db.repositories.connector_sessions import (
|
||||||
|
ConnectorSessionsRepository,
|
||||||
|
)
|
||||||
|
from application.storage.db.repositories.user_tools import UserToolsRepository
|
||||||
|
from application.storage.db.session import db_readonly, db_session
|
||||||
from application.utils import check_required_fields
|
from application.utils import check_required_fields
|
||||||
|
|
||||||
tools_mcp_ns = Namespace("tools", description="Tool management operations", path="/api")
|
tools_mcp_ns = Namespace("tools", description="Tool management operations", path="/api")
|
||||||
|
|
||||||
|
_ALLOWED_TRANSPORTS = {"auto", "sse", "http"}
|
||||||
|
|
||||||
|
|
||||||
|
def _sanitize_mcp_transport(config):
|
||||||
|
"""Normalise and validate the transport_type field.
|
||||||
|
|
||||||
|
Strips ``command`` / ``args`` keys that are only valid for local STDIO
|
||||||
|
transports and returns the cleaned transport type string.
|
||||||
|
"""
|
||||||
|
transport_type = (config.get("transport_type") or "auto").lower()
|
||||||
|
if transport_type not in _ALLOWED_TRANSPORTS:
|
||||||
|
raise ValueError(f"Unsupported transport_type: {transport_type}")
|
||||||
|
config.pop("command", None)
|
||||||
|
config.pop("args", None)
|
||||||
|
config["transport_type"] = transport_type
|
||||||
|
return transport_type
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_auth_credentials(config):
|
||||||
|
"""Build an ``auth_credentials`` dict from the raw MCP config."""
|
||||||
|
auth_credentials = {}
|
||||||
|
auth_type = config.get("auth_type", "none")
|
||||||
|
|
||||||
|
if auth_type == "api_key":
|
||||||
|
if config.get("api_key"):
|
||||||
|
auth_credentials["api_key"] = config["api_key"]
|
||||||
|
if config.get("api_key_header"):
|
||||||
|
auth_credentials["api_key_header"] = config["api_key_header"]
|
||||||
|
elif auth_type == "bearer":
|
||||||
|
if config.get("bearer_token"):
|
||||||
|
auth_credentials["bearer_token"] = config["bearer_token"]
|
||||||
|
elif auth_type == "basic":
|
||||||
|
if config.get("username"):
|
||||||
|
auth_credentials["username"] = config["username"]
|
||||||
|
if config.get("password"):
|
||||||
|
auth_credentials["password"] = config["password"]
|
||||||
|
|
||||||
|
return auth_credentials
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_mcp_server_url(config: dict) -> None:
|
||||||
|
"""Validate the server_url in an MCP config to prevent SSRF.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the URL is missing or points to a blocked address.
|
||||||
|
"""
|
||||||
|
server_url = (config.get("server_url") or "").strip()
|
||||||
|
if not server_url:
|
||||||
|
raise ValueError("server_url is required")
|
||||||
|
try:
|
||||||
|
validate_url(server_url)
|
||||||
|
except SSRFError as exc:
|
||||||
|
raise ValueError(f"Invalid server URL: {exc}") from exc
|
||||||
|
|
||||||
|
|
||||||
@tools_mcp_ns.route("/mcp_server/test")
|
@tools_mcp_ns.route("/mcp_server/test")
|
||||||
class TestMCPServerConfig(Resource):
|
class TestMCPServerConfig(Resource):
|
||||||
@@ -43,34 +102,63 @@ class TestMCPServerConfig(Resource):
|
|||||||
return missing_fields
|
return missing_fields
|
||||||
try:
|
try:
|
||||||
config = data["config"]
|
config = data["config"]
|
||||||
|
try:
|
||||||
|
_sanitize_mcp_transport(config)
|
||||||
|
except ValueError:
|
||||||
|
return make_response(
|
||||||
|
jsonify({"success": False, "error": "Unsupported transport_type"}),
|
||||||
|
400,
|
||||||
|
)
|
||||||
|
|
||||||
auth_credentials = {}
|
_validate_mcp_server_url(config)
|
||||||
auth_type = config.get("auth_type", "none")
|
|
||||||
|
|
||||||
if auth_type == "api_key" and "api_key" in config:
|
auth_credentials = _extract_auth_credentials(config)
|
||||||
auth_credentials["api_key"] = config["api_key"]
|
|
||||||
if "api_key_header" in config:
|
|
||||||
auth_credentials["api_key_header"] = config["api_key_header"]
|
|
||||||
elif auth_type == "bearer" and "bearer_token" in config:
|
|
||||||
auth_credentials["bearer_token"] = config["bearer_token"]
|
|
||||||
elif auth_type == "basic":
|
|
||||||
if "username" in config:
|
|
||||||
auth_credentials["username"] = config["username"]
|
|
||||||
if "password" in config:
|
|
||||||
auth_credentials["password"] = config["password"]
|
|
||||||
test_config = config.copy()
|
test_config = config.copy()
|
||||||
test_config["auth_credentials"] = auth_credentials
|
test_config["auth_credentials"] = auth_credentials
|
||||||
|
|
||||||
mcp_tool = MCPTool(config=test_config, user_id=user)
|
mcp_tool = MCPTool(config=test_config, user_id=user)
|
||||||
result = mcp_tool.test_connection()
|
result = mcp_tool.test_connection()
|
||||||
|
|
||||||
return make_response(jsonify(result), 200)
|
if result.get("requires_oauth"):
|
||||||
|
safe_result = {
|
||||||
|
k: v
|
||||||
|
for k, v in result.items()
|
||||||
|
if k in ("success", "requires_oauth", "auth_url")
|
||||||
|
}
|
||||||
|
return make_response(jsonify(safe_result), 200)
|
||||||
|
|
||||||
|
if not result.get("success"):
|
||||||
|
current_app.logger.error(
|
||||||
|
f"MCP connection test failed: {result.get('message')}"
|
||||||
|
)
|
||||||
|
return make_response(
|
||||||
|
jsonify(
|
||||||
|
{
|
||||||
|
"success": False,
|
||||||
|
"message": "Connection test failed",
|
||||||
|
"tools_count": 0,
|
||||||
|
}
|
||||||
|
),
|
||||||
|
200,
|
||||||
|
)
|
||||||
|
|
||||||
|
safe_result = {
|
||||||
|
"success": True,
|
||||||
|
"message": result.get("message", "Connection successful"),
|
||||||
|
"tools_count": result.get("tools_count", 0),
|
||||||
|
"tools": result.get("tools", []),
|
||||||
|
}
|
||||||
|
return make_response(jsonify(safe_result), 200)
|
||||||
|
except ValueError as e:
|
||||||
|
current_app.logger.warning(f"Invalid MCP server test request: {e}")
|
||||||
|
return make_response(
|
||||||
|
jsonify({"success": False, "error": "Invalid MCP server configuration"}),
|
||||||
|
400,
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
current_app.logger.error(f"Error testing MCP server: {e}", exc_info=True)
|
current_app.logger.error(f"Error testing MCP server: {e}", exc_info=True)
|
||||||
return make_response(
|
return make_response(
|
||||||
jsonify(
|
jsonify({"success": False, "error": "Connection test failed"}),
|
||||||
{"success": False, "error": f"Connection test failed: {str(e)}"}
|
|
||||||
),
|
|
||||||
500,
|
500,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -110,22 +198,18 @@ class MCPServerSave(Resource):
|
|||||||
return missing_fields
|
return missing_fields
|
||||||
try:
|
try:
|
||||||
config = data["config"]
|
config = data["config"]
|
||||||
|
try:
|
||||||
|
_sanitize_mcp_transport(config)
|
||||||
|
except ValueError:
|
||||||
|
return make_response(
|
||||||
|
jsonify({"success": False, "error": "Unsupported transport_type"}),
|
||||||
|
400,
|
||||||
|
)
|
||||||
|
|
||||||
auth_credentials = {}
|
_validate_mcp_server_url(config)
|
||||||
|
|
||||||
|
auth_credentials = _extract_auth_credentials(config)
|
||||||
auth_type = config.get("auth_type", "none")
|
auth_type = config.get("auth_type", "none")
|
||||||
if auth_type == "api_key":
|
|
||||||
if "api_key" in config and config["api_key"]:
|
|
||||||
auth_credentials["api_key"] = config["api_key"]
|
|
||||||
if "api_key_header" in config:
|
|
||||||
auth_credentials["api_key_header"] = config["api_key_header"]
|
|
||||||
elif auth_type == "bearer":
|
|
||||||
if "bearer_token" in config and config["bearer_token"]:
|
|
||||||
auth_credentials["bearer_token"] = config["bearer_token"]
|
|
||||||
elif auth_type == "basic":
|
|
||||||
if "username" in config and config["username"]:
|
|
||||||
auth_credentials["username"] = config["username"]
|
|
||||||
if "password" in config and config["password"]:
|
|
||||||
auth_credentials["password"] = config["password"]
|
|
||||||
mcp_config = config.copy()
|
mcp_config = config.copy()
|
||||||
mcp_config["auth_credentials"] = auth_credentials
|
mcp_config["auth_credentials"] = auth_credentials
|
||||||
|
|
||||||
@@ -163,79 +247,135 @@ class MCPServerSave(Resource):
|
|||||||
"No valid credentials provided for the selected authentication type"
|
"No valid credentials provided for the selected authentication type"
|
||||||
)
|
)
|
||||||
storage_config = config.copy()
|
storage_config = config.copy()
|
||||||
|
|
||||||
|
tool_id = data.get("id")
|
||||||
|
existing_doc = None
|
||||||
|
existing_encrypted = None
|
||||||
|
if tool_id:
|
||||||
|
with db_readonly() as conn:
|
||||||
|
repo = UserToolsRepository(conn)
|
||||||
|
existing_doc = repo.get_any(tool_id, user)
|
||||||
|
if existing_doc and existing_doc.get("name") == "mcp_tool":
|
||||||
|
existing_encrypted = (existing_doc.get("config") or {}).get(
|
||||||
|
"encrypted_credentials"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
existing_doc = None
|
||||||
|
|
||||||
if auth_credentials:
|
if auth_credentials:
|
||||||
encrypted_credentials_string = encrypt_credentials(
|
if existing_encrypted:
|
||||||
|
existing_secrets = decrypt_credentials(existing_encrypted, user)
|
||||||
|
existing_secrets.update(auth_credentials)
|
||||||
|
auth_credentials = existing_secrets
|
||||||
|
storage_config["encrypted_credentials"] = encrypt_credentials(
|
||||||
auth_credentials, user
|
auth_credentials, user
|
||||||
)
|
)
|
||||||
storage_config["encrypted_credentials"] = encrypted_credentials_string
|
elif existing_encrypted:
|
||||||
|
storage_config["encrypted_credentials"] = existing_encrypted
|
||||||
|
|
||||||
for field in [
|
for field in [
|
||||||
"api_key",
|
"api_key",
|
||||||
"bearer_token",
|
"bearer_token",
|
||||||
"username",
|
"username",
|
||||||
"password",
|
"password",
|
||||||
"api_key_header",
|
"api_key_header",
|
||||||
|
"redirect_uri",
|
||||||
]:
|
]:
|
||||||
storage_config.pop(field, None)
|
storage_config.pop(field, None)
|
||||||
transformed_actions = []
|
transformed_actions = transform_actions(actions_metadata)
|
||||||
for action in actions_metadata:
|
|
||||||
action["active"] = True
|
|
||||||
if "parameters" in action:
|
|
||||||
if "properties" in action["parameters"]:
|
|
||||||
for param_name, param_details in action["parameters"][
|
|
||||||
"properties"
|
|
||||||
].items():
|
|
||||||
param_details["filled_by_llm"] = True
|
|
||||||
param_details["value"] = ""
|
|
||||||
transformed_actions.append(action)
|
|
||||||
tool_data = {
|
|
||||||
"name": "mcp_tool",
|
|
||||||
"displayName": data["displayName"],
|
|
||||||
"customName": data["displayName"],
|
|
||||||
"description": f"MCP Server: {storage_config.get('server_url', 'Unknown')}",
|
|
||||||
"config": storage_config,
|
|
||||||
"actions": transformed_actions,
|
|
||||||
"status": data.get("status", True),
|
|
||||||
"user": user,
|
|
||||||
}
|
|
||||||
|
|
||||||
tool_id = data.get("id")
|
display_name = data["displayName"]
|
||||||
if tool_id:
|
description = f"MCP Server: {storage_config.get('server_url', 'Unknown')}"
|
||||||
result = user_tools_collection.update_one(
|
status_bool = bool(data.get("status", True))
|
||||||
{"_id": ObjectId(tool_id), "user": user, "name": "mcp_tool"},
|
|
||||||
{"$set": {k: v for k, v in tool_data.items() if k != "user"}},
|
with db_session() as conn:
|
||||||
)
|
repo = UserToolsRepository(conn)
|
||||||
if result.matched_count == 0:
|
if existing_doc:
|
||||||
return make_response(
|
repo.update(
|
||||||
jsonify(
|
str(existing_doc["id"]), user,
|
||||||
{
|
{
|
||||||
"success": False,
|
"display_name": display_name,
|
||||||
"error": "Tool not found or access denied",
|
"custom_name": display_name,
|
||||||
}
|
"description": description,
|
||||||
),
|
"config": storage_config,
|
||||||
404,
|
"actions": transformed_actions,
|
||||||
|
"status": status_bool,
|
||||||
|
},
|
||||||
)
|
)
|
||||||
response_data = {
|
saved_id = str(existing_doc["id"])
|
||||||
"success": True,
|
response_data = {
|
||||||
"id": tool_id,
|
"success": True,
|
||||||
"message": f"MCP server updated successfully! Discovered {len(transformed_actions)} tools.",
|
"id": saved_id,
|
||||||
"tools_count": len(transformed_actions),
|
"message": f"MCP server updated successfully! Discovered {len(transformed_actions)} tools.",
|
||||||
}
|
"tools_count": len(transformed_actions),
|
||||||
else:
|
}
|
||||||
result = user_tools_collection.insert_one(tool_data)
|
else:
|
||||||
tool_id = str(result.inserted_id)
|
# Fall back to find_by_user_and_name — the original
|
||||||
response_data = {
|
# dual-write path also ran an existence check before
|
||||||
"success": True,
|
# deciding between insert and update.
|
||||||
"id": tool_id,
|
existing_by_name = repo.find_by_user_and_name(user, "mcp_tool")
|
||||||
"message": f"MCP server created successfully! Discovered {len(transformed_actions)} tools.",
|
if tool_id is None and existing_by_name and (
|
||||||
"tools_count": len(transformed_actions),
|
(existing_by_name.get("config") or {}).get("server_url")
|
||||||
}
|
== storage_config.get("server_url")
|
||||||
|
):
|
||||||
|
repo.update(
|
||||||
|
str(existing_by_name["id"]), user,
|
||||||
|
{
|
||||||
|
"display_name": display_name,
|
||||||
|
"custom_name": display_name,
|
||||||
|
"description": description,
|
||||||
|
"config": storage_config,
|
||||||
|
"actions": transformed_actions,
|
||||||
|
"status": status_bool,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
saved_id = str(existing_by_name["id"])
|
||||||
|
response_data = {
|
||||||
|
"success": True,
|
||||||
|
"id": saved_id,
|
||||||
|
"message": f"MCP server updated successfully! Discovered {len(transformed_actions)} tools.",
|
||||||
|
"tools_count": len(transformed_actions),
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
created = repo.create(
|
||||||
|
user, "mcp_tool",
|
||||||
|
config=storage_config,
|
||||||
|
custom_name=display_name,
|
||||||
|
display_name=display_name,
|
||||||
|
description=description,
|
||||||
|
config_requirements={},
|
||||||
|
actions=transformed_actions,
|
||||||
|
status=status_bool,
|
||||||
|
)
|
||||||
|
saved_id = str(created["id"])
|
||||||
|
response_data = {
|
||||||
|
"success": True,
|
||||||
|
"id": saved_id,
|
||||||
|
"message": f"MCP server created successfully! Discovered {len(transformed_actions)} tools.",
|
||||||
|
"tools_count": len(transformed_actions),
|
||||||
|
}
|
||||||
|
if tool_id and existing_doc is None:
|
||||||
|
# Client requested update on a non-existent tool id.
|
||||||
|
return make_response(
|
||||||
|
jsonify(
|
||||||
|
{
|
||||||
|
"success": False,
|
||||||
|
"error": "Tool not found or access denied",
|
||||||
|
}
|
||||||
|
),
|
||||||
|
404,
|
||||||
|
)
|
||||||
return make_response(jsonify(response_data), 200)
|
return make_response(jsonify(response_data), 200)
|
||||||
|
except ValueError as e:
|
||||||
|
current_app.logger.warning(f"Invalid MCP server save request: {e}")
|
||||||
|
return make_response(
|
||||||
|
jsonify({"success": False, "error": "Invalid MCP server configuration"}),
|
||||||
|
400,
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
current_app.logger.error(f"Error saving MCP server: {e}", exc_info=True)
|
current_app.logger.error(f"Error saving MCP server: {e}", exc_info=True)
|
||||||
return make_response(
|
return make_response(
|
||||||
jsonify(
|
jsonify({"success": False, "error": "Failed to save MCP server"}),
|
||||||
{"success": False, "error": f"Failed to save MCP server: {str(e)}"}
|
|
||||||
),
|
|
||||||
500,
|
500,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -263,9 +403,12 @@ class MCPOAuthCallback(Resource):
|
|||||||
error = request.args.get("error")
|
error = request.args.get("error")
|
||||||
|
|
||||||
if error:
|
if error:
|
||||||
return redirect(
|
params = {
|
||||||
f"/api/connectors/callback-status?status=error&message=OAuth+error:+{error}.+Please+try+again+and+make+sure+to+grant+all+requested+permissions,+including+offline+access.&provider=mcp_tool"
|
"status": "error",
|
||||||
)
|
"message": f"OAuth error: {error}. Please try again and make sure to grant all requested permissions, including offline access.",
|
||||||
|
"provider": "mcp_tool",
|
||||||
|
}
|
||||||
|
return redirect(f"/api/connectors/callback-status?{urlencode(params)}")
|
||||||
if not code or not state:
|
if not code or not state:
|
||||||
return redirect(
|
return redirect(
|
||||||
"/api/connectors/callback-status?status=error&message=Authorization+code+or+state+not+provided.+Please+complete+the+authorization+process+and+make+sure+to+grant+offline+access.&provider=mcp_tool"
|
"/api/connectors/callback-status?status=error&message=Authorization+code+or+state+not+provided.+Please+complete+the+authorization+process+and+make+sure+to+grant+offline+access.&provider=mcp_tool"
|
||||||
@@ -276,7 +419,6 @@ class MCPOAuthCallback(Resource):
|
|||||||
return redirect(
|
return redirect(
|
||||||
"/api/connectors/callback-status?status=error&message=Internal+server+error:+Redis+not+available.&provider=mcp_tool"
|
"/api/connectors/callback-status?status=error&message=Internal+server+error:+Redis+not+available.&provider=mcp_tool"
|
||||||
)
|
)
|
||||||
code = unquote(code)
|
|
||||||
manager = MCPOAuthManager(redis_client)
|
manager = MCPOAuthManager(redis_client)
|
||||||
success = manager.handle_oauth_callback(state, code, error)
|
success = manager.handle_oauth_callback(state, code, error)
|
||||||
if success:
|
if success:
|
||||||
@@ -292,17 +434,13 @@ class MCPOAuthCallback(Resource):
|
|||||||
f"Error handling MCP OAuth callback: {str(e)}", exc_info=True
|
f"Error handling MCP OAuth callback: {str(e)}", exc_info=True
|
||||||
)
|
)
|
||||||
return redirect(
|
return redirect(
|
||||||
f"/api/connectors/callback-status?status=error&message=Internal+server+error:+{str(e)}.&provider=mcp_tool"
|
"/api/connectors/callback-status?status=error&message=Internal+server+error.&provider=mcp_tool"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@tools_mcp_ns.route("/mcp_server/oauth_status/<string:task_id>")
|
@tools_mcp_ns.route("/mcp_server/oauth_status/<string:task_id>")
|
||||||
class MCPOAuthStatus(Resource):
|
class MCPOAuthStatus(Resource):
|
||||||
def get(self, task_id):
|
def get(self, task_id):
|
||||||
"""
|
|
||||||
Get current status of OAuth flow.
|
|
||||||
Frontend should poll this endpoint periodically.
|
|
||||||
"""
|
|
||||||
try:
|
try:
|
||||||
redis_client = get_redis_instance()
|
redis_client = get_redis_instance()
|
||||||
status_key = f"mcp_oauth_status:{task_id}"
|
status_key = f"mcp_oauth_status:{task_id}"
|
||||||
@@ -310,6 +448,14 @@ class MCPOAuthStatus(Resource):
|
|||||||
|
|
||||||
if status_data:
|
if status_data:
|
||||||
status = json.loads(status_data)
|
status = json.loads(status_data)
|
||||||
|
if "tools" in status and isinstance(status["tools"], list):
|
||||||
|
status["tools"] = [
|
||||||
|
{
|
||||||
|
"name": t.get("name", "unknown"),
|
||||||
|
"description": t.get("description", ""),
|
||||||
|
}
|
||||||
|
for t in status["tools"]
|
||||||
|
]
|
||||||
return make_response(
|
return make_response(
|
||||||
jsonify({"success": True, "task_id": task_id, **status})
|
jsonify({"success": True, "task_id": task_id, **status})
|
||||||
)
|
)
|
||||||
@@ -317,17 +463,103 @@ class MCPOAuthStatus(Resource):
|
|||||||
return make_response(
|
return make_response(
|
||||||
jsonify(
|
jsonify(
|
||||||
{
|
{
|
||||||
"success": False,
|
"success": True,
|
||||||
"error": "Task not found or expired",
|
|
||||||
"task_id": task_id,
|
"task_id": task_id,
|
||||||
|
"status": "pending",
|
||||||
|
"message": "Waiting for OAuth to start...",
|
||||||
}
|
}
|
||||||
),
|
),
|
||||||
404,
|
200,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
current_app.logger.error(
|
current_app.logger.error(
|
||||||
f"Error getting OAuth status for task {task_id}: {str(e)}"
|
f"Error getting OAuth status for task {task_id}: {str(e)}",
|
||||||
|
exc_info=True,
|
||||||
)
|
)
|
||||||
return make_response(
|
return make_response(
|
||||||
jsonify({"success": False, "error": str(e), "task_id": task_id}), 500
|
jsonify(
|
||||||
|
{
|
||||||
|
"success": False,
|
||||||
|
"error": "Failed to get OAuth status",
|
||||||
|
"task_id": task_id,
|
||||||
|
}
|
||||||
|
),
|
||||||
|
500,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@tools_mcp_ns.route("/mcp_server/auth_status")
|
||||||
|
class MCPAuthStatus(Resource):
|
||||||
|
@api.doc(
|
||||||
|
description="Batch check auth status for all MCP tools. "
|
||||||
|
"Lightweight DB-only check — no network calls to MCP servers."
|
||||||
|
)
|
||||||
|
def get(self):
|
||||||
|
decoded_token = request.decoded_token
|
||||||
|
if not decoded_token:
|
||||||
|
return make_response(jsonify({"success": False}), 401)
|
||||||
|
user = decoded_token.get("sub")
|
||||||
|
try:
|
||||||
|
with db_readonly() as conn:
|
||||||
|
tools_repo = UserToolsRepository(conn)
|
||||||
|
sessions_repo = ConnectorSessionsRepository(conn)
|
||||||
|
all_tools = tools_repo.list_for_user(user)
|
||||||
|
mcp_tools = [t for t in all_tools if t.get("name") == "mcp_tool"]
|
||||||
|
if not mcp_tools:
|
||||||
|
return make_response(
|
||||||
|
jsonify({"success": True, "statuses": {}}), 200
|
||||||
|
)
|
||||||
|
|
||||||
|
oauth_server_urls: dict = {}
|
||||||
|
statuses: dict = {}
|
||||||
|
for tool in mcp_tools:
|
||||||
|
tool_id = str(tool["id"])
|
||||||
|
config = tool.get("config") or {}
|
||||||
|
auth_type = config.get("auth_type", "none")
|
||||||
|
if auth_type == "oauth":
|
||||||
|
server_url = config.get("server_url", "")
|
||||||
|
if server_url:
|
||||||
|
parsed = urlparse(server_url)
|
||||||
|
base_url = f"{parsed.scheme}://{parsed.netloc}"
|
||||||
|
oauth_server_urls[tool_id] = base_url
|
||||||
|
else:
|
||||||
|
statuses[tool_id] = "needs_auth"
|
||||||
|
else:
|
||||||
|
statuses[tool_id] = "configured"
|
||||||
|
|
||||||
|
if oauth_server_urls:
|
||||||
|
# Look up a session per distinct base URL. MCP sessions
|
||||||
|
# are stored with ``provider = "mcp:<server_url>"``
|
||||||
|
# and the URL in ``server_url``; reuse the repo's
|
||||||
|
# per-URL accessor rather than an ad-hoc $in query.
|
||||||
|
url_has_tokens: dict = {}
|
||||||
|
for base_url in set(oauth_server_urls.values()):
|
||||||
|
session = sessions_repo.get_by_user_and_server_url(
|
||||||
|
user, base_url,
|
||||||
|
)
|
||||||
|
tokens = (
|
||||||
|
(session or {}).get("session_data", {}) or {}
|
||||||
|
).get("tokens", {}) or {}
|
||||||
|
# MCP code also stashes tokens into token_info on
|
||||||
|
# the row; consider either present as "connected".
|
||||||
|
token_info = (session or {}).get("token_info") or {}
|
||||||
|
url_has_tokens[base_url] = bool(
|
||||||
|
tokens.get("access_token")
|
||||||
|
or token_info.get("access_token")
|
||||||
|
)
|
||||||
|
|
||||||
|
for tool_id, base_url in oauth_server_urls.items():
|
||||||
|
if url_has_tokens.get(base_url):
|
||||||
|
statuses[tool_id] = "connected"
|
||||||
|
else:
|
||||||
|
statuses[tool_id] = "needs_auth"
|
||||||
|
|
||||||
|
return make_response(jsonify({"success": True, "statuses": statuses}), 200)
|
||||||
|
except Exception as e:
|
||||||
|
current_app.logger.error(
|
||||||
|
"Error checking MCP auth status: %s", e, exc_info=True
|
||||||
|
)
|
||||||
|
return make_response(
|
||||||
|
jsonify({"success": False, "error": "Failed to check auth status"}),
|
||||||
|
500,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,19 +1,167 @@
|
|||||||
"""Tool management routes."""
|
"""Tool management routes."""
|
||||||
|
|
||||||
from bson.objectid import ObjectId
|
|
||||||
from flask import current_app, jsonify, make_response, request
|
from flask import current_app, jsonify, make_response, request
|
||||||
from flask_restx import fields, Namespace, Resource
|
from flask_restx import fields, Namespace, Resource
|
||||||
|
|
||||||
|
from application.agents.tools.spec_parser import parse_spec
|
||||||
from application.agents.tools.tool_manager import ToolManager
|
from application.agents.tools.tool_manager import ToolManager
|
||||||
from application.api import api
|
from application.api import api
|
||||||
from application.api.user.base import user_tools_collection
|
from application.core.url_validation import SSRFError, validate_url
|
||||||
from application.security.encryption import decrypt_credentials, encrypt_credentials
|
from application.security.encryption import decrypt_credentials, encrypt_credentials
|
||||||
|
from application.storage.db.repositories.notes import NotesRepository
|
||||||
|
from application.storage.db.repositories.todos import TodosRepository
|
||||||
|
from application.storage.db.repositories.user_tools import UserToolsRepository
|
||||||
|
from application.storage.db.session import db_readonly, db_session
|
||||||
from application.utils import check_required_fields, validate_function_name
|
from application.utils import check_required_fields, validate_function_name
|
||||||
|
|
||||||
tool_config = {}
|
tool_config = {}
|
||||||
tool_manager = ToolManager(config=tool_config)
|
tool_manager = ToolManager(config=tool_config)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Shape translation helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# The frontend speaks camelCase (``displayName`` / ``customName`` /
|
||||||
|
# ``configRequirements``). The PG ``user_tools`` table stores snake_case
|
||||||
|
# (``display_name`` / ``custom_name`` / ``config_requirements``). Keep the
|
||||||
|
# translation localized to this module so repositories stay pure.
|
||||||
|
|
||||||
|
_CAMEL_TO_SNAKE = {
|
||||||
|
"displayName": "display_name",
|
||||||
|
"customName": "custom_name",
|
||||||
|
"configRequirements": "config_requirements",
|
||||||
|
}
|
||||||
|
_SNAKE_TO_CAMEL = {v: k for k, v in _CAMEL_TO_SNAKE.items()}
|
||||||
|
|
||||||
|
|
||||||
|
def _row_to_api(row: dict) -> dict:
|
||||||
|
"""Rename DB-native snake_case keys to the camelCase shape the frontend expects."""
|
||||||
|
out = dict(row)
|
||||||
|
for snake, camel in _SNAKE_TO_CAMEL.items():
|
||||||
|
if snake in out:
|
||||||
|
out[camel] = out.pop(snake)
|
||||||
|
# ``user_id`` is exposed as ``user`` in the legacy API shape.
|
||||||
|
if "user_id" in out:
|
||||||
|
out["user"] = out.pop("user_id")
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def _api_to_update_fields(data: dict) -> dict:
|
||||||
|
"""Rename incoming camelCase update keys to the repo's snake_case columns."""
|
||||||
|
fields_out: dict = {}
|
||||||
|
for key, value in data.items():
|
||||||
|
fields_out[_CAMEL_TO_SNAKE.get(key, key)] = value
|
||||||
|
return fields_out
|
||||||
|
|
||||||
|
|
||||||
|
def _encrypt_secret_fields(config, config_requirements, user_id):
|
||||||
|
secret_keys = [
|
||||||
|
key for key, spec in config_requirements.items()
|
||||||
|
if spec.get("secret") and key in config and config[key]
|
||||||
|
]
|
||||||
|
if not secret_keys:
|
||||||
|
return config
|
||||||
|
|
||||||
|
storage_config = config.copy()
|
||||||
|
secret_values = {k: config[k] for k in secret_keys}
|
||||||
|
storage_config["encrypted_credentials"] = encrypt_credentials(secret_values, user_id)
|
||||||
|
for key in secret_keys:
|
||||||
|
storage_config.pop(key, None)
|
||||||
|
return storage_config
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_config(config, config_requirements, has_existing_secrets=False):
|
||||||
|
errors = {}
|
||||||
|
for key, spec in config_requirements.items():
|
||||||
|
depends_on = spec.get("depends_on")
|
||||||
|
if depends_on:
|
||||||
|
if not all(config.get(dk) == dv for dk, dv in depends_on.items()):
|
||||||
|
continue
|
||||||
|
if spec.get("required") and not config.get(key):
|
||||||
|
if has_existing_secrets and spec.get("secret"):
|
||||||
|
continue
|
||||||
|
errors[key] = f"{spec.get('label', key)} is required"
|
||||||
|
value = config.get(key)
|
||||||
|
if value is not None and value != "":
|
||||||
|
if spec.get("type") == "number":
|
||||||
|
try:
|
||||||
|
num = float(value)
|
||||||
|
if key == "timeout" and (num < 1 or num > 300):
|
||||||
|
errors[key] = "Timeout must be between 1 and 300"
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
errors[key] = f"{spec.get('label', key)} must be a number"
|
||||||
|
if spec.get("enum") and value not in spec["enum"]:
|
||||||
|
errors[key] = f"Invalid value for {spec.get('label', key)}"
|
||||||
|
return errors
|
||||||
|
|
||||||
|
|
||||||
|
def _merge_secrets_on_update(new_config, existing_config, config_requirements, user_id):
|
||||||
|
"""Merge incoming config with existing encrypted secrets and re-encrypt.
|
||||||
|
|
||||||
|
For updates, the client may omit unchanged secret values. This helper
|
||||||
|
decrypts any previously stored secrets, overlays whatever the client *did*
|
||||||
|
send, strips plain-text secrets from the stored config, and re-encrypts
|
||||||
|
the merged result.
|
||||||
|
|
||||||
|
Returns the final ``config`` dict ready for persistence.
|
||||||
|
"""
|
||||||
|
secret_keys = [
|
||||||
|
key for key, spec in config_requirements.items()
|
||||||
|
if spec.get("secret")
|
||||||
|
]
|
||||||
|
|
||||||
|
if not secret_keys:
|
||||||
|
return new_config
|
||||||
|
|
||||||
|
existing_secrets = {}
|
||||||
|
if "encrypted_credentials" in existing_config:
|
||||||
|
existing_secrets = decrypt_credentials(
|
||||||
|
existing_config["encrypted_credentials"], user_id
|
||||||
|
)
|
||||||
|
|
||||||
|
merged_secrets = existing_secrets.copy()
|
||||||
|
for key in secret_keys:
|
||||||
|
if key in new_config and new_config[key]:
|
||||||
|
merged_secrets[key] = new_config[key]
|
||||||
|
|
||||||
|
# Start from existing non-secret values, then overlay incoming non-secrets
|
||||||
|
storage_config = {
|
||||||
|
k: v for k, v in existing_config.items()
|
||||||
|
if k not in secret_keys and k != "encrypted_credentials"
|
||||||
|
}
|
||||||
|
storage_config.update(
|
||||||
|
{k: v for k, v in new_config.items() if k not in secret_keys}
|
||||||
|
)
|
||||||
|
|
||||||
|
if merged_secrets:
|
||||||
|
storage_config["encrypted_credentials"] = encrypt_credentials(
|
||||||
|
merged_secrets, user_id
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
storage_config.pop("encrypted_credentials", None)
|
||||||
|
|
||||||
|
storage_config.pop("has_encrypted_credentials", None)
|
||||||
|
return storage_config
|
||||||
|
|
||||||
|
|
||||||
|
def transform_actions(actions_metadata):
|
||||||
|
"""Set default flags on action metadata for storage.
|
||||||
|
|
||||||
|
Marks each action as active, sets ``filled_by_llm`` and ``value`` on every
|
||||||
|
parameter property. Used by both the generic create_tool and MCP save routes.
|
||||||
|
"""
|
||||||
|
transformed = []
|
||||||
|
for action in actions_metadata:
|
||||||
|
action["active"] = True
|
||||||
|
if "parameters" in action:
|
||||||
|
props = action["parameters"].get("properties", {})
|
||||||
|
for param_details in props.values():
|
||||||
|
param_details["filled_by_llm"] = True
|
||||||
|
param_details["value"] = ""
|
||||||
|
transformed.append(action)
|
||||||
|
return transformed
|
||||||
|
|
||||||
|
|
||||||
tools_ns = Namespace("tools", description="Tool management operations", path="/api")
|
tools_ns = Namespace("tools", description="Tool management operations", path="/api")
|
||||||
|
|
||||||
|
|
||||||
@@ -21,6 +169,8 @@ tools_ns = Namespace("tools", description="Tool management operations", path="/a
|
|||||||
class AvailableTools(Resource):
|
class AvailableTools(Resource):
|
||||||
@api.doc(description="Get available tools for a user")
|
@api.doc(description="Get available tools for a user")
|
||||||
def get(self):
|
def get(self):
|
||||||
|
if not request.decoded_token:
|
||||||
|
return make_response(jsonify({"success": False}), 401)
|
||||||
try:
|
try:
|
||||||
tools_metadata = []
|
tools_metadata = []
|
||||||
for tool_name, tool_instance in tool_manager.tools.items():
|
for tool_name, tool_instance in tool_manager.tools.items():
|
||||||
@@ -28,12 +178,15 @@ class AvailableTools(Resource):
|
|||||||
lines = doc.split("\n", 1)
|
lines = doc.split("\n", 1)
|
||||||
name = lines[0].strip()
|
name = lines[0].strip()
|
||||||
description = lines[1].strip() if len(lines) > 1 else ""
|
description = lines[1].strip() if len(lines) > 1 else ""
|
||||||
|
config_req = tool_instance.get_config_requirements()
|
||||||
|
actions = tool_instance.get_actions_metadata()
|
||||||
tools_metadata.append(
|
tools_metadata.append(
|
||||||
{
|
{
|
||||||
"name": tool_name,
|
"name": tool_name,
|
||||||
"displayName": name,
|
"displayName": name,
|
||||||
"description": description,
|
"description": description,
|
||||||
"configRequirements": tool_instance.get_config_requirements(),
|
"configRequirements": config_req,
|
||||||
|
"actions": actions,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
@@ -53,12 +206,26 @@ class GetTools(Resource):
|
|||||||
if not decoded_token:
|
if not decoded_token:
|
||||||
return make_response(jsonify({"success": False}), 401)
|
return make_response(jsonify({"success": False}), 401)
|
||||||
user = decoded_token.get("sub")
|
user = decoded_token.get("sub")
|
||||||
tools = user_tools_collection.find({"user": user})
|
with db_readonly() as conn:
|
||||||
|
rows = UserToolsRepository(conn).list_for_user(user)
|
||||||
user_tools = []
|
user_tools = []
|
||||||
for tool in tools:
|
for row in rows:
|
||||||
tool_copy = {**tool}
|
tool_copy = _row_to_api(row)
|
||||||
tool_copy["id"] = str(tool["_id"])
|
|
||||||
tool_copy.pop("_id", None)
|
config_req = tool_copy.get("configRequirements", {})
|
||||||
|
if not config_req:
|
||||||
|
tool_instance = tool_manager.tools.get(tool_copy.get("name"))
|
||||||
|
if tool_instance:
|
||||||
|
config_req = tool_instance.get_config_requirements()
|
||||||
|
tool_copy["configRequirements"] = config_req
|
||||||
|
|
||||||
|
has_secrets = any(
|
||||||
|
spec.get("secret") for spec in config_req.values()
|
||||||
|
) if config_req else False
|
||||||
|
if has_secrets and "encrypted_credentials" in tool_copy.get("config", {}):
|
||||||
|
tool_copy["config"]["has_encrypted_credentials"] = True
|
||||||
|
tool_copy["config"].pop("encrypted_credentials", None)
|
||||||
|
|
||||||
user_tools.append(tool_copy)
|
user_tools.append(tool_copy)
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
current_app.logger.error(f"Error getting user tools: {err}", exc_info=True)
|
current_app.logger.error(f"Error getting user tools: {err}", exc_info=True)
|
||||||
@@ -109,41 +276,61 @@ class CreateTool(Resource):
|
|||||||
if missing_fields:
|
if missing_fields:
|
||||||
return missing_fields
|
return missing_fields
|
||||||
try:
|
try:
|
||||||
|
if data["name"] == "mcp_tool":
|
||||||
|
server_url = (data.get("config", {}).get("server_url") or "").strip()
|
||||||
|
if server_url:
|
||||||
|
try:
|
||||||
|
validate_url(server_url)
|
||||||
|
except SSRFError:
|
||||||
|
return make_response(
|
||||||
|
jsonify({"success": False, "message": "Invalid server URL"}),
|
||||||
|
400,
|
||||||
|
)
|
||||||
tool_instance = tool_manager.tools.get(data["name"])
|
tool_instance = tool_manager.tools.get(data["name"])
|
||||||
if not tool_instance:
|
if not tool_instance:
|
||||||
return make_response(
|
return make_response(
|
||||||
jsonify({"success": False, "message": "Tool not found"}), 404
|
jsonify({"success": False, "message": "Tool not found"}), 404
|
||||||
)
|
)
|
||||||
actions_metadata = tool_instance.get_actions_metadata()
|
actions_metadata = tool_instance.get_actions_metadata()
|
||||||
transformed_actions = []
|
transformed_actions = transform_actions(actions_metadata)
|
||||||
for action in actions_metadata:
|
|
||||||
action["active"] = True
|
|
||||||
if "parameters" in action:
|
|
||||||
if "properties" in action["parameters"]:
|
|
||||||
for param_name, param_details in action["parameters"][
|
|
||||||
"properties"
|
|
||||||
].items():
|
|
||||||
param_details["filled_by_llm"] = True
|
|
||||||
param_details["value"] = ""
|
|
||||||
transformed_actions.append(action)
|
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
current_app.logger.error(
|
current_app.logger.error(
|
||||||
f"Error getting tool actions: {err}", exc_info=True
|
f"Error getting tool actions: {err}", exc_info=True
|
||||||
)
|
)
|
||||||
return make_response(jsonify({"success": False}), 400)
|
return make_response(jsonify({"success": False}), 400)
|
||||||
try:
|
try:
|
||||||
new_tool = {
|
config_requirements = tool_instance.get_config_requirements()
|
||||||
"user": user,
|
if config_requirements:
|
||||||
"name": data["name"],
|
validation_errors = _validate_config(
|
||||||
"displayName": data["displayName"],
|
data["config"], config_requirements
|
||||||
"description": data["description"],
|
)
|
||||||
"customName": data.get("customName", ""),
|
if validation_errors:
|
||||||
"actions": transformed_actions,
|
return make_response(
|
||||||
"config": data["config"],
|
jsonify(
|
||||||
"status": data["status"],
|
{
|
||||||
}
|
"success": False,
|
||||||
resp = user_tools_collection.insert_one(new_tool)
|
"message": "Validation failed",
|
||||||
new_id = str(resp.inserted_id)
|
"errors": validation_errors,
|
||||||
|
}
|
||||||
|
),
|
||||||
|
400,
|
||||||
|
)
|
||||||
|
storage_config = _encrypt_secret_fields(
|
||||||
|
data["config"], config_requirements, user
|
||||||
|
)
|
||||||
|
with db_session() as conn:
|
||||||
|
created = UserToolsRepository(conn).create(
|
||||||
|
user,
|
||||||
|
data["name"],
|
||||||
|
config=storage_config,
|
||||||
|
custom_name=data.get("customName", ""),
|
||||||
|
display_name=data["displayName"],
|
||||||
|
description=data["description"],
|
||||||
|
config_requirements=config_requirements,
|
||||||
|
actions=transformed_actions,
|
||||||
|
status=bool(data.get("status", True)),
|
||||||
|
)
|
||||||
|
new_id = str(created["id"])
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
current_app.logger.error(f"Error creating tool: {err}", exc_info=True)
|
current_app.logger.error(f"Error creating tool: {err}", exc_info=True)
|
||||||
return make_response(jsonify({"success": False}), 400)
|
return make_response(jsonify({"success": False}), 400)
|
||||||
@@ -181,17 +368,10 @@ class UpdateTool(Resource):
|
|||||||
if missing_fields:
|
if missing_fields:
|
||||||
return missing_fields
|
return missing_fields
|
||||||
try:
|
try:
|
||||||
update_data = {}
|
update_data: dict = {}
|
||||||
if "name" in data:
|
for key in ("name", "displayName", "customName", "description", "actions"):
|
||||||
update_data["name"] = data["name"]
|
if key in data:
|
||||||
if "displayName" in data:
|
update_data[key] = data[key]
|
||||||
update_data["displayName"] = data["displayName"]
|
|
||||||
if "customName" in data:
|
|
||||||
update_data["customName"] = data["customName"]
|
|
||||||
if "description" in data:
|
|
||||||
update_data["description"] = data["description"]
|
|
||||||
if "actions" in data:
|
|
||||||
update_data["actions"] = data["actions"]
|
|
||||||
if "config" in data:
|
if "config" in data:
|
||||||
if "actions" in data["config"]:
|
if "actions" in data["config"]:
|
||||||
for action_name in list(data["config"]["actions"].keys()):
|
for action_name in list(data["config"]["actions"].keys()):
|
||||||
@@ -206,66 +386,61 @@ class UpdateTool(Resource):
|
|||||||
),
|
),
|
||||||
400,
|
400,
|
||||||
)
|
)
|
||||||
tool_doc = user_tools_collection.find_one(
|
with db_session() as conn:
|
||||||
{"_id": ObjectId(data["id"]), "user": user}
|
repo = UserToolsRepository(conn)
|
||||||
)
|
tool_doc = repo.get_any(data["id"], user)
|
||||||
if tool_doc and tool_doc.get("name") == "mcp_tool":
|
if not tool_doc:
|
||||||
config = data["config"]
|
return make_response(
|
||||||
existing_config = tool_doc.get("config", {})
|
jsonify({"success": False, "message": "Tool not found"}),
|
||||||
storage_config = existing_config.copy()
|
404,
|
||||||
|
)
|
||||||
|
tool_name = tool_doc.get("name", data.get("name"))
|
||||||
|
tool_instance = tool_manager.tools.get(tool_name)
|
||||||
|
config_requirements = (
|
||||||
|
tool_instance.get_config_requirements()
|
||||||
|
if tool_instance
|
||||||
|
else {}
|
||||||
|
)
|
||||||
|
existing_config = tool_doc.get("config", {}) or {}
|
||||||
|
has_existing_secrets = "encrypted_credentials" in existing_config
|
||||||
|
|
||||||
storage_config.update(config)
|
if config_requirements:
|
||||||
existing_credentials = {}
|
validation_errors = _validate_config(
|
||||||
if "encrypted_credentials" in existing_config:
|
data["config"], config_requirements,
|
||||||
existing_credentials = decrypt_credentials(
|
has_existing_secrets=has_existing_secrets,
|
||||||
existing_config["encrypted_credentials"], user
|
|
||||||
)
|
)
|
||||||
auth_credentials = existing_credentials.copy()
|
if validation_errors:
|
||||||
auth_type = storage_config.get("auth_type", "none")
|
return make_response(
|
||||||
if auth_type == "api_key":
|
jsonify({
|
||||||
if "api_key" in config and config["api_key"]:
|
"success": False,
|
||||||
auth_credentials["api_key"] = config["api_key"]
|
"message": "Validation failed",
|
||||||
if "api_key_header" in config:
|
"errors": validation_errors,
|
||||||
auth_credentials["api_key_header"] = config[
|
}),
|
||||||
"api_key_header"
|
400,
|
||||||
]
|
)
|
||||||
elif auth_type == "bearer":
|
|
||||||
if "bearer_token" in config and config["bearer_token"]:
|
update_data["config"] = _merge_secrets_on_update(
|
||||||
auth_credentials["bearer_token"] = config["bearer_token"]
|
data["config"], existing_config, config_requirements, user
|
||||||
elif "encrypted_token" in config and config["encrypted_token"]:
|
)
|
||||||
auth_credentials["bearer_token"] = config["encrypted_token"]
|
if "status" in data:
|
||||||
elif auth_type == "basic":
|
update_data["status"] = bool(data["status"])
|
||||||
if "username" in config and config["username"]:
|
repo.update(
|
||||||
auth_credentials["username"] = config["username"]
|
str(tool_doc["id"]), user, _api_to_update_fields(update_data),
|
||||||
if "password" in config and config["password"]:
|
)
|
||||||
auth_credentials["password"] = config["password"]
|
else:
|
||||||
if auth_type != "none" and auth_credentials:
|
if "status" in data:
|
||||||
encrypted_credentials_string = encrypt_credentials(
|
update_data["status"] = bool(data["status"])
|
||||||
auth_credentials, user
|
with db_session() as conn:
|
||||||
|
repo = UserToolsRepository(conn)
|
||||||
|
tool_doc = repo.get_any(data["id"], user)
|
||||||
|
if not tool_doc:
|
||||||
|
return make_response(
|
||||||
|
jsonify({"success": False, "message": "Tool not found"}),
|
||||||
|
404,
|
||||||
)
|
)
|
||||||
storage_config["encrypted_credentials"] = (
|
repo.update(
|
||||||
encrypted_credentials_string
|
str(tool_doc["id"]), user, _api_to_update_fields(update_data),
|
||||||
)
|
)
|
||||||
elif auth_type == "none":
|
|
||||||
storage_config.pop("encrypted_credentials", None)
|
|
||||||
for field in [
|
|
||||||
"api_key",
|
|
||||||
"bearer_token",
|
|
||||||
"encrypted_token",
|
|
||||||
"username",
|
|
||||||
"password",
|
|
||||||
"api_key_header",
|
|
||||||
]:
|
|
||||||
storage_config.pop(field, None)
|
|
||||||
update_data["config"] = storage_config
|
|
||||||
else:
|
|
||||||
update_data["config"] = data["config"]
|
|
||||||
if "status" in data:
|
|
||||||
update_data["status"] = data["status"]
|
|
||||||
user_tools_collection.update_one(
|
|
||||||
{"_id": ObjectId(data["id"]), "user": user},
|
|
||||||
{"$set": update_data},
|
|
||||||
)
|
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
current_app.logger.error(f"Error updating tool: {err}", exc_info=True)
|
current_app.logger.error(f"Error updating tool: {err}", exc_info=True)
|
||||||
return make_response(jsonify({"success": False}), 400)
|
return make_response(jsonify({"success": False}), 400)
|
||||||
@@ -297,10 +472,50 @@ class UpdateToolConfig(Resource):
|
|||||||
if missing_fields:
|
if missing_fields:
|
||||||
return missing_fields
|
return missing_fields
|
||||||
try:
|
try:
|
||||||
user_tools_collection.update_one(
|
with db_session() as conn:
|
||||||
{"_id": ObjectId(data["id"]), "user": user},
|
repo = UserToolsRepository(conn)
|
||||||
{"$set": {"config": data["config"]}},
|
tool_doc = repo.get_any(data["id"], user)
|
||||||
)
|
if not tool_doc:
|
||||||
|
return make_response(jsonify({"success": False}), 404)
|
||||||
|
|
||||||
|
tool_name = tool_doc.get("name")
|
||||||
|
if tool_name == "mcp_tool":
|
||||||
|
server_url = (data["config"].get("server_url") or "").strip()
|
||||||
|
if server_url:
|
||||||
|
try:
|
||||||
|
validate_url(server_url)
|
||||||
|
except SSRFError:
|
||||||
|
return make_response(
|
||||||
|
jsonify({"success": False, "message": "Invalid server URL"}),
|
||||||
|
400,
|
||||||
|
)
|
||||||
|
tool_instance = tool_manager.tools.get(tool_name)
|
||||||
|
config_requirements = (
|
||||||
|
tool_instance.get_config_requirements() if tool_instance else {}
|
||||||
|
)
|
||||||
|
existing_config = tool_doc.get("config", {}) or {}
|
||||||
|
has_existing_secrets = "encrypted_credentials" in existing_config
|
||||||
|
|
||||||
|
if config_requirements:
|
||||||
|
validation_errors = _validate_config(
|
||||||
|
data["config"], config_requirements,
|
||||||
|
has_existing_secrets=has_existing_secrets,
|
||||||
|
)
|
||||||
|
if validation_errors:
|
||||||
|
return make_response(
|
||||||
|
jsonify({
|
||||||
|
"success": False,
|
||||||
|
"message": "Validation failed",
|
||||||
|
"errors": validation_errors,
|
||||||
|
}),
|
||||||
|
400,
|
||||||
|
)
|
||||||
|
|
||||||
|
final_config = _merge_secrets_on_update(
|
||||||
|
data["config"], existing_config, config_requirements, user
|
||||||
|
)
|
||||||
|
|
||||||
|
repo.update(str(tool_doc["id"]), user, {"config": final_config})
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
current_app.logger.error(
|
current_app.logger.error(
|
||||||
f"Error updating tool config: {err}", exc_info=True
|
f"Error updating tool config: {err}", exc_info=True
|
||||||
@@ -336,10 +551,17 @@ class UpdateToolActions(Resource):
|
|||||||
if missing_fields:
|
if missing_fields:
|
||||||
return missing_fields
|
return missing_fields
|
||||||
try:
|
try:
|
||||||
user_tools_collection.update_one(
|
with db_session() as conn:
|
||||||
{"_id": ObjectId(data["id"]), "user": user},
|
repo = UserToolsRepository(conn)
|
||||||
{"$set": {"actions": data["actions"]}},
|
tool_doc = repo.get_any(data["id"], user)
|
||||||
)
|
if not tool_doc:
|
||||||
|
return make_response(
|
||||||
|
jsonify({"success": False, "message": "Tool not found"}),
|
||||||
|
404,
|
||||||
|
)
|
||||||
|
repo.update(
|
||||||
|
str(tool_doc["id"]), user, {"actions": data["actions"]},
|
||||||
|
)
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
current_app.logger.error(
|
current_app.logger.error(
|
||||||
f"Error updating tool actions: {err}", exc_info=True
|
f"Error updating tool actions: {err}", exc_info=True
|
||||||
@@ -373,10 +595,17 @@ class UpdateToolStatus(Resource):
|
|||||||
if missing_fields:
|
if missing_fields:
|
||||||
return missing_fields
|
return missing_fields
|
||||||
try:
|
try:
|
||||||
user_tools_collection.update_one(
|
with db_session() as conn:
|
||||||
{"_id": ObjectId(data["id"]), "user": user},
|
repo = UserToolsRepository(conn)
|
||||||
{"$set": {"status": data["status"]}},
|
tool_doc = repo.get_any(data["id"], user)
|
||||||
)
|
if not tool_doc:
|
||||||
|
return make_response(
|
||||||
|
jsonify({"success": False, "message": "Tool not found"}),
|
||||||
|
404,
|
||||||
|
)
|
||||||
|
repo.update(
|
||||||
|
str(tool_doc["id"]), user, {"status": bool(data["status"])},
|
||||||
|
)
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
current_app.logger.error(
|
current_app.logger.error(
|
||||||
f"Error updating tool status: {err}", exc_info=True
|
f"Error updating tool status: {err}", exc_info=True
|
||||||
@@ -405,12 +634,165 @@ class DeleteTool(Resource):
|
|||||||
if missing_fields:
|
if missing_fields:
|
||||||
return missing_fields
|
return missing_fields
|
||||||
try:
|
try:
|
||||||
result = user_tools_collection.delete_one(
|
with db_session() as conn:
|
||||||
{"_id": ObjectId(data["id"]), "user": user}
|
repo = UserToolsRepository(conn)
|
||||||
)
|
tool_doc = repo.get_any(data["id"], user)
|
||||||
if result.deleted_count == 0:
|
if not tool_doc:
|
||||||
return {"success": False, "message": "Tool not found"}, 404
|
return make_response(
|
||||||
|
jsonify({"success": False, "message": "Tool not found"}), 404
|
||||||
|
)
|
||||||
|
repo.delete(str(tool_doc["id"]), user)
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
current_app.logger.error(f"Error deleting tool: {err}", exc_info=True)
|
current_app.logger.error(f"Error deleting tool: {err}", exc_info=True)
|
||||||
return {"success": False}, 400
|
return make_response(jsonify({"success": False}), 400)
|
||||||
return {"success": True}, 200
|
return make_response(jsonify({"success": True}), 200)
|
||||||
|
|
||||||
|
|
||||||
|
@tools_ns.route("/parse_spec")
|
||||||
|
class ParseSpec(Resource):
|
||||||
|
@api.doc(
|
||||||
|
description="Parse an API specification (OpenAPI 3.x or Swagger 2.0) and return actions"
|
||||||
|
)
|
||||||
|
def post(self):
|
||||||
|
decoded_token = request.decoded_token
|
||||||
|
if not decoded_token:
|
||||||
|
return make_response(jsonify({"success": False}), 401)
|
||||||
|
if "file" in request.files:
|
||||||
|
file = request.files["file"]
|
||||||
|
if not file.filename:
|
||||||
|
return make_response(
|
||||||
|
jsonify({"success": False, "message": "No file selected"}), 400
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
spec_content = file.read().decode("utf-8")
|
||||||
|
except UnicodeDecodeError:
|
||||||
|
return make_response(
|
||||||
|
jsonify({"success": False, "message": "Invalid file encoding"}), 400
|
||||||
|
)
|
||||||
|
elif request.is_json:
|
||||||
|
data = request.get_json()
|
||||||
|
spec_content = data.get("spec_content", "")
|
||||||
|
else:
|
||||||
|
return make_response(
|
||||||
|
jsonify({"success": False, "message": "No spec provided"}), 400
|
||||||
|
)
|
||||||
|
if not spec_content or not spec_content.strip():
|
||||||
|
return make_response(
|
||||||
|
jsonify({"success": False, "message": "Empty spec content"}), 400
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
metadata, actions = parse_spec(spec_content)
|
||||||
|
return make_response(
|
||||||
|
jsonify(
|
||||||
|
{
|
||||||
|
"success": True,
|
||||||
|
"metadata": metadata,
|
||||||
|
"actions": actions,
|
||||||
|
}
|
||||||
|
),
|
||||||
|
200,
|
||||||
|
)
|
||||||
|
except ValueError as e:
|
||||||
|
current_app.logger.error(f"Spec validation error: {e}")
|
||||||
|
return make_response(jsonify({"success": False, "error": "Invalid specification format"}), 400)
|
||||||
|
except Exception as err:
|
||||||
|
current_app.logger.error(f"Error parsing spec: {err}", exc_info=True)
|
||||||
|
return make_response(jsonify({"success": False, "error": "Failed to parse specification"}), 500)
|
||||||
|
|
||||||
|
|
||||||
|
@tools_ns.route("/artifact/<artifact_id>")
|
||||||
|
class GetArtifact(Resource):
|
||||||
|
@api.doc(description="Get artifact data by artifact ID. Returns all todos for the tool when fetching a todo artifact.")
|
||||||
|
def get(self, artifact_id: str):
|
||||||
|
decoded_token = request.decoded_token
|
||||||
|
if not decoded_token:
|
||||||
|
return make_response(jsonify({"success": False}), 401)
|
||||||
|
user_id = decoded_token.get("sub")
|
||||||
|
|
||||||
|
try:
|
||||||
|
with db_readonly() as conn:
|
||||||
|
notes_repo = NotesRepository(conn)
|
||||||
|
todos_repo = TodosRepository(conn)
|
||||||
|
|
||||||
|
# Artifact IDs may be PG UUIDs (post-cutover) or legacy
|
||||||
|
# Mongo ObjectIds embedded in older conversation history.
|
||||||
|
# Both repos' ``get_any`` handles the id-shape branching
|
||||||
|
# internally so a non-UUID input never reaches
|
||||||
|
# ``CAST(:id AS uuid)`` (which would poison the readonly
|
||||||
|
# transaction and break the fallback below).
|
||||||
|
note_doc = notes_repo.get_any(artifact_id, user_id)
|
||||||
|
|
||||||
|
if note_doc:
|
||||||
|
content = note_doc.get("note", "") or note_doc.get("content", "")
|
||||||
|
line_count = len(content.split("\n")) if content else 0
|
||||||
|
updated = note_doc.get("updated_at")
|
||||||
|
artifact = {
|
||||||
|
"artifact_type": "note",
|
||||||
|
"data": {
|
||||||
|
"content": content,
|
||||||
|
"line_count": line_count,
|
||||||
|
"updated_at": (
|
||||||
|
updated.isoformat()
|
||||||
|
if hasattr(updated, "isoformat")
|
||||||
|
else updated
|
||||||
|
),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return make_response(
|
||||||
|
jsonify({"success": True, "artifact": artifact}), 200
|
||||||
|
)
|
||||||
|
|
||||||
|
todo_doc = todos_repo.get_any(artifact_id, user_id)
|
||||||
|
if todo_doc:
|
||||||
|
tool_id = todo_doc.get("tool_id")
|
||||||
|
all_todos = todos_repo.list_for_tool(user_id, tool_id) if tool_id else []
|
||||||
|
items = []
|
||||||
|
open_count = 0
|
||||||
|
completed_count = 0
|
||||||
|
for t in all_todos:
|
||||||
|
# PG ``todos`` stores a ``completed BOOLEAN`` column;
|
||||||
|
# the legacy Mongo shape used a ``status`` string.
|
||||||
|
# Keep the response shape stable by translating here.
|
||||||
|
status = "completed" if t.get("completed") else "open"
|
||||||
|
if status == "open":
|
||||||
|
open_count += 1
|
||||||
|
else:
|
||||||
|
completed_count += 1
|
||||||
|
created = t.get("created_at")
|
||||||
|
updated = t.get("updated_at")
|
||||||
|
items.append({
|
||||||
|
"todo_id": t.get("todo_id"),
|
||||||
|
"title": t.get("title", ""),
|
||||||
|
"status": status,
|
||||||
|
"created_at": (
|
||||||
|
created.isoformat()
|
||||||
|
if hasattr(created, "isoformat")
|
||||||
|
else created
|
||||||
|
),
|
||||||
|
"updated_at": (
|
||||||
|
updated.isoformat()
|
||||||
|
if hasattr(updated, "isoformat")
|
||||||
|
else updated
|
||||||
|
),
|
||||||
|
})
|
||||||
|
artifact = {
|
||||||
|
"artifact_type": "todo_list",
|
||||||
|
"data": {
|
||||||
|
"items": items,
|
||||||
|
"total_count": len(items),
|
||||||
|
"open_count": open_count,
|
||||||
|
"completed_count": completed_count,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return make_response(
|
||||||
|
jsonify({"success": True, "artifact": artifact}), 200
|
||||||
|
)
|
||||||
|
except Exception as err:
|
||||||
|
current_app.logger.error(
|
||||||
|
f"Error retrieving artifact: {err}", exc_info=True
|
||||||
|
)
|
||||||
|
return make_response(jsonify({"success": False}), 400)
|
||||||
|
|
||||||
|
return make_response(
|
||||||
|
jsonify({"success": False, "message": "Artifact not found"}), 404
|
||||||
|
)
|
||||||
|
|||||||
75
application/api/user/utils.py
Normal file
75
application/api/user/utils.py
Normal file
@@ -0,0 +1,75 @@
|
|||||||
|
"""Centralized utilities for API routes.
|
||||||
|
|
||||||
|
Post-Mongo-cutover slim: the old Mongo-shaped helpers (``validate_object_id``,
|
||||||
|
``check_resource_ownership``, ``paginated_response``, ``serialize_object_id``,
|
||||||
|
``safe_db_operation``, ``validate_enum``, ``extract_sort_params``) have been
|
||||||
|
removed — they carried ``bson`` / ``pymongo`` imports and had zero callers.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from functools import wraps
|
||||||
|
from typing import Callable, Optional
|
||||||
|
|
||||||
|
from flask import (
|
||||||
|
Response,
|
||||||
|
jsonify,
|
||||||
|
make_response,
|
||||||
|
request,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_user_id() -> Optional[str]:
|
||||||
|
"""Extract user ID from decoded JWT token, or None if unauthenticated."""
|
||||||
|
decoded_token = getattr(request, "decoded_token", None)
|
||||||
|
return decoded_token.get("sub") if decoded_token else None
|
||||||
|
|
||||||
|
|
||||||
|
def require_auth(func: Callable) -> Callable:
|
||||||
|
"""Decorator to require authentication. Returns 401 when absent."""
|
||||||
|
|
||||||
|
@wraps(func)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
user_id = get_user_id()
|
||||||
|
if not user_id:
|
||||||
|
return make_response(jsonify({"success": False, "error": "Unauthorized"}), 401)
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
def success_response(
|
||||||
|
data=None, message: Optional[str] = None, status: int = 200
|
||||||
|
) -> Response:
|
||||||
|
"""Shape a successful JSON response."""
|
||||||
|
body = {"success": True}
|
||||||
|
if data is not None:
|
||||||
|
body["data"] = data
|
||||||
|
if message is not None:
|
||||||
|
body["message"] = message
|
||||||
|
return make_response(jsonify(body), status)
|
||||||
|
|
||||||
|
|
||||||
|
def error_response(message: str, status: int = 400, **kwargs) -> Response:
|
||||||
|
"""Shape an error JSON response; any kwargs are merged into the body."""
|
||||||
|
body = {"success": False, "error": message, **kwargs}
|
||||||
|
return make_response(jsonify(body), status)
|
||||||
|
|
||||||
|
|
||||||
|
def require_fields(required: list) -> Callable:
|
||||||
|
"""Decorator: return 400 if any listed field is missing/falsy in the JSON body."""
|
||||||
|
|
||||||
|
def decorator(func: Callable) -> Callable:
|
||||||
|
@wraps(func)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
data = request.get_json()
|
||||||
|
if not data:
|
||||||
|
return error_response("Request body required")
|
||||||
|
missing = [field for field in required if not data.get(field)]
|
||||||
|
if missing:
|
||||||
|
return error_response(
|
||||||
|
f"Missing required fields: {', '.join(missing)}"
|
||||||
|
)
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
return decorator
|
||||||
3
application/api/user/workflows/__init__.py
Normal file
3
application/api/user/workflows/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
from .routes import workflows_ns
|
||||||
|
|
||||||
|
__all__ = ["workflows_ns"]
|
||||||
509
application/api/user/workflows/routes.py
Normal file
509
application/api/user/workflows/routes.py
Normal file
@@ -0,0 +1,509 @@
|
|||||||
|
"""Workflow management routes."""
|
||||||
|
|
||||||
|
from typing import Any, Dict, List, Optional, Set
|
||||||
|
|
||||||
|
from flask import current_app, request
|
||||||
|
from flask_restx import Namespace, Resource
|
||||||
|
|
||||||
|
from application.storage.db.base_repository import looks_like_uuid
|
||||||
|
from application.storage.db.repositories.workflow_edges import WorkflowEdgesRepository
|
||||||
|
from application.storage.db.repositories.workflow_nodes import WorkflowNodesRepository
|
||||||
|
from application.storage.db.repositories.workflows import WorkflowsRepository
|
||||||
|
from application.storage.db.session import db_readonly, db_session
|
||||||
|
from application.core.json_schema_utils import (
|
||||||
|
JsonSchemaValidationError,
|
||||||
|
normalize_json_schema_payload,
|
||||||
|
)
|
||||||
|
from application.core.model_utils import get_model_capabilities
|
||||||
|
from application.api.user.utils import (
|
||||||
|
error_response,
|
||||||
|
get_user_id,
|
||||||
|
require_auth,
|
||||||
|
require_fields,
|
||||||
|
success_response,
|
||||||
|
)
|
||||||
|
|
||||||
|
workflows_ns = Namespace("workflows", path="/api")
|
||||||
|
|
||||||
|
|
||||||
|
def _workflow_error_response(message: str, err: Exception):
|
||||||
|
current_app.logger.error(f"{message}: {err}", exc_info=True)
|
||||||
|
return error_response(message)
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_workflow(repo: WorkflowsRepository, workflow_id: str, user_id: str):
|
||||||
|
"""Resolve a workflow by UUID or legacy Mongo id, scoped to user."""
|
||||||
|
if not workflow_id:
|
||||||
|
return None
|
||||||
|
if looks_like_uuid(workflow_id):
|
||||||
|
row = repo.get(workflow_id, user_id)
|
||||||
|
if row is not None:
|
||||||
|
return row
|
||||||
|
return repo.get_by_legacy_id(workflow_id, user_id)
|
||||||
|
|
||||||
|
|
||||||
|
def _write_graph(
|
||||||
|
conn,
|
||||||
|
pg_workflow_id: str,
|
||||||
|
graph_version: int,
|
||||||
|
nodes_data: List[Dict],
|
||||||
|
edges_data: List[Dict],
|
||||||
|
) -> List[Dict]:
|
||||||
|
"""Bulk-create nodes + edges for one graph version. Uses ON CONFLICT upsert.
|
||||||
|
|
||||||
|
Edges arrive with source/target as user-provided node-id strings. We
|
||||||
|
insert nodes first, capture their ``node_id → UUID`` map, then
|
||||||
|
translate edges before insertion. Edges referencing missing nodes are
|
||||||
|
dropped with a warning.
|
||||||
|
"""
|
||||||
|
nodes_repo = WorkflowNodesRepository(conn)
|
||||||
|
edges_repo = WorkflowEdgesRepository(conn)
|
||||||
|
|
||||||
|
if nodes_data:
|
||||||
|
created_nodes = nodes_repo.bulk_create(
|
||||||
|
pg_workflow_id, graph_version,
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"node_id": n["id"],
|
||||||
|
"node_type": n["type"],
|
||||||
|
"title": n.get("title", ""),
|
||||||
|
"description": n.get("description", ""),
|
||||||
|
"position": n.get("position", {"x": 0, "y": 0}),
|
||||||
|
"config": n.get("data", {}),
|
||||||
|
}
|
||||||
|
for n in nodes_data
|
||||||
|
],
|
||||||
|
)
|
||||||
|
node_uuid_by_str = {n["node_id"]: n["id"] for n in created_nodes}
|
||||||
|
else:
|
||||||
|
created_nodes = []
|
||||||
|
node_uuid_by_str = {}
|
||||||
|
|
||||||
|
if edges_data:
|
||||||
|
translated_edges: List[Dict] = []
|
||||||
|
for e in edges_data:
|
||||||
|
src = e.get("source")
|
||||||
|
tgt = e.get("target")
|
||||||
|
from_uuid = node_uuid_by_str.get(src)
|
||||||
|
to_uuid = node_uuid_by_str.get(tgt)
|
||||||
|
if not from_uuid or not to_uuid:
|
||||||
|
current_app.logger.warning(
|
||||||
|
"Workflow graph write: dropping edge %s; node refs unresolved "
|
||||||
|
"(source=%s, target=%s)",
|
||||||
|
e.get("id"), src, tgt,
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
translated_edges.append({
|
||||||
|
"edge_id": e["id"],
|
||||||
|
"from_node_id": from_uuid,
|
||||||
|
"to_node_id": to_uuid,
|
||||||
|
"source_handle": e.get("sourceHandle"),
|
||||||
|
"target_handle": e.get("targetHandle"),
|
||||||
|
})
|
||||||
|
if translated_edges:
|
||||||
|
edges_repo.bulk_create(
|
||||||
|
pg_workflow_id, graph_version, translated_edges,
|
||||||
|
)
|
||||||
|
|
||||||
|
return created_nodes
|
||||||
|
|
||||||
|
|
||||||
|
def serialize_workflow(w: Dict) -> Dict:
|
||||||
|
"""Serialize workflow row to API response format."""
|
||||||
|
created_at = w.get("created_at")
|
||||||
|
updated_at = w.get("updated_at")
|
||||||
|
return {
|
||||||
|
"id": str(w["id"]),
|
||||||
|
"name": w.get("name"),
|
||||||
|
"description": w.get("description"),
|
||||||
|
"created_at": created_at.isoformat() if hasattr(created_at, "isoformat") else created_at,
|
||||||
|
"updated_at": updated_at.isoformat() if hasattr(updated_at, "isoformat") else updated_at,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def serialize_node(n: Dict) -> Dict:
|
||||||
|
"""Serialize workflow node row to API response format."""
|
||||||
|
return {
|
||||||
|
"id": n["node_id"],
|
||||||
|
"type": n["node_type"],
|
||||||
|
"title": n.get("title"),
|
||||||
|
"description": n.get("description"),
|
||||||
|
"position": n.get("position"),
|
||||||
|
"data": n.get("config", {}) or {},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def serialize_edge(e: Dict) -> Dict:
|
||||||
|
"""Serialize workflow edge row to API response format."""
|
||||||
|
return {
|
||||||
|
"id": e["edge_id"],
|
||||||
|
"source": e.get("source_id"),
|
||||||
|
"target": e.get("target_id"),
|
||||||
|
"sourceHandle": e.get("source_handle"),
|
||||||
|
"targetHandle": e.get("target_handle"),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_workflow_graph_version(workflow: Dict) -> int:
|
||||||
|
"""Get current graph version with fallback."""
|
||||||
|
raw_version = workflow.get("current_graph_version", 1)
|
||||||
|
try:
|
||||||
|
version = int(raw_version)
|
||||||
|
return version if version > 0 else 1
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
return 1
|
||||||
|
|
||||||
|
|
||||||
|
def validate_json_schema_payload(
|
||||||
|
json_schema: Any,
|
||||||
|
) -> tuple[Optional[Dict[str, Any]], Optional[str]]:
|
||||||
|
"""Validate and normalize optional JSON schema payload for structured output."""
|
||||||
|
if json_schema is None:
|
||||||
|
return None, None
|
||||||
|
try:
|
||||||
|
return normalize_json_schema_payload(json_schema), None
|
||||||
|
except JsonSchemaValidationError as exc:
|
||||||
|
return None, str(exc)
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_agent_node_json_schemas(nodes: List[Dict]) -> List[Dict]:
|
||||||
|
"""Normalize agent-node JSON schema payloads before persistence."""
|
||||||
|
normalized_nodes: List[Dict] = []
|
||||||
|
for node in nodes:
|
||||||
|
if not isinstance(node, dict):
|
||||||
|
normalized_nodes.append(node)
|
||||||
|
continue
|
||||||
|
|
||||||
|
normalized_node = dict(node)
|
||||||
|
if normalized_node.get("type") != "agent":
|
||||||
|
normalized_nodes.append(normalized_node)
|
||||||
|
continue
|
||||||
|
|
||||||
|
raw_config = normalized_node.get("data")
|
||||||
|
if not isinstance(raw_config, dict) or "json_schema" not in raw_config:
|
||||||
|
normalized_nodes.append(normalized_node)
|
||||||
|
continue
|
||||||
|
|
||||||
|
normalized_config = dict(raw_config)
|
||||||
|
try:
|
||||||
|
normalized_config["json_schema"] = normalize_json_schema_payload(
|
||||||
|
raw_config.get("json_schema")
|
||||||
|
)
|
||||||
|
except JsonSchemaValidationError:
|
||||||
|
# Validation runs before normalization; keep original on unexpected shape.
|
||||||
|
normalized_config["json_schema"] = raw_config.get("json_schema")
|
||||||
|
normalized_node["data"] = normalized_config
|
||||||
|
normalized_nodes.append(normalized_node)
|
||||||
|
|
||||||
|
return normalized_nodes
|
||||||
|
|
||||||
|
|
||||||
|
def validate_workflow_structure(nodes: List[Dict], edges: List[Dict]) -> List[str]:
|
||||||
|
"""Validate workflow graph structure."""
|
||||||
|
errors = []
|
||||||
|
|
||||||
|
if not nodes:
|
||||||
|
errors.append("Workflow must have at least one node")
|
||||||
|
return errors
|
||||||
|
|
||||||
|
start_nodes = [n for n in nodes if n.get("type") == "start"]
|
||||||
|
if len(start_nodes) != 1:
|
||||||
|
errors.append("Workflow must have exactly one start node")
|
||||||
|
|
||||||
|
end_nodes = [n for n in nodes if n.get("type") == "end"]
|
||||||
|
if not end_nodes:
|
||||||
|
errors.append("Workflow must have at least one end node")
|
||||||
|
|
||||||
|
node_ids = {n.get("id") for n in nodes}
|
||||||
|
node_map = {n.get("id"): n for n in nodes}
|
||||||
|
end_ids = {n.get("id") for n in end_nodes}
|
||||||
|
|
||||||
|
for edge in edges:
|
||||||
|
source_id = edge.get("source")
|
||||||
|
target_id = edge.get("target")
|
||||||
|
if source_id not in node_ids:
|
||||||
|
errors.append(f"Edge references non-existent source: {source_id}")
|
||||||
|
if target_id not in node_ids:
|
||||||
|
errors.append(f"Edge references non-existent target: {target_id}")
|
||||||
|
|
||||||
|
if start_nodes:
|
||||||
|
start_id = start_nodes[0].get("id")
|
||||||
|
if not any(e.get("source") == start_id for e in edges):
|
||||||
|
errors.append("Start node must have at least one outgoing edge")
|
||||||
|
|
||||||
|
condition_nodes = [n for n in nodes if n.get("type") == "condition"]
|
||||||
|
for cnode in condition_nodes:
|
||||||
|
cnode_id = cnode.get("id")
|
||||||
|
cnode_title = cnode.get("title", cnode_id)
|
||||||
|
outgoing = [e for e in edges if e.get("source") == cnode_id]
|
||||||
|
if len(outgoing) < 2:
|
||||||
|
errors.append(
|
||||||
|
f"Condition node '{cnode_title}' must have at least 2 outgoing edges"
|
||||||
|
)
|
||||||
|
node_data = cnode.get("data", {}) or {}
|
||||||
|
cases = node_data.get("cases", [])
|
||||||
|
if not isinstance(cases, list):
|
||||||
|
cases = []
|
||||||
|
if not cases or not any(
|
||||||
|
isinstance(c, dict) and str(c.get("expression", "")).strip() for c in cases
|
||||||
|
):
|
||||||
|
errors.append(
|
||||||
|
f"Condition node '{cnode_title}' must have at least one case with an expression"
|
||||||
|
)
|
||||||
|
|
||||||
|
case_handles: Set[str] = set()
|
||||||
|
duplicate_case_handles: Set[str] = set()
|
||||||
|
for case in cases:
|
||||||
|
if not isinstance(case, dict):
|
||||||
|
continue
|
||||||
|
raw_handle = case.get("sourceHandle", "")
|
||||||
|
handle = raw_handle.strip() if isinstance(raw_handle, str) else ""
|
||||||
|
if not handle:
|
||||||
|
errors.append(
|
||||||
|
f"Condition node '{cnode_title}' has a case without a branch handle"
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
if handle in case_handles:
|
||||||
|
duplicate_case_handles.add(handle)
|
||||||
|
case_handles.add(handle)
|
||||||
|
|
||||||
|
for handle in duplicate_case_handles:
|
||||||
|
errors.append(
|
||||||
|
f"Condition node '{cnode_title}' has duplicate case handle '{handle}'"
|
||||||
|
)
|
||||||
|
|
||||||
|
outgoing_by_handle: Dict[str, List[Dict]] = {}
|
||||||
|
for out_edge in outgoing:
|
||||||
|
raw_handle = out_edge.get("sourceHandle", "")
|
||||||
|
handle = raw_handle.strip() if isinstance(raw_handle, str) else ""
|
||||||
|
outgoing_by_handle.setdefault(handle, []).append(out_edge)
|
||||||
|
|
||||||
|
for handle, handle_edges in outgoing_by_handle.items():
|
||||||
|
if not handle:
|
||||||
|
errors.append(
|
||||||
|
f"Condition node '{cnode_title}' has an outgoing edge without sourceHandle"
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
if handle != "else" and handle not in case_handles:
|
||||||
|
errors.append(
|
||||||
|
f"Condition node '{cnode_title}' has a connection from unknown branch '{handle}'"
|
||||||
|
)
|
||||||
|
if len(handle_edges) > 1:
|
||||||
|
errors.append(
|
||||||
|
f"Condition node '{cnode_title}' has multiple outgoing edges from branch '{handle}'"
|
||||||
|
)
|
||||||
|
|
||||||
|
if "else" not in outgoing_by_handle:
|
||||||
|
errors.append(f"Condition node '{cnode_title}' must have an 'else' branch")
|
||||||
|
|
||||||
|
for case in cases:
|
||||||
|
if not isinstance(case, dict):
|
||||||
|
continue
|
||||||
|
raw_handle = case.get("sourceHandle", "")
|
||||||
|
handle = raw_handle.strip() if isinstance(raw_handle, str) else ""
|
||||||
|
if not handle:
|
||||||
|
continue
|
||||||
|
|
||||||
|
raw_expression = case.get("expression", "")
|
||||||
|
has_expression = isinstance(raw_expression, str) and bool(
|
||||||
|
raw_expression.strip()
|
||||||
|
)
|
||||||
|
has_outgoing = bool(outgoing_by_handle.get(handle))
|
||||||
|
if has_expression and not has_outgoing:
|
||||||
|
errors.append(
|
||||||
|
f"Condition node '{cnode_title}' case '{handle}' has an expression but no outgoing edge"
|
||||||
|
)
|
||||||
|
if not has_expression and has_outgoing:
|
||||||
|
errors.append(
|
||||||
|
f"Condition node '{cnode_title}' case '{handle}' has an outgoing edge but no expression"
|
||||||
|
)
|
||||||
|
|
||||||
|
for handle, handle_edges in outgoing_by_handle.items():
|
||||||
|
if not handle:
|
||||||
|
continue
|
||||||
|
for out_edge in handle_edges:
|
||||||
|
target = out_edge.get("target")
|
||||||
|
if target and not _can_reach_end(target, edges, node_map, end_ids):
|
||||||
|
errors.append(
|
||||||
|
f"Branch '{handle}' of condition '{cnode_title}' "
|
||||||
|
f"must eventually reach an end node"
|
||||||
|
)
|
||||||
|
|
||||||
|
agent_nodes = [n for n in nodes if n.get("type") == "agent"]
|
||||||
|
for agent_node in agent_nodes:
|
||||||
|
agent_title = agent_node.get("title", agent_node.get("id", "unknown"))
|
||||||
|
raw_config = agent_node.get("data", {}) or {}
|
||||||
|
if not isinstance(raw_config, dict):
|
||||||
|
errors.append(f"Agent node '{agent_title}' has invalid configuration")
|
||||||
|
continue
|
||||||
|
normalized_schema, schema_error = validate_json_schema_payload(
|
||||||
|
raw_config.get("json_schema")
|
||||||
|
)
|
||||||
|
has_json_schema = normalized_schema is not None
|
||||||
|
|
||||||
|
model_id = raw_config.get("model_id")
|
||||||
|
if has_json_schema and isinstance(model_id, str) and model_id.strip():
|
||||||
|
capabilities = get_model_capabilities(model_id.strip())
|
||||||
|
if capabilities and not capabilities.get("supports_structured_output", False):
|
||||||
|
errors.append(
|
||||||
|
f"Agent node '{agent_title}' selected model does not support structured output"
|
||||||
|
)
|
||||||
|
if schema_error:
|
||||||
|
errors.append(f"Agent node '{agent_title}' JSON schema {schema_error}")
|
||||||
|
|
||||||
|
for node in nodes:
|
||||||
|
if not node.get("id"):
|
||||||
|
errors.append("All nodes must have an id")
|
||||||
|
if not node.get("type"):
|
||||||
|
errors.append(f"Node {node.get('id', 'unknown')} must have a type")
|
||||||
|
|
||||||
|
return errors
|
||||||
|
|
||||||
|
|
||||||
|
def _can_reach_end(
|
||||||
|
node_id: str, edges: List[Dict], node_map: Dict, end_ids: set, visited: set = None
|
||||||
|
) -> bool:
|
||||||
|
if visited is None:
|
||||||
|
visited = set()
|
||||||
|
if node_id in end_ids:
|
||||||
|
return True
|
||||||
|
if node_id in visited or node_id not in node_map:
|
||||||
|
return False
|
||||||
|
visited.add(node_id)
|
||||||
|
outgoing = [e.get("target") for e in edges if e.get("source") == node_id]
|
||||||
|
return any(_can_reach_end(t, edges, node_map, end_ids, visited) for t in outgoing if t)
|
||||||
|
|
||||||
|
|
||||||
|
@workflows_ns.route("/workflows")
|
||||||
|
class WorkflowList(Resource):
|
||||||
|
|
||||||
|
@require_auth
|
||||||
|
@require_fields(["name"])
|
||||||
|
def post(self):
|
||||||
|
"""Create a new workflow with nodes and edges."""
|
||||||
|
user_id = get_user_id()
|
||||||
|
data = request.get_json()
|
||||||
|
|
||||||
|
name = data.get("name", "").strip()
|
||||||
|
description = data.get("description", "")
|
||||||
|
nodes_data = data.get("nodes", [])
|
||||||
|
edges_data = data.get("edges", [])
|
||||||
|
|
||||||
|
validation_errors = validate_workflow_structure(nodes_data, edges_data)
|
||||||
|
if validation_errors:
|
||||||
|
return error_response(
|
||||||
|
"Workflow validation failed", errors=validation_errors
|
||||||
|
)
|
||||||
|
nodes_data = normalize_agent_node_json_schemas(nodes_data)
|
||||||
|
|
||||||
|
try:
|
||||||
|
with db_session() as conn:
|
||||||
|
repo = WorkflowsRepository(conn)
|
||||||
|
workflow = repo.create(user_id, name, description=description)
|
||||||
|
pg_workflow_id = str(workflow["id"])
|
||||||
|
_write_graph(conn, pg_workflow_id, 1, nodes_data, edges_data)
|
||||||
|
except Exception as err:
|
||||||
|
return _workflow_error_response("Failed to create workflow", err)
|
||||||
|
|
||||||
|
return success_response({"id": pg_workflow_id}, 201)
|
||||||
|
|
||||||
|
|
||||||
|
@workflows_ns.route("/workflows/<string:workflow_id>")
|
||||||
|
class WorkflowDetail(Resource):
|
||||||
|
|
||||||
|
@require_auth
|
||||||
|
def get(self, workflow_id: str):
|
||||||
|
"""Get workflow details with nodes and edges."""
|
||||||
|
user_id = get_user_id()
|
||||||
|
try:
|
||||||
|
with db_readonly() as conn:
|
||||||
|
repo = WorkflowsRepository(conn)
|
||||||
|
workflow = _resolve_workflow(repo, workflow_id, user_id)
|
||||||
|
if workflow is None:
|
||||||
|
return error_response("Workflow not found", 404)
|
||||||
|
pg_workflow_id = str(workflow["id"])
|
||||||
|
graph_version = get_workflow_graph_version(workflow)
|
||||||
|
nodes = WorkflowNodesRepository(conn).find_by_version(
|
||||||
|
pg_workflow_id, graph_version,
|
||||||
|
)
|
||||||
|
edges = WorkflowEdgesRepository(conn).find_by_version(
|
||||||
|
pg_workflow_id, graph_version,
|
||||||
|
)
|
||||||
|
except Exception as err:
|
||||||
|
return _workflow_error_response("Failed to fetch workflow", err)
|
||||||
|
|
||||||
|
return success_response(
|
||||||
|
{
|
||||||
|
"workflow": serialize_workflow(workflow),
|
||||||
|
"nodes": [serialize_node(n) for n in nodes],
|
||||||
|
"edges": [serialize_edge(e) for e in edges],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
@require_auth
|
||||||
|
@require_fields(["name"])
|
||||||
|
def put(self, workflow_id: str):
|
||||||
|
"""Update workflow and replace nodes/edges."""
|
||||||
|
user_id = get_user_id()
|
||||||
|
data = request.get_json()
|
||||||
|
name = data.get("name", "").strip()
|
||||||
|
description = data.get("description", "")
|
||||||
|
nodes_data = data.get("nodes", [])
|
||||||
|
edges_data = data.get("edges", [])
|
||||||
|
|
||||||
|
validation_errors = validate_workflow_structure(nodes_data, edges_data)
|
||||||
|
if validation_errors:
|
||||||
|
return error_response(
|
||||||
|
"Workflow validation failed", errors=validation_errors
|
||||||
|
)
|
||||||
|
nodes_data = normalize_agent_node_json_schemas(nodes_data)
|
||||||
|
|
||||||
|
try:
|
||||||
|
with db_session() as conn:
|
||||||
|
repo = WorkflowsRepository(conn)
|
||||||
|
workflow = _resolve_workflow(repo, workflow_id, user_id)
|
||||||
|
if workflow is None:
|
||||||
|
return error_response("Workflow not found", 404)
|
||||||
|
pg_workflow_id = str(workflow["id"])
|
||||||
|
current_graph_version = get_workflow_graph_version(workflow)
|
||||||
|
next_graph_version = current_graph_version + 1
|
||||||
|
|
||||||
|
_write_graph(
|
||||||
|
conn, pg_workflow_id, next_graph_version,
|
||||||
|
nodes_data, edges_data,
|
||||||
|
)
|
||||||
|
repo.update(
|
||||||
|
pg_workflow_id, user_id,
|
||||||
|
{
|
||||||
|
"name": name,
|
||||||
|
"description": description,
|
||||||
|
"current_graph_version": next_graph_version,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
WorkflowNodesRepository(conn).delete_other_versions(
|
||||||
|
pg_workflow_id, next_graph_version,
|
||||||
|
)
|
||||||
|
WorkflowEdgesRepository(conn).delete_other_versions(
|
||||||
|
pg_workflow_id, next_graph_version,
|
||||||
|
)
|
||||||
|
except Exception as err:
|
||||||
|
return _workflow_error_response("Failed to update workflow", err)
|
||||||
|
|
||||||
|
return success_response()
|
||||||
|
|
||||||
|
@require_auth
|
||||||
|
def delete(self, workflow_id: str):
|
||||||
|
"""Delete workflow and its graph."""
|
||||||
|
user_id = get_user_id()
|
||||||
|
try:
|
||||||
|
with db_session() as conn:
|
||||||
|
repo = WorkflowsRepository(conn)
|
||||||
|
workflow = _resolve_workflow(repo, workflow_id, user_id)
|
||||||
|
if workflow is None:
|
||||||
|
return error_response("Workflow not found", 404)
|
||||||
|
# ON DELETE CASCADE on workflow_nodes/edges cleans children.
|
||||||
|
repo.delete(str(workflow["id"]), user_id)
|
||||||
|
except Exception as err:
|
||||||
|
return _workflow_error_response("Failed to delete workflow", err)
|
||||||
|
|
||||||
|
return success_response()
|
||||||
3
application/api/v1/__init__.py
Normal file
3
application/api/v1/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
from application.api.v1.routes import v1_bp
|
||||||
|
|
||||||
|
__all__ = ["v1_bp"]
|
||||||
331
application/api/v1/routes.py
Normal file
331
application/api/v1/routes.py
Normal file
@@ -0,0 +1,331 @@
|
|||||||
|
"""Standard chat completions API routes.
|
||||||
|
|
||||||
|
Exposes ``/v1/chat/completions`` and ``/v1/models`` endpoints that
|
||||||
|
follow the widely-adopted chat completions protocol so external tools
|
||||||
|
(opencode, continue, etc.) can connect to DocsGPT agents.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
import traceback
|
||||||
|
from typing import Any, Dict, Generator, Optional
|
||||||
|
|
||||||
|
from flask import Blueprint, jsonify, make_response, request, Response
|
||||||
|
|
||||||
|
from application.api.answer.routes.base import BaseAnswerResource
|
||||||
|
from application.api.answer.services.stream_processor import StreamProcessor
|
||||||
|
from application.api.v1.translator import (
|
||||||
|
translate_request,
|
||||||
|
translate_response,
|
||||||
|
translate_stream_event,
|
||||||
|
)
|
||||||
|
from application.storage.db.repositories.agents import AgentsRepository
|
||||||
|
from application.storage.db.session import db_readonly
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
v1_bp = Blueprint("v1", __name__, url_prefix="/v1")
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_bearer_token() -> Optional[str]:
|
||||||
|
"""Extract API key from Authorization: Bearer header."""
|
||||||
|
auth = request.headers.get("Authorization", "")
|
||||||
|
if auth.startswith("Bearer "):
|
||||||
|
return auth[7:].strip()
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _lookup_agent(api_key: str) -> Optional[Dict]:
|
||||||
|
"""Look up the agent document for this API key."""
|
||||||
|
try:
|
||||||
|
with db_readonly() as conn:
|
||||||
|
return AgentsRepository(conn).find_by_key(api_key)
|
||||||
|
except Exception:
|
||||||
|
logger.warning("Failed to look up agent for API key", exc_info=True)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _get_model_name(agent: Optional[Dict], api_key: str) -> str:
|
||||||
|
"""Return agent name for display as model name."""
|
||||||
|
if agent:
|
||||||
|
return agent.get("name", api_key)
|
||||||
|
return api_key
|
||||||
|
|
||||||
|
|
||||||
|
class _V1AnswerHelper(BaseAnswerResource):
|
||||||
|
"""Thin wrapper to access complete_stream / process_response_stream."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@v1_bp.route("/chat/completions", methods=["POST"])
|
||||||
|
def chat_completions():
|
||||||
|
"""Handle POST /v1/chat/completions."""
|
||||||
|
api_key = _extract_bearer_token()
|
||||||
|
if not api_key:
|
||||||
|
return make_response(
|
||||||
|
jsonify({"error": {"message": "Missing Authorization header", "type": "auth_error"}}),
|
||||||
|
401,
|
||||||
|
)
|
||||||
|
|
||||||
|
data = request.get_json()
|
||||||
|
if not data or not data.get("messages"):
|
||||||
|
return make_response(
|
||||||
|
jsonify({"error": {"message": "messages field is required", "type": "invalid_request"}}),
|
||||||
|
400,
|
||||||
|
)
|
||||||
|
|
||||||
|
is_stream = data.get("stream", False)
|
||||||
|
agent_doc = _lookup_agent(api_key)
|
||||||
|
model_name = _get_model_name(agent_doc, api_key)
|
||||||
|
|
||||||
|
try:
|
||||||
|
internal_data = translate_request(data, api_key)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"/v1/chat/completions translate error: {e}", exc_info=True)
|
||||||
|
return make_response(
|
||||||
|
jsonify({"error": {"message": "Failed to process request", "type": "invalid_request"}}),
|
||||||
|
400,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Link decoded_token to the agent's owner so continuation state,
|
||||||
|
# logs, and tool execution use the correct user identity. The PG
|
||||||
|
# ``agents`` row exposes the owner via ``user_id`` (``user`` is the
|
||||||
|
# legacy Mongo field name kept in ``row_to_dict`` only for the
|
||||||
|
# mapping ``id``/``_id``).
|
||||||
|
agent_user = (
|
||||||
|
(agent_doc.get("user_id") or agent_doc.get("user"))
|
||||||
|
if agent_doc else None
|
||||||
|
)
|
||||||
|
decoded_token = {"sub": agent_user or "api_key_user"}
|
||||||
|
|
||||||
|
try:
|
||||||
|
processor = StreamProcessor(internal_data, decoded_token)
|
||||||
|
|
||||||
|
if internal_data.get("tool_actions"):
|
||||||
|
# Continuation mode
|
||||||
|
conversation_id = internal_data.get("conversation_id")
|
||||||
|
if not conversation_id:
|
||||||
|
return make_response(
|
||||||
|
jsonify({"error": {"message": "conversation_id required for tool continuation", "type": "invalid_request"}}),
|
||||||
|
400,
|
||||||
|
)
|
||||||
|
(
|
||||||
|
agent,
|
||||||
|
messages,
|
||||||
|
tools_dict,
|
||||||
|
pending_tool_calls,
|
||||||
|
tool_actions,
|
||||||
|
) = processor.resume_from_tool_actions(
|
||||||
|
internal_data["tool_actions"], conversation_id
|
||||||
|
)
|
||||||
|
continuation = {
|
||||||
|
"messages": messages,
|
||||||
|
"tools_dict": tools_dict,
|
||||||
|
"pending_tool_calls": pending_tool_calls,
|
||||||
|
"tool_actions": tool_actions,
|
||||||
|
}
|
||||||
|
question = ""
|
||||||
|
else:
|
||||||
|
# Normal mode
|
||||||
|
question = internal_data.get("question", "")
|
||||||
|
agent = processor.build_agent(question)
|
||||||
|
continuation = None
|
||||||
|
|
||||||
|
if not processor.decoded_token:
|
||||||
|
return make_response(
|
||||||
|
jsonify({"error": {"message": "Unauthorized", "type": "auth_error"}}),
|
||||||
|
401,
|
||||||
|
)
|
||||||
|
|
||||||
|
helper = _V1AnswerHelper()
|
||||||
|
usage_error = helper.check_usage(processor.agent_config)
|
||||||
|
if usage_error:
|
||||||
|
return usage_error
|
||||||
|
|
||||||
|
should_save_conversation = bool(internal_data.get("save_conversation", False))
|
||||||
|
|
||||||
|
if is_stream:
|
||||||
|
return Response(
|
||||||
|
_stream_response(
|
||||||
|
helper,
|
||||||
|
question,
|
||||||
|
agent,
|
||||||
|
processor,
|
||||||
|
model_name,
|
||||||
|
continuation,
|
||||||
|
should_save_conversation,
|
||||||
|
),
|
||||||
|
mimetype="text/event-stream",
|
||||||
|
headers={
|
||||||
|
"Cache-Control": "no-cache",
|
||||||
|
"X-Accel-Buffering": "no",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return _non_stream_response(
|
||||||
|
helper,
|
||||||
|
question,
|
||||||
|
agent,
|
||||||
|
processor,
|
||||||
|
model_name,
|
||||||
|
continuation,
|
||||||
|
should_save_conversation,
|
||||||
|
)
|
||||||
|
|
||||||
|
except ValueError as e:
|
||||||
|
logger.error(
|
||||||
|
f"/v1/chat/completions error: {e} - {traceback.format_exc()}",
|
||||||
|
extra={"error": str(e)},
|
||||||
|
)
|
||||||
|
return make_response(
|
||||||
|
jsonify({"error": {"message": "Failed to process request", "type": "invalid_request"}}),
|
||||||
|
400,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"/v1/chat/completions error: {e} - {traceback.format_exc()}",
|
||||||
|
extra={"error": str(e)},
|
||||||
|
)
|
||||||
|
return make_response(
|
||||||
|
jsonify({"error": {"message": "Internal server error", "type": "server_error"}}),
|
||||||
|
500,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _stream_response(
|
||||||
|
helper: _V1AnswerHelper,
|
||||||
|
question: str,
|
||||||
|
agent: Any,
|
||||||
|
processor: StreamProcessor,
|
||||||
|
model_name: str,
|
||||||
|
continuation: Optional[Dict],
|
||||||
|
should_save_conversation: bool,
|
||||||
|
) -> Generator[str, None, None]:
|
||||||
|
"""Generate translated SSE chunks for streaming response."""
|
||||||
|
completion_id = f"chatcmpl-{int(time.time())}"
|
||||||
|
|
||||||
|
internal_stream = helper.complete_stream(
|
||||||
|
question=question,
|
||||||
|
agent=agent,
|
||||||
|
conversation_id=processor.conversation_id,
|
||||||
|
user_api_key=processor.agent_config.get("user_api_key"),
|
||||||
|
decoded_token=processor.decoded_token,
|
||||||
|
agent_id=processor.agent_id,
|
||||||
|
model_id=processor.model_id,
|
||||||
|
should_save_conversation=should_save_conversation,
|
||||||
|
_continuation=continuation,
|
||||||
|
)
|
||||||
|
|
||||||
|
for line in internal_stream:
|
||||||
|
if not line.strip():
|
||||||
|
continue
|
||||||
|
# Parse the internal SSE event
|
||||||
|
event_str = line.replace("data: ", "").strip()
|
||||||
|
try:
|
||||||
|
event_data = json.loads(event_str)
|
||||||
|
except (json.JSONDecodeError, TypeError):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Update completion_id when we get the conversation id
|
||||||
|
if event_data.get("type") == "id":
|
||||||
|
conv_id = event_data.get("id", "")
|
||||||
|
if conv_id:
|
||||||
|
completion_id = f"chatcmpl-{conv_id}"
|
||||||
|
|
||||||
|
# Translate to standard format
|
||||||
|
translated = translate_stream_event(event_data, completion_id, model_name)
|
||||||
|
for chunk in translated:
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
|
||||||
|
def _non_stream_response(
|
||||||
|
helper: _V1AnswerHelper,
|
||||||
|
question: str,
|
||||||
|
agent: Any,
|
||||||
|
processor: StreamProcessor,
|
||||||
|
model_name: str,
|
||||||
|
continuation: Optional[Dict],
|
||||||
|
should_save_conversation: bool,
|
||||||
|
) -> Response:
|
||||||
|
"""Collect full response and return as single JSON."""
|
||||||
|
stream = helper.complete_stream(
|
||||||
|
question=question,
|
||||||
|
agent=agent,
|
||||||
|
conversation_id=processor.conversation_id,
|
||||||
|
user_api_key=processor.agent_config.get("user_api_key"),
|
||||||
|
decoded_token=processor.decoded_token,
|
||||||
|
agent_id=processor.agent_id,
|
||||||
|
model_id=processor.model_id,
|
||||||
|
should_save_conversation=should_save_conversation,
|
||||||
|
_continuation=continuation,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = helper.process_response_stream(stream)
|
||||||
|
|
||||||
|
if result["error"]:
|
||||||
|
return make_response(
|
||||||
|
jsonify({"error": {"message": result["error"], "type": "server_error"}}),
|
||||||
|
500,
|
||||||
|
)
|
||||||
|
|
||||||
|
extra = result.get("extra")
|
||||||
|
pending = extra.get("pending_tool_calls") if isinstance(extra, dict) else None
|
||||||
|
|
||||||
|
response = translate_response(
|
||||||
|
conversation_id=result["conversation_id"],
|
||||||
|
answer=result["answer"] or "",
|
||||||
|
sources=result["sources"],
|
||||||
|
tool_calls=result["tool_calls"],
|
||||||
|
thought=result["thought"] or "",
|
||||||
|
model_name=model_name,
|
||||||
|
pending_tool_calls=pending,
|
||||||
|
)
|
||||||
|
return make_response(jsonify(response), 200)
|
||||||
|
|
||||||
|
|
||||||
|
@v1_bp.route("/models", methods=["GET"])
|
||||||
|
def list_models():
|
||||||
|
"""Handle GET /v1/models — return agents as models."""
|
||||||
|
api_key = _extract_bearer_token()
|
||||||
|
if not api_key:
|
||||||
|
return make_response(
|
||||||
|
jsonify({"error": {"message": "Missing Authorization header", "type": "auth_error"}}),
|
||||||
|
401,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
with db_readonly() as conn:
|
||||||
|
agents_repo = AgentsRepository(conn)
|
||||||
|
agent = agents_repo.find_by_key(api_key)
|
||||||
|
if not agent:
|
||||||
|
return make_response(
|
||||||
|
jsonify({"error": {"message": "Invalid API key", "type": "auth_error"}}),
|
||||||
|
401,
|
||||||
|
)
|
||||||
|
|
||||||
|
created = agent.get("created_at") or agent.get("createdAt")
|
||||||
|
created_ts = (
|
||||||
|
int(created.timestamp()) if hasattr(created, "timestamp")
|
||||||
|
else int(time.time())
|
||||||
|
)
|
||||||
|
model_id = str(agent.get("id") or agent.get("_id") or "")
|
||||||
|
model = {
|
||||||
|
"id": model_id,
|
||||||
|
"object": "model",
|
||||||
|
"created": created_ts,
|
||||||
|
"owned_by": "docsgpt",
|
||||||
|
"name": agent.get("name", ""),
|
||||||
|
"description": agent.get("description", ""),
|
||||||
|
}
|
||||||
|
|
||||||
|
return make_response(
|
||||||
|
jsonify({"object": "list", "data": [model]}),
|
||||||
|
200,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"/v1/models error: {e}", exc_info=True)
|
||||||
|
return make_response(
|
||||||
|
jsonify({"error": {"message": "Internal server error", "type": "server_error"}}),
|
||||||
|
500,
|
||||||
|
)
|
||||||
433
application/api/v1/translator.py
Normal file
433
application/api/v1/translator.py
Normal file
@@ -0,0 +1,433 @@
|
|||||||
|
"""Translate between standard chat completions format and DocsGPT internals.
|
||||||
|
|
||||||
|
This module handles:
|
||||||
|
- Request translation (chat completions -> DocsGPT internal format)
|
||||||
|
- Response translation (DocsGPT response -> chat completions format)
|
||||||
|
- Streaming event translation (DocsGPT SSE -> standard SSE chunks)
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
def _get_client_tool_name(tc: Dict) -> str:
|
||||||
|
"""Return the original tool name for client-facing responses.
|
||||||
|
|
||||||
|
For client-side tools the ``tool_name`` field carries the name the
|
||||||
|
client originally registered. Fall back to ``action_name`` (which
|
||||||
|
is now the clean LLM-visible name) or ``name``.
|
||||||
|
"""
|
||||||
|
return tc.get("tool_name", tc.get("action_name", tc.get("name", "")))
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Request translation
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def is_continuation(messages: List[Dict]) -> bool:
|
||||||
|
"""Check if messages represent a tool-call continuation.
|
||||||
|
|
||||||
|
A continuation is detected when the last message(s) have ``role: "tool"``
|
||||||
|
immediately after an assistant message with ``tool_calls``.
|
||||||
|
"""
|
||||||
|
if not messages:
|
||||||
|
return False
|
||||||
|
# Walk backwards: if we see tool messages before hitting a non-tool, non-assistant message
|
||||||
|
# and there's an assistant message with tool_calls, it's a continuation.
|
||||||
|
i = len(messages) - 1
|
||||||
|
while i >= 0 and messages[i].get("role") == "tool":
|
||||||
|
i -= 1
|
||||||
|
if i < 0:
|
||||||
|
return False
|
||||||
|
return (
|
||||||
|
messages[i].get("role") == "assistant"
|
||||||
|
and bool(messages[i].get("tool_calls"))
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def extract_tool_results(messages: List[Dict]) -> List[Dict]:
|
||||||
|
"""Extract tool results from trailing tool messages for continuation.
|
||||||
|
|
||||||
|
Returns a list of ``tool_actions`` dicts with ``call_id`` and ``result``.
|
||||||
|
"""
|
||||||
|
results = []
|
||||||
|
for msg in reversed(messages):
|
||||||
|
if msg.get("role") != "tool":
|
||||||
|
break
|
||||||
|
call_id = msg.get("tool_call_id", "")
|
||||||
|
content = msg.get("content", "")
|
||||||
|
if isinstance(content, str):
|
||||||
|
try:
|
||||||
|
content = json.loads(content)
|
||||||
|
except (json.JSONDecodeError, TypeError):
|
||||||
|
pass
|
||||||
|
results.append({"call_id": call_id, "result": content})
|
||||||
|
results.reverse()
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def extract_conversation_id(messages: List[Dict]) -> Optional[str]:
|
||||||
|
"""Try to extract conversation_id from the assistant message before tool results.
|
||||||
|
|
||||||
|
The conversation_id may be stored in a custom field on the assistant message
|
||||||
|
from a previous response cycle.
|
||||||
|
"""
|
||||||
|
for msg in reversed(messages):
|
||||||
|
if msg.get("role") == "assistant":
|
||||||
|
# Check docsgpt extension
|
||||||
|
return msg.get("docsgpt", {}).get("conversation_id")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def extract_system_prompt(messages: List[Dict]) -> Optional[str]:
|
||||||
|
"""Extract the first system message content from the messages array.
|
||||||
|
|
||||||
|
Returns None if no system message is present.
|
||||||
|
"""
|
||||||
|
for msg in messages:
|
||||||
|
if msg.get("role") == "system":
|
||||||
|
return msg.get("content", "")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def convert_history(messages: List[Dict]) -> List[Dict]:
|
||||||
|
"""Convert chat completions messages array to DocsGPT history format.
|
||||||
|
|
||||||
|
DocsGPT history is a list of ``{prompt, response}`` dicts.
|
||||||
|
Excludes the last user message (that becomes the ``question``).
|
||||||
|
"""
|
||||||
|
history = []
|
||||||
|
i = 0
|
||||||
|
while i < len(messages):
|
||||||
|
msg = messages[i]
|
||||||
|
if msg.get("role") == "system":
|
||||||
|
i += 1
|
||||||
|
continue
|
||||||
|
if msg.get("role") == "user":
|
||||||
|
# Look ahead for assistant response
|
||||||
|
if i + 1 < len(messages) and messages[i + 1].get("role") == "assistant":
|
||||||
|
content = messages[i + 1].get("content") or ""
|
||||||
|
history.append({
|
||||||
|
"prompt": msg.get("content", ""),
|
||||||
|
"response": content,
|
||||||
|
})
|
||||||
|
i += 2
|
||||||
|
continue
|
||||||
|
# Last user message without response — skip (it's the question)
|
||||||
|
i += 1
|
||||||
|
continue
|
||||||
|
i += 1
|
||||||
|
return history
|
||||||
|
|
||||||
|
|
||||||
|
def translate_request(
|
||||||
|
data: Dict[str, Any], api_key: str
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Translate a chat completions request to DocsGPT internal format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: The incoming request body.
|
||||||
|
api_key: Agent API key from the Authorization header.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict suitable for passing to ``StreamProcessor``.
|
||||||
|
"""
|
||||||
|
messages = data.get("messages", [])
|
||||||
|
|
||||||
|
# Check for continuation (tool results after assistant tool_calls)
|
||||||
|
if is_continuation(messages):
|
||||||
|
tool_actions = extract_tool_results(messages)
|
||||||
|
conversation_id = extract_conversation_id(messages)
|
||||||
|
if not conversation_id:
|
||||||
|
conversation_id = data.get("conversation_id")
|
||||||
|
result = {
|
||||||
|
"conversation_id": conversation_id,
|
||||||
|
"tool_actions": tool_actions,
|
||||||
|
"api_key": api_key,
|
||||||
|
}
|
||||||
|
# Carry tools forward for next iteration
|
||||||
|
if data.get("tools"):
|
||||||
|
result["client_tools"] = data["tools"]
|
||||||
|
return result
|
||||||
|
|
||||||
|
# Normal request — extract question from last user message
|
||||||
|
question = ""
|
||||||
|
for msg in reversed(messages):
|
||||||
|
if msg.get("role") == "user":
|
||||||
|
question = msg.get("content", "")
|
||||||
|
break
|
||||||
|
|
||||||
|
history = convert_history(messages)
|
||||||
|
system_prompt_override = extract_system_prompt(messages)
|
||||||
|
|
||||||
|
docsgpt = data.get("docsgpt", {})
|
||||||
|
|
||||||
|
result = {
|
||||||
|
"question": question,
|
||||||
|
"api_key": api_key,
|
||||||
|
"history": json.dumps(history),
|
||||||
|
# Conversations are NOT persisted by default on the v1 endpoint.
|
||||||
|
# Callers opt in via ``docsgpt.save_conversation: true``.
|
||||||
|
"save_conversation": bool(docsgpt.get("save_conversation", False)),
|
||||||
|
}
|
||||||
|
|
||||||
|
if system_prompt_override is not None:
|
||||||
|
result["system_prompt_override"] = system_prompt_override
|
||||||
|
|
||||||
|
# Client tools
|
||||||
|
if data.get("tools"):
|
||||||
|
result["client_tools"] = data["tools"]
|
||||||
|
|
||||||
|
# DocsGPT extensions
|
||||||
|
if docsgpt.get("attachments"):
|
||||||
|
result["attachments"] = docsgpt["attachments"]
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Response translation (non-streaming)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def translate_response(
|
||||||
|
conversation_id: str,
|
||||||
|
answer: str,
|
||||||
|
sources: Optional[List[Dict]],
|
||||||
|
tool_calls: Optional[List[Dict]],
|
||||||
|
thought: str,
|
||||||
|
model_name: str,
|
||||||
|
pending_tool_calls: Optional[List[Dict]] = None,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Translate DocsGPT response to chat completions format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
conversation_id: The DocsGPT conversation ID.
|
||||||
|
answer: The assistant's text response.
|
||||||
|
sources: RAG retrieval sources.
|
||||||
|
tool_calls: Completed tool call results.
|
||||||
|
thought: Reasoning/thinking tokens.
|
||||||
|
model_name: Model/agent identifier.
|
||||||
|
pending_tool_calls: Pending client-side tool calls (if paused).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict in the standard chat completions response format.
|
||||||
|
"""
|
||||||
|
created = int(time.time())
|
||||||
|
completion_id = f"chatcmpl-{conversation_id}" if conversation_id else f"chatcmpl-{created}"
|
||||||
|
|
||||||
|
# Build message
|
||||||
|
message: Dict[str, Any] = {"role": "assistant"}
|
||||||
|
|
||||||
|
if pending_tool_calls:
|
||||||
|
# Tool calls pending — return them for client execution
|
||||||
|
message["content"] = None
|
||||||
|
message["tool_calls"] = [
|
||||||
|
{
|
||||||
|
"id": tc.get("call_id", ""),
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": _get_client_tool_name(tc),
|
||||||
|
"arguments": (
|
||||||
|
json.dumps(tc["arguments"])
|
||||||
|
if isinstance(tc.get("arguments"), dict)
|
||||||
|
else tc.get("arguments", "{}")
|
||||||
|
),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for tc in pending_tool_calls
|
||||||
|
]
|
||||||
|
finish_reason = "tool_calls"
|
||||||
|
else:
|
||||||
|
message["content"] = answer
|
||||||
|
if thought:
|
||||||
|
message["reasoning_content"] = thought
|
||||||
|
finish_reason = "stop"
|
||||||
|
|
||||||
|
result: Dict[str, Any] = {
|
||||||
|
"id": completion_id,
|
||||||
|
"object": "chat.completion",
|
||||||
|
"created": created,
|
||||||
|
"model": model_name,
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"message": message,
|
||||||
|
"finish_reason": finish_reason,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": 0,
|
||||||
|
"completion_tokens": 0,
|
||||||
|
"total_tokens": 0,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
# DocsGPT extensions
|
||||||
|
docsgpt: Dict[str, Any] = {}
|
||||||
|
if conversation_id:
|
||||||
|
docsgpt["conversation_id"] = conversation_id
|
||||||
|
if sources:
|
||||||
|
docsgpt["sources"] = sources
|
||||||
|
if tool_calls:
|
||||||
|
docsgpt["tool_calls"] = tool_calls
|
||||||
|
if docsgpt:
|
||||||
|
result["docsgpt"] = docsgpt
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Streaming event translation
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _make_chunk(
|
||||||
|
completion_id: str,
|
||||||
|
model_name: str,
|
||||||
|
delta: Dict[str, Any],
|
||||||
|
finish_reason: Optional[str] = None,
|
||||||
|
) -> str:
|
||||||
|
"""Build a single SSE chunk in the standard streaming format."""
|
||||||
|
chunk = {
|
||||||
|
"id": completion_id,
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"created": int(time.time()),
|
||||||
|
"model": model_name,
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"delta": delta,
|
||||||
|
"finish_reason": finish_reason,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
return f"data: {json.dumps(chunk)}\n\n"
|
||||||
|
|
||||||
|
|
||||||
|
def _make_docsgpt_chunk(data: Dict[str, Any]) -> str:
|
||||||
|
"""Build a DocsGPT extension SSE chunk."""
|
||||||
|
return f"data: {json.dumps({'docsgpt': data})}\n\n"
|
||||||
|
|
||||||
|
|
||||||
|
def translate_stream_event(
|
||||||
|
event_data: Dict[str, Any],
|
||||||
|
completion_id: str,
|
||||||
|
model_name: str,
|
||||||
|
) -> List[str]:
|
||||||
|
"""Translate a DocsGPT SSE event dict to standard streaming chunks.
|
||||||
|
|
||||||
|
May return 0, 1, or 2 chunks per input event. For example, a completed
|
||||||
|
tool call produces both a docsgpt extension chunk and nothing on the
|
||||||
|
standard side (since server-side tool calls aren't surfaced in standard
|
||||||
|
format).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
event_data: Parsed DocsGPT event dict.
|
||||||
|
completion_id: The completion ID for this response.
|
||||||
|
model_name: Model/agent identifier.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of SSE-formatted strings to send to the client.
|
||||||
|
"""
|
||||||
|
event_type = event_data.get("type")
|
||||||
|
chunks: List[str] = []
|
||||||
|
|
||||||
|
if event_type == "answer":
|
||||||
|
chunks.append(
|
||||||
|
_make_chunk(completion_id, model_name, {"content": event_data.get("answer", "")})
|
||||||
|
)
|
||||||
|
|
||||||
|
elif event_type == "thought":
|
||||||
|
chunks.append(
|
||||||
|
_make_chunk(
|
||||||
|
completion_id, model_name,
|
||||||
|
{"reasoning_content": event_data.get("thought", "")},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
elif event_type == "source":
|
||||||
|
chunks.append(
|
||||||
|
_make_docsgpt_chunk({
|
||||||
|
"type": "source",
|
||||||
|
"sources": event_data.get("source", []),
|
||||||
|
})
|
||||||
|
)
|
||||||
|
|
||||||
|
elif event_type == "tool_call":
|
||||||
|
tc_data = event_data.get("data", {})
|
||||||
|
status = tc_data.get("status")
|
||||||
|
|
||||||
|
if status == "requires_client_execution":
|
||||||
|
# Standard: stream as tool_calls delta
|
||||||
|
args = tc_data.get("arguments", {})
|
||||||
|
args_str = json.dumps(args) if isinstance(args, dict) else str(args)
|
||||||
|
chunks.append(
|
||||||
|
_make_chunk(completion_id, model_name, {
|
||||||
|
"tool_calls": [{
|
||||||
|
"index": 0,
|
||||||
|
"id": tc_data.get("call_id", ""),
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": _get_client_tool_name(tc_data),
|
||||||
|
"arguments": args_str,
|
||||||
|
},
|
||||||
|
}],
|
||||||
|
})
|
||||||
|
)
|
||||||
|
elif status == "awaiting_approval":
|
||||||
|
# Extension: approval needed
|
||||||
|
chunks.append(_make_docsgpt_chunk({"type": "tool_call", "data": tc_data}))
|
||||||
|
elif status in ("completed", "pending", "error", "denied", "skipped"):
|
||||||
|
# Extension: tool call progress
|
||||||
|
chunks.append(_make_docsgpt_chunk({"type": "tool_call", "data": tc_data}))
|
||||||
|
|
||||||
|
elif event_type == "tool_calls_pending":
|
||||||
|
# Standard: finish_reason = tool_calls
|
||||||
|
chunks.append(
|
||||||
|
_make_chunk(completion_id, model_name, {}, finish_reason="tool_calls")
|
||||||
|
)
|
||||||
|
# Also emit as docsgpt extension
|
||||||
|
chunks.append(
|
||||||
|
_make_docsgpt_chunk({
|
||||||
|
"type": "tool_calls_pending",
|
||||||
|
"pending_tool_calls": event_data.get("data", {}).get("pending_tool_calls", []),
|
||||||
|
})
|
||||||
|
)
|
||||||
|
|
||||||
|
elif event_type == "end":
|
||||||
|
chunks.append(
|
||||||
|
_make_chunk(completion_id, model_name, {}, finish_reason="stop")
|
||||||
|
)
|
||||||
|
chunks.append("data: [DONE]\n\n")
|
||||||
|
|
||||||
|
elif event_type == "id":
|
||||||
|
chunks.append(
|
||||||
|
_make_docsgpt_chunk({
|
||||||
|
"type": "id",
|
||||||
|
"conversation_id": event_data.get("id", ""),
|
||||||
|
})
|
||||||
|
)
|
||||||
|
|
||||||
|
elif event_type == "error":
|
||||||
|
# Emit as standard error (non-standard but widely supported)
|
||||||
|
error_data = {
|
||||||
|
"error": {
|
||||||
|
"message": event_data.get("error", "An error occurred"),
|
||||||
|
"type": "server_error",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
chunks.append(f"data: {json.dumps(error_data)}\n\n")
|
||||||
|
|
||||||
|
elif event_type == "structured_answer":
|
||||||
|
chunks.append(
|
||||||
|
_make_chunk(
|
||||||
|
completion_id, model_name,
|
||||||
|
{"content": event_data.get("answer", "")},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Skip: tool_calls (redundant), research_plan, research_progress
|
||||||
|
|
||||||
|
return chunks
|
||||||
@@ -1,3 +1,4 @@
|
|||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
import platform
|
import platform
|
||||||
import uuid
|
import uuid
|
||||||
@@ -17,8 +18,14 @@ from application.api.answer import answer # noqa: E402
|
|||||||
from application.api.internal.routes import internal # noqa: E402
|
from application.api.internal.routes import internal # noqa: E402
|
||||||
from application.api.user.routes import user # noqa: E402
|
from application.api.user.routes import user # noqa: E402
|
||||||
from application.api.connector.routes import connector # noqa: E402
|
from application.api.connector.routes import connector # noqa: E402
|
||||||
|
from application.api.v1 import v1_bp # noqa: E402
|
||||||
from application.celery_init import celery # noqa: E402
|
from application.celery_init import celery # noqa: E402
|
||||||
from application.core.settings import settings # noqa: E402
|
from application.core.settings import settings # noqa: E402
|
||||||
|
from application.storage.db.bootstrap import ensure_database_ready # noqa: E402
|
||||||
|
from application.stt.upload_limits import ( # noqa: E402
|
||||||
|
build_stt_file_size_limit_message,
|
||||||
|
should_reject_stt_request,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if platform.system() == "Windows":
|
if platform.system() == "Windows":
|
||||||
@@ -27,11 +34,23 @@ if platform.system() == "Windows":
|
|||||||
pathlib.PosixPath = pathlib.WindowsPath
|
pathlib.PosixPath = pathlib.WindowsPath
|
||||||
dotenv.load_dotenv()
|
dotenv.load_dotenv()
|
||||||
|
|
||||||
|
# Self-bootstrap the user-data Postgres DB. Runs before any blueprint or
|
||||||
|
# repository touches the engine, so the first request can't race the
|
||||||
|
# schema being created. Gated by AUTO_CREATE_DB / AUTO_MIGRATE settings
|
||||||
|
# (default ON for dev; disable in prod if schema is managed out-of-band).
|
||||||
|
ensure_database_ready(
|
||||||
|
settings.POSTGRES_URI,
|
||||||
|
create_db=settings.AUTO_CREATE_DB,
|
||||||
|
migrate=settings.AUTO_MIGRATE,
|
||||||
|
logger=logging.getLogger("application.app"),
|
||||||
|
)
|
||||||
|
|
||||||
app = Flask(__name__)
|
app = Flask(__name__)
|
||||||
app.register_blueprint(user)
|
app.register_blueprint(user)
|
||||||
app.register_blueprint(answer)
|
app.register_blueprint(answer)
|
||||||
app.register_blueprint(internal)
|
app.register_blueprint(internal)
|
||||||
app.register_blueprint(connector)
|
app.register_blueprint(connector)
|
||||||
|
app.register_blueprint(v1_bp)
|
||||||
app.config.update(
|
app.config.update(
|
||||||
UPLOAD_FOLDER="inputs",
|
UPLOAD_FOLDER="inputs",
|
||||||
CELERY_BROKER_URL=settings.CELERY_BROKER_URL,
|
CELERY_BROKER_URL=settings.CELERY_BROKER_URL,
|
||||||
@@ -68,6 +87,11 @@ def home():
|
|||||||
return "Welcome to DocsGPT Backend!"
|
return "Welcome to DocsGPT Backend!"
|
||||||
|
|
||||||
|
|
||||||
|
@app.route("/api/health")
|
||||||
|
def health():
|
||||||
|
return jsonify({"status": "ok"})
|
||||||
|
|
||||||
|
|
||||||
@app.route("/api/config")
|
@app.route("/api/config")
|
||||||
def get_config():
|
def get_config():
|
||||||
response = {
|
response = {
|
||||||
@@ -88,10 +112,33 @@ def generate_token():
|
|||||||
return jsonify({"error": "Token generation not allowed in current auth mode"}), 400
|
return jsonify({"error": "Token generation not allowed in current auth mode"}), 400
|
||||||
|
|
||||||
|
|
||||||
|
@app.before_request
|
||||||
|
def enforce_stt_request_size_limits():
|
||||||
|
if request.method == "OPTIONS":
|
||||||
|
return None
|
||||||
|
if should_reject_stt_request(request.path, request.content_length):
|
||||||
|
return (
|
||||||
|
jsonify(
|
||||||
|
{
|
||||||
|
"success": False,
|
||||||
|
"message": build_stt_file_size_limit_message(),
|
||||||
|
}
|
||||||
|
),
|
||||||
|
413,
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
@app.before_request
|
@app.before_request
|
||||||
def authenticate_request():
|
def authenticate_request():
|
||||||
if request.method == "OPTIONS":
|
if request.method == "OPTIONS":
|
||||||
return "", 200
|
return "", 200
|
||||||
|
# OpenAI-compatible routes authenticate via opaque agent API keys in the
|
||||||
|
# Authorization header, which the JWT decoder below would reject. Defer
|
||||||
|
# auth to the route handlers (see application/api/v1/routes.py).
|
||||||
|
if request.path.startswith("/v1/"):
|
||||||
|
request.decoded_token = None
|
||||||
|
return None
|
||||||
decoded_token = handle_auth(request)
|
decoded_token = handle_auth(request)
|
||||||
if not decoded_token:
|
if not decoded_token:
|
||||||
request.decoded_token = None
|
request.decoded_token = None
|
||||||
|
|||||||
@@ -19,9 +19,9 @@ def handle_auth(request, data={}):
|
|||||||
options={"verify_exp": False},
|
options={"verify_exp": False},
|
||||||
)
|
)
|
||||||
return decoded_token
|
return decoded_token
|
||||||
except Exception as e:
|
except Exception:
|
||||||
return {
|
return {
|
||||||
"message": f"Authentication error: {str(e)}",
|
"message": "Authentication error: invalid token",
|
||||||
"error": "invalid_token",
|
"error": "invalid_token",
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
from celery import Celery
|
from celery import Celery
|
||||||
from application.core.settings import settings
|
from application.core.settings import settings
|
||||||
from celery.signals import setup_logging
|
from celery.signals import setup_logging, worker_process_init
|
||||||
|
|
||||||
|
|
||||||
def make_celery(app_name=__name__):
|
def make_celery(app_name=__name__):
|
||||||
@@ -20,4 +20,24 @@ def config_loggers(*args, **kwargs):
|
|||||||
setup_logging()
|
setup_logging()
|
||||||
|
|
||||||
|
|
||||||
|
@worker_process_init.connect
|
||||||
|
def _dispose_db_engine_on_fork(*args, **kwargs):
|
||||||
|
"""Dispose the SQLAlchemy engine pool in each forked Celery worker.
|
||||||
|
|
||||||
|
SQLAlchemy connection pools are not fork-safe: file descriptors shared
|
||||||
|
between the parent and a forked worker will corrupt the pool. Disposing
|
||||||
|
on ``worker_process_init`` gives every worker its own fresh pool on
|
||||||
|
first use.
|
||||||
|
|
||||||
|
Imported lazily so Celery workers that don't touch Postgres (or where
|
||||||
|
``POSTGRES_URI`` is unset) don't fail at startup.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from application.storage.db.engine import dispose_engine
|
||||||
|
except Exception:
|
||||||
|
return
|
||||||
|
dispose_engine()
|
||||||
|
|
||||||
|
|
||||||
celery = make_celery()
|
celery = make_celery()
|
||||||
|
celery.config_from_object("application.celeryconfig")
|
||||||
|
|||||||
@@ -6,3 +6,6 @@ result_backend = os.getenv("CELERY_RESULT_BACKEND")
|
|||||||
task_serializer = 'json'
|
task_serializer = 'json'
|
||||||
result_serializer = 'json'
|
result_serializer = 'json'
|
||||||
accept_content = ['json']
|
accept_content = ['json']
|
||||||
|
|
||||||
|
# Autodiscover tasks
|
||||||
|
imports = ('application.api.user.tasks',)
|
||||||
|
|||||||
89
application/core/db_uri.py
Normal file
89
application/core/db_uri.py
Normal file
@@ -0,0 +1,89 @@
|
|||||||
|
"""Normalize user-supplied Postgres URIs for different drivers.
|
||||||
|
|
||||||
|
DocsGPT has two Postgres connection strings pointing at potentially
|
||||||
|
different databases:
|
||||||
|
|
||||||
|
* ``POSTGRES_URI`` feeds SQLAlchemy, which needs the
|
||||||
|
``postgresql+psycopg://`` dialect prefix to pick the psycopg v3 driver.
|
||||||
|
* ``PGVECTOR_CONNECTION_STRING`` feeds ``psycopg.connect()`` directly
|
||||||
|
(via libpq) in ``application/vectorstore/pgvector.py``. libpq only
|
||||||
|
understands ``postgres://`` and ``postgresql://`` — the SQLAlchemy
|
||||||
|
dialect prefix is an invalid URI from its point of view.
|
||||||
|
|
||||||
|
The two fields therefore need opposite normalization so operators don't
|
||||||
|
have to know which driver a given field feeds. Each normalizer also
|
||||||
|
silently upgrades the legacy ``postgresql+psycopg2://`` prefix since
|
||||||
|
psycopg2 is no longer in the project.
|
||||||
|
|
||||||
|
This module is deliberately separate from ``application/core/settings.py``
|
||||||
|
so the Settings class stays focused on field declarations, and the
|
||||||
|
URI-rewriting logic can be unit-tested without triggering ``.env``
|
||||||
|
file loading from importing Settings.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
|
||||||
|
def _rewrite_uri_prefixes(v, rewrites):
|
||||||
|
"""Shared URI prefix rewriter used by both normalizers below.
|
||||||
|
|
||||||
|
Strips whitespace, returns ``None`` for empty / ``"none"`` values,
|
||||||
|
applies the first matching rewrite, and passes unrecognised input
|
||||||
|
through so downstream consumers (SQLAlchemy, libpq) can produce
|
||||||
|
their own error messages rather than us silently eating a
|
||||||
|
misconfiguration.
|
||||||
|
"""
|
||||||
|
if v is None:
|
||||||
|
return None
|
||||||
|
if not isinstance(v, str):
|
||||||
|
return v
|
||||||
|
v = v.strip()
|
||||||
|
if not v or v.lower() == "none":
|
||||||
|
return None
|
||||||
|
for prefix, target in rewrites:
|
||||||
|
if v.startswith(prefix):
|
||||||
|
return target + v[len(prefix):]
|
||||||
|
return v
|
||||||
|
|
||||||
|
|
||||||
|
# POSTGRES_URI feeds SQLAlchemy, which needs a ``postgresql+psycopg://``
|
||||||
|
# dialect prefix to select the psycopg v3 driver. Normalize the
|
||||||
|
# operator-friendly forms TOWARD that dialect.
|
||||||
|
_POSTGRES_URI_REWRITES = (
|
||||||
|
("postgresql+psycopg2://", "postgresql+psycopg://"),
|
||||||
|
("postgresql://", "postgresql+psycopg://"),
|
||||||
|
("postgres://", "postgresql+psycopg://"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# PGVECTOR_CONNECTION_STRING feeds ``psycopg.connect()`` directly in
|
||||||
|
# application/vectorstore/pgvector.py — NOT SQLAlchemy. libpq only
|
||||||
|
# understands ``postgres://`` and ``postgresql://``; the SQLAlchemy
|
||||||
|
# dialect prefix is an invalid URI from libpq's point of view. Strip it
|
||||||
|
# if the operator accidentally copied their POSTGRES_URI value here.
|
||||||
|
_PGVECTOR_CONNECTION_STRING_REWRITES = (
|
||||||
|
("postgresql+psycopg2://", "postgresql://"),
|
||||||
|
("postgresql+psycopg://", "postgresql://"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_postgres_uri(v):
|
||||||
|
"""Normalize a user-supplied POSTGRES_URI to the SQLAlchemy psycopg3 form.
|
||||||
|
|
||||||
|
Accepts the forms operators naturally write (``postgres://``,
|
||||||
|
``postgresql://``) and rewrites them to ``postgresql+psycopg://``.
|
||||||
|
Unknown schemes pass through unchanged so SQLAlchemy can produce its
|
||||||
|
own dialect-not-found error.
|
||||||
|
"""
|
||||||
|
return _rewrite_uri_prefixes(v, _POSTGRES_URI_REWRITES)
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_pgvector_connection_string(v):
|
||||||
|
"""Normalize a user-supplied PGVECTOR_CONNECTION_STRING for libpq.
|
||||||
|
|
||||||
|
Strips the SQLAlchemy dialect prefix if the operator accidentally
|
||||||
|
copied their POSTGRES_URI value here — libpq can't parse it.
|
||||||
|
User-friendly forms (``postgres://``, ``postgresql://``) pass
|
||||||
|
through unchanged since libpq accepts them natively.
|
||||||
|
"""
|
||||||
|
return _rewrite_uri_prefixes(v, _PGVECTOR_CONNECTION_STRING_REWRITES)
|
||||||
34
application/core/json_schema_utils.py
Normal file
34
application/core/json_schema_utils.py
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
|
||||||
|
class JsonSchemaValidationError(ValueError):
|
||||||
|
"""Raised when a JSON schema payload is invalid."""
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_json_schema_payload(json_schema: Any) -> Optional[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Normalize accepted JSON schema payload shapes to a plain schema object.
|
||||||
|
|
||||||
|
Accepted inputs:
|
||||||
|
- None
|
||||||
|
- A raw schema object with a top-level "type"
|
||||||
|
- A wrapped payload with a top-level "schema" object
|
||||||
|
"""
|
||||||
|
if json_schema is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if not isinstance(json_schema, dict):
|
||||||
|
raise JsonSchemaValidationError("must be a valid JSON object")
|
||||||
|
|
||||||
|
wrapped_schema = json_schema.get("schema")
|
||||||
|
if wrapped_schema is not None:
|
||||||
|
if not isinstance(wrapped_schema, dict):
|
||||||
|
raise JsonSchemaValidationError('field "schema" must be a valid JSON object')
|
||||||
|
return wrapped_schema
|
||||||
|
|
||||||
|
if "type" not in json_schema:
|
||||||
|
raise JsonSchemaValidationError(
|
||||||
|
'must include either a "type" or "schema" field'
|
||||||
|
)
|
||||||
|
|
||||||
|
return json_schema
|
||||||
@@ -8,8 +8,8 @@ from application.core.model_settings import (
|
|||||||
ModelProvider,
|
ModelProvider,
|
||||||
)
|
)
|
||||||
|
|
||||||
OPENAI_ATTACHMENTS = [
|
# Base image attachment types supported by most vision-capable LLMs
|
||||||
"application/pdf",
|
IMAGE_ATTACHMENTS = [
|
||||||
"image/png",
|
"image/png",
|
||||||
"image/jpeg",
|
"image/jpeg",
|
||||||
"image/jpg",
|
"image/jpg",
|
||||||
@@ -17,14 +17,17 @@ OPENAI_ATTACHMENTS = [
|
|||||||
"image/gif",
|
"image/gif",
|
||||||
]
|
]
|
||||||
|
|
||||||
GOOGLE_ATTACHMENTS = [
|
# PDF excluded: most OpenAI-compatible endpoints don't support native PDF uploads.
|
||||||
"application/pdf",
|
# When excluded, PDFs are synthetically processed by converting pages to images.
|
||||||
"image/png",
|
OPENAI_ATTACHMENTS = IMAGE_ATTACHMENTS
|
||||||
"image/jpeg",
|
|
||||||
"image/jpg",
|
GOOGLE_ATTACHMENTS = ["application/pdf"] + IMAGE_ATTACHMENTS
|
||||||
"image/webp",
|
|
||||||
"image/gif",
|
ANTHROPIC_ATTACHMENTS = IMAGE_ATTACHMENTS
|
||||||
]
|
|
||||||
|
OPENROUTER_ATTACHMENTS = IMAGE_ATTACHMENTS
|
||||||
|
|
||||||
|
NOVITA_ATTACHMENTS = IMAGE_ATTACHMENTS
|
||||||
|
|
||||||
|
|
||||||
OPENAI_MODELS = [
|
OPENAI_MODELS = [
|
||||||
@@ -63,6 +66,7 @@ ANTHROPIC_MODELS = [
|
|||||||
description="Latest Claude 3.5 Sonnet with enhanced capabilities",
|
description="Latest Claude 3.5 Sonnet with enhanced capabilities",
|
||||||
capabilities=ModelCapabilities(
|
capabilities=ModelCapabilities(
|
||||||
supports_tools=True,
|
supports_tools=True,
|
||||||
|
supported_attachment_types=ANTHROPIC_ATTACHMENTS,
|
||||||
context_window=200000,
|
context_window=200000,
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
@@ -73,6 +77,7 @@ ANTHROPIC_MODELS = [
|
|||||||
description="Balanced performance and capability",
|
description="Balanced performance and capability",
|
||||||
capabilities=ModelCapabilities(
|
capabilities=ModelCapabilities(
|
||||||
supports_tools=True,
|
supports_tools=True,
|
||||||
|
supported_attachment_types=ANTHROPIC_ATTACHMENTS,
|
||||||
context_window=200000,
|
context_window=200000,
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
@@ -83,6 +88,7 @@ ANTHROPIC_MODELS = [
|
|||||||
description="Most capable Claude model",
|
description="Most capable Claude model",
|
||||||
capabilities=ModelCapabilities(
|
capabilities=ModelCapabilities(
|
||||||
supports_tools=True,
|
supports_tools=True,
|
||||||
|
supported_attachment_types=ANTHROPIC_ATTACHMENTS,
|
||||||
context_window=200000,
|
context_window=200000,
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
@@ -93,6 +99,7 @@ ANTHROPIC_MODELS = [
|
|||||||
description="Fastest Claude model",
|
description="Fastest Claude model",
|
||||||
capabilities=ModelCapabilities(
|
capabilities=ModelCapabilities(
|
||||||
supports_tools=True,
|
supports_tools=True,
|
||||||
|
supported_attachment_types=ANTHROPIC_ATTACHMENTS,
|
||||||
context_window=200000,
|
context_window=200000,
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
@@ -151,23 +158,78 @@ GROQ_MODELS = [
|
|||||||
),
|
),
|
||||||
),
|
),
|
||||||
AvailableModel(
|
AvailableModel(
|
||||||
id="llama-3.1-8b-instant",
|
id="openai/gpt-oss-120b",
|
||||||
provider=ModelProvider.GROQ,
|
provider=ModelProvider.GROQ,
|
||||||
display_name="Llama 3.1 8B",
|
display_name="GPT-OSS 120B",
|
||||||
description="Ultra-fast inference",
|
description="Open-source GPT model optimized for speed",
|
||||||
capabilities=ModelCapabilities(
|
capabilities=ModelCapabilities(
|
||||||
supports_tools=True,
|
supports_tools=True,
|
||||||
context_window=128000,
|
context_window=128000,
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
OPENROUTER_MODELS = [
|
||||||
AvailableModel(
|
AvailableModel(
|
||||||
id="mixtral-8x7b-32768",
|
id="qwen/qwen3-coder:free",
|
||||||
provider=ModelProvider.GROQ,
|
provider=ModelProvider.OPENROUTER,
|
||||||
display_name="Mixtral 8x7B",
|
display_name="Qwen 3 Coder",
|
||||||
description="High-speed inference with tools",
|
description="Latest Qwen model with high-speed inference",
|
||||||
capabilities=ModelCapabilities(
|
capabilities=ModelCapabilities(
|
||||||
supports_tools=True,
|
supports_tools=True,
|
||||||
context_window=32768,
|
context_window=128000,
|
||||||
|
supported_attachment_types=OPENROUTER_ATTACHMENTS
|
||||||
|
),
|
||||||
|
),
|
||||||
|
AvailableModel(
|
||||||
|
id="google/gemma-3-27b-it:free",
|
||||||
|
provider=ModelProvider.OPENROUTER,
|
||||||
|
display_name="Gemma 3 27B",
|
||||||
|
description="Latest Gemma model with high-speed inference",
|
||||||
|
capabilities=ModelCapabilities(
|
||||||
|
supports_tools=True,
|
||||||
|
context_window=128000,
|
||||||
|
supported_attachment_types=OPENROUTER_ATTACHMENTS
|
||||||
|
),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
NOVITA_MODELS = [
|
||||||
|
AvailableModel(
|
||||||
|
id="moonshotai/kimi-k2.5",
|
||||||
|
provider=ModelProvider.NOVITA,
|
||||||
|
display_name="Kimi K2.5",
|
||||||
|
description="MoE model with function calling, structured output, reasoning, and vision",
|
||||||
|
capabilities=ModelCapabilities(
|
||||||
|
supports_tools=True,
|
||||||
|
supports_structured_output=True,
|
||||||
|
supported_attachment_types=NOVITA_ATTACHMENTS,
|
||||||
|
context_window=262144,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
AvailableModel(
|
||||||
|
id="zai-org/glm-5",
|
||||||
|
provider=ModelProvider.NOVITA,
|
||||||
|
display_name="GLM-5",
|
||||||
|
description="MoE model with function calling, structured output, and reasoning",
|
||||||
|
capabilities=ModelCapabilities(
|
||||||
|
supports_tools=True,
|
||||||
|
supports_structured_output=True,
|
||||||
|
supported_attachment_types=[],
|
||||||
|
context_window=202800,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
AvailableModel(
|
||||||
|
id="minimax/minimax-m2.5",
|
||||||
|
provider=ModelProvider.NOVITA,
|
||||||
|
display_name="MiniMax M2.5",
|
||||||
|
description="MoE model with function calling, structured output, and reasoning",
|
||||||
|
capabilities=ModelCapabilities(
|
||||||
|
supports_tools=True,
|
||||||
|
supports_structured_output=True,
|
||||||
|
supported_attachment_types=[],
|
||||||
|
context_window=204800,
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
@@ -187,3 +249,18 @@ AZURE_OPENAI_MODELS = [
|
|||||||
),
|
),
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def create_custom_openai_model(model_name: str, base_url: str) -> AvailableModel:
|
||||||
|
"""Create a custom OpenAI-compatible model (e.g., LM Studio, Ollama)."""
|
||||||
|
return AvailableModel(
|
||||||
|
id=model_name,
|
||||||
|
provider=ModelProvider.OPENAI,
|
||||||
|
display_name=model_name,
|
||||||
|
description=f"Custom OpenAI-compatible model at {base_url}",
|
||||||
|
base_url=base_url,
|
||||||
|
capabilities=ModelCapabilities(
|
||||||
|
supports_tools=True,
|
||||||
|
supported_attachment_types=OPENAI_ATTACHMENTS,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
class ModelProvider(str, Enum):
|
class ModelProvider(str, Enum):
|
||||||
OPENAI = "openai"
|
OPENAI = "openai"
|
||||||
|
OPENROUTER = "openrouter"
|
||||||
AZURE_OPENAI = "azure_openai"
|
AZURE_OPENAI = "azure_openai"
|
||||||
ANTHROPIC = "anthropic"
|
ANTHROPIC = "anthropic"
|
||||||
GROQ = "groq"
|
GROQ = "groq"
|
||||||
@@ -84,9 +85,13 @@ class ModelRegistry:
|
|||||||
|
|
||||||
self.models.clear()
|
self.models.clear()
|
||||||
|
|
||||||
self._add_docsgpt_models(settings)
|
# Skip DocsGPT model if using custom OpenAI-compatible endpoint
|
||||||
if settings.OPENAI_API_KEY or (
|
if not settings.OPENAI_BASE_URL:
|
||||||
settings.LLM_PROVIDER == "openai" and settings.API_KEY
|
self._add_docsgpt_models(settings)
|
||||||
|
if (
|
||||||
|
settings.OPENAI_API_KEY
|
||||||
|
or (settings.LLM_PROVIDER == "openai" and settings.API_KEY)
|
||||||
|
or settings.OPENAI_BASE_URL
|
||||||
):
|
):
|
||||||
self._add_openai_models(settings)
|
self._add_openai_models(settings)
|
||||||
if settings.OPENAI_API_BASE or (
|
if settings.OPENAI_API_BASE or (
|
||||||
@@ -105,39 +110,73 @@ class ModelRegistry:
|
|||||||
settings.LLM_PROVIDER == "groq" and settings.API_KEY
|
settings.LLM_PROVIDER == "groq" and settings.API_KEY
|
||||||
):
|
):
|
||||||
self._add_groq_models(settings)
|
self._add_groq_models(settings)
|
||||||
|
if settings.OPEN_ROUTER_API_KEY or (
|
||||||
|
settings.LLM_PROVIDER == "openrouter" and settings.API_KEY
|
||||||
|
):
|
||||||
|
self._add_openrouter_models(settings)
|
||||||
|
if settings.NOVITA_API_KEY or (
|
||||||
|
settings.LLM_PROVIDER == "novita" and settings.API_KEY
|
||||||
|
):
|
||||||
|
self._add_novita_models(settings)
|
||||||
if settings.HUGGINGFACE_API_KEY or (
|
if settings.HUGGINGFACE_API_KEY or (
|
||||||
settings.LLM_PROVIDER == "huggingface" and settings.API_KEY
|
settings.LLM_PROVIDER == "huggingface" and settings.API_KEY
|
||||||
):
|
):
|
||||||
self._add_huggingface_models(settings)
|
self._add_huggingface_models(settings)
|
||||||
# Default model selection
|
# Default model selection
|
||||||
|
if settings.LLM_NAME:
|
||||||
if settings.LLM_NAME and settings.LLM_NAME in self.models:
|
# Parse LLM_NAME (may be comma-separated)
|
||||||
self.default_model_id = settings.LLM_NAME
|
model_names = self._parse_model_names(settings.LLM_NAME)
|
||||||
elif settings.LLM_PROVIDER and settings.API_KEY:
|
# First model in the list becomes default
|
||||||
for model_id, model in self.models.items():
|
for model_name in model_names:
|
||||||
if model.provider.value == settings.LLM_PROVIDER:
|
if model_name in self.models:
|
||||||
self.default_model_id = model_id
|
self.default_model_id = model_name
|
||||||
break
|
break
|
||||||
else:
|
# Backward compat: try exact match if no parsed model found
|
||||||
|
if not self.default_model_id and settings.LLM_NAME in self.models:
|
||||||
|
self.default_model_id = settings.LLM_NAME
|
||||||
|
|
||||||
|
if not self.default_model_id:
|
||||||
|
if settings.LLM_PROVIDER and settings.API_KEY:
|
||||||
|
for model_id, model in self.models.items():
|
||||||
|
if model.provider.value == settings.LLM_PROVIDER:
|
||||||
|
self.default_model_id = model_id
|
||||||
|
break
|
||||||
|
|
||||||
|
if not self.default_model_id and self.models:
|
||||||
self.default_model_id = next(iter(self.models.keys()))
|
self.default_model_id = next(iter(self.models.keys()))
|
||||||
logger.info(
|
logger.info(
|
||||||
f"ModelRegistry loaded {len(self.models)} models, default: {self.default_model_id}"
|
f"ModelRegistry loaded {len(self.models)} models, default: {self.default_model_id}"
|
||||||
)
|
)
|
||||||
|
|
||||||
def _add_openai_models(self, settings):
|
def _add_openai_models(self, settings):
|
||||||
from application.core.model_configs import OPENAI_MODELS
|
from application.core.model_configs import (
|
||||||
|
OPENAI_MODELS,
|
||||||
|
create_custom_openai_model,
|
||||||
|
)
|
||||||
|
|
||||||
if settings.OPENAI_API_KEY:
|
# Check if using local OpenAI-compatible endpoint (Ollama, LM Studio, etc.)
|
||||||
for model in OPENAI_MODELS:
|
using_local_endpoint = bool(
|
||||||
self.models[model.id] = model
|
settings.OPENAI_BASE_URL and settings.OPENAI_BASE_URL.strip()
|
||||||
return
|
)
|
||||||
if settings.LLM_PROVIDER == "openai" and settings.LLM_NAME:
|
|
||||||
for model in OPENAI_MODELS:
|
if using_local_endpoint:
|
||||||
if model.id == settings.LLM_NAME:
|
# When OPENAI_BASE_URL is set, ONLY register custom models from LLM_NAME
|
||||||
|
# Do NOT add standard OpenAI models (gpt-5.1, etc.)
|
||||||
|
if settings.LLM_NAME:
|
||||||
|
model_names = self._parse_model_names(settings.LLM_NAME)
|
||||||
|
for model_name in model_names:
|
||||||
|
custom_model = create_custom_openai_model(
|
||||||
|
model_name, settings.OPENAI_BASE_URL
|
||||||
|
)
|
||||||
|
self.models[model_name] = custom_model
|
||||||
|
logger.info(
|
||||||
|
f"Registered custom OpenAI model: {model_name} at {settings.OPENAI_BASE_URL}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Standard OpenAI API usage - add standard models if API key is valid
|
||||||
|
if settings.OPENAI_API_KEY:
|
||||||
|
for model in OPENAI_MODELS:
|
||||||
self.models[model.id] = model
|
self.models[model.id] = model
|
||||||
return
|
|
||||||
for model in OPENAI_MODELS:
|
|
||||||
self.models[model.id] = model
|
|
||||||
|
|
||||||
def _add_azure_openai_models(self, settings):
|
def _add_azure_openai_models(self, settings):
|
||||||
from application.core.model_configs import AZURE_OPENAI_MODELS
|
from application.core.model_configs import AZURE_OPENAI_MODELS
|
||||||
@@ -194,6 +233,36 @@ class ModelRegistry:
|
|||||||
return
|
return
|
||||||
for model in GROQ_MODELS:
|
for model in GROQ_MODELS:
|
||||||
self.models[model.id] = model
|
self.models[model.id] = model
|
||||||
|
|
||||||
|
def _add_openrouter_models(self, settings):
|
||||||
|
from application.core.model_configs import OPENROUTER_MODELS
|
||||||
|
|
||||||
|
if settings.OPEN_ROUTER_API_KEY:
|
||||||
|
for model in OPENROUTER_MODELS:
|
||||||
|
self.models[model.id] = model
|
||||||
|
return
|
||||||
|
if settings.LLM_PROVIDER == "openrouter" and settings.LLM_NAME:
|
||||||
|
for model in OPENROUTER_MODELS:
|
||||||
|
if model.id == settings.LLM_NAME:
|
||||||
|
self.models[model.id] = model
|
||||||
|
return
|
||||||
|
for model in OPENROUTER_MODELS:
|
||||||
|
self.models[model.id] = model
|
||||||
|
|
||||||
|
def _add_novita_models(self, settings):
|
||||||
|
from application.core.model_configs import NOVITA_MODELS
|
||||||
|
|
||||||
|
if settings.NOVITA_API_KEY:
|
||||||
|
for model in NOVITA_MODELS:
|
||||||
|
self.models[model.id] = model
|
||||||
|
return
|
||||||
|
if settings.LLM_PROVIDER == "novita" and settings.LLM_NAME:
|
||||||
|
for model in NOVITA_MODELS:
|
||||||
|
if model.id == settings.LLM_NAME:
|
||||||
|
self.models[model.id] = model
|
||||||
|
return
|
||||||
|
for model in NOVITA_MODELS:
|
||||||
|
self.models[model.id] = model
|
||||||
|
|
||||||
def _add_docsgpt_models(self, settings):
|
def _add_docsgpt_models(self, settings):
|
||||||
model_id = "docsgpt-local"
|
model_id = "docsgpt-local"
|
||||||
@@ -223,6 +292,15 @@ class ModelRegistry:
|
|||||||
)
|
)
|
||||||
self.models[model_id] = model
|
self.models[model_id] = model
|
||||||
|
|
||||||
|
def _parse_model_names(self, llm_name: str) -> List[str]:
|
||||||
|
"""
|
||||||
|
Parse LLM_NAME which may contain comma-separated model names.
|
||||||
|
E.g., 'deepseek-r1:1.5b,gemma:2b' -> ['deepseek-r1:1.5b', 'gemma:2b']
|
||||||
|
"""
|
||||||
|
if not llm_name:
|
||||||
|
return []
|
||||||
|
return [name.strip() for name in llm_name.split(",") if name.strip()]
|
||||||
|
|
||||||
def get_model(self, model_id: str) -> Optional[AvailableModel]:
|
def get_model(self, model_id: str) -> Optional[AvailableModel]:
|
||||||
return self.models.get(model_id)
|
return self.models.get(model_id)
|
||||||
|
|
||||||
|
|||||||
@@ -9,6 +9,8 @@ def get_api_key_for_provider(provider: str) -> Optional[str]:
|
|||||||
|
|
||||||
provider_key_map = {
|
provider_key_map = {
|
||||||
"openai": settings.OPENAI_API_KEY,
|
"openai": settings.OPENAI_API_KEY,
|
||||||
|
"openrouter": settings.OPEN_ROUTER_API_KEY,
|
||||||
|
"novita": settings.NOVITA_API_KEY,
|
||||||
"anthropic": settings.ANTHROPIC_API_KEY,
|
"anthropic": settings.ANTHROPIC_API_KEY,
|
||||||
"google": settings.GOOGLE_API_KEY,
|
"google": settings.GOOGLE_API_KEY,
|
||||||
"groq": settings.GROQ_API_KEY,
|
"groq": settings.GROQ_API_KEY,
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user